-
Notifications
You must be signed in to change notification settings - Fork 108
Open
Labels
Description
🐛 Bug
Applying "symbolic values" cache to treat shape of tensors symbolically explodes the subsymbols part of Thunder traces with many calls that do not contribute to the result of a larger symbols these subsymbols are part of. They should be removed.
To Reproduce
import torch, thunder
from functools import partial
a = torch.randn(1, 1024, device="cuda", dtype=torch.float32)
b = torch.randn(2, 1024, device="cuda", dtype=torch.float32)
@partial(thunder.jit, cache="symbolic values")
def f(a, b):
return a + b
f(a, b)
print(f._lc_cs.last_traces[0])Click to see printed trace
def computation(a, b):
# a: "cuda:0 f32[[IntegerProxy name=i0, value=1, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# b: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i3, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# /tmp/ipython-input-2728223152.py:9: return a + b
t25 = ltorch.add(a, b, alpha=1) # t25: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# (i0, i1) = prims.shape(a)
# (i0, i1) = prims.shape(a)
# (i2, i3) = prims.shape(b)
# (i2, i3) = prims.shape(b)
# i4 = prims.eq(i1, 1) # i4: "bool False"
# i5 = prims.eq(i1, i1) # i5: "bool True"
# i6 = prims.eq(i0, 1) # i6: "bool True"
# i7 = prims.eq(i1, 1) # i7: "bool False"
# i8 = prims.eq(i3, 1) # i8: "bool False"
# i9 = prims.eq(i1, i3) # i9: "bool True"
# i10 = prims.eq(i0, 1) # i10: "bool True"
# i11 = prims.eq(i2, 1) # i11: "bool False"
# i12 = prims.eq(i2, i2) # i12: "bool True"
# (i0, i1) = prims.shape(a)
# (i0, i1) = prims.shape(a)
# i13 = prims.eq(i0, i2) # i13: "bool False"
# (i0, i1) = prims.shape(a)
# (i0, i1) = prims.shape(a)
# (i0, i1) = prims.shape(a)
# i14 = prims.eq(i2, i0) # i14: "bool False"
# i15 = prims.eq(i0, 1) # i15: "bool True"
# i16 = prims.ne(i2, -1) # i16: "bool True"
# i17 = prims.eq(i1, i1) # i17: "bool True"
# i18 = prims.ne(i1, -1) # i18: "bool True"
# (i0, i1) = prims.shape(a)
# t22 = prims.broadcast_in_dim(a, (i2, i1), (0, 1)) # t22: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# (i2, i3) = prims.shape(b)
# (i2, i3) = prims.shape(b)
# i23 = prims.eq(i3, i1) # i23: "bool True"
# t25 = prims.add(t22, b) # t25: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
return (t25,)and DCE doesn't not work leaving the trace unmodified:
from thunder.core.transform_common import dce
print(dce(f._lc_cs.last_traces[0]))The trace should simply be:
def computation(a, b):
# a: "cuda:0 f32[[IntegerProxy name=i0, value=1, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# b: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i3, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# /tmp/ipython-input-2728223152.py:9: return a + b
t25 = ltorch.add(a, b, alpha=1) # t25: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# (i0, i1) = prims.shape(a)
# (i2, i3) = prims.shape(b)
# t22 = prims.broadcast_in_dim(a, (i2, i1), (0, 1)) # t22: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
# t25 = prims.add(t22, b) # t25: "cuda:0 f32[[IntegerProxy name=i2, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=1024, static=CONSTRAINT.CONSTRAINABLE]]"
return (t25,)