[go: nahoru, domu]

Skip to content

Commit

Permalink
Add target clipping and update init point for platt transform (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang committed Dec 18, 2023
1 parent e4c824d commit 458d69c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 23 deletions.
24 changes: 18 additions & 6 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2058,25 +2058,37 @@ def link_calibrator_methods(self):
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
)

def fit_platt_transform(self, logits, tgt_prob):
def fit_platt_transform(self, logits, targets, clip_tgt_prob=True):
"""Python to C/C++ interface for platt transfrom fit.
Ref: https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf
Args:
logits (ndarray): 1-d array of logit with length N.
tgt_prob (ndarray): 1-d array of target probability scores within [0, 1] with length N.
targets (ndarray): 1-d array of target probability scores within [0, 1] with length N.
clip_tgt_prob (bool): whether to clip the target probability to
[1/(prior0 + 2), 1 - 1/(prior1 + 2)]
where prior1 = sum(targets), prior0 = N - prior1
Returns:
A, B: coefficients for Platt's scale.
"""
assert isinstance(logits, np.ndarray)
assert isinstance(tgt_prob, np.ndarray)
assert len(logits) == len(tgt_prob)
assert logits.dtype == tgt_prob.dtype
assert isinstance(targets, np.ndarray)
assert len(logits) == len(targets)
assert logits.dtype == targets.dtype

if tgt_prob.min() < 0 or tgt_prob.max() > 1.0:
if targets.min() < 0 or targets.max() > 1.0:
raise ValueError("Target probability out of bound!")

min_prob, max_prob = 0.0, 1.0
if clip_tgt_prob:
prior1 = np.sum(targets)
prior0 = len(targets) - prior1
min_prob = 1.0 / (prior0 + 2.0)
max_prob = (prior1 + 1.0) / (prior1 + 2.0)

tgt_prob = np.clip(targets, min_prob, max_prob)

AB = np.array([0, 0], dtype=np.float64)

if tgt_prob.dtype == np.float32:
Expand Down
37 changes: 22 additions & 15 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,23 +280,29 @@ namespace pecos {

template <typename value_type>
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
// define the return code
enum {
SUCCESS=0,
LINE_SEARCH_FAIL=1,
MAX_ITER_REACHED=2,
};
// define the return code
enum {
SUCCESS=0,
LINE_SEARCH_FAIL=1,
MAX_ITER_REACHED=2,
};

// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
double sigma = 1e-12; // For numerically strict PD of Hessian
double eps = 1e-6;
double eps = 1e-5;

// calculate prior of B
double prior1 = 0;
for (size_t i = 0; i < num_samples; i++) {
prior1 += tgt_probs[i];
}
double prior0 = double(num_samples) - prior1;

int iter;

// Initial Point and Initial Fun Value
A = 0.0; B = 1.0;
A = 0.0; B = log((prior0 + 1.0) / (prior1 + 1.0));
double fval = 0.0;

for (size_t i = 0; i < num_samples; i++) {
Expand All @@ -307,17 +313,18 @@ namespace pecos {
fval += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
}
}
int iter;
for (iter = 0; iter < max_iter; iter++) {
// Update Gradient and Hessian (use H' = H + sigma I)
double h11 = sigma;
double h22 = sigma; // numerically ensures strict PD
double h21 = 0.0;
double g1 = 0.0;
double g2 = 0.0;
double g1 = A * sigma;
double g2 = B * sigma;

for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * A + B;
double p = 0, q = 0;
double p = 0, q = 0;
if (fApB >= 0) {
p = exp(-fApB) / (1.0 + exp(-fApB));
q = 1.0 / (1.0 + exp(-fApB));
Expand Down Expand Up @@ -376,15 +383,15 @@ namespace pecos {

if (stepsize < min_step) {
printf("WARNING: fit_platt_transform: Line search fails\n");
return LINE_SEARCH_FAIL;
return LINE_SEARCH_FAIL;
}
}

if (iter >= max_iter) {
printf("WARNING: fit_platt_transform: Reaching maximal iterations\n");
return MAX_ITER_REACHED;
return MAX_ITER_REACHED;
}
return SUCCESS;
return SUCCESS;
}
} // namespace pecos
#endif
4 changes: 2 additions & 2 deletions test/pecos/core/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def test_platt_scale():

orig = np.arange(-15, 15, 1, dtype=np.float32)
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float32)
At, Bt = clib.fit_platt_transform(orig, tgt)
At, Bt = clib.fit_platt_transform(orig, tgt, clip_tgt_prob=False)
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"

orig = np.arange(-15, 15, 1, dtype=np.float64)
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float64)
At, Bt = clib.fit_platt_transform(orig, tgt)
At, Bt = clib.fit_platt_transform(orig, tgt, clip_tgt_prob=False)
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"

0 comments on commit 458d69c

Please sign in to comment.