Skip to content

Commit 2d7f638

Browse files
committed
Correct SparseInteger to avoid int multiplication overflow
1 parent 0ef385f commit 2d7f638

File tree

1 file changed

+92
-56
lines changed

1 file changed

+92
-56
lines changed

src/main/java/info/debatty/java/utils/SparseIntegerVector.java

Lines changed: 92 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,39 @@
3030
import java.util.TreeSet;
3131

3232
/**
33-
* Sparse vector of int, implemented using two arrays
33+
* Sparse vector of int, implemented using two arrays.
3434
* @author Thibault Debatty
3535
*/
3636
public class SparseIntegerVector implements Serializable {
37-
38-
protected int[] keys;
39-
protected int[] values;
40-
protected int size = 0;
4137

42-
public SparseIntegerVector(int size) {
38+
private int[] keys;
39+
private int[] values;
40+
private int size = 0;
41+
42+
private static final int DEFAULT_SIZE = 20;
43+
44+
/**
45+
* Sparse vector of int, implemented using two arrays.
46+
* @param size number of non zero elements in the vector
47+
*/
48+
public SparseIntegerVector(final int size) {
4349
keys = new int[size];
4450
values = new int[size];
4551
}
46-
52+
53+
/**
54+
* Sparse vector of int, implemented using two arrays.
55+
* Default size is 20.
56+
*/
4757
public SparseIntegerVector() {
48-
this(20);
58+
this(DEFAULT_SIZE);
4959
}
50-
51-
public SparseIntegerVector(HashMap<Integer, Integer> hashmap) {
60+
61+
/**
62+
* Sparse vector of int, implemented using two arrays.
63+
* @param hashmap
64+
*/
65+
public SparseIntegerVector(final HashMap<Integer, Integer> hashmap) {
5266
this(hashmap.size());
5367
SortedSet<Integer> sorted_keys = new TreeSet<Integer>(hashmap.keySet());
5468
for (int key : sorted_keys) {
@@ -59,17 +73,17 @@ public SparseIntegerVector(HashMap<Integer, Integer> hashmap) {
5973
}
6074

6175
/**
62-
*
63-
* @param array
76+
* Sparse vector of int, implemented using two arrays.
77+
* @param array
6478
*/
65-
public SparseIntegerVector(int[] array) {
66-
79+
public SparseIntegerVector(final int[] array) {
80+
6781
for (int i = 0; i < array.length; i++) {
6882
if (array[i] != 0) {
6983
size++;
7084
}
7185
}
72-
86+
7387
keys = new int[size];
7488
values = new int[size];
7589
int j = 0;
@@ -81,8 +95,14 @@ public SparseIntegerVector(int[] array) {
8195
}
8296
}
8397
}
84-
85-
public double cosineSimilarity(SparseIntegerVector other) {
98+
99+
/**
100+
* Compute and return the cosine similarity (cosine of angle between both
101+
* vectors).
102+
* @param other
103+
* @return
104+
*/
105+
public final double cosineSimilarity(final SparseIntegerVector other) {
86106
double den = this.norm() * other.norm();
87107
double agg = 0;
88108
int i = 0;
@@ -92,7 +112,7 @@ public double cosineSimilarity(SparseIntegerVector other) {
92112
int k2 = other.keys[j];
93113

94114
if (k1 == k2) {
95-
agg += this.values[i] * other.values[j] / den;
115+
agg += 1.0 * this.values[i] * other.values[j] / den;
96116
i++;
97117
j++;
98118

@@ -104,13 +124,13 @@ public double cosineSimilarity(SparseIntegerVector other) {
104124
}
105125
return agg;
106126
}
107-
127+
108128
/**
109-
*
129+
* Compute and return the dot product.
110130
* @param other
111-
* @return
131+
* @return
112132
*/
113-
public double dotProduct(SparseIntegerVector other) {
133+
public final double dotProduct(final SparseIntegerVector other) {
114134
double agg = 0;
115135
int i = 0;
116136
int j = 0;
@@ -119,7 +139,7 @@ public double dotProduct(SparseIntegerVector other) {
119139
int k2 = other.keys[j];
120140

121141
if (k1 == k2) {
122-
agg += this.values[i] * other.values[j];
142+
agg += 1.0 * this.values[i] * other.values[j];
123143
i++;
124144
j++;
125145

@@ -131,55 +151,61 @@ public double dotProduct(SparseIntegerVector other) {
131151
}
132152
return agg;
133153
}
134-
135-
public double dotProduct(double[] other) {
154+
155+
/**
156+
* Compute and return the dot product.
157+
* @param other
158+
* @return
159+
*/
160+
public final double dotProduct(final double[] other) {
136161
double agg = 0;
137162
for (int i = 0; i < keys.length; i++) {
138-
agg += other[keys[i]] * values[i];
163+
agg += 1.0 * other[keys[i]] * values[i];
139164
}
140165
return agg;
141166
}
142-
167+
143168
/**
144-
* Compute and return the L2 norm of the vector
145-
* @return
169+
* Compute and return the L2 norm of the vector.
170+
* @return
146171
*/
147-
public double norm() {
172+
public final double norm() {
148173
double agg = 0;
149174
for (int i = 0; i < values.length; i++) {
150-
agg += values[i] * values[i];
175+
agg += 1.0 * values[i] * values[i];
151176
}
152177
return Math.sqrt(agg);
153178
}
154-
179+
155180
/**
156181
* Computes and return the Jaccard index with other SparseVector.
157182
* |A inter B| / |A union B|
158183
* It is actually computed as |A inter B| / (|A| +|B| - | A inter B|)
159184
* using a single loop over A and B
160185
* @param other
161-
* @return
186+
* @return
162187
*/
163-
public double jaccard(SparseIntegerVector other) {
188+
public final double jaccard(final SparseIntegerVector other) {
164189
int intersection = this.intersection(other);
165190
return (double) intersection / (this.size + other.size - intersection);
166191
}
167-
192+
168193
/**
169-
*
194+
* Compute the size of the union of these two vectors.
170195
* @param other
171-
* @return
196+
* @return
172197
*/
173-
public int union(SparseIntegerVector other) {
198+
public final int union(final SparseIntegerVector other) {
174199
return this.size + other.size - this.intersection(other);
175200
}
176-
201+
177202
/**
178-
*
203+
* Compute the number of values that are present in both vectors (used to
204+
* compute jaccard index).
179205
* @param other
180-
* @return
206+
* @return
181207
*/
182-
public int intersection(SparseIntegerVector other) {
208+
public final int intersection(final SparseIntegerVector other) {
183209
int agg = 0;
184210
int i = 0;
185211
int j = 0;
@@ -194,35 +220,35 @@ public int intersection(SparseIntegerVector other) {
194220

195221
} else if (k1 < k2) {
196222
i++;
197-
223+
198224
} else {
199225
j++;
200226
}
201227
}
202228
return agg;
203229
}
204-
230+
205231
@Override
206-
public String toString() {
232+
public final String toString() {
207233
String r = "";
208234
for (int i = 0; i < size; i++) {
209235
r += keys[i] + ":" + values[i] + " ";
210236
}
211-
237+
212238
return r;
213239
}
214240

215241
/**
216242
* Compute and return the qgram similarity with other vector.
217243
* Sum(|a_i - b_i|)
218244
* @param other
219-
* @return
245+
* @return
220246
*/
221-
public double qgram(SparseIntegerVector other) {
247+
public final double qgram(final SparseIntegerVector other) {
222248
double agg = 0;
223249
int i = 0, j = 0;
224250
int k1, k2;
225-
251+
226252
while (i < this.keys.length && j < other.keys.length) {
227253
k1 = this.keys[i];
228254
k2 = other.keys[j];
@@ -235,19 +261,19 @@ public double qgram(SparseIntegerVector other) {
235261
} else if (k1 < k2) {
236262
agg += Math.abs(this.values[i]);
237263
i++;
238-
264+
239265
} else {
240266
agg += Math.abs(other.values[j]);
241267
j++;
242268
}
243269
}
244-
270+
245271
// Maybe one of the two vectors was not completely walked...
246272
while (i < this.keys.length) {
247273
agg += Math.abs(this.values[i]);
248274
i++;
249275
}
250-
276+
251277
while (j < other.keys.length) {
252278
agg += Math.abs(other.values[j]);
253279
j++;
@@ -257,17 +283,27 @@ public double qgram(SparseIntegerVector other) {
257283

258284
/**
259285
* Return the number of (non-zero) elements in this vector.
260-
* @return
286+
* @return
261287
*/
262-
public int size() {
288+
public final int size() {
263289
return this.size;
264290
}
265291

266-
public int getKey(int i) {
292+
/**
293+
* Get the key at position i.
294+
* @param i
295+
* @return
296+
*/
297+
public final int getKey(final int i) {
267298
return this.keys[i];
268299
}
269300

270-
public int getValue(int i) {
301+
/**
302+
* Get the value of position i.
303+
* @param i
304+
* @return
305+
*/
306+
public final int getValue(final int i) {
271307
return this.values[i];
272308
}
273309
}

0 commit comments

Comments
 (0)