Skip to content

Simplify trace representation with shape symbolic values enabled #2728

@IvanYashchuk

Description

@IvanYashchuk

🐛 Bug

Applying "symbolic values" cache to treat shape of tensors symbolically explodes the subsymbols part of Thunder traces with many calls that do not contribute to the result of a larger symbols these subsymbols are part of. They should be removed.

To Reproduce

import torch, thunder

from functools import partial

a = torch.randn(1, 1024, device="cuda", dtype=torch.float32)
b = torch.randn(2, 1024, device="cuda", dtype=torch.float32)

@partial(thunder.jit, cache="symbolic values")
def f(a, b):
    return a + b

f(a, b)

print(f._lc_cs.last_traces[0])
Click to see printed trace
def computation(a, b):
# a: "cuda:0 f32[[IntegerProxy name=i0, value=1, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# b: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i3, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"

# /tmp/ipython-input-2728223152.py:9: 	    return a + b
t25 = ltorch.add(a, b, alpha=1)  # t25: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
  # (i0, i1) = prims.shape(a)
  # (i0, i1) = prims.shape(a)
  # (i2, i3) = prims.shape(b)
  # (i2, i3) = prims.shape(b)
  # i4 = prims.eq(i1, 1)  # i4: "bool False"
  # i5 = prims.eq(i1, i1)  # i5: "bool True"
  # i6 = prims.eq(i0, 1)  # i6: "bool True"
  # i7 = prims.eq(i1, 1)  # i7: "bool False"
  # i8 = prims.eq(i3, 1)  # i8: "bool False"
  # i9 = prims.eq(i1, i3)  # i9: "bool True"
  # i10 = prims.eq(i0, 1)  # i10: "bool True"
  # i11 = prims.eq(i2, 1)  # i11: "bool False"
  # i12 = prims.eq(i2, i2)  # i12: "bool True"
  # (i0, i1) = prims.shape(a)
  # (i0, i1) = prims.shape(a)
  # i13 = prims.eq(i0, i2)  # i13: "bool False"
  # (i0, i1) = prims.shape(a)
  # (i0, i1) = prims.shape(a)
  # (i0, i1) = prims.shape(a)
  # i14 = prims.eq(i2, i0)  # i14: "bool False"
  # i15 = prims.eq(i0, 1)  # i15: "bool True"
  # i16 = prims.ne(i2, -1)  # i16: "bool True"
  # i17 = prims.eq(i1, i1)  # i17: "bool True"
  # i18 = prims.ne(i1, -1)  # i18: "bool True"
  # (i0, i1) = prims.shape(a)
  # t22 = prims.broadcast_in_dim(a, (i2, i1), (0, 1))  # t22: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
  # (i2, i3) = prims.shape(b)
  # (i2, i3) = prims.shape(b)
  # i23 = prims.eq(i3, i1)  # i23: "bool True"
  # t25 = prims.add(t22, b)  # t25: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
return (t25,)

and DCE doesn't not work leaving the trace unmodified:

from thunder.core.transform_common import dce

print(dce(f._lc_cs.last_traces[0]))

The trace should simply be:

def computation(a, b):
  # a: "cuda:0 f32[[IntegerProxy name=i0, value=1, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
  # b: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i3, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"

  # /tmp/ipython-input-2728223152.py:9: 	    return a + b
  t25 = ltorch.add(a, b, alpha=1)  # t25: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
    # (i0, i1) = prims.shape(a)
    # (i2, i3) = prims.shape(b)
    # t22 = prims.broadcast_in_dim(a, (i2, i1), (0, 1))  # t22: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
    # t25 = prims.add(t22, b)  # t25: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
  return (t25,)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions