Skip to content

Commit 9b7cdf3

Browse files
committed
Fix a lot of bugs
1 parent 0804fa2 commit 9b7cdf3

File tree

6 files changed

+94
-10
lines changed

6 files changed

+94
-10
lines changed

code_graph/__init__.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from .graph import CodeGraph
44

55
from .pylang import pylang_analyses
6-
from .javalang import javalang_analyses
6+
from .javalang import javalang_analyses, java_preprocess
77

88

99
def codegraph(source_code, lang = "guess", analyses = None, **kwargs):
10-
tokens = ctok.tokenize(source_code, lang = lang, **kwargs)
11-
root_node = _root_node(tokens)
10+
root_node, tokens = preprocess_code(source_code, lang, **kwargs)
1211

1312
graph_analyses = load_lang_analyses(tokens[0].config.lang)
1413

@@ -32,14 +31,31 @@ def load_lang_analyses(lang):
3231
if lang == 'java' : return javalang_analyses()
3332

3433
raise NotImplementedError("Language %s is not supported" % lang)
35-
34+
35+
36+
def preprocess_code(source_code, lang, **kwargs):
37+
38+
if lang == "java": return java_preprocess(source_code, **kwargs)
39+
40+
return default_preprocess(source_code, lang, **kwargs)
41+
42+
43+
def default_preprocess(source_code, lang, **kwargs):
44+
tokens = ctok.tokenize(source_code, lang = lang, **kwargs)
45+
root_node = _root_node(tokens)
46+
return root_node, tokens
3647

3748
# Helper methods --------------------------------
3849

3950
def _root_node(tokens):
4051
if len(tokens) == 0: raise ValueError("Empty program has no root node")
4152

42-
base_token = tokens[0]
53+
i = 0
54+
while not hasattr(tokens[i], "ast_node"):
55+
i += 1
56+
57+
base_token = tokens[i]
58+
4359
current_ast = base_token.ast_node
4460

4561
while current_ast.parent is not None:

code_graph/ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ def visit(self, ast_node):
1515

1616
prev_sibling = ast_node.prev_sibling
1717
if prev_sibling is not None:
18-
graph.add_relation(prev_sibling, ast_node, "sibling")
18+
graph.add_relation(prev_sibling, ast_node, "sibling", no_create=True)

code_graph/graph.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@ def __init__(self, root_node, tokens, lang = "python"):
1616
self._ast_nodes = {} # Nodes indexed by an AST node
1717
self._anonymous_nodes = [] # Unindexed nodes, can only be indexed by traversal
1818

19+
self.token_nodes = []
20+
1921
prev_token = self._add_token(tokens[0])
22+
self.token_nodes.append(prev_token)
23+
2024
for token in tokens[1:]:
2125
token_node = self._add_token(token)
2226
prev_token.add_successor(token_node, "next_token")
2327
prev_token = token_node
28+
self.token_nodes.append(token_node)
2429

2530
self.root_node = self.add_or_get_node(root_node)
2631

@@ -70,7 +75,12 @@ def add_or_get_node(self, node):
7075
except Exception:
7176
raise ValueError("Cannot add or get node %s. Only AST nodes can be indexed." % str(node))
7277

73-
def add_relation(self, source_node, target_node, relation = "ast"):
78+
def add_relation(self, source_node, target_node, relation = "ast", no_create = False):
79+
80+
if no_create:
81+
if not self.has_node(source_node): return
82+
if not self.has_node(target_node): return
83+
7484
source_node = self.add_or_get_node(source_node)
7585
target_node = self.add_or_get_node(target_node)
7686
source_node.add_successor(target_node, relation)
@@ -207,7 +217,8 @@ def __hash__(self):
207217
class TokenNode(SyntaxNode):
208218

209219
def __init__(self, token):
210-
super().__init__(token.ast_node)
220+
token_node = token.ast_node if hasattr(token, 'ast_node') else None
221+
super().__init__(token_node)
211222
self.token = token
212223

213224
def node_name(self):
@@ -216,7 +227,7 @@ def node_name(self):
216227
def __hash__(self):
217228
if self.ast_node is not None:
218229
return hash(node_key(self.ast_node))
219-
return self.token.text
230+
return hash(self.token.text)
220231

221232

222233

code_graph/javalang/__init__.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,59 @@
33
from .cfg import ControlFlowVisitor
44
from .dataflow import DataFlowVisitor
55

6+
import code_tokenize as ctok
7+
8+
from code_tokenize.tokens import TokenSequence
9+
10+
611
def javalang_analyses():
712
return {
813
"ast": ASTRelationVisitor,
914
"cfg": ControlFlowVisitor,
1015
"dataflow": DataFlowVisitor,
11-
}
16+
}
17+
18+
19+
# Preprocessor ------------------------------------------------
20+
21+
def _try_tokenize_or_wrap(source_code, **kwargs):
22+
try:
23+
custom_args = {k: v for k, v in kwargs.items() if k != "syntax_error"}
24+
return ctok.tokenize(
25+
source_code, lang = "java", syntax_error = "raise", **custom_args
26+
), False
27+
except SyntaxError:
28+
source_code = "public class Test {%s}" % source_code
29+
return ctok.tokenize(source_code, lang = "java", **kwargs), True
30+
31+
32+
def java_preprocess(source_code, **kwargs):
33+
tokens, has_wrapped = _try_tokenize_or_wrap(source_code, **kwargs)
34+
35+
if not has_wrapped:
36+
return _root_node(tokens), tokens
37+
38+
output_tokens = TokenSequence(tokens[4:-1])
39+
40+
root_node = output_tokens[0].ast_node
41+
while root_node.parent is not None:
42+
root_node = root_node.parent
43+
if root_node.type == "method_declaration": break
44+
45+
return root_node, output_tokens
46+
47+
48+
def _root_node(tokens):
49+
if len(tokens) == 0: raise ValueError("Empty program has no root node")
50+
51+
base_token = tokens[0]
52+
current_ast = base_token.ast_node
53+
54+
while current_ast.parent is not None:
55+
current_ast = current_ast.parent
56+
57+
# If root only has one child skip to child
58+
if current_ast.child_count == 1:
59+
current_ast = current_ast.children[0]
60+
61+
return current_ast

code_graph/pylang/dataflow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def record_write(self, node):
8888

8989
def record_read(self, node):
9090
node = self.graph.add_or_get_node(node)
91+
92+
assert hasattr(node, "token"), "Expected to read from a token, but got: %s" % node
93+
9194
qname = self.qualname(node.token.text)
9295

9396
for last_read in self._last_reads[qname]:
@@ -385,6 +388,9 @@ def visit_lambda(self, node):
385388

386389
self._restore_rw_context(rw_context)
387390
return False
391+
392+
def visit_string(self, node):
393+
return False # Currently we do not support f-strings
388394

389395

390396
# Helper --------------------------------------------------------

code_graph/visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __call__(self, root_node):
156156
# Helper --------------------------------
157157

158158
def node_equal(n1, n2):
159+
if n1 == n2: return True
159160
try:
160161
return (n1.type == n2.type
161162
and n1.start_point == n2.start_point

0 commit comments

Comments
 (0)