diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index 2f9209445d9..d6cdfacb612 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -54,6 +54,9 @@ def __init__(self, exported_program): exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, exir_ops.edge.aten.where.self, + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, ] def _match_op_rank(self, graph_module, node, arg, max_rank): diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index 185ce941247..249eb9ffd41 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -34,6 +34,9 @@ exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.le.Scalar: exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.ne.Scalar: exir_ops.edge.aten.ne.Tensor, + exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.bitwise_xor.Tensor, torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, @@ -46,6 +49,9 @@ torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor, torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor, torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor, + torch.ops.aten.bitwise_and.Scalar: torch.ops.aten.bitwise_and.Tensor, + torch.ops.aten.bitwise_or.Scalar: torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor, } diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index 1ae8ab1ace2..a1b5de85d08 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -125,6 +125,9 @@ class EthosU55NotSupported(OperatorSupportBase): exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.bitwise_and.Scalar, + exir_ops.edge.aten.bitwise_or.Scalar, + exir_ops.edge.aten.bitwise_xor.Scalar, exir_ops.edge.aten.bitwise_not, exir_ops.edge.aten.logical_and.default, exir_ops.edge.aten.logical_or.default, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 0a0430b7906..29ef36aa658 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -164,6 +164,9 @@ def is_node_supported( exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.bitwise_and.Scalar, + exir_ops.edge.aten.bitwise_or.Scalar, + exir_ops.edge.aten.bitwise_xor.Scalar, exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.ceil.default, diff --git a/backends/arm/test/ops/test_bitwise.py b/backends/arm/test/ops/test_bitwise.py index 032639b8607..d29ea7c91f2 100644 --- a/backends/arm/test/ops/test_bitwise.py +++ b/backends/arm/test/ops/test_bitwise.py @@ -56,6 +56,27 @@ class BitwiseBinary(torch.nn.Module): } +class BitwiseBinaryScalar(torch.nn.Module): + test_data = { + "zeros": lambda: (torch.zeros(1, 10, 10, 10, dtype=torch.int32), 0), + "ones_int8": lambda: (torch.ones(10, 10, 10, dtype=torch.int8), 1), + "pattern_int8": lambda: (0xAA * torch.ones(1, 2, 2, 2, dtype=torch.int8), 0x77), + "pattern_int16": lambda: ( + 0xAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int16), + 0x7777, + ), + "pattern_int32": lambda: ( + 0xAAAAAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int32), + 0x77777777, + ), + "rand_rank2": lambda: (torch.randint(-128, 127, (10, 10), dtype=torch.int8), 5), + "rand_rank4": lambda: ( + torch.randint(-128, 127, (1, 10, 10, 10), dtype=torch.int8), + -7, + ), + } + + class And(BitwiseBinary): aten_op = "torch.ops.aten.bitwise_and.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor" @@ -80,6 +101,36 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor): return tensor1.bitwise_or(tensor2) +class AndScalar(BitwiseBinaryScalar): + aten_op = "torch.ops.aten.bitwise_and.Scalar" + # Tensor because it gets converted from Scalar -> Tensor in lowering + exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor" + + def forward(self, tensor: torch.Tensor, scalar: int): + return tensor.bitwise_and(scalar) + + +class XorScalar(BitwiseBinaryScalar): + aten_op = "torch.ops.aten.bitwise_xor.Scalar" + # Tensor because it gets converted from Scalar -> Tensor in lowering + exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_xor_Tensor" + + def forward(self, tensor: torch.Tensor, scalar: int): + return tensor.bitwise_xor(scalar) + + +class OrScalar(BitwiseBinaryScalar): + aten_op = "torch.ops.aten.bitwise_or.Scalar" + # Tensor because it gets converted from Scalar -> Tensor in lowering + exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_or_Tensor" + + def forward(self, tensor: torch.Tensor, scalar: int): + return tensor.bitwise_or(scalar) + + +# Bitwise AND + + @common.parametrize("test_data", And().test_data) def test_bitwise_and_tensor_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( @@ -94,6 +145,20 @@ def test_bitwise_and_tensor_tosa_MI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", AndScalar.test_data) +def test_bitwise_and_scalar_tosa_MI(test_data: input_t2): + pipeline = TosaPipelineMI[input_t2]( + AndScalar(), + test_data(), + AndScalar.aten_op, + AndScalar.exir_op, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.run() + + @common.parametrize("test_data", And().test_data) def test_bitwise_and_tensor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( @@ -110,6 +175,22 @@ def test_bitwise_and_tensor_tosa_BI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", AndScalar.test_data) +def test_bitwise_and_scalar_tosa_BI(test_data: input_t2): + pipeline = TosaPipelineBI[input_t2]( + AndScalar(), + test_data(), + AndScalar.aten_op, + AndScalar.exir_op, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.pop_stage("quantize") + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + @common.parametrize("test_data", And().test_data) def test_bitwise_and_tensor_u55_BI(test_data: input_t2): # Tests that we don't delegate these ops since they are not supported on U55. @@ -123,6 +204,43 @@ def test_bitwise_and_tensor_u55_BI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", AndScalar.test_data) +def test_bitwise_and_scalar_u55_BI(test_data: input_t2): + # There will be one full op which will be delegated. + num_delegates = 1 + num_exir = 0 + pipeline = OpNotSupportedPipeline[input_t2]( + AndScalar(), + test_data(), + { + AndScalar.exir_op: 1, + "executorch_exir_dialects_edge__ops_aten_full_default": num_exir, + }, + num_delegates, + quantize=True, + u55_subset=True, + ) + pipeline.run() + + +@common.parametrize("test_data", AndScalar.test_data) +@common.XfailIfNoCorstone320 +def test_bitwise_and_scalar_u85_BI(test_data: input_t2): + pipeline = EthosU85PipelineBI[input_t2]( + AndScalar(), + test_data(), + AndScalar.aten_op, + AndScalar.exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.pop_stage("quantize") + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + @common.parametrize("test_data", And().test_data) @common.XfailIfNoCorstone320 def test_bitwise_and_tensor_u85_BI(test_data: input_t2): @@ -155,6 +273,20 @@ def test_bitwise_xor_tensor_tosa_MI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", XorScalar.test_data) +def test_bitwise_xor_scalar_tosa_MI(test_data: input_t2): + pipeline = TosaPipelineMI[input_t2]( + XorScalar(), + test_data(), + XorScalar.aten_op, + XorScalar.exir_op, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.run() + + @common.parametrize("test_data", Xor().test_data) def test_bitwise_xor_tensor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( @@ -171,6 +303,22 @@ def test_bitwise_xor_tensor_tosa_BI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", XorScalar.test_data) +def test_bitwise_xor_scalar_tosa_BI(test_data: input_t2): + pipeline = TosaPipelineBI[input_t2]( + XorScalar(), + test_data(), + XorScalar.aten_op, + XorScalar.exir_op, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.pop_stage("quantize") + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + @common.parametrize("test_data", Xor().test_data) def test_bitwise_xor_tensor_u55_BI(test_data: input_t2): # Tests that we don't delegate these ops since they are not supported on U55. @@ -184,6 +332,25 @@ def test_bitwise_xor_tensor_u55_BI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", XorScalar.test_data) +def test_bitwise_xor_scalar_u55_BI(test_data: input_t2): + # There will be one full op which will be delegated. + num_delegates = 1 + num_exir = 0 + pipeline = OpNotSupportedPipeline[input_t2]( + XorScalar(), + test_data(), + { + XorScalar.exir_op: 1, + "executorch_exir_dialects_edge__ops_aten_full_default": num_exir, + }, + num_delegates, + quantize=True, + u55_subset=True, + ) + pipeline.run() + + @common.parametrize("test_data", Xor().test_data) @common.XfailIfNoCorstone320 def test_bitwise_xor_tensor_u85_BI(test_data: input_t2): @@ -202,6 +369,24 @@ def test_bitwise_xor_tensor_u85_BI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", XorScalar.test_data) +@common.XfailIfNoCorstone320 +def test_bitwise_xor_scalar_u85_BI(test_data: input_t2): + pipeline = EthosU85PipelineBI[input_t2]( + XorScalar(), + test_data(), + XorScalar.aten_op, + XorScalar.exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.pop_stage("quantize") + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + @common.parametrize("test_data", Or().test_data) def test_bitwise_or_tensor_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( @@ -216,6 +401,20 @@ def test_bitwise_or_tensor_tosa_MI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", OrScalar.test_data) +def test_bitwise_or_scalar_tosa_MI(test_data: input_t2): + pipeline = TosaPipelineMI[input_t2]( + OrScalar(), + test_data(), + OrScalar.aten_op, + OrScalar.exir_op, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.run() + + @common.parametrize("test_data", Or().test_data) def test_bitwise_or_tensor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( @@ -232,6 +431,22 @@ def test_bitwise_or_tensor_tosa_BI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", OrScalar.test_data) +def test_bitwise_or_scalar_tosa_BI(test_data: input_t2): + pipeline = TosaPipelineBI[input_t2]( + OrScalar(), + test_data(), + OrScalar.aten_op, + OrScalar.exir_op, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.pop_stage("quantize") + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + @common.parametrize("test_data", Or().test_data) def test_bitwise_or_tensor_u55_BI(test_data: input_t2): # Tests that we don't delegate these ops since they are not supported on U55. @@ -245,6 +460,25 @@ def test_bitwise_or_tensor_u55_BI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", OrScalar.test_data) +def test_bitwise_or_scalar_u55_BI(test_data: input_t2): + # There will be one full op which will be delegated. + num_delegates = 1 + num_exir = 0 + pipeline = OpNotSupportedPipeline[input_t2]( + OrScalar(), + test_data(), + { + OrScalar.exir_op: 1, + "executorch_exir_dialects_edge__ops_aten_full_default": num_exir, + }, + num_delegates, + quantize=True, + u55_subset=True, + ) + pipeline.run() + + @common.parametrize("test_data", Or().test_data) @common.XfailIfNoCorstone320 def test_bitwise_or_tensor_u85_BI(test_data: input_t2): @@ -261,3 +495,21 @@ def test_bitwise_or_tensor_u85_BI(test_data: input_t2): pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") pipeline.run() + + +@common.parametrize("test_data", OrScalar.test_data) +@common.XfailIfNoCorstone320 +def test_bitwise_or_scalar_u85_BI(test_data: input_t2): + pipeline = EthosU85PipelineBI[input_t2]( + OrScalar(), + test_data(), + OrScalar.aten_op, + OrScalar.exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, + ) + pipeline.pop_stage("quantize") + pipeline.pop_stage("check.quant_nodes") + pipeline.run()