Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions python/gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from gigl.types.graph import (
DEFAULT_HOMOGENEOUS_EDGE_TYPE,
DEFAULT_HOMOGENEOUS_NODE_TYPE,
reverse_edge_type,
select_label_edge_types,
)
from gigl.utils.data_splitters import get_labels_for_anchor_nodes
Expand Down Expand Up @@ -243,15 +242,18 @@ def __init__(
)
self._is_input_heterogeneous = True
anchor_node_type, anchor_node_ids = input_nodes
# TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if
# this assumption is no longer valid and/or is too opinionated
assert (
supervision_edge_type[0] == anchor_node_type
), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \
got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}"
supervision_node_type = supervision_edge_type[2]
if dataset.edge_dir == "in":
supervision_edge_type = reverse_edge_type(supervision_edge_type)
supervision_node_type = supervision_edge_type[0]
if supervision_edge_type[2] != anchor_node_type:
raise ValueError(
f"Found anchor node type {anchor_node_type} but expected {supervision_edge_type[2]}"
)
else:
supervision_node_type = supervision_edge_type[2]
if supervision_edge_type[0] != anchor_node_type:
raise ValueError(
f"Found anchor node type {anchor_node_type} but expected {supervision_edge_type[0]}"
)

elif isinstance(input_nodes, torch.Tensor):
if supervision_edge_type is not None:
Expand Down
17 changes: 12 additions & 5 deletions python/gigl/distributed/dist_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,12 +1053,19 @@ def _partition_label_edge_index(
"""
start_time = time.time()

if edge_type.src_node_type not in node_partition_book:
if self._should_assign_edges_by_src_node:
target_node_type = edge_type.src_node_type
target_edge_src_dst_index = 0
else:
target_node_type = edge_type.dst_node_type
target_edge_src_dst_index = 1

if target_node_type not in node_partition_book:
raise ValueError(
f"Edge type {edge_type} source node type {edge_type.src_node_type} not found in the node partition book node keys: {node_partition_book.keys()}"
f"Edge type {edge_type} source node type {target_node_type} not found in the node partition book node keys: {node_partition_book.keys()}"
)

target_node_partition_book = node_partition_book[edge_type.src_node_type]
target_node_partition_book = node_partition_book[target_node_type]
if is_positive:
assert (
self._positive_label_edge_index is not None
Expand All @@ -1084,9 +1091,9 @@ def _label_pfn(source_node_ids, _):
),
# 'partition_fn' takes 'val_indices' as input, uses it as keys for partition,
# and returns the partition index.
rank_indices=label_edge_index[0],
rank_indices=label_edge_index[target_edge_src_dst_index],
partition_function=_label_pfn,
total_val_size=label_edge_index[0].size(0),
total_val_size=label_edge_index[target_edge_src_dst_index].size(0),
generate_pb=False,
)

Expand Down
68 changes: 4 additions & 64 deletions python/gigl/types/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,30 +109,6 @@ class FeatureInfo:
dtype: torch.dtype


def _get_label_edges(
labeled_edge_index: torch.Tensor,
edge_dir: Literal["in", "out"],
labeled_edge_type: EdgeType,
) -> tuple[EdgeType, torch.Tensor]:
"""
If edge direction is `out`, return the provided edge type and edge index. Otherwise, reverse the edge type and flip the edge index rows
so that the labeled edge index may be the same direction as the rest of the edges.
Args:
labeled_edge_index (torch.Tensor): Edge index containing positive or negative labels for supervision
edge_dir (Literal["in", "out"]): Direction of edges in the graph
labeled_edge_type (EdgeType): Edge type used for the positive or negative labeled edges
Returns:
EdgeType: Labeled edge type, which has been reversed if edge_dir = "in"
torch.Tensor: Labeled edge index, which has its rows flipped if edge_dir = "in"
"""
if edge_dir == "in":
rev_edge_type = reverse_edge_type(labeled_edge_type)
rev_labeled_edge_index = labeled_edge_index.flip(0)
return rev_edge_type, rev_labeled_edge_index
else:
return labeled_edge_type, labeled_edge_index


# This dataclass should not be frozen, as we are expected to delete its members once they have been registered inside of the partitioner
# in order to save memory.
@dataclass
Expand All @@ -153,8 +129,6 @@ class LoadedGraphTensors:
def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None:
"""
Convert positive and negative labels to edges. Converts this object in-place to a "heterogeneous" representation.
If the edge direction is "in", we must reverse the supervision edge type. This is because we assume that provided labels are directed
outwards in form (`anchor_node_type`, `relation`, `supervision_node_type`), and all edges in the edge index must be in the same direction.

This function requires the following conditions and will throw if they are not met:
1. The positive_label is not None
Expand Down Expand Up @@ -184,12 +158,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None:
"Detected multiple edge types in provided edge_index, but no edge types specified for provided positive label."
)
positive_label_edge_type = message_passing_to_positive_label(main_edge_type)
labeled_edge_type, edge_index = _get_label_edges(
labeled_edge_index=self.positive_label,
edge_dir=edge_dir,
labeled_edge_type=positive_label_edge_type,
)
edge_index_with_labels[labeled_edge_type] = edge_index
edge_index_with_labels[positive_label_edge_type] = self.positive_label
logger.info(
f"Treating homogeneous positive labels as edge type {positive_label_edge_type}."
)
Expand All @@ -202,12 +171,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None:
positive_label_edge_type = message_passing_to_positive_label(
positive_label_type
)
labeled_edge_type, edge_index = _get_label_edges(
labeled_edge_index=positive_label_tensor,
edge_dir=edge_dir,
labeled_edge_type=positive_label_edge_type,
)
edge_index_with_labels[labeled_edge_type] = edge_index
edge_index_with_labels[positive_label_edge_type] = positive_label_tensor
logger.info(
f"Treating heterogeneous positive labels {positive_label_type} as edge type {positive_label_edge_type}."
)
Expand All @@ -218,12 +182,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None:
"Detected multiple edge types in provided edge_index, but no edge types specified for provided negative label."
)
negative_label_edge_type = message_passing_to_negative_label(main_edge_type)
labeled_edge_type, edge_index = _get_label_edges(
labeled_edge_index=self.negative_label,
edge_dir=edge_dir,
labeled_edge_type=negative_label_edge_type,
)
edge_index_with_labels[labeled_edge_type] = edge_index
edge_index_with_labels[negative_label_edge_type] = self.negative_label
logger.info(
f"Treating homogeneous negative labels as edge type {negative_label_edge_type}."
)
Expand All @@ -235,12 +194,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None:
negative_label_edge_type = message_passing_to_negative_label(
negative_label_type
)
labeled_edge_type, edge_index = _get_label_edges(
labeled_edge_index=negative_label_tensor,
edge_dir=edge_dir,
labeled_edge_type=negative_label_edge_type,
)
edge_index_with_labels[labeled_edge_type] = edge_index
edge_index_with_labels[negative_label_edge_type] = negative_label_tensor
logger.info(
f"Treating heterogeneous negative labels {negative_label_type} as edge type {negative_label_edge_type}."
)
Expand Down Expand Up @@ -490,17 +444,3 @@ def to_homogeneous(
n = next(iter(x.values()))
return n
return x


def reverse_edge_type(edge_type: _EdgeType) -> _EdgeType:
"""
Reverses the source and destination node types of the provided edge type
Args:
edge_type (EdgeType): The target edge to have its source and destinated node types reversed
Returns:
EdgeType: The reversed edge type
"""
if isinstance(edge_type, EdgeType):
return EdgeType(edge_type[2], edge_type[1], edge_type[0])
else:
return (edge_type[2], edge_type[1], edge_type[0])
11 changes: 1 addition & 10 deletions python/gigl/utils/data_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
DEFAULT_HOMOGENEOUS_NODE_TYPE,
message_passing_to_negative_label,
message_passing_to_positive_label,
reverse_edge_type,
)

logger = Logger()
Expand Down Expand Up @@ -242,15 +241,7 @@ def __init__(
message_passing_to_negative_label(supervision_edge_type)
for supervision_edge_type in supervision_edge_types
]
# If the edge direction is "in", we must reverse the labeled edge type, since separately provided labels are expected to be initially outgoing, and all edges
# in the graph must have the same edge direction.
if sampling_direction == "in":
self._labeled_edge_types = [
reverse_edge_type(labeled_edge_type)
for labeled_edge_type in labeled_edge_types
]
else:
self._labeled_edge_types = labeled_edge_types
self._labeled_edge_types = labeled_edge_types
else:
self._labeled_edge_types = supervision_edge_types

Expand Down
73 changes: 54 additions & 19 deletions python/tests/unit/distributed/distributed_neighborloader_test.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, the DBLP dataset contains both (paper, to, author), and (author, to, paper) right 1?

I have some concerns about some case with only (paper, to, author) and a supervision edge type (author, to, paper) having previously worked, since we reversed the edge type, but now it would fail.

Does this seem correct to you? And if so could we add some test for a supervision edge type whose either doesn't have a corresponding message passing edge type, or whose message passing edge type isn't reciprocal?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can aim to increase the test coverage here, thanks!

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from collections.abc import Mapping
from typing import Optional, Union
from typing import Literal, Optional, Union

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -334,8 +334,13 @@ def _run_dblp_supervised(
len(supervision_edge_types) == 1
), "TODO (mkolodner-sc): Support multiple supervision edge types in dataloading"
supervision_edge_type = supervision_edge_types[0]
anchor_node_type = supervision_edge_type.src_node_type
supervision_node_type = supervision_edge_type.dst_node_type
sampling_edge_direction = dataset.edge_dir
if sampling_edge_direction == "in":
anchor_node_type = supervision_edge_type.dst_node_type
supervision_node_type = supervision_edge_type.src_node_type
else:
anchor_node_type = supervision_edge_type.src_node_type
supervision_node_type = supervision_edge_type.dst_node_type
assert isinstance(dataset.train_node_ids, dict)
assert isinstance(dataset.graph, dict)
fanout = [2, 2]
Expand Down Expand Up @@ -374,20 +379,28 @@ def _run_toy_heterogeneous_ablp(
supervision_edge_types: list[EdgeType],
fanout: Union[list[int], dict[EdgeType, list[int]]],
):
anchor_node_type = NodeType("user")
supervision_node_type = NodeType("story")
assert (
len(supervision_edge_types) == 1
), "TODO (mkolodner-sc): Support multiple supervision edge types in dataloading"
supervision_edge_type = supervision_edge_types[0]
labeled_edge_type = message_passing_to_positive_label(supervision_edge_type)
sampling_edge_direction = dataset.edge_dir

assert isinstance(dataset.train_node_ids, dict)
assert isinstance(dataset.graph, dict)
labeled_edge_type = EdgeType(
supervision_node_type, Relation("to_gigl_positive"), anchor_node_type
)
all_positive_supervision_nodes, all_anchor_nodes, _, _ = dataset.graph[
labeled_edge_type
].topo.to_coo()
if sampling_edge_direction == "in":
anchor_node_type = supervision_edge_type.dst_node_type
supervision_node_type = supervision_edge_type.src_node_type
all_positive_supervision_nodes, all_anchor_nodes, _, _ = dataset.graph[
labeled_edge_type
].topo.to_coo()
else:
anchor_node_type = supervision_edge_type.src_node_type
supervision_node_type = supervision_edge_type.dst_node_type
all_anchor_nodes, all_positive_supervision_nodes, _, _ = dataset.graph[
labeled_edge_type
].topo.to_coo()

loader = DistABLPLoader(
dataset=dataset,
num_neighbors=fanout,
Expand Down Expand Up @@ -818,9 +831,21 @@ def test_multiple_neighbor_loader(self):
args=(dataset, self._context, expected_data_count),
)

@parameterized.expand(
[
param(
"Inward edge direction",
sampling_edge_direction="in",
),
param(
"Outward edge direction",
sampling_edge_direction="out",
),
]
)
# TODO: (mkolodner-sc) - Figure out why this test is failing on Google Cloud Build
@unittest.skip("Failing on Google Cloud Build - skiping for now")
def test_dblp_supervised(self):
def test_dblp_supervised(self, _, sampling_edge_direction: Literal["in", "out"]):
dblp_supervised_info = get_mocked_dataset_artifact_metadata()[
DBLP_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO.name
]
Expand All @@ -842,15 +867,15 @@ def test_dblp_supervised(self):
)

splitter = HashedNodeAnchorLinkSplitter(
sampling_direction="in",
sampling_direction=sampling_edge_direction,
supervision_edge_types=supervision_edge_types,
should_convert_labels_to_edges=True,
)

dataset = build_dataset(
serialized_graph_metadata=serialized_graph_metadata,
distributed_context=self._context,
sample_edge_direction="in",
sample_edge_direction=sampling_edge_direction,
_ssl_positive_label_percentage=0.1,
splitter=splitter,
)
Expand All @@ -863,17 +888,19 @@ def test_dblp_supervised(self):
@parameterized.expand(
[
param(
"Tensor-based partitioning, list fanout",
"Tensor-based partitioning, list fanout, inward edge direction",
partitioner_class=DistPartitioner,
fanout=[2, 2],
sampling_edge_direction="in",
),
param(
"Range-based partitioning, list fanout",
"Range-based partitioning, list fanout, inward edge direction",
partitioner_class=DistRangePartitioner,
fanout=[2, 2],
sampling_edge_direction="in",
),
param(
"Range-based partitioning, dict fanout",
"Range-based partitioning, dict fanout, inward edge direction",
partitioner_class=DistRangePartitioner,
fanout={
EdgeType(NodeType("user"), Relation("to"), NodeType("story")): [
Expand All @@ -885,6 +912,13 @@ def test_dblp_supervised(self):
2,
],
},
sampling_edge_direction="in",
),
param(
"Range-based partitioning, list fanout, outward edge direction",
partitioner_class=DistRangePartitioner,
fanout=[2, 2],
sampling_edge_direction="out",
),
]
)
Expand All @@ -893,6 +927,7 @@ def test_toy_heterogeneous_ablp(
_,
partitioner_class: type[DistPartitioner],
fanout: Union[list[int], dict[EdgeType, list[int]]],
sampling_edge_direction: Literal["in", "out"],
):
toy_heterogeneous_supervised_info = get_mocked_dataset_artifact_metadata()[
HETEROGENEOUS_TOY_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO.name
Expand All @@ -915,15 +950,15 @@ def test_toy_heterogeneous_ablp(
)

splitter = HashedNodeAnchorLinkSplitter(
sampling_direction="in",
sampling_direction=sampling_edge_direction,
supervision_edge_types=supervision_edge_types,
should_convert_labels_to_edges=True,
)

dataset = build_dataset(
serialized_graph_metadata=serialized_graph_metadata,
distributed_context=self._context,
sample_edge_direction="in",
sample_edge_direction=sampling_edge_direction,
_ssl_positive_label_percentage=0.1,
splitter=splitter,
partitioner_class=partitioner_class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def test_partitioning_correctness(
rank=rank,
is_heterogeneous=is_heterogeneous,
output_node_partition_book=partition_output.node_partition_book,
should_assign_edges_by_src_node=True,
should_assign_edges_by_src_node=should_assign_edges_by_src_node,
output_labeled_edge_index=partition_output.partitioned_positive_labels,
expected_edge_types=MOCKED_HETEROGENEOUS_EDGE_TYPES,
expected_pb_dtype=expected_pb_dtype,
Expand All @@ -811,7 +811,7 @@ def test_partitioning_correctness(
rank=rank,
is_heterogeneous=is_heterogeneous,
output_node_partition_book=partition_output.node_partition_book,
should_assign_edges_by_src_node=True,
should_assign_edges_by_src_node=should_assign_edges_by_src_node,
output_labeled_edge_index=partition_output.partitioned_negative_labels,
expected_edge_types=MOCKED_HETEROGENEOUS_EDGE_TYPES,
expected_pb_dtype=expected_pb_dtype,
Expand Down