-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Description
Hello, I am a beginner in GAT , and I've been studying your GATv2 code lately. I have a question while going through the code in
labml_nn/graphs/gatv2/init.py
When calculating g_sum
g_sum = g_l_repeat + g_r_repeat_interleave
You mentioned in the comments: Now we add the two tensors to get
But in the previous code, g_l_repeat
gets
and g_r_repeat_interleave
gets
So I think the result of adding the two tensors should be
I'm not sure whether I may have overlooked some crucial information or if there's a mismatch between your comments and the code. I would greatly appreciate it if you could help clarify my confusion. Thank you.