From 0f06a240678b725edbf776544f718000a76e1cf6 Mon Sep 17 00:00:00 2001 From: nshah Date: Sat, 27 Sep 2025 01:38:09 +0000 Subject: [PATCH 1/3] add evaluation utils --- .../lib/evaluation.py | 208 ++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py b/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py new file mode 100644 index 000000000..cdecc2ab6 --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py @@ -0,0 +1,208 @@ +from collections import defaultdict +from typing import Dict, Iterator, List, Tuple + +import torch +import torch.distributed as dist +from torchrec.distributed import TrainPipelineSparseDist + +from gigl.common.logger import Logger +from gigl.experimental.knowledge_graph_embedding.lib.config import EvaluationPhaseConfig +from gigl.experimental.knowledge_graph_embedding.lib.model.loss_utils import ( + average_pos_neg_scores, + hit_rate_at_k, + mean_reciprocal_rank, +) +from gigl.experimental.knowledge_graph_embedding.lib.model.types import ModelPhase +from gigl.src.common.types.graph_data import CondensedEdgeType +from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper + +logger = Logger() + +# Type aliases for better readability +EdgeTypeMetrics = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +EvaluationResult = Tuple[torch.Tensor, Dict[CondensedEdgeType, EdgeTypeMetrics]] +MetricAccumulators = Tuple[List[float], Dict, Dict, Dict, Dict] + + +def _compute_mean(values: List[float]) -> float: + """Compute mean of values, returning 0.0 for empty lists.""" + return sum(values) / len(values) if values else 0.0 + + +def _accumulate_metrics_for_edge_type( + condensed_edge_type: CondensedEdgeType, + logits: torch.Tensor, + labels: torch.Tensor, + condensed_edge_types: torch.Tensor, + evaluation_config: EvaluationPhaseConfig, + accumulators: MetricAccumulators, +) -> None: + """Accumulate evaluation metrics for a specific edge type.""" + losses, pos_logits, neg_logits, mrrs, hit_rates = accumulators + + mask = condensed_edge_types == condensed_edge_type + if not mask.any(): + return + + # Compute metrics for this edge type + avg_pos_score, avg_neg_score = average_pos_neg_scores(logits[mask], labels[mask]) + mrr = mean_reciprocal_rank(scores=logits[mask], labels=labels[mask]) + hr_at_k = hit_rate_at_k( + scores=logits[mask], + labels=labels[mask], + ks=evaluation_config.hit_rates_at_k, + ) + + # Store metrics for this edge type + pos_logits[condensed_edge_type].append(avg_pos_score.item()) + neg_logits[condensed_edge_type].append(avg_neg_score.item()) + mrrs[condensed_edge_type].append(mrr.item()) + hit_rates[condensed_edge_type].append(hr_at_k) + + +def _aggregate_metrics( + accumulators: MetricAccumulators, + unique_edge_types: List[CondensedEdgeType], + device: torch.device, +) -> List[torch.Tensor]: + """Aggregate accumulated metrics into tensors for distributed reduction.""" + losses, pos_logits, neg_logits, mrrs, hit_rates = accumulators + + # Calculate per-edge-type averages + avg_loss = _compute_mean(losses) + pos_logits_by_cet = { + cet: _compute_mean(pos_logits[cet]) for cet in unique_edge_types + } + neg_logits_by_cet = { + cet: _compute_mean(neg_logits[cet]) for cet in unique_edge_types + } + mrrs_by_cet = {cet: _compute_mean(mrrs[cet]) for cet in unique_edge_types} + hit_rates_by_cet = { + cet: torch.stack(hit_rates[cet]).mean(dim=0) for cet in unique_edge_types + } + + # Convert to tensors and move to device for distributed reduction + metrics = [ + torch.tensor(avg_loss, device=device), # overall loss + torch.tensor( + [pos_logits_by_cet[cet] for cet in unique_edge_types], device=device + ), # positive logits by edge type + torch.tensor( + [neg_logits_by_cet[cet] for cet in unique_edge_types], device=device + ), # negative logits by edge type + torch.tensor( + [mrrs_by_cet[cet] for cet in unique_edge_types], device=device + ), # mean reciprocal ranks by edge type + torch.stack([hit_rates_by_cet[cet] for cet in unique_edge_types], dim=0).to( + device + ), # hit rates by edge type + ] + + return metrics + + +def _format_output_metrics( + metrics: List[torch.Tensor], unique_edge_types: List[CondensedEdgeType] +) -> EvaluationResult: + """Format metrics into the expected return structure.""" + metrics_by_edge_type = {} + for i, edge_type in enumerate(unique_edge_types): + metrics_by_edge_type[edge_type] = ( + metrics[1][i], # avg_pos_score + metrics[2][i], # avg_neg_score + metrics[3][i], # avg_mrr + metrics[4][i], # avg_hit_rate + ) + + return metrics[0], metrics_by_edge_type + + +def evaluate( + pipeline: TrainPipelineSparseDist, + val_iter: Iterator, + phase: ModelPhase, + evaluation_phase_config: EvaluationPhaseConfig, + graph_metadata: GraphMetadataPbWrapper, +) -> EvaluationResult: + """Evaluate a knowledge graph embedding model on validation data. + + This function runs the model in evaluation mode, processes validation batches, + computes various metrics (loss, MRR, hit rates) per edge type, and aggregates + results across distributed workers. + + Args: + pipeline: Distributed training pipeline containing the model + val_iter: Iterator over validation data batches + phase: Model phase to set during evaluation (e.g., VALIDATION, TEST) + evaluation_phase_config: Configuration specifying evaluation parameters + like hit_rates_at_k values + graph_metadata: Metadata containing information about condensed edge types + in the graph + + Returns: + A tuple containing: + - overall_loss: Average loss across all batches and edge types + - metrics_by_edge_type: Dictionary mapping each CondensedEdgeType to + a tuple of (avg_pos_score, avg_neg_score, avg_mrr, avg_hit_rates) + where avg_hit_rates is a tensor with hit rates at different k values + + Note: + This function temporarily switches the model to evaluation mode and the + specified phase, then restores the original state. Results are averaged + across distributed workers using all_reduce operations. + """ + # Set model to evaluation mode and save original state + pipeline._model.eval() + original_phase = pipeline._model.module.phase + pipeline._model.module.set_phase(phase) + device = pipeline._device + + # Initialize metric accumulators + accumulators = ( + [], # losses + defaultdict(list), # pos_logits by edge type + defaultdict(list), # neg_logits by edge type + defaultdict(list), # mrrs by edge type + defaultdict(list), # hit_rates by edge type + ) + losses, _, _, _, _ = accumulators + + unique_edge_types = sorted(graph_metadata.condensed_edge_types) + step_count = 0 + + # Process validation batches + while True: + try: + batch_loss, logits, labels, edge_types = pipeline.progress(val_iter) + losses.append(batch_loss.item()) + + # Accumulate metrics for each edge type in this batch + for edge_type in unique_edge_types: + _accumulate_metrics_for_edge_type( + edge_type, + logits, + labels, + edge_types, + evaluation_phase_config, + accumulators, + ) + + step_count += 1 + except StopIteration: + break + + logger.info(f"Completed {phase} evaluation over {step_count} steps.") + + # Aggregate metrics and prepare for distributed reduction + aggregated_metrics = _aggregate_metrics(accumulators, unique_edge_types, device) + + # Perform distributed reduction to average across all workers + for metric in aggregated_metrics: + dist.all_reduce(metric, op=dist.ReduceOp.AVG) + + # Format results and restore original model state + result = _format_output_metrics(aggregated_metrics, unique_edge_types) + pipeline._model.module.set_phase(original_phase) + pipeline._model.train() + + return result From faf4c618cb17e1dc98049412a3738c2991e2a2f3 Mon Sep 17 00:00:00 2001 From: nshah Date: Thu, 2 Oct 2025 02:03:13 +0000 Subject: [PATCH 2/3] comments --- .../lib/evaluation.py | 331 +++++++++++------- 1 file changed, 207 insertions(+), 124 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py b/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py index cdecc2ab6..58685849a 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py @@ -1,4 +1,4 @@ -from collections import defaultdict +from dataclasses import dataclass from typing import Dict, Iterator, List, Tuple import torch @@ -18,103 +18,194 @@ logger = Logger() + +@dataclass(frozen=True) +class EdgeTypeMetrics: + """Container for evaluation metrics for a specific edge type. + + Attributes: + avg_pos_score: Average positive score for this edge type + avg_neg_score: Average negative score for this edge type + avg_mrr: Average Mean Reciprocal Rank for this edge type + avg_hit_rates: Average hit rates at different k values for this edge type + """ + + avg_pos_score: torch.Tensor + avg_neg_score: torch.Tensor + avg_mrr: torch.Tensor + avg_hit_rates: torch.Tensor + + # Type aliases for better readability -EdgeTypeMetrics = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] EvaluationResult = Tuple[torch.Tensor, Dict[CondensedEdgeType, EdgeTypeMetrics]] -MetricAccumulators = Tuple[List[float], Dict, Dict, Dict, Dict] - - -def _compute_mean(values: List[float]) -> float: - """Compute mean of values, returning 0.0 for empty lists.""" - return sum(values) / len(values) if values else 0.0 - - -def _accumulate_metrics_for_edge_type( - condensed_edge_type: CondensedEdgeType, - logits: torch.Tensor, - labels: torch.Tensor, - condensed_edge_types: torch.Tensor, - evaluation_config: EvaluationPhaseConfig, - accumulators: MetricAccumulators, -) -> None: - """Accumulate evaluation metrics for a specific edge type.""" - losses, pos_logits, neg_logits, mrrs, hit_rates = accumulators - - mask = condensed_edge_types == condensed_edge_type - if not mask.any(): - return - - # Compute metrics for this edge type - avg_pos_score, avg_neg_score = average_pos_neg_scores(logits[mask], labels[mask]) - mrr = mean_reciprocal_rank(scores=logits[mask], labels=labels[mask]) - hr_at_k = hit_rate_at_k( - scores=logits[mask], - labels=labels[mask], - ks=evaluation_config.hit_rates_at_k, - ) - # Store metrics for this edge type - pos_logits[condensed_edge_type].append(avg_pos_score.item()) - neg_logits[condensed_edge_type].append(avg_neg_score.item()) - mrrs[condensed_edge_type].append(mrr.item()) - hit_rates[condensed_edge_type].append(hr_at_k) - - -def _aggregate_metrics( - accumulators: MetricAccumulators, - unique_edge_types: List[CondensedEdgeType], - device: torch.device, -) -> List[torch.Tensor]: - """Aggregate accumulated metrics into tensors for distributed reduction.""" - losses, pos_logits, neg_logits, mrrs, hit_rates = accumulators - - # Calculate per-edge-type averages - avg_loss = _compute_mean(losses) - pos_logits_by_cet = { - cet: _compute_mean(pos_logits[cet]) for cet in unique_edge_types - } - neg_logits_by_cet = { - cet: _compute_mean(neg_logits[cet]) for cet in unique_edge_types - } - mrrs_by_cet = {cet: _compute_mean(mrrs[cet]) for cet in unique_edge_types} - hit_rates_by_cet = { - cet: torch.stack(hit_rates[cet]).mean(dim=0) for cet in unique_edge_types - } - - # Convert to tensors and move to device for distributed reduction - metrics = [ - torch.tensor(avg_loss, device=device), # overall loss - torch.tensor( - [pos_logits_by_cet[cet] for cet in unique_edge_types], device=device - ), # positive logits by edge type - torch.tensor( - [neg_logits_by_cet[cet] for cet in unique_edge_types], device=device - ), # negative logits by edge type - torch.tensor( - [mrrs_by_cet[cet] for cet in unique_edge_types], device=device - ), # mean reciprocal ranks by edge type - torch.stack([hit_rates_by_cet[cet] for cet in unique_edge_types], dim=0).to( - device - ), # hit rates by edge type - ] - - return metrics - - -def _format_output_metrics( - metrics: List[torch.Tensor], unique_edge_types: List[CondensedEdgeType] -) -> EvaluationResult: - """Format metrics into the expected return structure.""" - metrics_by_edge_type = {} - for i, edge_type in enumerate(unique_edge_types): - metrics_by_edge_type[edge_type] = ( - metrics[1][i], # avg_pos_score - metrics[2][i], # avg_neg_score - metrics[3][i], # avg_mrr - metrics[4][i], # avg_hit_rate - ) - return metrics[0], metrics_by_edge_type +class EvaluationMetricsAccumulator: + """Maintains tensors of evaluation metrics for all edge types. + + This class uses tensors throughout to efficiently accumulate and reduce metrics + across batches and distributed workers. Each edge type corresponds to an index + in the metric tensors. + + Attributes: + total_loss: Scalar tensor tracking total loss across all batches + total_batches: Scalar tensor tracking total number of batches + sample_counts: Tensor of sample counts per edge type [num_edge_types] + pos_scores: Tensor of accumulated positive scores per edge type [num_edge_types] + neg_scores: Tensor of accumulated negative scores per edge type [num_edge_types] + mrrs: Tensor of accumulated MRRs per edge type [num_edge_types] + hit_rates: Tensor of accumulated hit rates per edge type [num_edge_types, num_k_values] + edge_type_to_idx: Mapping from CondensedEdgeType to tensor index + evaluation_config: Configuration containing evaluation parameters + """ + + def __init__( + self, + unique_edge_types: List[CondensedEdgeType], + evaluation_config: EvaluationPhaseConfig, + device: torch.device, + ): + """Initialize the accumulator with zero tensors. + + Args: + unique_edge_types: Sorted list of unique edge types in the graph + evaluation_config: Configuration containing hit rate k values and other evaluation parameters + device: Device to place tensors on + """ + self.evaluation_config = evaluation_config + self.edge_type_to_idx = {et: i for i, et in enumerate(unique_edge_types)} + + num_edge_types = len(unique_edge_types) + num_k_values = len(evaluation_config.hit_rates_at_k) + + self.total_loss = torch.tensor(0.0, device=device) + self.total_batches = torch.tensor(0, dtype=torch.long, device=device) + self.sample_counts = torch.zeros( + num_edge_types, dtype=torch.long, device=device + ) + self.pos_scores = torch.zeros(num_edge_types, device=device) + self.neg_scores = torch.zeros(num_edge_types, device=device) + self.mrrs = torch.zeros(num_edge_types, device=device) + self.hit_rates = torch.zeros(num_edge_types, num_k_values, device=device) + + def accumulate( + self, + batch_loss: torch.Tensor, + logits: torch.Tensor, + labels: torch.Tensor, + condensed_edge_types: torch.Tensor, + ) -> None: + """Accumulate metrics from a batch for all edge types. + + Args: + batch_loss: Loss value for this batch + logits: Model logits for this batch + labels: Ground truth labels for this batch + condensed_edge_types: Edge type indices for each sample in the batch + """ + # Accumulate batch-level metrics + self.total_loss += batch_loss + self.total_batches += 1 + + # Process each edge type + for edge_type, idx in self.edge_type_to_idx.items(): + edge_type_mask = condensed_edge_types == edge_type + if not edge_type_mask.any(): + continue + + # Compute metrics for this edge type + avg_pos_score, avg_neg_score = average_pos_neg_scores( + logits[edge_type_mask], labels[edge_type_mask] + ) + mrr = mean_reciprocal_rank( + scores=logits[edge_type_mask], labels=labels[edge_type_mask] + ) + hr_at_k = hit_rate_at_k( + scores=logits[edge_type_mask], + labels=labels[edge_type_mask], + ks=self.evaluation_config.hit_rates_at_k, + ) + + # Accumulate weighted totals for this edge type + edge_type_sample_count_in_batch = edge_type_mask.sum() + self.sample_counts[idx] += edge_type_sample_count_in_batch + self.pos_scores[idx] += avg_pos_score * edge_type_sample_count_in_batch + self.neg_scores[idx] += avg_neg_score * edge_type_sample_count_in_batch + self.mrrs[idx] += mrr * edge_type_sample_count_in_batch + self.hit_rates[idx] += hr_at_k * edge_type_sample_count_in_batch + + def reduce_all(self) -> None: + """Perform distributed reduction (sum) on all metric tensors.""" + dist.all_reduce(self.total_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(self.total_batches, op=dist.ReduceOp.SUM) + dist.all_reduce(self.sample_counts, op=dist.ReduceOp.SUM) + dist.all_reduce(self.pos_scores, op=dist.ReduceOp.SUM) + dist.all_reduce(self.neg_scores, op=dist.ReduceOp.SUM) + dist.all_reduce(self.mrrs, op=dist.ReduceOp.SUM) + dist.all_reduce(self.hit_rates, op=dist.ReduceOp.SUM) + + def compute_final_metrics( + self, + unique_edge_types: List[CondensedEdgeType], + ) -> EvaluationResult: + """Compute final averaged metrics and format as return structure. + + Args: + unique_edge_types: Sorted list of unique edge types in the graph + + Returns: + Tuple of (average_loss, metrics_by_edge_type) + """ + # Compute average loss + avg_loss = self.total_loss / self.total_batches + + # Create mask for valid edge types (those with samples) + mask = self.sample_counts > 0 + + # Initialize metric tensors with NaN for missing edge types + avg_pos_scores = torch.full_like(self.pos_scores, float("nan")) + avg_neg_scores = torch.full_like(self.neg_scores, float("nan")) + avg_mrrs = torch.full_like(self.mrrs, float("nan")) + avg_hit_rates = torch.full_like(self.hit_rates, float("nan")) + + # Compute averages for edge types with samples + if mask.any(): + avg_pos_scores[mask] = ( + self.pos_scores[mask] / self.sample_counts[mask].float() + ) + avg_neg_scores[mask] = ( + self.neg_scores[mask] / self.sample_counts[mask].float() + ) + avg_mrrs[mask] = self.mrrs[mask] / self.sample_counts[mask].float() + # Broadcast sample counts for hit rates division + avg_hit_rates[mask] = self.hit_rates[mask] / self.sample_counts[ + mask + ].float().unsqueeze(-1) + + # Format into expected return structure + metrics_by_edge_type = {} + for edge_type in unique_edge_types: + idx = self.edge_type_to_idx[edge_type] + metrics_by_edge_type[edge_type] = EdgeTypeMetrics( + avg_pos_score=avg_pos_scores[idx], + avg_neg_score=avg_neg_scores[idx], + avg_mrr=avg_mrrs[idx], + avg_hit_rates=avg_hit_rates[idx], + ) + + # Log edge types with undefined metrics + missing_edge_types = [ + et + for et in unique_edge_types + if self.sample_counts[self.edge_type_to_idx[et]] == 0 + ] + if missing_edge_types: + logger.warning( + f"Edge types {missing_edge_types} have no samples across all ranks. " + f"Setting metrics to NaN (undefined)." + ) + + return avg_loss, metrics_by_edge_type def evaluate( @@ -143,8 +234,8 @@ def evaluate( A tuple containing: - overall_loss: Average loss across all batches and edge types - metrics_by_edge_type: Dictionary mapping each CondensedEdgeType to - a tuple of (avg_pos_score, avg_neg_score, avg_mrr, avg_hit_rates) - where avg_hit_rates is a tensor with hit rates at different k values + an EdgeTypeMetrics object containing avg_pos_score, avg_neg_score, + avg_mrr, and avg_hit_rates (tensor with hit rates at different k values) Note: This function temporarily switches the model to evaluation mode and the @@ -157,35 +248,29 @@ def evaluate( pipeline._model.module.set_phase(phase) device = pipeline._device - # Initialize metric accumulators - accumulators = ( - [], # losses - defaultdict(list), # pos_logits by edge type - defaultdict(list), # neg_logits by edge type - defaultdict(list), # mrrs by edge type - defaultdict(list), # hit_rates by edge type + # Initialize tensor-based metric accumulator + unique_edge_types = sorted(graph_metadata.condensed_edge_types) + + accumulator = EvaluationMetricsAccumulator( + unique_edge_types=unique_edge_types, + evaluation_config=evaluation_phase_config, + device=device, ) - losses, _, _, _, _ = accumulators - unique_edge_types = sorted(graph_metadata.condensed_edge_types) step_count = 0 # Process validation batches while True: try: batch_loss, logits, labels, edge_types = pipeline.progress(val_iter) - losses.append(batch_loss.item()) - - # Accumulate metrics for each edge type in this batch - for edge_type in unique_edge_types: - _accumulate_metrics_for_edge_type( - edge_type, - logits, - labels, - edge_types, - evaluation_phase_config, - accumulators, - ) + + # Accumulate metrics for all edge types in this batch + accumulator.accumulate( + batch_loss=batch_loss, + logits=logits, + labels=labels, + condensed_edge_types=edge_types, + ) step_count += 1 except StopIteration: @@ -193,15 +278,13 @@ def evaluate( logger.info(f"Completed {phase} evaluation over {step_count} steps.") - # Aggregate metrics and prepare for distributed reduction - aggregated_metrics = _aggregate_metrics(accumulators, unique_edge_types, device) + # Perform distributed reduction on all metric tensors + accumulator.reduce_all() - # Perform distributed reduction to average across all workers - for metric in aggregated_metrics: - dist.all_reduce(metric, op=dist.ReduceOp.AVG) + # Compute final averaged metrics and format results + result = accumulator.compute_final_metrics(unique_edge_types) - # Format results and restore original model state - result = _format_output_metrics(aggregated_metrics, unique_edge_types) + # Restore original model state pipeline._model.module.set_phase(original_phase) pipeline._model.train() From 23edaca0cde5a11eb4d3ffddfad073ee9b8f9dd0 Mon Sep 17 00:00:00 2001 From: nshah Date: Thu, 2 Oct 2025 02:10:15 +0000 Subject: [PATCH 3/3] fmt --- .../lib/evaluation.py | 84 +++++++++---------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py b/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py index 58685849a..b582dec1d 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/evaluation.py @@ -72,23 +72,23 @@ def __init__( evaluation_config: Configuration containing hit rate k values and other evaluation parameters device: Device to place tensors on """ - self.evaluation_config = evaluation_config - self.edge_type_to_idx = {et: i for i, et in enumerate(unique_edge_types)} + self._evaluation_config = evaluation_config + self._edge_type_to_idx = {et: i for i, et in enumerate(unique_edge_types)} num_edge_types = len(unique_edge_types) num_k_values = len(evaluation_config.hit_rates_at_k) - self.total_loss = torch.tensor(0.0, device=device) - self.total_batches = torch.tensor(0, dtype=torch.long, device=device) - self.sample_counts = torch.zeros( + self._total_loss = torch.tensor(0.0, device=device) + self._total_batches = torch.tensor(0, dtype=torch.long, device=device) + self._sample_counts = torch.zeros( num_edge_types, dtype=torch.long, device=device ) - self.pos_scores = torch.zeros(num_edge_types, device=device) - self.neg_scores = torch.zeros(num_edge_types, device=device) - self.mrrs = torch.zeros(num_edge_types, device=device) - self.hit_rates = torch.zeros(num_edge_types, num_k_values, device=device) + self._pos_scores = torch.zeros(num_edge_types, device=device) + self._neg_scores = torch.zeros(num_edge_types, device=device) + self._mrrs = torch.zeros(num_edge_types, device=device) + self._hit_rates = torch.zeros(num_edge_types, num_k_values, device=device) - def accumulate( + def accumulate_batch( self, batch_loss: torch.Tensor, logits: torch.Tensor, @@ -104,11 +104,11 @@ def accumulate( condensed_edge_types: Edge type indices for each sample in the batch """ # Accumulate batch-level metrics - self.total_loss += batch_loss - self.total_batches += 1 + self._total_loss += batch_loss + self._total_batches += 1 # Process each edge type - for edge_type, idx in self.edge_type_to_idx.items(): + for edge_type, idx in self._edge_type_to_idx.items(): edge_type_mask = condensed_edge_types == edge_type if not edge_type_mask.any(): continue @@ -123,26 +123,26 @@ def accumulate( hr_at_k = hit_rate_at_k( scores=logits[edge_type_mask], labels=labels[edge_type_mask], - ks=self.evaluation_config.hit_rates_at_k, + ks=self._evaluation_config.hit_rates_at_k, ) # Accumulate weighted totals for this edge type edge_type_sample_count_in_batch = edge_type_mask.sum() - self.sample_counts[idx] += edge_type_sample_count_in_batch - self.pos_scores[idx] += avg_pos_score * edge_type_sample_count_in_batch - self.neg_scores[idx] += avg_neg_score * edge_type_sample_count_in_batch - self.mrrs[idx] += mrr * edge_type_sample_count_in_batch - self.hit_rates[idx] += hr_at_k * edge_type_sample_count_in_batch + self._sample_counts[idx] += edge_type_sample_count_in_batch + self._pos_scores[idx] += avg_pos_score * edge_type_sample_count_in_batch + self._neg_scores[idx] += avg_neg_score * edge_type_sample_count_in_batch + self._mrrs[idx] += mrr * edge_type_sample_count_in_batch + self._hit_rates[idx] += hr_at_k * edge_type_sample_count_in_batch - def reduce_all(self) -> None: + def sum_metrics_over_ranks(self) -> None: """Perform distributed reduction (sum) on all metric tensors.""" - dist.all_reduce(self.total_loss, op=dist.ReduceOp.SUM) - dist.all_reduce(self.total_batches, op=dist.ReduceOp.SUM) - dist.all_reduce(self.sample_counts, op=dist.ReduceOp.SUM) - dist.all_reduce(self.pos_scores, op=dist.ReduceOp.SUM) - dist.all_reduce(self.neg_scores, op=dist.ReduceOp.SUM) - dist.all_reduce(self.mrrs, op=dist.ReduceOp.SUM) - dist.all_reduce(self.hit_rates, op=dist.ReduceOp.SUM) + dist.all_reduce(self._total_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(self._total_batches, op=dist.ReduceOp.SUM) + dist.all_reduce(self._sample_counts, op=dist.ReduceOp.SUM) + dist.all_reduce(self._pos_scores, op=dist.ReduceOp.SUM) + dist.all_reduce(self._neg_scores, op=dist.ReduceOp.SUM) + dist.all_reduce(self._mrrs, op=dist.ReduceOp.SUM) + dist.all_reduce(self._hit_rates, op=dist.ReduceOp.SUM) def compute_final_metrics( self, @@ -157,35 +157,35 @@ def compute_final_metrics( Tuple of (average_loss, metrics_by_edge_type) """ # Compute average loss - avg_loss = self.total_loss / self.total_batches + avg_loss = self._total_loss / self._total_batches # Create mask for valid edge types (those with samples) - mask = self.sample_counts > 0 + mask = self._sample_counts > 0 # Initialize metric tensors with NaN for missing edge types - avg_pos_scores = torch.full_like(self.pos_scores, float("nan")) - avg_neg_scores = torch.full_like(self.neg_scores, float("nan")) - avg_mrrs = torch.full_like(self.mrrs, float("nan")) - avg_hit_rates = torch.full_like(self.hit_rates, float("nan")) + avg_pos_scores = torch.full_like(self._pos_scores, float("nan")) + avg_neg_scores = torch.full_like(self._neg_scores, float("nan")) + avg_mrrs = torch.full_like(self._mrrs, float("nan")) + avg_hit_rates = torch.full_like(self._hit_rates, float("nan")) # Compute averages for edge types with samples if mask.any(): avg_pos_scores[mask] = ( - self.pos_scores[mask] / self.sample_counts[mask].float() + self._pos_scores[mask] / self._sample_counts[mask].float() ) avg_neg_scores[mask] = ( - self.neg_scores[mask] / self.sample_counts[mask].float() + self._neg_scores[mask] / self._sample_counts[mask].float() ) - avg_mrrs[mask] = self.mrrs[mask] / self.sample_counts[mask].float() + avg_mrrs[mask] = self._mrrs[mask] / self._sample_counts[mask].float() # Broadcast sample counts for hit rates division - avg_hit_rates[mask] = self.hit_rates[mask] / self.sample_counts[ + avg_hit_rates[mask] = self._hit_rates[mask] / self._sample_counts[ mask ].float().unsqueeze(-1) # Format into expected return structure metrics_by_edge_type = {} for edge_type in unique_edge_types: - idx = self.edge_type_to_idx[edge_type] + idx = self._edge_type_to_idx[edge_type] metrics_by_edge_type[edge_type] = EdgeTypeMetrics( avg_pos_score=avg_pos_scores[idx], avg_neg_score=avg_neg_scores[idx], @@ -197,7 +197,7 @@ def compute_final_metrics( missing_edge_types = [ et for et in unique_edge_types - if self.sample_counts[self.edge_type_to_idx[et]] == 0 + if self._sample_counts[self._edge_type_to_idx[et]] == 0 ] if missing_edge_types: logger.warning( @@ -265,7 +265,7 @@ def evaluate( batch_loss, logits, labels, edge_types = pipeline.progress(val_iter) # Accumulate metrics for all edge types in this batch - accumulator.accumulate( + accumulator.accumulate_batch( batch_loss=batch_loss, logits=logits, labels=labels, @@ -279,10 +279,10 @@ def evaluate( logger.info(f"Completed {phase} evaluation over {step_count} steps.") # Perform distributed reduction on all metric tensors - accumulator.reduce_all() + accumulator.sum_metrics_over_ranks() # Compute final averaged metrics and format results - result = accumulator.compute_final_metrics(unique_edge_types) + result = accumulator.compute_final_metrics(unique_edge_types=unique_edge_types) # Restore original model state pipeline._model.module.set_phase(original_phase)