diff --git a/pyproject.toml b/pyproject.toml index aa22f1b1b9..6514bbd093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,7 +159,6 @@ ignore = [ "E402", # https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/ "F405", # https://docs.astral.sh/ruff/rules/undefined-local-with-import-star-usage/ "E712", # https://docs.astral.sh/ruff/rules/true-false-comparison/ - "E721", # https://docs.astral.sh/ruff/rules/type-comparison/ "E722", # https://docs.astral.sh/ruff/rules/bare-except/ ] diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 4639c5df4e..efc0adc661 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -251,16 +251,16 @@ def __init__(self, config: InferenceBenchmarkConfig): # Sanity check if not self.config.disable_moe_replacement: - assert type(model.model.layers[1].feed_forward.shared_experts.gate_proj.weight) == DTensor - assert type(model.model.layers[1].feed_forward.shared_experts.up_proj.weight) == DTensor - assert type(model.model.layers[1].feed_forward.shared_experts.down_proj.weight) == DTensor - assert type(model.model.layers[1].feed_forward.routed_experts.gate_proj.weight) == DTensor - assert type(model.model.layers[1].feed_forward.routed_experts.up_proj.weight) == DTensor - assert type(model.model.layers[1].feed_forward.routed_experts.down_proj.weight) == DTensor + assert isinstance(model.model.layers[1].feed_forward.shared_experts.gate_proj.weight, DTensor) + assert isinstance(model.model.layers[1].feed_forward.shared_experts.up_proj.weight, DTensor) + assert isinstance(model.model.layers[1].feed_forward.shared_experts.down_proj.weight, DTensor) + assert isinstance(model.model.layers[1].feed_forward.routed_experts.gate_proj.weight, DTensor) + assert isinstance(model.model.layers[1].feed_forward.routed_experts.up_proj.weight, DTensor) + assert isinstance(model.model.layers[1].feed_forward.routed_experts.down_proj.weight, DTensor) else: - assert type(model.model.layers[1].feed_forward.shared_expert.gate_proj.weight) == DTensor - assert type(model.model.layers[1].feed_forward.shared_expert.up_proj.weight) == DTensor - assert type(model.model.layers[1].feed_forward.shared_expert.down_proj.weight) == DTensor + assert isinstance(model.model.layers[1].feed_forward.shared_expert.gate_proj.weight, DTensor) + assert isinstance(model.model.layers[1].feed_forward.shared_expert.up_proj.weight, DTensor) + assert isinstance(model.model.layers[1].feed_forward.shared_expert.down_proj.weight, DTensor) # Materialize the model on the device (after Llama4MoE replacement and sharding) model.to_empty(device=DEVICE) diff --git a/thunder/core/dtypes.py b/thunder/core/dtypes.py index a8eecbb401..a9d9de3c4a 100644 --- a/thunder/core/dtypes.py +++ b/thunder/core/dtypes.py @@ -462,7 +462,7 @@ def is_inexact_dtype(dtype): # TODO: we could consider a more general notion of number defined by issubclass(typ, Number) def is_numbertype(x): # Note: the first argument to issubclass must be a type - if not type(x) == type: + if type(x) is not type: return False return issubclass(x, tuple(all_numbertypes)) diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index a2e9e5a59a..938c4aef9e 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -1653,7 +1653,7 @@ def lookup_descriptor_field(field_name): # been manually assigned) Python appears to reinterpret it as a simple dict for the purpose of # attribute resolution. # we avoid interpreting into dict.get if obj_dict is a plain dict to avoid creating a wrapper for it. - if type(uobj_dict) == dict: + if type(uobj_dict) is dict: instance_value = uobj_dict.get(name, null) else: instance_value = _interpret_call_with_unwrapping(dict.get, obj_dict, name, null) @@ -5254,13 +5254,13 @@ def _make_function_handler( if flag & 0x02: kwdefaults = stack.pop() - assert type(kwdefaults) == dict + assert type(kwdefaults) is dict else: kwdefaults = None if flag & 0x01: argdefs = stack.pop() - assert type(argdefs) == tuple + assert type(argdefs) is tuple else: argdefs = None @@ -5348,14 +5348,14 @@ def _set_function_attribute_handler( if flag == 0x02: kwdefaults = stack.pop() - assert type(kwdefaults) == dict + assert type(kwdefaults) is dict fn.__kwdefaults__ = kwdefaults stack.append(fn) return if flag == 0x01: argdefs = stack.pop() - assert type(argdefs) == tuple + assert type(argdefs) is tuple fn.__defaults__ = argdefs stack.append(fn) return diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index b19e0ea1fc..6dcfc5cd29 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -127,7 +127,7 @@ def replace(self, **changes): r"""Return a copy of the Proxy object with new values for the specified fields as given to the constructor as arguments. Valid keyword arguments are ``name``, ``history``. Note that the copy will use the current (environment) tracectx.""" - if type(self) != Proxy: + if type(self) is not Proxy: raise NotImplementedError(f"replace is not implemented for {type(self)}") kwargs = dict( name=self.name, diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index cdfd2e9ab3..66e8dda3ae 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1555,7 +1555,7 @@ def python_callable(*args, **kwargs): gradtrc = dce(gradtrc) grad_output = gradtrc.output pro_to_epi = prologue_trc.output[1] - if type(grad_output) == dict: + if type(grad_output) is dict: grad_output = grad_output["output"] def new_epilogue(*args): diff --git a/thunder/executors/pythonex.py b/thunder/executors/pythonex.py index 0083853f59..4467aa5474 100644 --- a/thunder/executors/pythonex.py +++ b/thunder/executors/pythonex.py @@ -97,7 +97,7 @@ def _check_instance_impl(x: Any, types: tuple[type], /) -> None: def _check_number_type_and_value_impl(n: Number, v: Number) -> None: utils.check( - type(n) == type(v) and (n == v or (n != n and v != v)), + type(n) is type(v) and (n == v or (n != n and v != v)), lambda: f"Expected {n} to be equal to and have the type of {v}", ) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index d218a8c654..0fe1cd7bed 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -722,8 +722,8 @@ def raise_from_external(): jit(raise_from_external)() e = excinfo.value - assert type(e) == ValueError - assert type(e.__cause__) == IndexError and msg in str(e.__cause__), excinfo.value + assert type(e) is ValueError + assert type(e.__cause__) is IndexError and msg in str(e.__cause__), excinfo.value def test_nested_try_except(jit):