[go: nahoru, domu]

Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
luokn committed May 22, 2022
1 parent a1ab0a4 commit c8fe317
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/adaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def __init__(self, lr: float):

def fit(self, X: np.ndarray, y: np.ndarray, weights: np.ndarray):
min_error = float("inf") # 最小带权误差
for f, x in enumerate(X.T):
for feature, x in enumerate(X.T):
for threshold in np.arange(np.min(x) - self.lr, np.max(x) + self.lr, self.lr):
for sign in [1, -1]:
# 取分类错误的样本权重求和
error = np.sum(weights[np.where(x > threshold, sign, -sign) != y])
if error < min_error:
self.feature, self.threshold, self.sign, min_error = f, threshold, sign, error
# 取分类错误的样本权重求和
pos_error = np.sum(weights[np.where(x > threshold, 1, -1) != y])
if pos_error < min_error:
min_error, self.feature, self.threshold, self.sign = pos_error, feature, threshold, 1
neg_error = 1 - pos_error
if neg_error < min_error:
min_error, self.feature, self.threshold, self.sign = neg_error, feature, threshold, -1
return min_error

def __call__(self, X: np.ndarray) -> np.ndarray:
Expand All @@ -79,7 +81,7 @@ def load_data(n_samples_per_class=500):
)
y = np.array([1] * n_samples_per_class + [-1] * n_samples_per_class)

training_set, test_set = np.split(np.random.permutation(len(X)), [int(len(X) * 0.6)])
training_set, test_set = np.split(np.random.permutation(len(X)), [int(len(X) * 0.8)])
return X, y, training_set, test_set


Expand Down

0 comments on commit c8fe317

Please sign in to comment.