[go: nahoru, domu]

Skip to content

Commit

Permalink
添加了输出概率的方法: predict_prob
Browse files Browse the repository at this point in the history
  • Loading branch information
Kun Luo committed Dec 28, 2021
1 parent c97cdaa commit 3fc8605
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,19 @@ def __call__(self, X: np.ndarray):
Y = np.zeros([len(X)], dtype=int)
for i, x in enumerate(X):
prob = np.log(self.prior_prob) + np.array(
[np.sum(np.log(cond_prob[range(len(x)), x])) for cond_prob in self.cond_prob]
[np.log(cond_prob[range(len(x)), x]).sum() for cond_prob in self.cond_prob]
) # 先验概率的对数,加上条件概率的对数
Y[i] = np.argmax(prob)
return Y

def predict_prob(self, X: np.ndarray):
prob = np.zeros([len(X), len(self.cond_prob)])
for i, x in enumerate(X):
for c, prior_prob, cond_prob in zip(range(len(self.cond_prob)), self.prior_prob, self.cond_prob):
print(f"prior_prob = {prior_prob}, cond_prob = {cond_prob[range(len(x)), x]}")
prob[i, c] = prior_prob * np.prod(cond_prob[range(len(x)), x])
return prob

@staticmethod
def _estimate_prob(x: np.ndarray, n: int):
return (np.bincount(x, minlength=n) + 1) / (len(x) + n) # 使用贝叶斯估计
Expand Down Expand Up @@ -81,3 +89,6 @@ def load_data():
# [2/12, 5/12, 5/12]]
acc = np.sum(pred == y) / len(pred)
print(f"Accuracy = {100 * acc:.2f}%")
print()
print(naive_bayes.predict_prob([[1, 0]])) # 输出 [[1, 0]]的概率
# [[0.061, 0.0327]]

0 comments on commit 3fc8605

Please sign in to comment.