[go: nahoru, domu]

Skip to content

Commit

Permalink
Update computation of group rbf kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
wotzlaff committed Oct 18, 2021
1 parent 6f2aa37 commit e597c20
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
28 changes: 18 additions & 10 deletions optimized_lssvr/_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ def _prepare_kernel(self):
elif self.kernel == 'multi_rbf':
self.dsqr_ = (X[:, np.newaxis, :] - X[np.newaxis, :, :]) ** 2
elif self.kernel == 'group_multi_rbf':
self.dsqr_ = np.stack([
((X[:, np.newaxis, g] - X[np.newaxis, :, g]) ** 2).sum(axis=2)
for g in self.feature_groups
], axis=2)
dsqr_parts = []
for g in self.feature_groups:
xsqr = (X[:, g] * X[:, g]).sum(axis=1)
dsqr_tmp = -2.0 * \
X[:, g].dot(X[:, g].T) + xsqr[np.newaxis, :] + \
xsqr[:, np.newaxis]
dsqr_parts.append(dsqr_tmp)
self.dsqr_ = np.stack(dsqr_parts, axis=2)
else:
raise ValueError(f"unknown kernel '{self.kernel}'")

Expand All @@ -44,18 +48,22 @@ def _compute_kernel(self, params, Xother=None):
else:
dsqr = (self.X_[:, np.newaxis, :] -
Xother[np.newaxis, :, :]) ** 2
return np.exp(-np.tensordot(dsqr, params['gamma'], axes=(2, 0)))
elif self.kernel == 'group_multi_rbf':
if Xother is None:
if not hasattr(self, 'dsqr_'):
self._prepare_kernel()
dsqr = self.dsqr_
else:
dsqr = np.stack([
((
self.X_[:, np.newaxis, g] - Xother[np.newaxis, :, g]
) ** 2).sum(axis=2)
for g in self.feature_groups
], axis=2)
dsqr_parts = []
for g in self.feature_groups:
xsqr0 = (self.X_[:, g] * self.X_[:, g]).sum(axis=1)
xsqr1 = (Xother[:, g] * Xother[:, g]).sum(axis=1)
dsqr_tmp = -2.0 * \
self.X_[:, g].dot(
Xother[:, g].T) + xsqr1[np.newaxis, :] + xsqr0[:, np.newaxis]
dsqr_parts.append(dsqr_tmp)
dsqr = np.stack(dsqr_parts, axis=2)
return np.exp(-np.tensordot(dsqr, params['gamma'], axes=(2, 0)))
else:
raise ValueError(f"unknown kernel '{self.kernel}'")
Expand Down
3 changes: 2 additions & 1 deletion optimized_lssvr/_simple_lssvr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ class LSSVR(BaseEstimator, RegressorMixin, LSSVRBase):
LSSVR()
"""

def __init__(self, kernel='rbf', gamma=1.0, lmbda=1e-3):
def __init__(self, kernel='rbf', gamma=1.0, lmbda=1e-3, feature_groups=None):
self.kernel = kernel
self.gamma = gamma
self.lmbda = lmbda
self.feature_groups = feature_groups

def fit(self, X, y):
self.params_ = dict(gamma=self.gamma, lmbda=self.lmbda)
Expand Down

0 comments on commit e597c20

Please sign in to comment.