[go: nahoru, domu]

Skip to content

Commit

Permalink
修改参数
Browse files Browse the repository at this point in the history
  • Loading branch information
LuoKun committed Jan 5, 2022
1 parent f2f9591 commit 2e20f5f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/em.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ class EM: # 三硬币模型
Expectation-maximization algorithm(期望最大算法)
"""

def __init__(self, prob: list, iterations=100):
self.prob, self.iterations = np.array(prob), iterations
def __init__(self, prob: list):
self.prob = np.array(prob)

def fit(self, X: np.ndarray):
for _ in range(self.iterations):
def fit(self, X: np.ndarray, iterations=100):
for _ in range(iterations):
M = self._expect(X) # E步
self._maximize(X, M) # M步

Expand All @@ -38,10 +38,10 @@ def _maximize(self, X: np.ndarray, M: np.ndarray): # M步
if __name__ == "__main__":
x = np.array([1, 1, 0, 1, 0, 0, 1, 0, 1, 1])

em = EM([0.5, 0.5, 0.5], 100)
em = EM([0.5, 0.5, 0.5])
em.fit(x)
print(em.prob) # [0.5, 0.6, 0.6]

em = EM([0.4, 0.6, 0.7], 100)
em = EM([0.4, 0.6, 0.7])
em.fit(x)
print(em.prob) # [0.4064, 0.5368, 0.6432]

0 comments on commit 2e20f5f

Please sign in to comment.