Skip to content

Commit d6623b3

Browse files
committed
Improve math.js
1 parent 49daf7a commit d6623b3

File tree

8 files changed

+29
-19
lines changed

8 files changed

+29
-19
lines changed

lib/model/label_propagation.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ export default class LabelPropagation {
4646
}
4747

4848
if (this._affinity === 'rbf') {
49-
return distances.copyMap((v, i) => (con.at(...i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
49+
return distances.copyMap((v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
5050
} else if (this._affinity === 'knn') {
5151
return con.copyMap(v => (v > 0 ? 1 : 0))
5252
}

lib/model/label_spreading.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ export default class LabelSpreading {
4747
}
4848

4949
if (this._affinity === 'rbf') {
50-
return distances.copyMap((v, i) => (con.at(...i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
50+
return distances.copyMap((v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
5151
} else if (this._affinity === 'knn') {
5252
return con.copyMap(v => (v > 0 ? 1 : 0))
5353
}

lib/model/laplacian_eigenmaps.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export class LaplacianEigenmaps {
5454

5555
let W
5656
if (this._affinity === 'rbf') {
57-
W = distances.copyMap((v, i) => (con.at(...i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
57+
W = distances.copyMap((v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
5858
} else if (this._affinity === 'knn') {
5959
W = con.copyMap(v => (v > 0 ? 1 : 0))
6060
}

lib/model/layer/base.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,17 @@ export default class Layer {
5757

5858
/**
5959
* Returns calculated values.
60-
* @param {Matrix | Tensor} x
60+
* @param {...(Matrix | Tensor)} x
6161
*/
62-
calc(x) {
62+
calc(...x) {
6363
throw new NeuralnetworkException('Not impleneted', this)
6464
}
6565

6666
/**
6767
* Returns gradient values.
68-
* @param {Matrix | Tensor} bo
68+
* @param {...(Matrix | Tensor)} bo
6969
*/
70-
grad(bo) {
70+
grad(...bo) {
7171
throw new NeuralnetworkException('Not impleneted', this)
7272
}
7373

lib/model/layer/cond.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ export default class CondLayer extends Layer {
77
const t = x[1]
88
const f = x[2]
99
this._o = new Matrix(this._cond.rows, this._cond.cols)
10-
this._o.map((v, i) => (this._cond.at(...i) ? t.at(...i) : f.at(...i)))
10+
this._o.map((v, i) => (this._cond.at(i) ? t.at(i) : f.at(i)))
1111
return this._o
1212
}
1313

1414
grad(bo) {
1515
const bi = [null, bo.copy(), bo.copy()]
16-
this._cond.forEach((v, i) => (v ? bi[2].set(...i, 0) : bi[1].set(...i, 0)))
16+
this._cond.forEach((v, i) => (v ? bi[2].set(i, 0) : bi[1].set(i, 0)))
1717
return bi
1818
}
1919
}

lib/model/layer/huber.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ export default class HuberLayer extends LossLayer {
1515
err.cols,
1616
err.value.map(v => v < 1.0)
1717
)
18-
err.map((v, i) => (this._cond.at(...i) ? 0.5 * v * v : v - 0.5))
18+
err.map((v, i) => (this._cond.at(i) ? 0.5 * v * v : v - 0.5))
1919
return new Matrix(1, 1, err.sum())
2020
}
2121

2222
grad() {
2323
this._bi = this._cond.copy()
2424
this._bi.map((c, i) =>
25-
c ? this._i.at(...i) - this._t.at(...i) : Math.sign(this._i.at(...i) - this._t.at(...i))
25+
c ? this._i.at(i) - this._t.at(i) : Math.sign(this._i.at(i) - this._t.at(i))
2626
)
2727
return this._bi
2828
}

lib/model/polynomial_histogram.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ export default class PolynomialHistogram {
7777
})
7878
this._a[0] = b.copy()
7979
this._a[0].map((v, i) => {
80-
const a = (4 + 5 * this._d) / 4 - (15 / this._h ** 2) * s2.at(...i).trace()
80+
const a = (4 + 5 * this._d) / 4 - (15 / this._h ** 2) * s2.at(i).trace()
8181
return (a * v) / x.length / this._h ** this._d
8282
})
8383
this._a[1] = b.copy()
@@ -101,7 +101,7 @@ export default class PolynomialHistogram {
101101
}
102102
}
103103
}
104-
v.mult(b.at(...i) / x.length / this._h ** (this._d + 2))
104+
v.mult(b.at(i) / x.length / this._h ** (this._d + 2))
105105
return v
106106
})
107107
}
@@ -130,7 +130,7 @@ export default class PolynomialHistogram {
130130
p.push(0)
131131
continue
132132
}
133-
const a = this._a.map(v => v.at(...idx))
133+
const a = this._a.map(v => v.at(idx))
134134
const xi = Matrix.fromArray(x[i])
135135
const m = Matrix.fromArray(this._ranges.map((r, k) => (r[idx[k] + 1] + r[idx[k]]) / 2))
136136
xi.sub(m)

lib/util/math.js

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,9 @@ export class Tensor {
528528
* @returns {number | Tensor}
529529
*/
530530
at(...i) {
531+
if (Array.isArray(i[0])) {
532+
i = i[0]
533+
}
531534
if (i.length === this.dimension) {
532535
return this._value[this._to_position(...i)]
533536
}
@@ -986,23 +989,30 @@ export class Matrix {
986989

987990
/**
988991
* Returns a value at the position.
989-
* @param {number} r
990-
* @param {number} c
992+
* @param {number | [number, number]} r
993+
* @param {?number} c
991994
* @returns {number}
992995
*/
993996
at(r, c) {
997+
if (Array.isArray(r)) {
998+
;[r, c] = r
999+
}
9941000
if (r < 0 || this.rows <= r || c < 0 || this.cols <= c) throw new MatrixException('Index out of bounds.')
9951001
return this._value[r * this.cols + c]
9961002
}
9971003

9981004
/**
9991005
* Set a value at the position.
1000-
* @param {number} r
1001-
* @param {number} c
1002-
* @param {number | Matrix} value
1006+
* @param {number | [number, number]} r If this value is an array, the next argument should be the value to be set
1007+
* @param {number | Matrix} c
1008+
* @param {?(number | Matrix)} value
10031009
* @returns {?number} Old value
10041010
*/
10051011
set(r, c, value) {
1012+
if (Array.isArray(r)) {
1013+
value = c
1014+
;[r, c] = r
1015+
}
10061016
if (value instanceof Matrix) {
10071017
if (r < 0 || this.rows <= r + value.rows - 1 || c < 0 || this.cols <= c + value.cols - 1)
10081018
throw new MatrixException('Index out of bounds.')

0 commit comments

Comments
 (0)