Skip to content

Commit cc2ca02

Browse files
authored
Fix regularized incomplete beta function (#440)
1 parent e13f6f6 commit cc2ca02

File tree

2 files changed

+57
-32
lines changed

2 files changed

+57
-32
lines changed

lib/model/tightest_perceptron.js

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,9 @@ const logGamma = z => {
3030
return x
3131
}
3232

33-
const beta = (p, q) => {
33+
const logBeta = (p, q) => {
3434
// https://www2.math.kyushu-u.ac.jp/~snii/AdvancedCalculus/7-1.pdf
35-
// return gamma(p) * gamma(q) / gamma(p + q)
36-
return Math.exp(logGamma(p) + logGamma(q) - logGamma(p + q))
35+
return logGamma(p) + logGamma(q) - logGamma(p + q)
3736
}
3837

3938
const hypergeometric = (a, b, c, z) => {
@@ -52,34 +51,37 @@ const hypergeometric = (a, b, c, z) => {
5251
return f
5352
}
5453

55-
const incompleteBeta = (z, a, b) => {
54+
const logIncompleteBeta = (z, a, b) => {
5655
// https://ja.wikipedia.org/wiki/%E4%B8%8D%E5%AE%8C%E5%85%A8%E3%83%99%E3%83%BC%E3%82%BF%E9%96%A2%E6%95%B0
5756
// https://math-functions-1.watson.jp/sub1_spec_050.html#section030
5857
// https://qiita.com/moriokumura/items/e35025d4ade312b0a017
5958
if (b === 1) {
60-
return z ** a / a
59+
return Math.log(z) * a - Math.log(a)
6160
} else if (a === 1) {
62-
return (1 - (1 - z) ** b) / b
61+
return Math.log(1 - (1 - z) ** b) - Math.log(b)
6362
} else if (a === 0.5 && b === 0) {
64-
return 2 * Math.atanh(Math.sqrt(z))
63+
return Math.log(2 * Math.atanh(Math.sqrt(z)))
6564
} else if (a === 0.5 && b === 0.5) {
66-
return 2 * Math.asin(Math.sqrt(z))
65+
return Math.log(2 * Math.asin(Math.sqrt(z)))
6766
} else if (Number.isInteger(b)) {
6867
const za = z ** a
6968
let ib = za / a
7069
for (let i = 1; i < b; i++) {
7170
ib = (i * ib + za * (1 - z) ** i) / (a + i)
7271
}
73-
return ib
72+
return Math.log(ib)
7473
} else if (Number.isInteger(a)) {
7574
const zb = (1 - z) ** b
7675
let ib = (1 - zb) / b
7776
for (let i = 1; i < a; i++) {
7877
ib = (i * ib - z ** i * zb) / (i + b)
7978
}
80-
return ib
79+
return Math.log(ib)
8180
}
82-
return (z ** a / a) * hypergeometric(a, 1 - b, a + 1, z)
81+
if (a < b) {
82+
return Math.log(Math.exp(logBeta(a, b)) - Math.exp(logIncompleteBeta(1 - z, b, a)))
83+
}
84+
return Math.log(z) * a - Math.log(a) + b * Math.log(1 - z) + Math.log(hypergeometric(a + b, 1, a + 1, z))
8385
}
8486

8587
const regularizedIncompleteBeta = (z, a, b) => {
@@ -93,7 +95,10 @@ const regularizedIncompleteBeta = (z, a, b) => {
9395
} else if (a === 1) {
9496
return 1 - (1 - z) ** b
9597
}
96-
return incompleteBeta(z, a, b) / beta(a, b)
98+
if (a < b) {
99+
return 1 - regularizedIncompleteBeta(1 - z, b, a)
100+
}
101+
return Math.exp(logIncompleteBeta(z, a, b) - logBeta(a, b))
97102
}
98103

99104
/**

tests/lib/model/tightest_perceptron.test.js

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { accuracy } from '../../../lib/evaluate/classification.js'
99
describe('classification', () => {
1010
describe.each([undefined, 'zero_one', 'hinge'])('accuracyLoss %s', accuracyLoss => {
1111
test.each([undefined, 'gaussian'])('kernel %s', kernel => {
12-
const model = new TightestPerceptron(10, kernel)
12+
const model = new TightestPerceptron(10, kernel, accuracyLoss)
1313
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
1414
x[50] = [0.1, 0.1]
1515
const t = []
@@ -25,26 +25,11 @@ describe('classification', () => {
2525
expect(acc).toBeGreaterThan(0.9)
2626
})
2727

28-
test.each(['polynomial'])('kernel %s', kernel => {
29-
const model = new TightestPerceptron(10, kernel)
30-
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
31-
x[50] = [0.1, 0.1]
32-
const t = []
33-
for (let i = 0; i < x.length; i++) {
34-
t[i] = Math.floor(i / 50) * 2 - 1
35-
}
36-
model.init(x, t)
37-
for (let i = 0; i < 10; i++) {
38-
model.fit()
39-
}
40-
const y = model.predict(x)
41-
const acc = accuracy(y, t)
42-
expect(acc).toBeGreaterThan(0.7)
43-
})
44-
4528
test('custom kernel', () => {
46-
const model = new TightestPerceptron(10, (a, b) =>
47-
Math.exp(-(a.reduce((s, v, i) => s + (v - b[i]) ** 2, 0) ** 2) / 0.01)
29+
const model = new TightestPerceptron(
30+
10,
31+
(a, b) => Math.exp(-(a.reduce((s, v, i) => s + (v - b[i]) ** 2, 0) ** 2) / 0.01),
32+
accuracyLoss
4833
)
4934
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
5035
const t = []
@@ -60,4 +45,39 @@ describe('classification', () => {
6045
expect(acc).toBeGreaterThan(0.9)
6146
})
6247
})
48+
49+
test.each(['polynomial'])('kernel %s', kernel => {
50+
const model = new TightestPerceptron(10, kernel)
51+
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
52+
x[50] = [0.1, 0.1]
53+
const t = []
54+
for (let i = 0; i < x.length; i++) {
55+
t[i] = Math.floor(i / 50) * 2 - 1
56+
}
57+
model.init(x, t)
58+
for (let i = 0; i < 10; i++) {
59+
model.fit()
60+
}
61+
const y = model.predict(x)
62+
const acc = accuracy(y, t)
63+
expect(acc).toBeGreaterThan(0.7)
64+
})
65+
66+
test('regularizedIncompleteBeta', () => {
67+
const model = new TightestPerceptron(10)
68+
model._ap = 1.1
69+
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
70+
x[50] = [0.1, 0.1]
71+
const t = []
72+
for (let i = 0; i < x.length; i++) {
73+
t[i] = Math.floor(i / 50) * 2 - 1
74+
}
75+
model.init(x, t)
76+
for (let i = 0; i < 10; i++) {
77+
model.fit()
78+
}
79+
const y = model.predict(x)
80+
const acc = accuracy(y, t)
81+
expect(acc).toBeGreaterThan(0.7)
82+
})
6383
})

0 commit comments

Comments
 (0)