@@ -6308,24 +6308,29 @@ def ComputationShape(split_size, topology=None) -> List[int]:
63086308 A 4-element list that describes the computation shape.
63096309 """
63106310 if topology :
6311- if isinstance (topology , tf .tpu .experimental .Topology ):
6312- topology_info = topology
6313- else :
6314- topology_info = tf_topology .Topology (serialized = topology )
6315- if topology and functools .reduce (lambda a , b : a * b ,
6316- topology_info .mesh_shape ) == split_size :
6317- computation_shape = topology_info .mesh_shape
6311+ if not isinstance (topology , tf .tpu .experimental .Topology ):
6312+ topology = tf_topology .Topology (serialized = topology )
6313+ if (
6314+ topology
6315+ and functools .reduce (lambda a , b : a * b , topology .mesh_shape )
6316+ == split_size
6317+ ):
6318+ computation_shape = topology .mesh_shape
63186319 elif split_size == 1 :
63196320 computation_shape = [1 , 1 , 1 , 1 ]
6320- elif topology and topology_info .mesh_shape [
6321- - 1 ] == 1 and split_size in topology_info .mesh_shape :
6321+ elif (
6322+ topology
6323+ and topology .mesh_shape [- 1 ] == 1
6324+ and split_size in topology .mesh_shape
6325+ ):
63226326 # For Megacore, if we find exact match on mesh shape, map split_size to it
63236327 computation_shape = [1 , 1 , 1 , 1 ]
6324- computation_shape [topology_info .mesh_shape .tolist ().index (
6325- split_size )] = split_size
6328+ computation_shape [topology .mesh_shape .tolist ().index (split_size )] = (
6329+ split_size
6330+ )
63266331 else :
63276332 if topology :
6328- cores_per_chip = topology_info .mesh_shape [- 1 ]
6333+ cores_per_chip = topology .mesh_shape [- 1 ]
63296334 else :
63306335 cores_per_chip = 2
63316336 assert split_size % cores_per_chip == 0
@@ -6345,7 +6350,7 @@ def ComputationShape(split_size, topology=None) -> List[int]:
63456350 elif split_chips == 24 :
63466351 computation_shape = [1 , 2 , 12 , cores_per_chip ]
63476352 elif split_chips == 32 :
6348- if topology and topology_info .mesh_shape [1 ] == 32 :
6353+ if topology and topology .mesh_shape [1 ] == 32 :
63496354 # Fwd within-replica all-reduces is performed along column;
63506355 # Bwd gradient cross-replica all-reduces is performed along row.
63516356 # This currently has better performance than the strided patten.
0 commit comments