Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 9a57ce9

Browse files
authored
Linkpred bug fix (#816)
* exclude self from neighbors * running on same nodes should return 0 * adding tests to find common neighbors * more tests around common neighbors finder
1 parent 4a2e9b6 commit 9a57ce9

File tree

6 files changed

+324
-32
lines changed

6 files changed

+324
-32
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package org.neo4j.graphalgo.linkprediction;
2+
3+
import org.neo4j.graphdb.Direction;
4+
import org.neo4j.graphdb.Node;
5+
import org.neo4j.graphdb.Relationship;
6+
import org.neo4j.graphdb.RelationshipType;
7+
import org.neo4j.kernel.internal.GraphDatabaseAPI;
8+
9+
import java.util.Collections;
10+
import java.util.HashSet;
11+
import java.util.Set;
12+
13+
import static org.neo4j.graphdb.Direction.*;
14+
15+
public class CommonNeighborsFinder {
16+
17+
private GraphDatabaseAPI api;
18+
19+
public CommonNeighborsFinder(GraphDatabaseAPI api) {
20+
this.api = api;
21+
}
22+
23+
public Set<Node> findCommonNeighbors(Node node1, Node node2, RelationshipType relationshipType, Direction direction) {
24+
if(node1.equals(node2)) {
25+
return Collections.emptySet();
26+
}
27+
28+
Set<Node> neighbors = findPotentialNeighbors(node1, relationshipType, direction);
29+
neighbors.removeIf(node -> noCommonNeighbors(node, relationshipType, flipDirection(direction), node2));
30+
return neighbors;
31+
}
32+
33+
private Direction flipDirection(Direction direction) {
34+
switch(direction) {
35+
case OUTGOING:
36+
return INCOMING;
37+
case INCOMING:
38+
return OUTGOING;
39+
default:
40+
return BOTH;
41+
}
42+
}
43+
44+
private Set<Node> findPotentialNeighbors(Node node, RelationshipType relationshipType, Direction direction) {
45+
Set<Node> neighbors = new HashSet<>();
46+
47+
for (Relationship rel : loadRelationships(node, relationshipType, direction)) {
48+
Node endNode = rel.getOtherNode(node);
49+
50+
if (!endNode.equals(node)) {
51+
neighbors.add(endNode);
52+
}
53+
}
54+
return neighbors;
55+
}
56+
57+
private boolean noCommonNeighbors(Node node, RelationshipType relationshipType, Direction direction, Node node2) {
58+
for (Relationship rel : loadRelationships(node, relationshipType, direction)) {
59+
if (rel.getOtherNode(node).equals(node2)) {
60+
return false;
61+
}
62+
}
63+
return true;
64+
}
65+
66+
private Iterable<Relationship> loadRelationships(Node node, RelationshipType relationshipType, Direction direction) {
67+
return relationshipType == null ? node.getRelationships(direction) : node.getRelationships(relationshipType, direction);
68+
}
69+
70+
}

algo/src/main/java/org/neo4j/graphalgo/linkprediction/LinkPrediction.java

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,17 @@
2323
import org.neo4j.graphdb.Node;
2424
import org.neo4j.graphdb.Relationship;
2525
import org.neo4j.graphdb.RelationshipType;
26+
import org.neo4j.kernel.internal.GraphDatabaseAPI;
27+
import org.neo4j.procedure.Context;
2628
import org.neo4j.procedure.Description;
2729
import org.neo4j.procedure.Name;
2830
import org.neo4j.procedure.UserFunction;
2931

3032
import java.util.*;
3133

3234
public class LinkPrediction {
35+
@Context
36+
public GraphDatabaseAPI api;
3337

3438
@UserFunction("algo.linkprediction.adamicAdar")
3539
@Description("algo.linkprediction.adamicAdar(node1:Node, node2:Node, {relationshipQuery:'relationshipName', direction:'BOTH'}) " +
@@ -46,8 +50,7 @@ public double adamicAdarSimilarity(@Name("node1") Node node1, @Name("node2") Nod
4650
RelationshipType relationshipType = configuration.getRelationship();
4751
Direction direction = configuration.getDirection(Direction.BOTH);
4852

49-
Set<Node> neighbors = findPotentialNeighbors(node1, relationshipType, direction);
50-
neighbors.removeIf(node -> noCommonNeighbors(node, relationshipType, direction, node2));
53+
Set<Node> neighbors = new CommonNeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
5154
return neighbors.stream().mapToDouble(nb -> 1.0 / Math.log(degree(relationshipType, direction, nb))).sum();
5255
}
5356

@@ -66,8 +69,7 @@ public double resourceAllocationSimilarity(@Name("node1") Node node1, @Name("nod
6669
RelationshipType relationshipType = configuration.getRelationship();
6770
Direction direction = configuration.getDirection(Direction.BOTH);
6871

69-
Set<Node> neighbors = findPotentialNeighbors(node1, relationshipType, direction);
70-
neighbors.removeIf(node -> noCommonNeighbors(node, relationshipType, direction, node2));
72+
Set<Node> neighbors = new CommonNeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
7173
return neighbors.stream().mapToDouble(nb -> 1.0 / degree(relationshipType, direction, nb)).sum();
7274
}
7375

@@ -84,37 +86,13 @@ public double commonNeighbors(@Name("node1") Node node1, @Name("node2") Node nod
8486
RelationshipType relationshipType = configuration.getRelationship();
8587
Direction direction = configuration.getDirection(Direction.BOTH);
8688

87-
Set<Node> neighbors = findPotentialNeighbors(node1, relationshipType, direction);
88-
neighbors.removeIf(node -> noCommonNeighbors(node, relationshipType, direction, node2));
89+
Set<Node> neighbors = new CommonNeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
8990
return neighbors.size();
9091
}
9192

92-
private Set<Node> findPotentialNeighbors(@Name("node1") Node node1, RelationshipType relationshipType, Direction direction) {
93-
Set<Node> neighbors = new HashSet<>();
94-
95-
for (Relationship rel : loadRelationships(node1, relationshipType, direction)) {
96-
Node endNode = rel.getEndNode();
97-
neighbors.add(endNode);
98-
}
99-
return neighbors;
100-
}
10193

10294
private int degree(RelationshipType relationshipType, Direction direction, Node node) {
10395
return relationshipType == null ? node.getDegree(direction) : node.getDegree(relationshipType, direction);
10496
}
10597

106-
private Iterable<Relationship> loadRelationships(Node node, RelationshipType relationshipType, Direction direction) {
107-
return relationshipType == null ? node.getRelationships(direction) : node.getRelationships(relationshipType, direction);
108-
}
109-
110-
private boolean noCommonNeighbors(Node node, RelationshipType relationshipType, Direction direction, Node node2) {
111-
for (Relationship rel : loadRelationships(node, relationshipType, direction)) {
112-
if (rel.getOtherNode(node).equals(node2)) {
113-
return false;
114-
}
115-
}
116-
return true;
117-
}
118-
119-
12098
}

tests/src/test/java/org/neo4j/graphalgo/algo/linkprediction/AdamicAdarProcIntegrationTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,19 @@ public void noNeighbors() throws Exception {
136136
}
137137
}
138138

139+
@Test
140+
public void bothNodesTheSame() throws Exception {
141+
String controlQuery =
142+
"MATCH (p1:Person {name: 'Praveena'})\n" +
143+
"MATCH (p2:Person {name: 'Praveena'})\n" +
144+
"RETURN algo.linkprediction.adamicAdar(p1, p2) AS score, " +
145+
" 0.0 AS cypherScore";
146+
147+
try (Transaction tx = db.beginTx()) {
148+
Result result = db.execute(controlQuery);
149+
Map<String, Object> node = result.next();
150+
assertEquals((Double) node.get("cypherScore"), (double) node.get("score"), 0.01);
151+
}
152+
}
153+
139154
}
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package org.neo4j.graphalgo.algo.linkprediction;
2+
3+
import org.junit.Before;
4+
import org.junit.Rule;
5+
import org.junit.Test;
6+
import org.neo4j.graphalgo.linkprediction.CommonNeighborsFinder;
7+
import org.neo4j.graphdb.Direction;
8+
import org.neo4j.graphdb.Node;
9+
import org.neo4j.graphdb.RelationshipType;
10+
import org.neo4j.graphdb.Transaction;
11+
import org.neo4j.kernel.internal.GraphDatabaseAPI;
12+
import org.neo4j.test.rule.ImpermanentDatabaseRule;
13+
14+
import java.util.Set;
15+
16+
import static org.junit.Assert.assertEquals;
17+
18+
public class CommonNeighborsFinderTest {
19+
20+
@Rule
21+
public final ImpermanentDatabaseRule DB = new ImpermanentDatabaseRule();
22+
23+
private GraphDatabaseAPI api;
24+
public static final RelationshipType FRIEND = RelationshipType.withName("FRIEND");
25+
public static final RelationshipType COLLEAGUE = RelationshipType.withName("COLLEAGUE");
26+
public static final RelationshipType FOLLOWS = RelationshipType.withName("FOLLOWS");
27+
28+
@Before
29+
public void setup() {
30+
api = DB.getGraphDatabaseAPI();
31+
}
32+
33+
@Test
34+
public void excludeDirectRelationships() throws Throwable {
35+
try (Transaction tx = api.beginTx()) {
36+
Node node1 = api.createNode();
37+
Node node2 = api.createNode();
38+
node1.createRelationshipTo(node2, FRIEND);
39+
tx.success();
40+
}
41+
42+
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
43+
44+
try (Transaction tx = api.beginTx()) {
45+
Node node1 = api.getNodeById(0);
46+
Node node2 = api.getNodeById(1);
47+
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, null, Direction.BOTH);
48+
49+
assertEquals(0, neighbors.size());
50+
}
51+
}
52+
53+
@Test
54+
public void sameNodeHasNoCommonNeighbors() throws Throwable {
55+
try (Transaction tx = api.beginTx()) {
56+
Node node1 = api.createNode();
57+
Node node2 = api.createNode();
58+
node1.createRelationshipTo(node2, FRIEND);
59+
tx.success();
60+
}
61+
62+
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
63+
64+
try (Transaction tx = api.beginTx()) {
65+
Node node1 = api.getNodeById(0);
66+
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node1, null, Direction.BOTH);
67+
68+
assertEquals(0, neighbors.size());
69+
}
70+
}
71+
72+
@Test
73+
public void findNeighborsExcludingDirection() throws Throwable {
74+
75+
try (Transaction tx = api.beginTx()) {
76+
Node node1 = api.createNode();
77+
Node node2 = api.createNode();
78+
Node node3 = api.createNode();
79+
Node node4 = api.createNode();
80+
81+
node1.createRelationshipTo(node3, FRIEND);
82+
node2.createRelationshipTo(node3, FRIEND);
83+
node1.createRelationshipTo(node4, COLLEAGUE);
84+
node2.createRelationshipTo(node4, COLLEAGUE);
85+
86+
tx.success();
87+
}
88+
89+
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
90+
91+
try (Transaction tx = api.beginTx()) {
92+
Node node1 = api.getNodeById(0);
93+
Node node2 = api.getNodeById(1);
94+
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, null, Direction.BOTH);
95+
96+
assertEquals(2, neighbors.size());
97+
}
98+
}
99+
100+
@Test
101+
public void findOutgoingNeighbors() throws Throwable {
102+
103+
try (Transaction tx = api.beginTx()) {
104+
Node node1 = api.createNode();
105+
Node node2 = api.createNode();
106+
Node node3 = api.createNode();
107+
108+
node1.createRelationshipTo(node3, FOLLOWS);
109+
node2.createRelationshipTo(node3, FOLLOWS);
110+
111+
tx.success();
112+
}
113+
114+
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
115+
116+
try (Transaction tx = api.beginTx()) {
117+
Node node1 = api.getNodeById(0);
118+
Node node2 = api.getNodeById(1);
119+
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, FOLLOWS, Direction.OUTGOING);
120+
121+
assertEquals(1, neighbors.size());
122+
}
123+
}
124+
125+
@Test
126+
public void findIncomingNeighbors() throws Throwable {
127+
128+
try (Transaction tx = api.beginTx()) {
129+
Node node1 = api.createNode();
130+
Node node2 = api.createNode();
131+
Node node3 = api.createNode();
132+
133+
node3.createRelationshipTo(node1, FOLLOWS);
134+
node3.createRelationshipTo(node2, FOLLOWS);
135+
136+
tx.success();
137+
}
138+
139+
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
140+
141+
try (Transaction tx = api.beginTx()) {
142+
Node node1 = api.getNodeById(0);
143+
Node node2 = api.getNodeById(1);
144+
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, FOLLOWS, Direction.INCOMING);
145+
146+
assertEquals(1, neighbors.size());
147+
}
148+
}
149+
150+
@Test
151+
public void findNeighborsOfSpecificRelationshipType() throws Throwable {
152+
153+
try (Transaction tx = api.beginTx()) {
154+
Node node1 = api.createNode();
155+
Node node2 = api.createNode();
156+
Node node3 = api.createNode();
157+
Node node4 = api.createNode();
158+
159+
node1.createRelationshipTo(node3, FRIEND);
160+
node2.createRelationshipTo(node3, FRIEND);
161+
node1.createRelationshipTo(node4, COLLEAGUE);
162+
node2.createRelationshipTo(node4, COLLEAGUE);
163+
164+
tx.success();
165+
}
166+
167+
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
168+
169+
try (Transaction tx = api.beginTx()) {
170+
Node node1 = api.getNodeById(0);
171+
Node node2 = api.getNodeById(1);
172+
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, COLLEAGUE, Direction.BOTH);
173+
174+
assertEquals(1, neighbors.size());
175+
}
176+
}
177+
178+
179+
180+
}
181+

0 commit comments

Comments
 (0)