Skip to content

Commit 381f8eb

Browse files
lingvo-botcopybara-github
authored andcommitted
Fix pylint warnings.
PiperOrigin-RevId: 591069522
1 parent e1c1c67 commit 381f8eb

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

lingvo/core/py_utils.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6308,24 +6308,29 @@ def ComputationShape(split_size, topology=None) -> List[int]:
63086308
A 4-element list that describes the computation shape.
63096309
"""
63106310
if topology:
6311-
if isinstance(topology, tf.tpu.experimental.Topology):
6312-
topology_info = topology
6313-
else:
6314-
topology_info = tf_topology.Topology(serialized=topology)
6315-
if topology and functools.reduce(lambda a, b: a * b,
6316-
topology_info.mesh_shape) == split_size:
6317-
computation_shape = topology_info.mesh_shape
6311+
if not isinstance(topology, tf.tpu.experimental.Topology):
6312+
topology = tf_topology.Topology(serialized=topology)
6313+
if (
6314+
topology
6315+
and functools.reduce(lambda a, b: a * b, topology.mesh_shape)
6316+
== split_size
6317+
):
6318+
computation_shape = topology.mesh_shape
63186319
elif split_size == 1:
63196320
computation_shape = [1, 1, 1, 1]
6320-
elif topology and topology_info.mesh_shape[
6321-
-1] == 1 and split_size in topology_info.mesh_shape:
6321+
elif (
6322+
topology
6323+
and topology.mesh_shape[-1] == 1
6324+
and split_size in topology.mesh_shape
6325+
):
63226326
# For Megacore, if we find exact match on mesh shape, map split_size to it
63236327
computation_shape = [1, 1, 1, 1]
6324-
computation_shape[topology_info.mesh_shape.tolist().index(
6325-
split_size)] = split_size
6328+
computation_shape[topology.mesh_shape.tolist().index(split_size)] = (
6329+
split_size
6330+
)
63266331
else:
63276332
if topology:
6328-
cores_per_chip = topology_info.mesh_shape[-1]
6333+
cores_per_chip = topology.mesh_shape[-1]
63296334
else:
63306335
cores_per_chip = 2
63316336
assert split_size % cores_per_chip == 0
@@ -6345,7 +6350,7 @@ def ComputationShape(split_size, topology=None) -> List[int]:
63456350
elif split_chips == 24:
63466351
computation_shape = [1, 2, 12, cores_per_chip]
63476352
elif split_chips == 32:
6348-
if topology and topology_info.mesh_shape[1] == 32:
6353+
if topology and topology.mesh_shape[1] == 32:
63496354
# Fwd within-replica all-reduces is performed along column;
63506355
# Bwd gradient cross-replica all-reduces is performed along row.
63516356
# This currently has better performance than the strided patten.

0 commit comments

Comments
 (0)