13
13
# limitations under the License.
14
14
# ==============================================================================
15
15
"""Tests for dgi."""
16
+ from absl .testing import parameterized
16
17
import tensorflow as tf
18
+ import tensorflow .__internal__ .distribute as tfdistribute
19
+ import tensorflow .__internal__ .test as tftest
17
20
import tensorflow_gnn as tfgnn
18
21
19
22
from tensorflow_gnn .runner import orchestration
42
45
""" % tfgnn .HIDDEN_STATE
43
46
44
47
45
- class DeepGraphInfomaxTest (tf .test .TestCase ):
46
-
48
+ def _all_eager_distributed_strategy_combinations ():
49
+ strategies = [
50
+ # MirroredStrategy
51
+ tfdistribute .combinations .mirrored_strategy_with_gpu_and_cpu ,
52
+ tfdistribute .combinations .mirrored_strategy_with_one_cpu ,
53
+ tfdistribute .combinations .mirrored_strategy_with_one_gpu ,
54
+ """ # MultiWorkerMirroredStrategy
55
+ tfdistribute.combinations.multi_worker_mirrored_2x1_cpu,
56
+ tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
57
+ # TPUStrategy
58
+ tfdistribute.combinations.tpu_strategy,
59
+ tfdistribute.combinations.tpu_strategy_one_core,
60
+ tfdistribute.combinations.tpu_strategy_packed_var,
61
+ # ParameterServerStrategy
62
+ tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu,
63
+ tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu,
64
+ tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu,
65
+ tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """
66
+ ]
67
+ return tftest .combinations .combine (distribution = strategies )
68
+
69
+
70
+ class DeepGraphInfomaxTest (tf .test .TestCase , parameterized .TestCase ):
71
+
72
+ global_batch_size = 2
47
73
gtspec = tfgnn .create_graph_spec_from_schema_pb (tfgnn .parse_schema (SCHEMA ))
48
- task = dgi .DeepGraphInfomax ("node" , seed = 8191 )
74
+ task = dgi .DeepGraphInfomax (
75
+ "node" , global_batch_size = global_batch_size , seed = 8191 )
76
+
77
+ def get_graph_tensor (self ):
78
+ gt = tfgnn .GraphTensor .from_pieces (
79
+ node_sets = {
80
+ "node" :
81
+ tfgnn .NodeSet .from_fields (
82
+ features = {
83
+ tfgnn .HIDDEN_STATE :
84
+ tf .convert_to_tensor ([[1. , 2. , 3. , 4. ],
85
+ [11. , 11. , 11. , 11. ],
86
+ [19. , 19. , 19. , 19. ]])
87
+ },
88
+ sizes = tf .convert_to_tensor ([3 ])),
89
+ },
90
+ edge_sets = {
91
+ "edge" :
92
+ tfgnn .EdgeSet .from_fields (
93
+ sizes = tf .convert_to_tensor ([2 ]),
94
+ adjacency = tfgnn .Adjacency .from_indices (
95
+ ("node" , tf .convert_to_tensor ([0 , 1 ], dtype = tf .int32 )),
96
+ ("node" , tf .convert_to_tensor ([2 , 0 ], dtype = tf .int32 )),
97
+ )),
98
+ })
99
+ return gt
49
100
50
101
def build_model (self ):
51
102
graph = inputs = tf .keras .layers .Input (type_spec = self .gtspec )
@@ -87,12 +138,12 @@ def test_adapt(self):
87
138
feature_name = tfgnn .HIDDEN_STATE )(model (gt ))
88
139
actual = adapted (gt )
89
140
90
- self .assertAllClose (actual , expected )
141
+ self .assertAllClose (actual , expected , rtol = 1e-04 , atol = 1e-04 )
91
142
92
143
def test_fit (self ):
93
- gt = tfgnn . random_graph_tensor (self .gtspec )
94
- ds = tf . data . Dataset . from_tensors ( gt ). repeat ( 8 )
95
- ds = ds . batch ( 2 ). map ( tfgnn .GraphTensor .merge_batch_to_components )
144
+ ds = tf . data . Dataset . from_tensors (self .get_graph_tensor ()). repeat ( 8 )
145
+ ds = ds . batch ( self . global_batch_size ). map (
146
+ tfgnn .GraphTensor .merge_batch_to_components )
96
147
97
148
model = self .task .adapt (self .build_model ())
98
149
model .compile ()
@@ -105,12 +156,47 @@ def get_loss():
105
156
model .fit (ds )
106
157
after = get_loss ()
107
158
108
- self .assertAllClose (before , 250.42036 , rtol = 1e-04 , atol = 1e-04 )
109
- self .assertAllClose (after , 13.18533 , rtol = 1e-04 , atol = 1e-04 )
159
+ self .assertAllClose (before , 92.92909 , rtol = 1e-04 , atol = 1e-04 )
160
+ self .assertAllClose (after , 4.05084 , rtol = 1e-04 , atol = 1e-04 )
161
+
162
+ @tfdistribute .combinations .generate (
163
+ tftest .combinations .combine (distribution = [
164
+ tfdistribute .combinations .mirrored_strategy_with_one_gpu ,
165
+ tfdistribute .combinations .multi_worker_mirrored_2x1_gpu ,
166
+ ]))
167
+ def test_distributed (self , distribution ):
168
+ gt = self .get_graph_tensor ()
169
+
170
+ def dataset_fn (input_context = None , gt = gt ):
171
+ ds = tf .data .Dataset .from_tensors (gt ).repeat (8 )
172
+ if input_context :
173
+ batch_size = input_context .get_per_replica_batch_size (
174
+ self .global_batch_size )
175
+ else :
176
+ batch_size = self .global_batch_size
177
+ ds = ds .batch (batch_size ).map (tfgnn .GraphTensor .merge_batch_to_components )
178
+ return ds
179
+
180
+ with distribution .scope ():
181
+ model = self .task .adapt (self .build_model ())
182
+ model .compile ()
183
+
184
+ def get_loss ():
185
+ values = model .evaluate (
186
+ distribution .distribute_datasets_from_function (dataset_fn ), steps = 4 )
187
+ return dict (zip (model .metrics_names , values ))["loss" ]
188
+
189
+ before = get_loss ()
190
+ model .fit (
191
+ distribution .distribute_datasets_from_function (dataset_fn ),
192
+ steps_per_epoch = 4 )
193
+ after = get_loss ()
194
+ self .assertAllClose (before , 92.92909 , rtol = 2 , atol = 1 )
195
+ self .assertAllClose (after , 4.05084 , rtol = 2 , atol = 1 )
110
196
111
197
def test_protocol (self ):
112
198
self .assertIsInstance (dgi .DeepGraphInfomax , orchestration .Task )
113
199
114
200
115
201
if __name__ == "__main__" :
116
- tf . test . main ()
202
+ tfdistribute . multi_process_runner . test_main ()
0 commit comments