diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses.md
new file mode 100644
index 00000000..33de654c
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses.md
@@ -0,0 +1,81 @@
+
+
+# Module: contrastive_losses
+
+[TOC]
+
+
+
+
for more details.
+
+
+
+
+
+Args |
+
+
+
+`fns`
+ |
+
+a mapping from a metric name to a `Callable` that accepts
+representations as well as the result of their SVD decomposition.
+Currently only singular values are passed.
+ |
+
+
+`name`
+ |
+
+Name for the metric class, used for Keras bookkeeping.
+ |
+
+
+
+## Methods
+
+merge_state
+
+
+merge_state(
+ metrics
+)
+
+
+Merges the state from one or more metrics.
+
+This method can be used by distributed systems to merge the state computed by
+different metric instances. Typically the state will be stored in the form of
+the metric's weights. For example, a tf.keras.metrics.Mean metric contains a
+list of two weight values: a total and a count. If there were two instances of a
+tf.keras.metrics.Accuracy that each independently aggregated partial state for
+an overall accuracy calculation, these two metric's states could be combined as
+follows:
+
+ >>> m1 = tf.keras.metrics.Accuracy()
+ >>> _ = m1.update_state([[1], [2]], [[0], [2]])
+
+ >>> m2 = tf.keras.metrics.Accuracy()
+ >>> _ = m2.update_state([[3], [4]], [[3], [4]])
+
+ >>> m2.merge_state([m1])
+ >>> m2.result().numpy()
+ 0.75
+
+
+
+
+
+Args |
+
+
+
+`metrics`
+ |
+
+an iterable of metrics. The metrics must have compatible
+state.
+ |
+
+
+
+
+
+
+
+Raises |
+
+
+
+`ValueError`
+ |
+
+If the provided iterable does not contain metrics matching
+the metric's required specifications.
+ |
+
+
+
+reset_state
+
+View
+source
+
+
+reset_state() -> None
+
+
+Resets all of the metric state variables.
+
+This function is called between epochs/steps, when a metric is evaluated during
+training.
+
+result
+
+View
+source
+
+
+result() -> Mapping[str, tf.Tensor]
+
+
+Computes and returns the scalar metric value tensor or a dict of scalars.
+
+Result computation is an idempotent operation that simply calculates the metric
+value using the state variables.
+
+
+
+
+
+Returns |
+
+
+A scalar tensor, or a dictionary of scalar tensors.
+ |
+
+
+
+
+update_state
+
+View
+source
+
+
+update_state(
+ _, y_pred: tf.Tensor, sample_weight=None
+) -> None
+
+
+Accumulates statistics for the metric.
+
+Note: This function is executed as a graph function in graph mode. This means:
+a) Operations on the same resource are executed in textual order. This should
+make it easier to do things like add the updated value of a variable to another,
+for example. b) You don't need to worry about collecting the update ops to
+execute. All update ops added to the graph by this function will be executed. As
+a result, code should generally work the same way with graph or eager execution.
+
+
+
+
+
+Args |
+
+
+
+`*args`
+ |
+
+
+ |
+
+
+`**kwargs`
+ |
+
+A mini-batch of inputs to the Metric.
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/BarlowTwinsTask.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/BarlowTwinsTask.md
new file mode 100644
index 00000000..0a746592
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/BarlowTwinsTask.md
@@ -0,0 +1,178 @@
+
+
+# contrastive_losses.BarlowTwinsTask
+
+[TOC]
+
+
+
+
+
+A Barlow Twins (BT) Task.
+
+Inherits From:
+[`ContrastiveLossTask`](../contrastive_losses/ContrastiveLossTask.md)
+
+
+contrastive_losses.BarlowTwinsTask(
+ *args,
+ lambda_: Optional[Union[tf.Tensor, float]] = None,
+ normalize_batch: bool = True,
+ **kwargs
+)
+
+
+
+
+
+
+
+
+Args |
+
+
+
+`node_set_name`
+ |
+
+Name of the node set for readout.
+ |
+
+
+`feature_name`
+ |
+
+Feature name for readout.
+ |
+
+
+`representations_layer_name`
+ |
+
+Layer name for uncorrupted representations.
+ |
+
+
+`corruptor`
+ |
+
+`Corruptor` instance for creating negative samples. If not
+specified, we use `ShuffleFeaturesGlobally` by default.
+ |
+
+
+`projector_units`
+ |
+
+`Sequence` of layer sizes for projector network.
+Projectors prevent dimensional collapse, but can hinder training for
+easy corruptions. For more details, see
+https://arxiv.org/abs/2304.12210.
+ |
+
+
+`seed`
+ |
+
+Random seed for the default corruptor (`ShuffleFeaturesGlobally`).
+ |
+
+
+
+## Methods
+
+losses
+
+View
+source
+
+
+losses() -> runner.Losses
+
+
+Returns arbitrary task specific losses.
+
+make_contrastive_layer
+
+View
+source
+
+
+make_contrastive_layer() -> tf.keras.layers.Layer
+
+
+Returns the layer contrasting clean outputs with the correupted ones.
+
+metrics
+
+View
+source
+
+
+metrics() -> runner.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ *args
+) -> runner.Predictions
+
+
+Apply a readout head for use with various contrastive losses.
+
+
+
+
+
+Args |
+
+
+
+`*args`
+ |
+
+A tuple of (clean, corrupted) `tfgnn.GraphTensor`s.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The logits for some contrastive loss as produced by the implementing
+subclass.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[Sequence[GraphTensor], runner.Predictions]
+
+
+Applies a `Corruptor` and returns empty pseudo-labels.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/ContrastiveLossTask.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/ContrastiveLossTask.md
new file mode 100644
index 00000000..4c44af3f
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/ContrastiveLossTask.md
@@ -0,0 +1,195 @@
+
+
+# contrastive_losses.ContrastiveLossTask
+
+[TOC]
+
+
+
+
+
+Base class for unsupervised contrastive representation learning tasks.
+
+
+contrastive_losses.ContrastiveLossTask(
+ node_set_name: str,
+ *,
+ feature_name: str = tfgnn.HIDDEN_STATE,
+ representations_layer_name: Optional[str] = None,
+ corruptor: Optional[layers.Corruptor] = None,
+ projector_units: Optional[Sequence[int]] = None,
+ seed: Optional[int] = None
+)
+
+
+
+
+The process is separated into preprocessing and contrastive parts, with the
+focus on reusability of individual components. The `preprocess` produces input
+GraphTensors to be used with the `predict` as well as labels for the task. The
+default `predict` method implementation expects a pair of positive and negative
+GraphTensors. There are multiple ways proposed in the literature to learn
+representations based on the activations - we achieve that by using custom
+losses.
+
+Any subclass must implement `make_contrastive_layer` method, which produces the
+final prediction outputs.
+
+If the loss involves labels for each example, subclasses should leverage
+`losses` and `metrics` methods to specify task's losses. When the loss only
+involves model outputs, `make_contrastive_layer` should output both positive and
+perturb examples, and the `losses` should use pseudolabels.
+
+Any model-specific preprocessing should be implemented in the `preprocess`.
+
+
+
+
+
+Args |
+
+
+
+`node_set_name`
+ |
+
+Name of the node set for readout.
+ |
+
+
+`feature_name`
+ |
+
+Feature name for readout.
+ |
+
+
+`representations_layer_name`
+ |
+
+Layer name for uncorrupted representations.
+ |
+
+
+`corruptor`
+ |
+
+`Corruptor` instance for creating negative samples. If not
+specified, we use `ShuffleFeaturesGlobally` by default.
+ |
+
+
+`projector_units`
+ |
+
+`Sequence` of layer sizes for projector network.
+Projectors prevent dimensional collapse, but can hinder training for
+easy corruptions. For more details, see
+https://arxiv.org/abs/2304.12210.
+ |
+
+
+`seed`
+ |
+
+Random seed for the default corruptor (`ShuffleFeaturesGlobally`).
+ |
+
+
+
+## Methods
+
+losses
+
+
+@abc.abstractmethod
+losses() -> Losses
+
+
+Returns arbitrary task specific losses.
+
+make_contrastive_layer
+
+View
+source
+
+
+@abc.abstractmethod
+make_contrastive_layer() -> tf.keras.layers.Layer
+
+
+Returns the layer contrasting clean outputs with the correupted ones.
+
+metrics
+
+View
+source
+
+
+metrics() -> runner.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ *args
+) -> runner.Predictions
+
+
+Apply a readout head for use with various contrastive losses.
+
+
+
+
+
+Args |
+
+
+
+`*args`
+ |
+
+A tuple of (clean, corrupted) `tfgnn.GraphTensor`s.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The logits for some contrastive loss as produced by the implementing
+subclass.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[Sequence[GraphTensor], runner.Predictions]
+
+
+Applies a `Corruptor` and returns empty pseudo-labels.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/CorruptionSpec.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/CorruptionSpec.md
new file mode 100644
index 00000000..3e75522e
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/CorruptionSpec.md
@@ -0,0 +1,96 @@
+
+
+# contrastive_losses.CorruptionSpec
+
+[TOC]
+
+
+
+
+
+Class for defining corruption specification.
+
+
+contrastive_losses.CorruptionSpec(
+ node_set_corruption: NodeCorruptionSpec = dataclasses.field(default_factory=dict),
+ edge_set_corruption: EdgeCorruptionSpec = dataclasses.field(default_factory=dict),
+ context_corruption: ContextCorruptionSpec = dataclasses.field(default_factory=dict)
+)
+
+
+
+
+This has three fields for specifying the corruption behavior of node-, edge-,
+and context sets.
+
+A value of the key "\*" is a wildcard value that is used for either all features
+or all node/edge sets.
+
+#### Some example usages:
+
+Want: corrupt everything with parameter 1.0. Solution: either set default to 1.0
+or set all corruption specs to `{"*": 1.}`.
+
+Want: corrupt all context features with parameter 1.0 except for "feat", which
+should not be corrupted. Solution: set `context_corruption` to `{"feat": 0.,
+"*": 1.}`
+
+
+
+
+
+Attributes |
+
+
+
+`node_set_corruption`
+ |
+
+Dataclass field
+ |
+
+
+`edge_set_corruption`
+ |
+
+Dataclass field
+ |
+
+
+`context_corruption`
+ |
+
+Dataclass field
+ |
+
+
+
+## Methods
+
+with_default
+
+View
+source
+
+
+with_default(
+ default: T
+)
+
+
+__eq__
+
+
+__eq__(
+ other
+)
+
+
+Return self==value.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/Corruptor.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/Corruptor.md
new file mode 100644
index 00000000..4b17fb24
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/Corruptor.md
@@ -0,0 +1,68 @@
+
+
+# contrastive_losses.Corruptor
+
+[TOC]
+
+
+
+
+
+Base class for graph corruptor.
+
+
+contrastive_losses.Corruptor(
+ corruption_spec: Optional[CorruptionSpec[T]] = None,
+ *,
+ corruption_fn: Callable[[tfgnn.Field, T], tfgnn.Field],
+ default: Optional[T] = None,
+ **kwargs
+)
+
+
+
+
+
+
+
+
+Args |
+
+
+
+`corruption_spec`
+ |
+
+A spec for corruption application.
+ |
+
+
+`corruption_fn`
+ |
+
+Corruption function.
+ |
+
+
+`default`
+ |
+
+Global application default of the corruptor. This is only used
+when `corruption_spec` is None.
+ |
+
+
+`**kwargs`
+ |
+
+Additional keyword arguments.
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DeepGraphInfomaxLogits.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DeepGraphInfomaxLogits.md
new file mode 100644
index 00000000..3b791460
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DeepGraphInfomaxLogits.md
@@ -0,0 +1,26 @@
+
+
+# contrastive_losses.DeepGraphInfomaxLogits
+
+[TOC]
+
+
+
+
+
+Computes clean and corrupted logits for Deep Graph Infomax (DGI).
+
+
+contrastive_losses.DeepGraphInfomaxLogits(
+ trainable=True, name=None, dtype=None, dynamic=False, **kwargs
+)
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DeepGraphInfomaxTask.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DeepGraphInfomaxTask.md
new file mode 100644
index 00000000..17a24e16
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DeepGraphInfomaxTask.md
@@ -0,0 +1,175 @@
+
+
+# contrastive_losses.DeepGraphInfomaxTask
+
+[TOC]
+
+
+
+
+
+A Deep Graph Infomax (DGI) Task.
+
+Inherits From:
+[`ContrastiveLossTask`](../contrastive_losses/ContrastiveLossTask.md)
+
+
+contrastive_losses.DeepGraphInfomaxTask(
+ *args, **kwargs
+)
+
+
+
+
+
+
+
+
+Args |
+
+
+
+`node_set_name`
+ |
+
+Name of the node set for readout.
+ |
+
+
+`feature_name`
+ |
+
+Feature name for readout.
+ |
+
+
+`representations_layer_name`
+ |
+
+Layer name for uncorrupted representations.
+ |
+
+
+`corruptor`
+ |
+
+`Corruptor` instance for creating negative samples. If not
+specified, we use `ShuffleFeaturesGlobally` by default.
+ |
+
+
+`projector_units`
+ |
+
+`Sequence` of layer sizes for projector network.
+Projectors prevent dimensional collapse, but can hinder training for
+easy corruptions. For more details, see
+https://arxiv.org/abs/2304.12210.
+ |
+
+
+`seed`
+ |
+
+Random seed for the default corruptor (`ShuffleFeaturesGlobally`).
+ |
+
+
+
+## Methods
+
+losses
+
+View
+source
+
+
+losses() -> runner.Losses
+
+
+Returns arbitrary task specific losses.
+
+make_contrastive_layer
+
+View
+source
+
+
+make_contrastive_layer() -> tf.keras.layers.Layer
+
+
+Returns the layer contrasting clean outputs with the correupted ones.
+
+metrics
+
+View
+source
+
+
+metrics() -> runner.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ *args
+) -> runner.Predictions
+
+
+Apply a readout head for use with various contrastive losses.
+
+
+
+
+
+Args |
+
+
+
+`*args`
+ |
+
+A tuple of (clean, corrupted) `tfgnn.GraphTensor`s.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The logits for some contrastive loss as produced by the implementing
+subclass.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[Sequence[GraphTensor], Mapping[str, Field]]
+
+
+Creates labels--i.e., (positive, negative)--for Deep Graph Infomax.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DropoutFeatures.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DropoutFeatures.md
new file mode 100644
index 00000000..0a2f69e0
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DropoutFeatures.md
@@ -0,0 +1,66 @@
+
+
+# contrastive_losses.DropoutFeatures
+
+[TOC]
+
+
+
+
+
+Base class for graph corruptor.
+
+Inherits From: [`Corruptor`](../contrastive_losses/Corruptor.md)
+
+
+contrastive_losses.DropoutFeatures(
+ *args, seed: Optional[float] = None, **kwargs
+)
+
+
+
+
+
+
+
+
+Args |
+
+
+
+`corruption_spec`
+ |
+
+A spec for corruption application.
+ |
+
+
+`corruption_fn`
+ |
+
+Corruption function.
+ |
+
+
+`default`
+ |
+
+Global application default of the corruptor. This is only used
+when `corruption_spec` is None.
+ |
+
+
+`**kwargs`
+ |
+
+Additional keyword arguments.
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/ShuffleFeaturesGlobally.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/ShuffleFeaturesGlobally.md
new file mode 100644
index 00000000..a2cc3b62
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/ShuffleFeaturesGlobally.md
@@ -0,0 +1,69 @@
+
+
+# contrastive_losses.ShuffleFeaturesGlobally
+
+[TOC]
+
+
+
+
+
+A corruptor that shuffles features.
+
+Inherits From: [`Corruptor`](../contrastive_losses/Corruptor.md)
+
+
+contrastive_losses.ShuffleFeaturesGlobally(
+ *args, seed: Optional[float] = None, **kwargs
+)
+
+
+
+
+NOTE: this function does not currently support TPUs. Consider using other
+corruptor functions if executing on TPUs. See b/269249455 for reference.
+
+
+
+
+
+Args |
+
+
+
+`corruption_spec`
+ |
+
+A spec for corruption application.
+ |
+
+
+`corruption_fn`
+ |
+
+Corruption function.
+ |
+
+
+`default`
+ |
+
+Global application default of the corruptor. This is only used
+when `corruption_spec` is None.
+ |
+
+
+`**kwargs`
+ |
+
+Additional keyword arguments.
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/TripletEmbeddingSquaredDistances.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/TripletEmbeddingSquaredDistances.md
new file mode 100644
index 00000000..8e0b2949
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/TripletEmbeddingSquaredDistances.md
@@ -0,0 +1,26 @@
+
+
+# contrastive_losses.TripletEmbeddingSquaredDistances
+
+[TOC]
+
+
+
+
+
+Computes embeddings distance between positive and negative pairs.
+
+
+contrastive_losses.TripletEmbeddingSquaredDistances(
+ trainable=True, name=None, dtype=None, dynamic=False, **kwargs
+)
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/TripletLossTask.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/TripletLossTask.md
new file mode 100644
index 00000000..1e6cd2d0
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/TripletLossTask.md
@@ -0,0 +1,210 @@
+
+
+# contrastive_losses.TripletLossTask
+
+[TOC]
+
+
+
+
+
+The triplet loss task.
+
+Inherits From:
+[`ContrastiveLossTask`](../contrastive_losses/ContrastiveLossTask.md)
+
+
+contrastive_losses.TripletLossTask(
+ *args, margin: float = 1.0, **kwargs
+)
+
+
+
+
+
+
+
+
+Args |
+
+
+
+`node_set_name`
+ |
+
+Name of the node set for readout.
+ |
+
+
+`feature_name`
+ |
+
+Feature name for readout.
+ |
+
+
+`representations_layer_name`
+ |
+
+Layer name for uncorrupted representations.
+ |
+
+
+`corruptor`
+ |
+
+`Corruptor` instance for creating negative samples. If not
+specified, we use `ShuffleFeaturesGlobally` by default.
+ |
+
+
+`projector_units`
+ |
+
+`Sequence` of layer sizes for projector network.
+Projectors prevent dimensional collapse, but can hinder training for
+easy corruptions. For more details, see
+https://arxiv.org/abs/2304.12210.
+ |
+
+
+`seed`
+ |
+
+Random seed for the default corruptor (`ShuffleFeaturesGlobally`).
+ |
+
+
+
+## Methods
+
+losses
+
+View
+source
+
+
+losses() -> runner.Losses
+
+
+Returns arbitrary task specific losses.
+
+make_contrastive_layer
+
+View
+source
+
+
+make_contrastive_layer() -> tf.keras.layers.Layer
+
+
+Returns the layer contrasting clean outputs with the correupted ones.
+
+metrics
+
+View
+source
+
+
+metrics() -> runner.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ *args
+) -> runner.Predictions
+
+
+Apply a readout head for use with triplet contrastive loss.
+
+
+
+
+
+Args |
+
+
+
+`*args`
+ |
+
+A tuple of (anchor, positive_sample, negative_sample)
+`tfgnn.GraphTensor`s.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The positive and negative distance embeddings for triplet loss as produced
+by the implementing subclass.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[Sequence[GraphTensor], tfgnn.Field]
+
+
+Creates unused pseudo-labels.
+
+The input tensor should have the anchor and positive sample stacked along the
+first dimension for each feature within each node set. The corruptor is applied
+on the positive sample.
+
+
+
+
+
+Args |
+
+
+
+`inputs`
+ |
+
+The anchor and positive sample stack along the first axis.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Sequence of three graph tensors (anchor, positive_sample,
+corrupted_sample) and unused pseudo-labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/VicRegTask.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/VicRegTask.md
new file mode 100644
index 00000000..12a02fba
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/VicRegTask.md
@@ -0,0 +1,179 @@
+
+
+# contrastive_losses.VicRegTask
+
+[TOC]
+
+
+
+
+
+A VICReg Task.
+
+Inherits From:
+[`ContrastiveLossTask`](../contrastive_losses/ContrastiveLossTask.md)
+
+
+contrastive_losses.VicRegTask(
+ *args,
+ sim_weight: Union[tf.Tensor, float] = 25.0,
+ var_weight: Union[tf.Tensor, float] = 25.0,
+ cov_weight: Union[tf.Tensor, float] = 1.0,
+ **kwargs
+)
+
+
+
+
+
+
+
+
+Args |
+
+
+
+`node_set_name`
+ |
+
+Name of the node set for readout.
+ |
+
+
+`feature_name`
+ |
+
+Feature name for readout.
+ |
+
+
+`representations_layer_name`
+ |
+
+Layer name for uncorrupted representations.
+ |
+
+
+`corruptor`
+ |
+
+`Corruptor` instance for creating negative samples. If not
+specified, we use `ShuffleFeaturesGlobally` by default.
+ |
+
+
+`projector_units`
+ |
+
+`Sequence` of layer sizes for projector network.
+Projectors prevent dimensional collapse, but can hinder training for
+easy corruptions. For more details, see
+https://arxiv.org/abs/2304.12210.
+ |
+
+
+`seed`
+ |
+
+Random seed for the default corruptor (`ShuffleFeaturesGlobally`).
+ |
+
+
+
+## Methods
+
+losses
+
+View
+source
+
+
+losses() -> runner.Losses
+
+
+Returns arbitrary task specific losses.
+
+make_contrastive_layer
+
+View
+source
+
+
+make_contrastive_layer() -> tf.keras.layers.Layer
+
+
+Returns the layer contrasting clean outputs with the correupted ones.
+
+metrics
+
+View
+source
+
+
+metrics() -> runner.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ *args
+) -> runner.Predictions
+
+
+Apply a readout head for use with various contrastive losses.
+
+
+
+
+
+Args |
+
+
+
+`*args`
+ |
+
+A tuple of (clean, corrupted) `tfgnn.GraphTensor`s.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The logits for some contrastive loss as produced by the implementing
+subclass.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[Sequence[GraphTensor], runner.Predictions]
+
+
+Applies a `Corruptor` and returns empty pseudo-labels.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/all_symbols.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/all_symbols.md
new file mode 100644
index 00000000..72b37858
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/all_symbols.md
@@ -0,0 +1,24 @@
+# All symbols in TensorFlow GNN Models: fContrastiveLosses
+
+
+
+## Primary symbols
+
+* contrastive_losses
+* contrastive_losses.AllSvdMetrics
+* contrastive_losses.BarlowTwinsTask
+* contrastive_losses.ContrastiveLossTask
+* contrastive_losses.CorruptionSpec
+* contrastive_losses.Corruptor
+* contrastive_losses.DeepGraphInfomaxLogits
+* contrastive_losses.DeepGraphInfomaxTask
+* contrastive_losses.DropoutFeatures
+* contrastive_losses.ShuffleFeaturesGlobally
+* contrastive_losses.TripletEmbeddingSquaredDistances
+* contrastive_losses.TripletLossTask
+* contrastive_losses.VicRegTask
+* contrastive_losses.coherence
+* contrastive_losses.numerical_rank
+* contrastive_losses.pseudo_condition_number
+* contrastive_losses.rankme
+* contrastive_losses.self_clustering
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/coherence.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/coherence.md
new file mode 100644
index 00000000..9cb8ccce
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/coherence.md
@@ -0,0 +1,78 @@
+
+
+# contrastive_losses.coherence
+
+[TOC]
+
+
+
+
+
+Coherence metric implementation.
+
+
+@tf.function
+contrastive_losses.coherence(
+ representations: tf.Tensor,
+ *,
+ sigma: Optional[tf.Tensor] = None,
+ u: Optional[tf.Tensor] = None
+) -> tf.Tensor
+
+
+
+
+Coherence measures how easy it is to construct a linear classifier on top of
+data without knowing downstream labels. Refer to
+ for more details.
+
+
+
+
+
+Args |
+
+
+
+`representations`
+ |
+
+Input representations, a rank-2 tensor.
+ |
+
+
+`sigma`
+ |
+
+Unused.
+ |
+
+
+`u`
+ |
+
+An optional tensor with left singular vectors of representations. If not
+present, computes a SVD of representations.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Metric value as scalar `tf.Tensor`.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/numerical_rank.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/numerical_rank.md
new file mode 100644
index 00000000..54dad607
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/numerical_rank.md
@@ -0,0 +1,79 @@
+
+
+# contrastive_losses.numerical_rank
+
+[TOC]
+
+
+
+
+
+Numerical rank implementation.
+
+
+@tf.function
+contrastive_losses.numerical_rank(
+ representations: tf.Tensor,
+ *,
+ sigma: Optional[tf.Tensor] = None,
+ u: Optional[tf.Tensor] = None
+) -> tf.Tensor
+
+
+
+
+Computes a metric that estimates the numerical column rank of a matrix. Rank is
+estimated as a matrix trace divided by the largest eigenvalue. When our matrix
+is a covariance matrix, we can compute both the trace and the largest eigenvalue
+efficiently via SVD as the largest singular value squared.
+
+
+
+
+
+Args |
+
+
+
+`representations`
+ |
+
+Input representations. We expect rank 2 input.
+ |
+
+
+`sigma`
+ |
+
+An optional tensor with singular values of representations. If not
+present, computes SVD (singular values only) of representations.
+ |
+
+
+`u`
+ |
+
+Unused.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Metric value as scalar `tf.Tensor`.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/pseudo_condition_number.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/pseudo_condition_number.md
new file mode 100644
index 00000000..dcefc3de
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/pseudo_condition_number.md
@@ -0,0 +1,78 @@
+
+
+# contrastive_losses.pseudo_condition_number
+
+[TOC]
+
+
+
+
+
+Pseudo-condition number metric implementation.
+
+
+@tf.function
+contrastive_losses.pseudo_condition_number(
+ representations: tf.Tensor,
+ *,
+ sigma: Optional[tf.Tensor] = None,
+ u: Optional[tf.Tensor] = None
+) -> tf.Tensor
+
+
+
+
+Computes a metric that measures the decay rate of the singular values. NOTE: Can
+be unstable in practice, when using small batch sizes, leading to numerical
+instabilities.
+
+
+
+
+
+Args |
+
+
+
+`representations`
+ |
+
+Input representations. We expect rank 2 input.
+ |
+
+
+`sigma`
+ |
+
+An optional tensor with singular values of representations. If not
+present, computes SVD (singular values only) of representations.
+ |
+
+
+`u`
+ |
+
+Unused.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Metric value as scalar `tf.Tensor`.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/rankme.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/rankme.md
new file mode 100644
index 00000000..f5d67eef
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/rankme.md
@@ -0,0 +1,86 @@
+
+
+# contrastive_losses.rankme
+
+[TOC]
+
+
+
+
+
+RankMe metric implementation.
+
+
+@tf.function
+contrastive_losses.rankme(
+ representations: tf.Tensor,
+ *,
+ sigma: Optional[tf.Tensor] = None,
+ u: Optional[tf.Tensor] = None,
+ epsilon: float = 1e-12,
+ **_
+) -> tf.Tensor
+
+
+
+
+Computes a metric that measures the decay rate of the singular values. For the
+paper, see .
+
+
+
+
+
+Args |
+
+
+
+`representations`
+ |
+
+Input representations as rank-2 tensor.
+ |
+
+
+`sigma`
+ |
+
+An optional tensor with singular values of representations. If not
+present, computes SVD (singular values only) of representations.
+ |
+
+
+`u`
+ |
+
+Unused.
+ |
+
+
+`epsilon`
+ |
+
+Epsilon for numerican stability.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Metric value as scalar `tf.Tensor`.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/self_clustering.md b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/self_clustering.md
new file mode 100644
index 00000000..65a014c7
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/self_clustering.md
@@ -0,0 +1,72 @@
+
+
+# contrastive_losses.self_clustering
+
+[TOC]
+
+
+
+
+
+Self-clustering metric implementation.
+
+
+@tf.function
+contrastive_losses.self_clustering(
+ representations: tf.Tensor, *, subtract_mean: bool = False, **_
+) -> tf.Tensor
+
+
+
+
+Computes a metric that measures how well distributed representations are, if
+projected on the unit sphere. If `subtract_mean` is True, we additionally remove
+the mean from representations. The metric has a range of (-0.5, 1\]. It achieves
+its maximum of 1 if representations collapse to a single point, and it is
+approximately 0 if representations are distributed randomly on the sphere. In
+theory, it can achieve negative values if the points are maximally equiangular,
+although this is very rare in practice. Refer to
+ for more details.
+
+
+
+
+
+Args |
+
+
+
+`representations`
+ |
+
+Input representations.
+ |
+
+
+`subtract_mean`
+ |
+
+Whether to subtract the mean from representations.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Metric value as scalar `tf.Tensor`.
+ |
+
+
+