Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/frontends/pytorch/src/op/quantized_linear_relu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/relu.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/add.hpp"
#include "utils_quantize.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_quantized_linear_relu(const NodeContext& context) {
// Expect 4 inputs: X, packed_weight, scale, zero_point
// adapted from quantized_linear.cpp
num_inputs_check(context, 4, 4);
auto x = context.get_input(0);
auto packed_params_node = ov::as_type_ptr<ov::op::util::FrameworkNode>(context.get_input(1).get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(packed_params_node, "Packed params input node type is required to be FrameworkNode.");
const auto& attrs = packed_params_node->get_attrs();
PYTORCH_OP_CONVERSION_CHECK((attrs.find(PtFrameworkNode::op_type_key) != attrs.end()),
"Packed params input node does not contain information about op type.");
PYTORCH_OP_CONVERSION_CHECK((attrs.at(PtFrameworkNode::op_type_key) == "prim::GetAttr"),
"Incorrect packed params input node operator type, expected prim::GetAttr.");

auto packed_params = packed_params_node->inputs();
PYTORCH_OP_CONVERSION_CHECK(packed_params.size() == 2,
"Packed parameters for quantized linear should contain 2 items.");
auto weights = packed_params[0].get_source_output();
auto bias = packed_params[1].get_source_output();

auto linear = context.mark_node(std::make_shared<ov::op::v0::MatMul>(x, weights, false, true));
linear = context.mark_node(std::make_shared<ov::op::v1::Add>(linear, bias));

auto relu = context.mark_node(std::make_shared<ov::op::v0::Relu>(linear));


auto scale = context.get_input(2);
auto zero_point = context.get_input(3);
return {quantize(context, relu, scale, zero_point, x)};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ OP_CONVERTER(translate_quantized_cat);
OP_CONVERTER(translate_quantized_convnd);
OP_CONVERTER(translate_quantized_convnd_relu);
OP_CONVERTER(translate_quantized_linear);
OP_CONVERTER(translate_quantized_linear_relu);
OP_CONVERTER(translate_xor);
// Torch FX Translations
OP_CONVERTER(translate_adaptive_max_pool1d_fx);
Expand Down Expand Up @@ -811,6 +812,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"quantized::conv2d_relu", op::translate_quantized_convnd_relu},
{"quantized::hardswish", op::translate_quantized_hardswish},
{"quantized::linear", op::translate_quantized_linear},
{"quantized::linear_relu", op::translate_quantized_linear_relu},
{"quantized::mul", op::translate_quantized_mul},
{"torchvision::deform_conv2d", op::translate_deform_conv},
{"torchvision::nms", op::translate_nms},
Expand Down
64 changes: 64 additions & 0 deletions tests/layer_tests/pytorch_tests/test_quantized_linear_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import numpy as np
from pytorch_layer_test_class import PytorchLayerTest


class QuantizedLinearReLU(torch.nn.Module):
def __init__(self, weight_shape, bias, scale, zero_point):
super().__init__()
self.linear_relu = torch.ao.nn.intrinsic.quantized.LinearReLU(
weight_shape[-1], weight_shape[0], bias
)
if bias:
torch.nn.init.normal_(self.linear_relu.bias())
self.linear_relu.scale = float(scale)
self.linear_relu.zero_point = int(zero_point)

def forward(self, inp):
inp_q = torch.quantize_per_tensor(inp, self.linear_relu.scale, self.linear_relu.zero_point, torch.quint8)
return torch.dequantize(self.linear_relu(inp_q))


class TestQuantizedLinear(PytorchLayerTest):
rng = np.random.default_rng(seed=123)

def _prepare_input(self, input_shape=(2, 2)):
return (np.round(self.rng.random(input_shape, dtype=np.float32), 4),)

@pytest.mark.parametrize("params", [
{'input_shape': [3, 9], 'weight_shape': [10, 9]},

{'input_shape': [2, 3, 9], 'weight_shape': [10, 9]},
{'input_shape': [3, 9], 'weight_shape': [9], "bias": True},
{'input_shape': [2, 3, 9], 'weight_shape': [10, 9], "bias": True},
])
@pytest.mark.parametrize("scale", [1., 0.3, 1.3])
@pytest.mark.parametrize("zero_point", [0, 1])
@pytest.mark.parametrize("trace", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_quantized_linear_relu(self, params, scale, zero_point, trace, ie_device, precision, ir_version):
input_shape = params.get("input_shape")
weight_shape = params.get("weight_shape")
bias = params.get("bias", False)

model = QuantizedLinearReLU(weight_shape, bias, scale, zero_point)
ref_net = None

self._test(
model,
ref_net,
["quantized::linear_relu"],
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"input_shape": input_shape},
trace_model=trace,
freeze_model=False,
quantized_ops=True,
quant_size=scale
)