[go: nahoru, domu]

Skip to content

Commit

Permalink
简化
Browse files Browse the repository at this point in the history
  • Loading branch information
Kun Luo committed Jul 21, 2021
1 parent cab2970 commit c1c39bd
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions AdaBoost/adaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @Author: Luokun
# @Email : olooook@outlook.com


import numpy as np
from matplotlib import pyplot as plt

Expand Down Expand Up @@ -42,18 +43,17 @@ def __call__(self, X: np.ndarray):
class WeakEstimator: # 弱分类器, 一阶决策树
def __init__(self, lr: float):
self.lr = lr
# 划分特征、划分阈值,符号{-1,1}
self.feature, self.threshold, self.sign = None, None, None
self.feature, self.threshold, self.sign = None, None, None # 划分特征、划分阈值,符号{-1,1}

def fit(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray):
error = float('inf') # 最小带权误差
min_error = float('inf') # 最小带权误差
for feature, x in enumerate(X.T):
for threshold in np.arange(np.min(x) - self.lr, np.max(x) + self.lr, self.lr):
for sign in [1, -1]:
e = np.sum(weights[np.where(x > threshold, sign, -sign) != Y]) # 取分类错误的样本权重求和
if e < error:
self.feature, self.threshold, self.sign, error = feature, threshold, sign, e
return error
error = np.sum(weights[np.where(x > threshold, sign, -sign) != Y]) # 取分类错误的样本权重求和
if error < min_error:
self.feature, self.threshold, self.sign, min_error = feature, threshold, sign, error
return min_error

def __call__(self, X: np.ndarray):
return np.where(X[:, self.feature] > self.threshold, self.sign, -self.sign)
Expand Down

0 comments on commit c1c39bd

Please sign in to comment.