Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 7888407

Browse files
authored
handle arrays of values - this is what we get if we store properties (#707)
* handle arrays of values - this is what we get if we store properties * handle embeddings for jaccard
1 parent 0fede3f commit 7888407

File tree

4 files changed

+111
-5
lines changed

4 files changed

+111
-5
lines changed

algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,13 @@ CategoricalInput[] prepareCategories(List<Map<String, Object>> data, long degree
193193
CategoricalInput[] ids = new CategoricalInput[data.size()];
194194
int idx = 0;
195195
for (Map<String, Object> row : data) {
196-
List<Long> targetIds = (List<Long>) row.get("categories");
196+
List<Number> targetIds = extractValues(row.get("categories"));
197197
int size = targetIds.size();
198198
if ( size > degreeCutoff) {
199199
long[] targets = new long[size];
200200
int i=0;
201-
for (Long id : targetIds) {
202-
targets[i++]=id;
201+
for (Number id : targetIds) {
202+
targets[i++]=id.longValue();
203203
}
204204
Arrays.sort(targets);
205205
ids[idx++] = new CategoricalInput((Long) row.get("item"), targets);
@@ -214,7 +214,9 @@ WeightedInput[] prepareWeights(List<Map<String, Object>> data, long degreeCutoff
214214
WeightedInput[] inputs = new WeightedInput[data.size()];
215215
int idx = 0;
216216
for (Map<String, Object> row : data) {
217-
List<Number> weightList = (List<Number>) row.get("weights");
217+
218+
List<Number> weightList = extractValues(row.get("weights"));
219+
218220
int size = weightList.size();
219221
if ( size > degreeCutoff) {
220222
double[] weights = new double[size];
@@ -230,6 +232,28 @@ WeightedInput[] prepareWeights(List<Map<String, Object>> data, long degreeCutoff
230232
return inputs;
231233
}
232234

235+
private List<Number> extractValues(Object rawValues) {
236+
if(rawValues == null) {
237+
return Collections.emptyList();
238+
}
239+
240+
List<Number> valueList = new ArrayList<>();
241+
if (rawValues instanceof long[]) {
242+
long[] values = (long[]) rawValues;
243+
for (long value : values) {
244+
valueList.add(value);
245+
}
246+
} else if (rawValues instanceof double[]) {
247+
double[] values = (double[]) rawValues;
248+
for (double value : values) {
249+
valueList.add(value);
250+
}
251+
} else {
252+
valueList = (List<Number>) rawValues;
253+
}
254+
return valueList;
255+
}
256+
233257
protected int getTopK(ProcedureConfiguration configuration) {
234258
return configuration.getInt("topK", 0);
235259
}

tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ public class CosineTest {
5353
"yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " +
5454
"RETURN *";
5555

56+
public static final String STORE_EMBEDDING_STATEMENT = "MATCH (i:Item) WITH i ORDER BY id(i) MATCH (p:Person) OPTIONAL MATCH (p)-[r:LIKES]->(i)\n" +
57+
"WITH p, collect(coalesce(r.stars,0)) as userData\n" +
58+
"SET p.embedding = userData";
59+
60+
public static final String EMBEDDING_STATEMENT = "MATCH (p:Person)\n" +
61+
"WITH {item:id(p), weights: p.embedding} as userData\n" +
62+
"WITH collect(userData) as data\n" +
63+
64+
"CALL algo.similarity.cosine(data, $config) " +
65+
"yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " +
66+
"RETURN *";
67+
5668
@BeforeClass
5769
public static void beforeClass() throws KernelException {
5870
db = TestDatabaseCreator.createTestDatabase();
@@ -267,6 +279,22 @@ public void simpleCosineTest() {
267279
assertEquals((double) row.get("p100"), 0.91, 0.01);
268280
}
269281

282+
@Test
283+
public void simpleCosineFromEmbeddingTest() {
284+
db.execute(STORE_EMBEDDING_STATEMENT);
285+
286+
Map<String, Object> params = map("config", map());
287+
288+
Map<String, Object> row = db.execute(EMBEDDING_STATEMENT,params).next();
289+
assertEquals((double) row.get("p25"), 0.0, 0.01);
290+
assertEquals((double) row.get("p50"), 0, 0.01);
291+
assertEquals((double) row.get("p75"), 0.40, 0.01);
292+
assertEquals((double) row.get("p90"), 0.40, 0.01);
293+
assertEquals((double) row.get("p95"), 0.91, 0.01);
294+
assertEquals((double) row.get("p99"), 0.91, 0.01);
295+
assertEquals((double) row.get("p100"), 0.91, 0.01);
296+
}
297+
270298
@Test
271299
public void simpleCosineWriteTest() {
272300
Map<String, Object> params = map("config", map( "write",true, "similarityCutoff", 0.1));

tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import static org.junit.Assert.*;
3535
import static org.neo4j.helpers.collection.MapUtil.map;
3636

37-
public class EuclideanTest {
37+
public class EuclideanTest {
3838

3939
private static GraphDatabaseAPI db;
4040
private Transaction tx;
@@ -53,6 +53,18 @@ public class EuclideanTest {
5353
"yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " +
5454
"RETURN *";
5555

56+
public static final String STORE_EMBEDDING_STATEMENT = "MATCH (i:Item) WITH i ORDER BY id(i) MATCH (p:Person) OPTIONAL MATCH (p)-[r:LIKES]->(i)\n" +
57+
"WITH p, collect(coalesce(r.stars,0)) as userData\n" +
58+
"SET p.embedding = userData";
59+
60+
public static final String EMBEDDING_STATEMENT = "MATCH (p:Person)\n" +
61+
"WITH {item:id(p), weights: p.embedding} as userData\n" +
62+
"WITH collect(userData) as data\n" +
63+
64+
"CALL algo.similarity.euclidean(data, $config) " +
65+
"yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " +
66+
"RETURN *";
67+
5668
@BeforeClass
5769
public static void beforeClass() throws KernelException {
5870
db = TestDatabaseCreator.createTestDatabase();
@@ -265,6 +277,21 @@ public void simpleEuclideanTest() {
265277
assertEquals((double) row.get("p100"), 5.48, 0.01);
266278
}
267279

280+
@Test
281+
public void simpleEuclideanFromEmbeddingTest() {
282+
db.execute(STORE_EMBEDDING_STATEMENT);
283+
284+
Map<String, Object> params = map("config", map());
285+
286+
Map<String, Object> row = db.execute(EMBEDDING_STATEMENT,params).next();
287+
assertEquals((double) row.get("p25"), 3.16, 0.01);
288+
assertEquals((double) row.get("p50"), 4.00, 0.01);
289+
assertEquals((double) row.get("p75"), 5.10, 0.01);
290+
assertEquals((double) row.get("p95"), 5.48, 0.01);
291+
assertEquals((double) row.get("p99"), 5.48, 0.01);
292+
assertEquals((double) row.get("p100"), 5.48, 0.01);
293+
}
294+
268295
@Test
269296
public void simpleEuclideanWriteTest() {
270297
Map<String, Object> params = map("config", map( "write",true, "similarityCutoff", 4.0));

tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ public class JaccardTest {
5252
"yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " +
5353
"RETURN *";
5454

55+
public static final String STORE_EMBEDDING_STATEMENT = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" +
56+
"WITH p, collect(distinct id(i)) as userData\n" +
57+
"SET p.embedding = userData";
58+
59+
public static final String EMBEDDING_STATEMENT = "MATCH (p:Person) \n" +
60+
"WITH {item:id(p), categories: p.embedding} as userData\n" +
61+
"WITH collect(userData) as data\n" +
62+
"CALL algo.similarity.jaccard(data, $config) " +
63+
"yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " +
64+
"RETURN *";
65+
5566
@BeforeClass
5667
public static void beforeClass() throws KernelException {
5768
db = TestDatabaseCreator.createTestDatabase();
@@ -240,6 +251,22 @@ public void simpleJaccardTest() {
240251
assertEquals((double) row.get("p100"), 0.66, 0.01);
241252
}
242253

254+
@Test
255+
public void simpleJaccardFromEmbeddingTest() {
256+
db.execute(STORE_EMBEDDING_STATEMENT);
257+
258+
Map<String, Object> params = map("config", map("similarityCutoff", 0.0));
259+
260+
Map<String, Object> row = db.execute(EMBEDDING_STATEMENT,params).next();
261+
assertEquals((double) row.get("p25"), 0.33, 0.01);
262+
assertEquals((double) row.get("p50"), 0.33, 0.01);
263+
assertEquals((double) row.get("p75"), 0.66, 0.01);
264+
assertEquals((double) row.get("p95"), 0.66, 0.01);
265+
assertEquals((double) row.get("p99"), 0.66, 0.01);
266+
assertEquals((double) row.get("p100"), 0.66, 0.01);
267+
}
268+
269+
243270
@Test
244271
public void simpleJaccardWriteTest() {
245272
Map<String, Object> params = map("config", map( "write",true, "similarityCutoff", 0.1));

0 commit comments

Comments
 (0)