Skip to content

Conversation

@swong3-sc
Copy link
Collaborator

@swong3-sc swong3-sc commented Oct 31, 2025

Scope of work done

  • Added Heterogeneous functionality for LightGCN, using HeteroData input to forward

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

  • Added unit test to check final embeddings for bipartite user-item graph
  • Added unit test to check anchor-node functionality for a bipartite user-item graph

Updated Changelog.md? NO

Ready for code review?: YES

data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous).
data (Union[Data, HeteroData]): Graph data.
- For homogeneous: Data object with edge_index and node field
- For bipartite: HeteroData with 2 node types and edge_index_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we restricting this to bipartite? Does it get easier if we do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, LightGCN is not defined for cases beyond bipartite. We decided for now to only worry about bipartite.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this happen when I was gone? I really don't see the point in doing so, I don't think the implementation for a "heterogeneous" implementation would be any different.

And if we go with this approach then in the future, will we have a third _forward_heterogeneous implementation? Is that not just more complicated?

It seems to me like the current _forward_bipartite implementation has no restrictions on this being bipartite or not, right? I don't see any asserts/etc in the code that would restrict us right?

Why not keep the current requirement that users provide two node types, but rename this function to _forward_heterogeneous?

Or am I missing something and the code here is in fact only correct for bipartite?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're asking if the code can handle n different node types, then yes I don't see why not. The problem I have with heterogeneous is that it kind of implies that it should be able to handle multiple node types and multiple edge types. It is much more complicated to consider multiple edge types. For now, I changed the naming convention to _forward_heterogeneous, and generalized the comments to heterogeneous.

src_offset = node_type_to_offset[src_node_type]
dst_offset = node_type_to_offset[dst_node_type]

offset_edge_index = edge_index.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we clone here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, is this because we mutate this? does to create a copy already?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe it creates a copy necessarily. So if data[edge_type_tuple].edge_index is already on the target device, .to(device) would return the same tensor. Then when we do the mutation operations, this would affect the original graph data, which we don't want I don't think.


for node_type in output_node_types:
node_type_str = str(node_type)
key = f"{node_type_str}_id"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably parameterise this somehow, either with some NODE_ID_FMT = "{node_type}_id"; NODE_ID_FMT.format(node_type=node_type_str) or like, def get_nt_key().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. I added a static method to centralize the key naming.

# LightGCN propagation across node types
all_node_types = list(node_type_to_embeddings_0.keys())

# For bipartite, we need to create a unified edge representation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do this? Don't we have separate embedding tables for the different nodes?

Or does the pyg convolution require all nodes be in the same space?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGConv, unlike other HeteroConvs, doesn't support heterogeneous data, so we'd have to re-implement it. Thus, all nodes need to be in a single unified index space. Additionally, we could run multiple forward passes for each type, but this seems inefficient, as we'd have to aggregate the results anyway.

Comment on lines 454 to 458
if anchor_node_ids is not None:
for node_type in all_node_types:
if isinstance(anchor_node_ids, dict) and node_type in anchor_node_ids:
anchors = anchor_node_ids[node_type].to(device).long()
final_embeddings[node_type] = final_embeddings[node_type][anchors]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have all node types: a, b, c and anchor node ids {a: [1, 2], b, [3, 4]} then the final embeddings would still contain all of the embeddings for c, right?

Should we fix this? (and also add some test for it?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly sure what you're asking, but I added some anchor node checks in the test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'm saying is let's say a user provides anchor_node_ids = {a: [10, 20]}, and the graph has node types {a, b}

The returned final_embeddings will be {a: [10, 20], b: all_nodes}. Right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, I see. Yes, this is the expected behavior. I take it you're saying we shouldn't be returning any embeddings for b here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, to me, providing anchor_node_ids is some sort of sign that the user wants to filter down the output (to save memory, for instance).

In that case, it seems "more correct" to not return any embeddings for unspecified anchor nodes (in this case, b).

WDYT?

@swong3-sc swong3-sc force-pushed the swong3/add_dmp_tests branch from af3d278 to 371c85a Compare November 6, 2025 00:09
Base automatically changed from swong3/add_dmp_tests to main November 8, 2025 01:05
@swong3-sc swong3-sc force-pushed the swong3/add_heterogenous_lightgcn branch from b02fb53 to 50e1f47 Compare November 8, 2025 06:50
@swong3-sc swong3-sc marked this pull request as ready for review November 10, 2025 06:34
Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None.
"""

@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not make this a static method?

why not just a _private free function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that makes sense, changed to a free function outside LightGCN class.

data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous).
data (Union[Data, HeteroData]): Graph data.
- For homogeneous: Data object with edge_index and node field
- For bipartite: HeteroData with 2 node types and edge_index_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this happen when I was gone? I really don't see the point in doing so, I don't think the implementation for a "heterogeneous" implementation would be any different.

And if we go with this approach then in the future, will we have a third _forward_heterogeneous implementation? Is that not just more complicated?

It seems to me like the current _forward_bipartite implementation has no restrictions on this being bipartite or not, right? I don't see any asserts/etc in the code that would restrict us right?

Why not keep the current requirement that users provide two node types, but rename this function to _forward_heterogeneous?

Or am I missing something and the code here is in fact only correct for bipartite?

Comment on lines 454 to 458
if anchor_node_ids is not None:
for node_type in all_node_types:
if isinstance(anchor_node_ids, dict) and node_type in anchor_node_ids:
anchors = anchor_node_ids[node_type].to(device).long()
final_embeddings[node_type] = final_embeddings[node_type][anchors]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'm saying is let's say a user provides anchor_node_ids = {a: [10, 20]}, and the graph has node types {a, b}

The returned final_embeddings will be {a: [10, 20], b: all_nodes}. Right?

@swong3-sc
Copy link
Collaborator Author

Note the answers to some of your comments @kmontemayor2-sc are above my review for some reason.

output_with_anchors = model(
data,
self.device,
output_node_types=[NodeType("user"), NodeType("item")],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just do the one output node type, and ensure the output dict is of length 1, I think there's still the bug here :) Related to #370 (comment)

Suggested change
output_node_types=[NodeType("user"), NodeType("item")],
output_node_types=[NodeType("user")],

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, yeah I forgot to fix this bug. It is fixed now in the model code, so we don't have to change output_node_types

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to not have one output type here?

It seems like we'd have use cases where we really do only want the one output type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found an issue where if you pass just one output node type, the code errors because it needs to know all the different node types. So I think we should just pass anchor nodes to determine the type of output we get. LMK your thoughts, but I don't see a need to specify the output node types anymore.

return LinkPredictionGNN(encoder=encoder, decoder=decoder)


def _get_feature_key(node_type: Union[str, NodeType]) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd suggest a clearer name, e.g. _build_node_embedding_table_name_by_node_type. i think we might actually use something like get_feature_key for other generic functionalities inside GiGL

# Sort node types for deterministic ordering across machines
self._feature_keys: list[str] = [
f"{node_type}_id" for node_type in self._node_type_to_num_nodes.keys()
_get_feature_key(node_type) for node_type in sorted(self._node_type_to_num_nodes.keys())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this use e.g. SortedDict functionalities as introduced in e.g. #391? i assume it was introduced for such usecases

tables: list[EmbeddingBagConfig] = []
for node_type, num_nodes in self._node_type_to_num_nodes.items():
# Sort node types for deterministic ordering across machines
for node_type, num_nodes in sorted(self._node_type_to_num_nodes.items()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging in case sorteddict should be used here

anchor_node_ids (Optional[torch.Tensor]): Local node indices to return
embeddings for. If None, returns embeddings for all nodes. Default: None.
anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]):
Local node indices to return embeddings for.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just clarifying -- "local" here refers specifically to the local index in the object, which is different from the global node ID, right?

node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {}

for node_type in all_node_types_in_data:
node_type_str = str(node_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might be less confusing to just have this variable be implicit and use _get_feature_key(str(node_type)) .

Managing two variables that both are strings and both slightly differently reflect the node type is actually somehow more confusing than just using one and appropriately using a function modifier when you need it IMO


for node_type in all_node_types_in_data:
node_type_str = str(node_type)
key = _get_feature_key(node_type_str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use a more descriptive name, e.g. embedding_table_key or something

global_ids = data[node_type_str].node.to(device).long() # shape [N_type]

embeddings = self._lookup_embeddings_for_single_node_type(
key, global_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: kwargs

Comment on lines +419 to +454
# For heterogeneous graphs, we need to create a unified edge representation
# Collect all edges and map node indices to a combined space
# E.g., node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1)
node_type_to_offset: dict[NodeType, int] = {}
offset = 0
for node_type in all_node_types_in_data:
node_type_to_offset[node_type] = offset
node_type_str = str(node_type)
offset += data[node_type_str].num_nodes

# Combine all embeddings into a single tensor
combined_embeddings_0 = torch.cat(
[node_type_to_embeddings_0[nt] for nt in all_node_types_in_data], dim=0
) # shape [total_nodes, D]

# Combine all edges into a single edge_index
# Sort edge types for deterministic ordering across machines
combined_edge_list: list[torch.Tensor] = []
for edge_type_tuple in sorted(data.edge_types):
src_nt_str, _, dst_nt_str = edge_type_tuple
src_node_type = NodeType(src_nt_str)
dst_node_type = NodeType(dst_nt_str)

edge_index = data[edge_type_tuple].edge_index.to(device) # shape [2, E]

# Offset the indices to the combined node space
src_offset = node_type_to_offset[src_node_type]
dst_offset = node_type_to_offset[dst_node_type]

offset_edge_index = edge_index.clone()
offset_edge_index[0] += src_offset
offset_edge_index[1] += dst_offset

combined_edge_list.append(offset_edge_index)

combined_edge_index = torch.cat(combined_edge_list, dim=1) # shape [2, total_edges]
Copy link
Collaborator

@nshah-sc nshah-sc Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why we need to unify all these edge tensors and node tensors with this offset logic

why can't we use native PyG operations like message or propagate on the HeteroData or Data objects using e.g. torch_geometric.nn.conv.LGConv?

I think the important part of your current implementation is that it needs to be able to fetch embeddings at large scale. I am not sure you need to handroll your own convolution logic for messaging and aggregation too. The advantage of PyG is that it enables you to not have to do that right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also in general if you have K different types of edges between 2 types of nodes (e.g. user and item), it feels reasonable that the application of LightGCN on this setting would be like LightGCN convolution applied to each of these edge-types + some pooling (sum, mean, max, etc.) applied to the K resulting user embeddings or K resulting item embeddings. basic LightGCN can be seen as an application which only works on 1 type of edge, and skips the pooling for like an Identity function.

What you're doing here in terms of merging all the lists of edges actually feels like a specific sub-case of LightGCN and i'm not sure it is worth implementing it so specifically?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like our goal in implementing LightGCN in this exercise is to "plug and play" with TorchRec + PyG. I think we are currently doing something like TorchRec + hand-rolling, which feels clunky if that makes sense.

After all the end goal we have with all this work (whether LightGCN or otherwise) is to connect the TorchRec elements to PyG elements more generally for GiGl models so following that pattern seems logical here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why we need to unify all these edge tensors and node tensors with this offset logic

I may be a pyg noob here but my understanding is that [part of] the highlighted code is to build node_type_to_offset so that we can do the one lookup into the torchrec tables and then later reference the embeddings by node type?

Maybe I'm a bit confused as to what you'd be suggesting here? Is it to do the conv by edge index? e..g ~ for edge_index in data.edge_indicies(): embeddings.append(self.conv(edge_index))

also in general if you have K different types of edges between 2 types of nodes (e.g. user and item), it feels reasonable that the application of LightGCN on this setting would be like LightGCN convolution applied to each of these edge-types + some pooling (sum, mean, max, etc.) applied to the K resulting user embeddings or K resulting item embeddings. basic LightGCN can be seen as an application which only works on 1 type of edge, and skips the pooling for like an Identity function.

this makes sense actually! again a bit of a pyg noob so I'm not sure the "idiomatic" way to solve this problem? Would it be to have some base class and then have subclasses override a pool method on it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be a pyg noob here but my understanding is that [part of] the highlighted code is to build node_type_to_offset so that we can do the one lookup into the torchrec tables and then later reference the embeddings by node type?

the block from line 399-417 already gets all the node embeddings for each node_type by making len(all_node_types_in_data) calls to the TorchRec tables IIUC.

I think we already went through the bother of making these separate calls. You are right that perhaps they could be batched and fetched in one lookup if we were using a different implementation, but @swong3-sc is relying on _lookup_embeddings_for_single_node_type being called multiple times so that isn't happening here.

this makes sense actually! again a bit of a pyg noob so I'm not sure the "idiomatic" way to solve this problem? Would it be to have some base class and then have subclasses override a pool method on it?

I don't totally know the best "most idiomatic" way here as others have likely read more PyG code than me, but just looking at LGConv's implementation: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/lg_conv.html#LGConv it seems like this operates on a single tensor of node features x (which contains multiple types of nodes all in a single (num_nodes, F) tensor, and a single edge_index tensor which contains the edges between src and dst nodes. Presumably this edge_index is defined on a single edge type. we could just apply this on num_edge_types in a for loop or something and accumulate the embeddings for each one using some pooling function. I would probably say that our best bet when building these "torchrec-enhanced modules" is to rely on a pattern of "doing the TorchRec munging to get in PyG format" + "calling PyG" so we don't have to reinvent parts of the wheel that are already built and tested

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @swong3-sc is currently creating one big tensor of node features and one big edge_index which is a union of all edges in the individual edge-typed edge_index tensors, but the issue with this (I think) is that in reality, this should be treated as different edge_indices and convolutions, rather than just one. LGConv operation involves normalizing messages based on degree of nodes, which is specific to each edge type. The message-passing should (I believe) not be mixed across all edge types and should rather be specific to each edge type, so that we end up with multiple embeddings for each node if it is adjacent to multiple edge types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants