diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index 9fa29f62..26876636 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -134,6 +134,19 @@ def unwrap_from_ddp(self) -> "LinkPredictionGNN": return LinkPredictionGNN(encoder=encoder, decoder=decoder) +def _get_feature_key(node_type: Union[str, NodeType]) -> str: + """ + Get the feature key for a node type's embedding table. + + Args: + node_type: Node type as string or NodeType object. + + Returns: + str: Feature key in format "{node_type}_id" + """ + return f"{node_type}_id" + + # TODO(swong3): Move specific models to gigl.nn.models whenever we restructure model placement. # TODO(swong3): Abstract TorchRec functionality, and make this LightGCN specific # TODO(swong3): Remove device context from LightGCN module (use meta, but will have to figure out how to handle buffer transfer) @@ -187,18 +200,29 @@ def __init__( ) # Build TorchRec EBC (one table per node type) - # feature key naming convention: f"{node_type}_id" + # Sort node types for deterministic ordering across machines self._feature_keys: list[str] = [ - f"{node_type}_id" for node_type in self._node_type_to_num_nodes.keys() + _get_feature_key(node_type) for node_type in sorted(self._node_type_to_num_nodes.keys()) ] + + # Validate model configuration: restrict to homogeneous or bipartite graphs + num_node_types = len(self._feature_keys) + if num_node_types not in [1, 2]: + # TODO(kmonte, swong3): We should loosen this restriction and allow fully heterogenous graphs in the future. + raise ValueError( + f"LightGCN only supports homogeneous (1 node type) or bipartite (2 node types) graphs; " + f"got {num_node_types} node types: {self._feature_keys}" + ) + tables: list[EmbeddingBagConfig] = [] - for node_type, num_nodes in self._node_type_to_num_nodes.items(): + # Sort node types for deterministic ordering across machines + for node_type, num_nodes in sorted(self._node_type_to_num_nodes.items()): tables.append( EmbeddingBagConfig( name=f"node_embedding_{node_type}", embedding_dim=embedding_dim, num_embeddings=num_nodes, - feature_names=[f"{node_type}_id"], + feature_names=[_get_feature_key(node_type)], ) ) @@ -215,32 +239,44 @@ def forward( self, data: Union[Data, HeteroData], device: torch.device, - output_node_types: Optional[list[NodeType]] = None, - anchor_node_ids: Optional[torch.Tensor] = None, + anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """ Forward pass of the LightGCN model. Args: - data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous). + data (Union[Data, HeteroData]): Graph data. + - For homogeneous: Data object with edge_index and node field + - For heterogeneous: HeteroData with node types and edge_index_dict device (torch.device): Device to run the computation on. - output_node_types (Optional[List[NodeType]]): List of node types to return - embeddings for. Required for heterogeneous graphs. Default: None. - anchor_node_ids (Optional[torch.Tensor]): Local node indices to return - embeddings for. If None, returns embeddings for all nodes. Default: None. + anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): + Local node indices to return embeddings for. + - For homogeneous: torch.Tensor of shape [num_anchors] + - For heterogeneous: dict mapping node types to anchor tensors + If None, returns embeddings for all nodes. Default: None. Returns: Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: Node embeddings. - For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim]. - For heterogeneous graphs, returns dict mapping node types to embeddings. + - For homogeneous: tensor of shape [num_nodes, embedding_dim] + - For heterogeneous: dict mapping node types to embeddings """ - if isinstance(data, HeteroData): - raise NotImplementedError("HeteroData is not yet supported for LightGCN") - output_node_types = output_node_types or list(data.node_types) - return self._forward_heterogeneous( - data, device, output_node_types, anchor_node_ids - ) + is_heterogeneous = isinstance(data, HeteroData) + + if is_heterogeneous: + # For heterogeneous graphs, anchor_node_ids must be a dict, not a Tensor + if anchor_node_ids is not None and not isinstance(anchor_node_ids, dict): + raise TypeError( + f"For heterogeneous graphs, anchor_node_ids must be a dict or None, " + f"got {type(anchor_node_ids)}" + ) + return self._forward_heterogeneous(data, device, anchor_node_ids) else: + # For homogeneous graphs, anchor_node_ids must be a Tensor, not a dict + if anchor_node_ids is not None and not isinstance(anchor_node_ids, torch.Tensor): + raise TypeError( + f"For homogeneous graphs, anchor_node_ids must be a Tensor or None, " + f"got {type(anchor_node_ids)}" + ) return self._forward_homogeneous(data, device, anchor_node_ids) def _forward_homogeneous( @@ -323,6 +359,134 @@ def _forward_homogeneous( final_embeddings # shape [N_sub, D], embeddings for all nodes in subgraph ) + def _forward_heterogeneous( + self, + data: HeteroData, + device: torch.device, + anchor_node_ids: Optional[dict[NodeType, torch.Tensor]] = None, + ) -> dict[NodeType, torch.Tensor]: + """ + Forward pass for heterogeneous graphs using LightGCN propagation. + + For heterogeneous graphs (e.g., user-item), we have + multiple node types. LightGCN propagates embeddings across + all node types by creating a unified node space, running propagation, then splitting + back into per-type embeddings. + + Note: All node types in the graph are processed during message passing, as this is + required for correct GNN computation. Use anchor_node_ids to filter which node types + and specific nodes are returned in the output. + + Args: + data (HeteroData): PyG HeteroData object with node types. + device (torch.device): Device to run computation on. + anchor_node_ids (Optional[Dict[NodeType, torch.Tensor]]): Dict mapping node types + to local anchor indices. If None, returns all nodes for all types. + If provided, only returns embeddings for the specified node types and indices. + + Returns: + Dict[NodeType, torch.Tensor]: Dict mapping node types to their embeddings, + each of shape [num_nodes_of_type, embedding_dim] (or [num_anchors, embedding_dim] + if anchor_node_ids is provided for that type). + """ + # Process all node types - this is required for correct message passing in GNNs + # Sort node types for deterministic ordering across machines + all_node_types_in_data = [NodeType(nt) for nt in sorted(data.node_types)] + + # Lookup initial embeddings e^(0) for each node type + node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {} + + for node_type in all_node_types_in_data: + node_type_str = str(node_type) + key = _get_feature_key(node_type_str) + + assert hasattr(data[node_type_str], "node"), ( + f"Subgraph must include .node field for node type {node_type_str}" + ) + + global_ids = data[node_type_str].node.to(device).long() # shape [N_type] + + embeddings = self._lookup_embeddings_for_single_node_type( + key, global_ids + ) # shape [N_type, D] + + # Handle DMP Awaitable + if isinstance(embeddings, Awaitable): + embeddings = embeddings.wait() + + node_type_to_embeddings_0[node_type] = embeddings + + # For heterogeneous graphs, we need to create a unified edge representation + # Collect all edges and map node indices to a combined space + # E.g., node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1) + node_type_to_offset: dict[NodeType, int] = {} + offset = 0 + for node_type in all_node_types_in_data: + node_type_to_offset[node_type] = offset + node_type_str = str(node_type) + offset += data[node_type_str].num_nodes + + # Combine all embeddings into a single tensor + combined_embeddings_0 = torch.cat( + [node_type_to_embeddings_0[nt] for nt in all_node_types_in_data], dim=0 + ) # shape [total_nodes, D] + + # Combine all edges into a single edge_index + # Sort edge types for deterministic ordering across machines + combined_edge_list: list[torch.Tensor] = [] + for edge_type_tuple in sorted(data.edge_types): + src_nt_str, _, dst_nt_str = edge_type_tuple + src_node_type = NodeType(src_nt_str) + dst_node_type = NodeType(dst_nt_str) + + edge_index = data[edge_type_tuple].edge_index.to(device) # shape [2, E] + + # Offset the indices to the combined node space + src_offset = node_type_to_offset[src_node_type] + dst_offset = node_type_to_offset[dst_node_type] + + offset_edge_index = edge_index.clone() + offset_edge_index[0] += src_offset + offset_edge_index[1] += dst_offset + + combined_edge_list.append(offset_edge_index) + + combined_edge_index = torch.cat(combined_edge_list, dim=1) # shape [2, total_edges] + + # Track all layer embeddings + all_layer_embeddings: list[torch.Tensor] = [combined_embeddings_0] + current_embeddings = combined_embeddings_0 + + # Perform K layers of propagation + for conv in self._convs: + current_embeddings = conv(current_embeddings, combined_edge_index) # shape [total_nodes, D] + all_layer_embeddings.append(current_embeddings) + + # Weighted sum across layers + combined_final_embeddings = self._weighted_layer_sum(all_layer_embeddings) # shape [total_nodes, D] + + # Split back into per-node-type embeddings + final_embeddings: dict[NodeType, torch.Tensor] = {} + for node_type in all_node_types_in_data: + start_idx = node_type_to_offset[node_type] + node_type_str = str(node_type) + num_nodes = data[node_type_str].num_nodes + end_idx = start_idx + num_nodes + + final_embeddings[node_type] = combined_final_embeddings[start_idx:end_idx] # shape [num_nodes, D] + + # Extract anchor nodes if specified + if anchor_node_ids is not None: + # Only return embeddings for node types specified in anchor_node_ids + filtered_embeddings: dict[NodeType, torch.Tensor] = {} + for node_type in all_node_types_in_data: + if node_type in anchor_node_ids: + anchors = anchor_node_ids[node_type].to(device).long() + filtered_embeddings[node_type] = final_embeddings[node_type][anchors] + return filtered_embeddings + + return final_embeddings + def _lookup_embeddings_for_single_node_type( self, node_type: str, ids: torch.Tensor ) -> torch.Tensor: diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index b201876d..82a75119 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -326,6 +326,198 @@ def test_dmp_multiprocess(self): nprocs=world_size, ) + def test_compare_bipartite_with_math(self): + """Test that bipartite implementation matches the mathematical formulation of LightGCN. + + This test converts the homogeneous 4-node graph into a bipartite graph to verify + that the bipartite implementation produces identical results. The same initial + embeddings and edge structure are used, just split by node type. + + Graph structure: + - Homogeneous: nodes [0, 1, 2, 3] with edges 0->2, 0->3, 1->3, 2->0, 3->0, 3->1 + - Bipartite: users [0, 1] and items [0, 1] with equivalent cross-type edges + + Expected behavior: Bipartite embeddings should match the homogeneous embeddings + for the corresponding nodes (user 0 = node 0, user 1 = node 1, etc.) + """ + # Create bipartite graph + num_users = 2 + num_items = 2 + + node_type_to_num_nodes = { + NodeType("user"): num_users, + NodeType("item"): num_items, + } + + model = self._create_lightgcn_model(node_type_to_num_nodes) + + # Use same embeddings as homogeneous test, split by node type + user_embeddings = torch.tensor( + [ + [0.2, 0.5, 0.1, 0.4], # User 0 (was Node 0) + [0.6, 0.1, 0.2, 0.5], # User 1 (was Node 1) + ], + dtype=torch.float32, + ) + + item_embeddings = torch.tensor( + [ + [0.9, 0.4, 0.1, 0.4], # Item 0 (was Node 2) + [0.3, 0.8, 0.3, 0.6], # Item 1 (was Node 3) + ], + dtype=torch.float32, + ) + + with torch.no_grad(): + user_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_user" + ] + user_table.weight[:] = user_embeddings + + item_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_item" + ] + item_table.weight[:] = item_embeddings + + data = HeteroData() + + # User nodes (local IDs 0, 1 map to global IDs 0, 1) + data["user"].node = torch.tensor([0, 1], dtype=torch.long) + data["user"].num_nodes = num_users + + # Item nodes (local IDs 0, 1 map to global IDs 0, 1) + data["item"].node = torch.tensor([0, 1], dtype=torch.long) + data["item"].num_nodes = num_items + + # User to item edges (converting from original homogeneous edges) + # Original: 0->2, 0->3, 1->3 becomes user 0->item 0, user 0->item 1, user 1->item 1 + data["user", "to", "item"].edge_index = torch.tensor( + [[0, 0, 1], [0, 1, 1]], dtype=torch.long + ) + + # Item to user edges (reverse direction) + # Original: 2->0, 3->0, 3->1 becomes item 0->user 0, item 1->user 0, item 1->user 1 + data["item", "to", "user"].edge_index = torch.tensor( + [[0, 1, 1], [0, 0, 1]], dtype=torch.long + ) + + # Forward pass - will return both user and item embeddings + output = model( + data, + self.device, + ) + + expected_user_embeddings = torch.tensor( + [ + [0.4495, 0.5311, 0.1555, 0.4865], # User 0 + [0.3943, 0.2975, 0.1825, 0.4386], # User 1 + ], + dtype=torch.float32, + ) + + expected_item_embeddings = torch.tensor( + [ + [0.5325, 0.4121, 0.1089, 0.3650], # Item 0 + [0.4558, 0.6207, 0.2506, 0.5817], # Item 1 + ], + dtype=torch.float32, + ) + + # Check that bipartite output matches expected + self.assertTrue( + torch.allclose(output[NodeType("user")], expected_user_embeddings, atol=1e-4, rtol=1e-4) + ) + self.assertTrue( + torch.allclose(output[NodeType("item")], expected_item_embeddings, atol=1e-4, rtol=1e-4) + ) + + def test_bipartite_with_anchor_nodes(self): + """Test anchor node selection in bipartite/heterogeneous graphs.""" + # Create bipartite graph + num_users = 2 + num_items = 2 + + node_type_to_num_nodes = { + NodeType("user"): num_users, + NodeType("item"): num_items, + } + + model = self._create_lightgcn_model(node_type_to_num_nodes) + + # Set embeddings + user_embeddings = torch.tensor( + [ + [0.2, 0.5, 0.1, 0.4], # User 0 + [0.6, 0.1, 0.2, 0.5], # User 1 + ], + dtype=torch.float32, + ) + + item_embeddings = torch.tensor( + [ + [0.9, 0.4, 0.1, 0.4], # Item 0 + [0.3, 0.8, 0.3, 0.6], # Item 1 + ], + dtype=torch.float32, + ) + + with torch.no_grad(): + user_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_user" + ] + user_table.weight[:] = user_embeddings + + item_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_item" + ] + item_table.weight[:] = item_embeddings + + data = HeteroData() + + # Set up nodes + data["user"].node = torch.tensor([0, 1], dtype=torch.long) + data["user"].num_nodes = num_users + data["item"].node = torch.tensor([0, 1], dtype=torch.long) + data["item"].num_nodes = num_items + + # Set up edges + data["user", "to", "item"].edge_index = torch.tensor( + [[0, 0, 1], [0, 1, 1]], dtype=torch.long + ) + data["item", "to", "user"].edge_index = torch.tensor( + [[0, 1, 1], [0, 0, 1]], dtype=torch.long + ) + + # First get full output to compare against (will return all node types) + full_output = model( + data, + self.device, + ) + + # Test with anchor nodes - select specific nodes from specific types + # By only including "user" in anchor_node_ids, we'll only get user embeddings back + anchor_node_ids = { + NodeType("user"): torch.tensor([0], dtype=torch.long), # Select user 0 + } + + output_with_anchors = model( + data, + self.device, + anchor_node_ids=anchor_node_ids, + ) + + # Check that only user embeddings are returned + self.assertEqual(output_with_anchors.keys(), set([NodeType("user")])) + + # Check values - should match the corresponding rows from full output + self.assertTrue( + torch.allclose( + output_with_anchors[NodeType("user")], + full_output[NodeType("user")][0:1], # User 0 + atol=1e-4, + rtol=1e-4, + ) + ) def _run_dmp_multiprocess_test( rank: int,