Skip to content

Commit f103287

Browse files
committed
Improve KNN
1 parent 66a3cfd commit f103287

File tree

3 files changed

+101
-53
lines changed

3 files changed

+101
-53
lines changed

js/view/knearestneighbor.js

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
import { KNN, KNNRegression, KNNAnomaly, KNNDensityEstimation, SemiSupervisedKNN } from '../../lib/model/knearestneighbor.js'
1+
import {
2+
KNN,
3+
KNNRegression,
4+
KNNAnomaly,
5+
KNNDensityEstimation,
6+
SemiSupervisedKNN,
7+
} from '../../lib/model/knearestneighbor.js'
28

39
var dispKNN = function (elm, platform) {
410
const mode = platform.task
@@ -17,7 +23,7 @@ var dispKNN = function (elm, platform) {
1723
ty.map(v => v[0])
1824
)
1925
platform.predict((px, pred_cb) => {
20-
const pred = px.map(p => model.predict(p))
26+
const pred = model.predict(px)
2127
pred_cb(pred)
2228
}, 4)
2329
})
@@ -32,7 +38,7 @@ var dispKNN = function (elm, platform) {
3238

3339
platform.predict(
3440
(px, pred_cb) => {
35-
let p = px.map(p => model.predict(p))
41+
let p = model.predict(px)
3642

3743
pred_cb(p)
3844
},
@@ -45,7 +51,7 @@ var dispKNN = function (elm, platform) {
4551
model.fit(tx)
4652

4753
const threshold = +elm.select('[name=threshold]').property('value')
48-
const outliers = tx.map(p => model.predict(p) > threshold)
54+
const outliers = model.predict(tx).map(p => p > threshold)
4955
cb(outliers)
5056
})
5157
} else if (mode === 'DE') {
@@ -54,7 +60,7 @@ var dispKNN = function (elm, platform) {
5460
model.fit(tx)
5561

5662
platform.predict((px, cb) => {
57-
const pred = px.map(p => model.predict(p))
63+
const pred = model.predict(px)
5864
const min = Math.min(...pred)
5965
const max = Math.max(...pred)
6066
cb(pred.map(v => specialCategory.density((v - min) / (max - min))))
@@ -68,7 +74,7 @@ var dispKNN = function (elm, platform) {
6874
model.fit(data)
6975

7076
const threshold = +elm.select('[name=threshold]').property('value')
71-
const pred = data.map(p => model.predict(p))
77+
const pred = model.predict(data)
7278
for (let i = 0; i < d / 2; i++) {
7379
pred.unshift(0)
7480
}
@@ -92,7 +98,7 @@ var dispKNN = function (elm, platform) {
9298
)
9399

94100
platform.predict((px, pred_cb) => {
95-
let p = px.map(p => model.predict(p))
101+
let p = model.predict(px)
96102

97103
pred_cb(p)
98104
}, 1)

lib/model/knearestneighbor.js

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -91,36 +91,38 @@ export class KNN {
9191
}
9292

9393
/**
94-
* Returns predicted category.
95-
* @param {number[]} data
96-
* @returns {number}
94+
* Returns predicted categories.
95+
* @param {Array<Array<number>>} datas
96+
* @returns {number[]}
9797
*/
98-
predict(data) {
99-
const ps = this._near_points(data)
100-
const clss = {}
101-
ps.forEach(p => {
102-
let cat = p.category
103-
if (!clss[cat]) {
104-
clss[cat] = {
105-
count: 1,
106-
min_d: p.d,
98+
predict(datas) {
99+
return datas.map(data => {
100+
const ps = this._near_points(data)
101+
const clss = {}
102+
ps.forEach(p => {
103+
let cat = p.category
104+
if (!clss[cat]) {
105+
clss[cat] = {
106+
count: 1,
107+
min_d: p.d,
108+
}
109+
} else {
110+
clss[cat].count += 1
111+
clss[cat].min_d = Math.min(clss[cat].min_d, p.d)
112+
}
113+
})
114+
let max_count = 0
115+
let min_dist = -1
116+
let target_cat = -1
117+
for (let k of Object.keys(clss)) {
118+
if (max_count < clss[k].count || (max_count === clss[k].count && clss[k].min_d < min_dist)) {
119+
max_count = clss[k].count
120+
min_dist = clss[k].min_d
121+
target_cat = +k
107122
}
108-
} else {
109-
clss[cat].count += 1
110-
clss[cat].min_d = Math.min(clss[cat].min_d, p.d)
111123
}
124+
return target_cat
112125
})
113-
let max_count = 0
114-
let min_dist = -1
115-
let target_cat = -1
116-
for (let k of Object.keys(clss)) {
117-
if (max_count < clss[k].count || (max_count === clss[k].count && clss[k].min_d < min_dist)) {
118-
max_count = clss[k].count
119-
min_dist = clss[k].min_d
120-
target_cat = +k
121-
}
122-
}
123-
return target_cat
124126
}
125127
}
126128

@@ -137,13 +139,15 @@ export class KNNRegression extends KNN {
137139
}
138140

139141
/**
140-
* Returns predicted value.
141-
* @param {number[]} data
142-
* @returns {number}
142+
* Returns predicted values.
143+
* @param {Array<Array<number>>} datas
144+
* @returns {number[]}
143145
*/
144-
predict(data) {
145-
const ps = this._near_points(data)
146-
return ps.reduce((acc, v) => acc + v.category, 0) / ps.length
146+
predict(datas) {
147+
return datas.map(data => {
148+
const ps = this._near_points(data)
149+
return ps.reduce((acc, v) => acc + v.category, 0) / ps.length
150+
})
147151
}
148152
}
149153

@@ -161,12 +165,14 @@ export class KNNAnomaly extends KNN {
161165

162166
/**
163167
* Returns anomaly degrees.
164-
* @param {number[]} data
165-
* @returns {number}
168+
* @param {Array<Array<number>>} datas
169+
* @returns {number[]}
166170
*/
167-
predict(data) {
168-
const ps = this._near_points(data)
169-
return ps[ps.length - 1].d
171+
predict(datas) {
172+
return datas.map(data => {
173+
const ps = this._near_points(data)
174+
return ps[ps.length - 1].d
175+
})
170176
}
171177
}
172178

@@ -201,16 +207,18 @@ export class KNNDensityEstimation extends KNN {
201207
}
202208

203209
/**
204-
* Returns predicted value.
205-
* @param {number[]} data
206-
* @returns {number}
210+
* Returns predicted values.
211+
* @param {Array<Array<number>>} datas
212+
* @returns {number[]}
207213
*/
208-
predict(data) {
209-
const ps = this._near_points(data)
210-
const r = ps[ps.length - 1].d
211-
const d = data.length
212-
const ilogv = this._logGamma(d / 2 + 1) - (d / 2) * Math.log(Math.PI) - d * Math.log(r)
213-
return (Math.exp(ilogv) * this.k) / this._p.length
214+
predict(datas) {
215+
return datas.map(data => {
216+
const ps = this._near_points(data)
217+
const r = ps[ps.length - 1].d
218+
const d = data.length
219+
const ilogv = this._logGamma(d / 2 + 1) - (d / 2) * Math.log(Math.PI) - d * Math.log(r)
220+
return (Math.exp(ilogv) * this.k) / this._p.length
221+
})
214222
}
215223
}
216224

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { KNN, KNNRegression } from '../../../lib/model/knearestneighbor.js'
2+
import { Matrix } from '../../../lib/util/math.js'
3+
4+
test.each(['euclid', 'manhattan', 'chebyshev'])('classifier %s', metric => {
5+
const model = new KNN(5, metric)
6+
const x = Matrix.randn(50, 2, 0, 0.2).concat(Matrix.randn(50, 2, 5, 0.2)).toArray()
7+
const t = []
8+
for (let i = 0; i < x.length; i++) {
9+
t[i] = Math.floor(i / 50) * 2 - 1
10+
}
11+
model.fit(x, t)
12+
const y = model.predict(x)
13+
let acc = 0
14+
for (let i = 0; i < t.length; i++) {
15+
if (y[i] === t[i]) {
16+
acc++
17+
}
18+
}
19+
expect(acc / y.length).toBeGreaterThan(0.95)
20+
})
21+
22+
test.each(['euclid', 'manhattan', 'chebyshev'])('regression %s', metric => {
23+
const model = new KNNRegression(1, metric)
24+
const x = Matrix.randn(50, 2, 0, 5).toArray()
25+
const t = []
26+
for (let i = 0; i < x.length; i++) {
27+
t[i] = x[i][0] + x[i][1] + (Math.random() - 0.5) / 10
28+
}
29+
model.fit(x, t)
30+
const y = model.predict(x)
31+
for (let i = 0; i < 4; i++) {
32+
expect(y[i]).toBe(t[i])
33+
}
34+
})

0 commit comments

Comments
 (0)