Skip to content

Commit a191e2c

Browse files
Neslihanstensorflower-gardener
authored andcommitted
Update add_loss calls in DGI task to reduce them by dividing with the global_batch_size before passing to Keras.
PiperOrigin-RevId: 487897250
1 parent ff5e05e commit a191e2c

File tree

3 files changed

+126
-23
lines changed

3 files changed

+126
-23
lines changed

tensorflow_gnn/runner/tasks/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
22
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "py_strict_test")
3+
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "distribute_py_test")
34

45
licenses(["notice"])
56

@@ -40,13 +41,16 @@ pytype_strict_library(
4041
],
4142
)
4243

43-
py_strict_test(
44+
distribute_py_test(
4445
name = "dgi_test",
4546
srcs = ["dgi_test.py"],
4647
srcs_version = "PY3",
48+
xla_enable_strict_auto_jit = False,
4749
deps = [
4850
":dgi",
51+
"//:expect_absl_installed",
4952
"//:expect_tensorflow_installed",
53+
"//:expect_tensorflow_installed:tensorflow_no_contrib",
5054
"//tensorflow_gnn",
5155
"//tensorflow_gnn/runner:orchestration",
5256
],

tensorflow_gnn/runner/tasks/dgi.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@
2323
class AddLossDeepGraphInfomax(tf.keras.layers.Layer):
2424
""""A bilinear layer with losses and metrics for Deep Graph Infomax."""
2525

26-
def __init__(self, units: int):
26+
def __init__(self, units: int, global_batch_size: int):
2727
"""Builds the bilinear layer weights.
2828
2929
Args:
3030
units: Units for the bilinear layer.
31+
global_batch_size: Global batch size to compute the average loss.
3132
"""
3233
super().__init__()
3334
self._bilinear = tf.keras.layers.Dense(units, use_bias=False)
35+
self._global_batch_size = global_batch_size
3436

3537
def get_config(self) -> Mapping[Any, Any]:
3638
"""Returns the config of the layer.
@@ -60,11 +62,14 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
6062
summary = tf.math.reduce_mean(y_clean, axis=0, keepdims=True)
6163
# Clean losses and metrics
6264
logits_clean = tf.matmul(y_clean, self._bilinear(summary), transpose_b=True)
63-
self.add_loss(tf.keras.losses.BinaryCrossentropy(
65+
loss_clean = tf.keras.losses.BinaryCrossentropy(
6466
from_logits=True,
65-
name="binary_crossentropy_clean")(
66-
tf.ones_like(logits_clean),
67-
logits_clean))
67+
name="binary_crossentropy_clean",
68+
reduction=tf.keras.losses.Reduction.NONE)
69+
self.add_loss(
70+
tf.nn.compute_average_loss(
71+
loss_clean(tf.ones_like(logits_clean), logits_clean),
72+
global_batch_size=self._global_batch_size))
6873
self.add_metric(
6974
tf.keras.metrics.binary_crossentropy(
7075
tf.ones_like(logits_clean),
@@ -81,11 +86,15 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
8186
y_corrupted,
8287
self._bilinear(summary),
8388
transpose_b=True)
84-
self.add_loss(tf.keras.losses.BinaryCrossentropy(
89+
loss_object_corrupted = tf.keras.losses.BinaryCrossentropy(
8590
from_logits=True,
86-
name="binary_crossentropy_corrupted")(
87-
tf.zeros_like(logits_corrupted),
88-
logits_corrupted))
91+
name="binary_crossentropy_corrupted",
92+
reduction=tf.keras.losses.Reduction.NONE)
93+
self.add_loss(
94+
tf.nn.compute_average_loss(
95+
loss_object_corrupted(
96+
tf.zeros_like(logits_corrupted), logits_corrupted),
97+
global_batch_size=self._global_batch_size))
8998
self.add_metric(
9099
tf.keras.metrics.binary_crossentropy(
91100
tf.zeros_like(logits_corrupted),
@@ -125,18 +134,21 @@ class DeepGraphInfomax:
125134
def __init__(self,
126135
node_set_name: str,
127136
*,
137+
global_batch_size: int,
128138
state_name: str = tfgnn.HIDDEN_STATE,
129139
seed: Optional[int] = None):
130140
"""Captures arguments for the task.
131141
132142
Args:
133143
node_set_name: The node set for activations.
144+
global_batch_size: Global batch size(not per-replica) for the training.
134145
state_name: The state name of any activations.
135146
seed: A seed for corrupted representations.
136147
"""
137148
self._state_name = state_name
138149
self._node_set_name = node_set_name
139150
self._seed = seed
151+
self._global_batch_size = global_batch_size
140152

141153
def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
142154
"""Adapt a `tf.keras.Model` for Deep Graph Infomax.
@@ -164,15 +176,16 @@ def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
164176
feature_name=self._state_name)(model.output)
165177

166178
# Corrupted representations: shuffling, model application and readout
167-
shuffled = tfgnn.shuffle_features_globally(model.input)
179+
shuffled = tfgnn.shuffle_features_globally(model.input, seed=self._seed)
168180
y_corrupted = tfgnn.keras.layers.ReadoutFirstNode(
169181
node_set_name=self._node_set_name,
170182
feature_name=self._state_name)(model(shuffled))
171183

172184
return tf.keras.Model(
173185
model.input,
174-
AddLossDeepGraphInfomax(
175-
y_clean.get_shape()[-1])((y_clean, y_corrupted)))
186+
AddLossDeepGraphInfomax(y_clean.get_shape()[-1],
187+
self._global_batch_size)(
188+
(y_clean, y_corrupted)))
176189

177190
def preprocess(self, gt: tfgnn.GraphTensor) -> tfgnn.GraphTensor:
178191
"""Returns the input GraphTensor."""

tensorflow_gnn/runner/tasks/dgi_test.py

Lines changed: 96 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Tests for dgi."""
16+
from absl.testing import parameterized
1617
import tensorflow as tf
18+
import tensorflow.__internal__.distribute as tfdistribute
19+
import tensorflow.__internal__.test as tftest
1720
import tensorflow_gnn as tfgnn
1821

1922
from tensorflow_gnn.runner import orchestration
@@ -42,10 +45,58 @@
4245
""" % tfgnn.HIDDEN_STATE
4346

4447

45-
class DeepGraphInfomaxTest(tf.test.TestCase):
46-
48+
def _all_eager_distributed_strategy_combinations():
49+
strategies = [
50+
# MirroredStrategy
51+
tfdistribute.combinations.mirrored_strategy_with_gpu_and_cpu,
52+
tfdistribute.combinations.mirrored_strategy_with_one_cpu,
53+
tfdistribute.combinations.mirrored_strategy_with_one_gpu,
54+
""" # MultiWorkerMirroredStrategy
55+
tfdistribute.combinations.multi_worker_mirrored_2x1_cpu,
56+
tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
57+
# TPUStrategy
58+
tfdistribute.combinations.tpu_strategy,
59+
tfdistribute.combinations.tpu_strategy_one_core,
60+
tfdistribute.combinations.tpu_strategy_packed_var,
61+
# ParameterServerStrategy
62+
tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu,
63+
tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu,
64+
tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu,
65+
tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """
66+
]
67+
return tftest.combinations.combine(distribution=strategies)
68+
69+
70+
class DeepGraphInfomaxTest(tf.test.TestCase, parameterized.TestCase):
71+
72+
global_batch_size = 2
4773
gtspec = tfgnn.create_graph_spec_from_schema_pb(tfgnn.parse_schema(SCHEMA))
48-
task = dgi.DeepGraphInfomax("node", seed=8191)
74+
task = dgi.DeepGraphInfomax(
75+
"node", global_batch_size=global_batch_size, seed=8191)
76+
77+
def get_graph_tensor(self):
78+
gt = tfgnn.GraphTensor.from_pieces(
79+
node_sets={
80+
"node":
81+
tfgnn.NodeSet.from_fields(
82+
features={
83+
tfgnn.HIDDEN_STATE:
84+
tf.convert_to_tensor([[1., 2., 3., 4.],
85+
[11., 11., 11., 11.],
86+
[19., 19., 19., 19.]])
87+
},
88+
sizes=tf.convert_to_tensor([3])),
89+
},
90+
edge_sets={
91+
"edge":
92+
tfgnn.EdgeSet.from_fields(
93+
sizes=tf.convert_to_tensor([2]),
94+
adjacency=tfgnn.Adjacency.from_indices(
95+
("node", tf.convert_to_tensor([0, 1], dtype=tf.int32)),
96+
("node", tf.convert_to_tensor([2, 0], dtype=tf.int32)),
97+
)),
98+
})
99+
return gt
49100

50101
def build_model(self):
51102
graph = inputs = tf.keras.layers.Input(type_spec=self.gtspec)
@@ -87,12 +138,12 @@ def test_adapt(self):
87138
feature_name=tfgnn.HIDDEN_STATE)(model(gt))
88139
actual = adapted(gt)
89140

90-
self.assertAllClose(actual, expected)
141+
self.assertAllClose(actual, expected, rtol=1e-04, atol=1e-04)
91142

92143
def test_fit(self):
93-
gt = tfgnn.random_graph_tensor(self.gtspec)
94-
ds = tf.data.Dataset.from_tensors(gt).repeat(8)
95-
ds = ds.batch(2).map(tfgnn.GraphTensor.merge_batch_to_components)
144+
ds = tf.data.Dataset.from_tensors(self.get_graph_tensor()).repeat(8)
145+
ds = ds.batch(self.global_batch_size).map(
146+
tfgnn.GraphTensor.merge_batch_to_components)
96147

97148
model = self.task.adapt(self.build_model())
98149
model.compile()
@@ -105,12 +156,47 @@ def get_loss():
105156
model.fit(ds)
106157
after = get_loss()
107158

108-
self.assertAllClose(before, 250.42036, rtol=1e-04, atol=1e-04)
109-
self.assertAllClose(after, 13.18533, rtol=1e-04, atol=1e-04)
159+
self.assertAllClose(before, 92.92909, rtol=1e-04, atol=1e-04)
160+
self.assertAllClose(after, 4.05084, rtol=1e-04, atol=1e-04)
161+
162+
@tfdistribute.combinations.generate(
163+
tftest.combinations.combine(distribution=[
164+
tfdistribute.combinations.mirrored_strategy_with_one_gpu,
165+
tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
166+
]))
167+
def test_distributed(self, distribution):
168+
gt = self.get_graph_tensor()
169+
170+
def dataset_fn(input_context=None, gt=gt):
171+
ds = tf.data.Dataset.from_tensors(gt).repeat(8)
172+
if input_context:
173+
batch_size = input_context.get_per_replica_batch_size(
174+
self.global_batch_size)
175+
else:
176+
batch_size = self.global_batch_size
177+
ds = ds.batch(batch_size).map(tfgnn.GraphTensor.merge_batch_to_components)
178+
return ds
179+
180+
with distribution.scope():
181+
model = self.task.adapt(self.build_model())
182+
model.compile()
183+
184+
def get_loss():
185+
values = model.evaluate(
186+
distribution.distribute_datasets_from_function(dataset_fn), steps=4)
187+
return dict(zip(model.metrics_names, values))["loss"]
188+
189+
before = get_loss()
190+
model.fit(
191+
distribution.distribute_datasets_from_function(dataset_fn),
192+
steps_per_epoch=4)
193+
after = get_loss()
194+
self.assertAllClose(before, 92.92909, rtol=2, atol=1)
195+
self.assertAllClose(after, 4.05084, rtol=2, atol=1)
110196

111197
def test_protocol(self):
112198
self.assertIsInstance(dgi.DeepGraphInfomax, orchestration.Task)
113199

114200

115201
if __name__ == "__main__":
116-
tf.test.main()
202+
tfdistribute.multi_process_runner.test_main()

0 commit comments

Comments
 (0)