Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ ignore = [
"E402", # https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/
"F405", # https://docs.astral.sh/ruff/rules/undefined-local-with-import-star-usage/
"E712", # https://docs.astral.sh/ruff/rules/true-false-comparison/
"E721", # https://docs.astral.sh/ruff/rules/type-comparison/
"E722", # https://docs.astral.sh/ruff/rules/bare-except/
]

Expand Down
18 changes: 9 additions & 9 deletions thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,16 @@ def __init__(self, config: InferenceBenchmarkConfig):

# Sanity check
if not self.config.disable_moe_replacement:
assert type(model.model.layers[1].feed_forward.shared_experts.gate_proj.weight) == DTensor
assert type(model.model.layers[1].feed_forward.shared_experts.up_proj.weight) == DTensor
assert type(model.model.layers[1].feed_forward.shared_experts.down_proj.weight) == DTensor
assert type(model.model.layers[1].feed_forward.routed_experts.gate_proj.weight) == DTensor
assert type(model.model.layers[1].feed_forward.routed_experts.up_proj.weight) == DTensor
assert type(model.model.layers[1].feed_forward.routed_experts.down_proj.weight) == DTensor
assert isinstance(model.model.layers[1].feed_forward.shared_experts.gate_proj.weight, DTensor)
assert isinstance(model.model.layers[1].feed_forward.shared_experts.up_proj.weight, DTensor)
assert isinstance(model.model.layers[1].feed_forward.shared_experts.down_proj.weight, DTensor)
assert isinstance(model.model.layers[1].feed_forward.routed_experts.gate_proj.weight, DTensor)
assert isinstance(model.model.layers[1].feed_forward.routed_experts.up_proj.weight, DTensor)
assert isinstance(model.model.layers[1].feed_forward.routed_experts.down_proj.weight, DTensor)
else:
assert type(model.model.layers[1].feed_forward.shared_expert.gate_proj.weight) == DTensor
assert type(model.model.layers[1].feed_forward.shared_expert.up_proj.weight) == DTensor
assert type(model.model.layers[1].feed_forward.shared_expert.down_proj.weight) == DTensor
assert isinstance(model.model.layers[1].feed_forward.shared_expert.gate_proj.weight, DTensor)
assert isinstance(model.model.layers[1].feed_forward.shared_expert.up_proj.weight, DTensor)
assert isinstance(model.model.layers[1].feed_forward.shared_expert.down_proj.weight, DTensor)

# Materialize the model on the device (after Llama4MoE replacement and sharding)
model.to_empty(device=DEVICE)
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def is_inexact_dtype(dtype):
# TODO: we could consider a more general notion of number defined by issubclass(typ, Number)
def is_numbertype(x):
# Note: the first argument to issubclass must be a type
if not type(x) == type:
if type(x) is not type:
return False
return issubclass(x, tuple(all_numbertypes))

Expand Down
10 changes: 5 additions & 5 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ def lookup_descriptor_field(field_name):
# been manually assigned) Python appears to reinterpret it as a simple dict for the purpose of
# attribute resolution.
# we avoid interpreting into dict.get if obj_dict is a plain dict to avoid creating a wrapper for it.
if type(uobj_dict) == dict:
if type(uobj_dict) is dict:
instance_value = uobj_dict.get(name, null)
else:
instance_value = _interpret_call_with_unwrapping(dict.get, obj_dict, name, null)
Expand Down Expand Up @@ -5254,13 +5254,13 @@ def _make_function_handler(

if flag & 0x02:
kwdefaults = stack.pop()
assert type(kwdefaults) == dict
assert type(kwdefaults) is dict
else:
kwdefaults = None

if flag & 0x01:
argdefs = stack.pop()
assert type(argdefs) == tuple
assert type(argdefs) is tuple
else:
argdefs = None

Expand Down Expand Up @@ -5348,14 +5348,14 @@ def _set_function_attribute_handler(

if flag == 0x02:
kwdefaults = stack.pop()
assert type(kwdefaults) == dict
assert type(kwdefaults) is dict
fn.__kwdefaults__ = kwdefaults
stack.append(fn)
return

if flag == 0x01:
argdefs = stack.pop()
assert type(argdefs) == tuple
assert type(argdefs) is tuple
fn.__defaults__ = argdefs
stack.append(fn)
return
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def replace(self, **changes):
r"""Return a copy of the Proxy object with new values for the specified fields as given to the constructor as arguments.
Valid keyword arguments are ``name``, ``history``.
Note that the copy will use the current (environment) tracectx."""
if type(self) != Proxy:
if type(self) is not Proxy:
raise NotImplementedError(f"replace is not implemented for {type(self)}")
kwargs = dict(
name=self.name,
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ def python_callable(*args, **kwargs):
gradtrc = dce(gradtrc)
grad_output = gradtrc.output
pro_to_epi = prologue_trc.output[1]
if type(grad_output) == dict:
if type(grad_output) is dict:
grad_output = grad_output["output"]

def new_epilogue(*args):
Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/pythonex.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _check_instance_impl(x: Any, types: tuple[type], /) -> None:

def _check_number_type_and_value_impl(n: Number, v: Number) -> None:
utils.check(
type(n) == type(v) and (n == v or (n != n and v != v)),
type(n) is type(v) and (n == v or (n != n and v != v)),
lambda: f"Expected {n} to be equal to and have the type of {v}",
)

Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,8 @@ def raise_from_external():
jit(raise_from_external)()

e = excinfo.value
assert type(e) == ValueError
assert type(e.__cause__) == IndexError and msg in str(e.__cause__), excinfo.value
assert type(e) is ValueError
assert type(e.__cause__) is IndexError and msg in str(e.__cause__), excinfo.value


def test_nested_try_except(jit):
Expand Down
Loading