Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions src/frontends/pytorch/src/op/rand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,30 @@
#include "openvino/frontend/common/random_normal_helper.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/cum_sum.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/log.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/power.hpp"
#include "openvino/op/random_uniform.hpp"
#include "openvino/op/reduce_logical_or.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/subtract.hpp"
#include "pt_framework_node.hpp"
#include "transformations/rt_info/disable_fp16_compression.hpp"
#include "utils.hpp"
Expand All @@ -27,6 +42,8 @@ namespace op {
using namespace ov::op;

namespace {
constexpr int64_t standard_gamma_trials = 16;
constexpr float min_uniform_value = 1e-7f;
OutputVector make_random_normal(const NodeContext& context,
const Output<Node>& sizes,
element::Type target_type,
Expand Down Expand Up @@ -267,6 +284,120 @@ OutputVector translate_randint(const NodeContext& context) {
return {res};
};

OutputVector translate_standard_gamma(const NodeContext& context) {
// aten::_standard_gamma(Tensor self, *, Generator? generator=None) -> Tensor
num_inputs_check(context, 1, 2);
if (context.get_input_size() == 2) {
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(1),
"aten::_standard_gamma conversion with generator is not supported");
}

auto input = context.get_input(0);
auto output_type = input.get_element_type();
auto concentration = input;
if (output_type != element::f32) {
concentration = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
}

auto shape_i32 = context.mark_node(std::make_shared<v3::ShapeOf>(concentration, element::i32));
auto shape = context.mark_node(std::make_shared<v0::Convert>(shape_i32, element::i64));
auto trials =
context.mark_node(v0::Constant::create(element::i64, Shape{1}, {standard_gamma_trials}));
auto expanded_shape =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{trials, shape}, 0));
auto axis_zero_i64 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
auto axis_zero_i32 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));

auto zero = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0.f}));
auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1.f}));
auto half = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0.5f}));
auto third = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1.f / 3.f}));
auto nine = context.mark_node(v0::Constant::create(element::f32, Shape{}, {9.f}));
auto min_uniform =
context.mark_node(v0::Constant::create(element::f32, Shape{}, {min_uniform_value}));

auto lt_one_mask = context.mark_node(std::make_shared<v1::Less>(concentration, one));
auto conc_plus_one = context.mark_node(std::make_shared<v1::Add>(concentration, one));
auto conc_ge_one = context.mark_node(std::make_shared<v1::Select>(lt_one_mask, conc_plus_one, concentration));

auto d = context.mark_node(std::make_shared<v1::Subtract>(conc_ge_one, third));
auto nine_d = context.mark_node(std::make_shared<v1::Multiply>(d, nine));
auto sqrt_term = context.mark_node(std::make_shared<v0::Sqrt>(nine_d));
auto c = context.mark_node(std::make_shared<v1::Divide>(one, sqrt_term));

auto scale = one;
auto mean = zero;
auto normals = make_random_normal(context, expanded_shape, element::f32, scale, mean)[0];
auto uniform_accept =
context.mark_node(std::make_shared<v8::RandomUniform>(expanded_shape, min_uniform, one, element::f32));

auto zero_bc = context.mark_node(std::make_shared<v3::Broadcast>(zero, expanded_shape));
auto one_bc = context.mark_node(std::make_shared<v3::Broadcast>(one, expanded_shape));
auto min_uniform_bc =
context.mark_node(std::make_shared<v3::Broadcast>(min_uniform, expanded_shape));
auto d_bc = context.mark_node(std::make_shared<v3::Broadcast>(d, expanded_shape));
auto c_bc = context.mark_node(std::make_shared<v3::Broadcast>(c, expanded_shape));

auto cx = context.mark_node(std::make_shared<v1::Multiply>(c_bc, normals));
auto one_plus_cx = context.mark_node(std::make_shared<v1::Add>(one_bc, cx));
auto v = context.mark_node(std::make_shared<v1::Multiply>(one_plus_cx, one_plus_cx));
v = context.mark_node(std::make_shared<v1::Multiply>(v, one_plus_cx));
auto safe_v = context.mark_node(std::make_shared<v1::Maximum>(v, min_uniform_bc));

auto log_v = context.mark_node(std::make_shared<v0::Log>(safe_v));
auto log_u = context.mark_node(std::make_shared<v0::Log>(uniform_accept));
auto x_sq = context.mark_node(std::make_shared<v1::Multiply>(normals, normals));
auto x_sq_half = context.mark_node(std::make_shared<v1::Multiply>(x_sq, half));

auto d_times_v = context.mark_node(std::make_shared<v1::Multiply>(d_bc, v));
auto d_minus_dv = context.mark_node(std::make_shared<v1::Subtract>(d_bc, d_times_v));
auto d_log_v = context.mark_node(std::make_shared<v1::Multiply>(d_bc, log_v));
auto rhs = context.mark_node(std::make_shared<v1::Add>(x_sq_half, d_minus_dv));
rhs = context.mark_node(std::make_shared<v1::Add>(rhs, d_log_v));

auto positive_mask = context.mark_node(std::make_shared<v1::Greater>(v, zero_bc));
auto compare_mask = context.mark_node(std::make_shared<v1::Less>(log_u, rhs));
auto accept_mask = context.mark_node(std::make_shared<v1::LogicalAnd>(positive_mask, compare_mask));

auto candidate = context.mark_node(std::make_shared<v1::Multiply>(d_bc, v));
auto accept_i32 = context.mark_node(std::make_shared<v0::Convert>(accept_mask, element::i32));
auto prefix = context.mark_node(std::make_shared<v0::CumSum>(accept_i32, axis_zero_i32, false, false));
auto one_i32 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto one_i32_bc = context.mark_node(std::make_shared<v3::Broadcast>(one_i32, expanded_shape));
auto first_accept = context.mark_node(
std::make_shared<v1::LogicalAnd>(accept_mask,
context.mark_node(std::make_shared<v1::Equal>(prefix, one_i32_bc))));

auto first_accept_f = context.mark_node(std::make_shared<v0::Convert>(first_accept, element::f32));
auto selected = context.mark_node(std::make_shared<v1::Multiply>(candidate, first_accept_f));
auto gamma_candidates =
context.mark_node(std::make_shared<v1::ReduceSum>(selected, axis_zero_i64, false));
auto any_accept =
context.mark_node(std::make_shared<v1::ReduceLogicalOr>(accept_mask, axis_zero_i64, false));

auto last_index =
context.mark_node(v0::Constant::create(element::i64, Shape{}, {standard_gamma_trials - 1}));
auto last_candidate =
context.mark_node(std::make_shared<v8::Gather>(candidate, last_index, axis_zero_i64));
auto gamma_base =
context.mark_node(std::make_shared<v1::Select>(any_accept, gamma_candidates, last_candidate));

auto frac_uniform =
context.mark_node(std::make_shared<v8::RandomUniform>(shape, min_uniform, one, element::f32));
auto safe_alpha = context.mark_node(std::make_shared<v1::Maximum>(concentration, min_uniform));
auto alpha_for_inv = context.mark_node(std::make_shared<v1::Select>(lt_one_mask, safe_alpha, one));
auto inv_alpha = context.mark_node(std::make_shared<v1::Divide>(one, alpha_for_inv));
auto pow_term = context.mark_node(std::make_shared<v1::Power>(frac_uniform, inv_alpha));
auto adjusted = context.mark_node(std::make_shared<v1::Multiply>(gamma_base, pow_term));
auto gamma_fp32 = context.mark_node(std::make_shared<v1::Select>(lt_one_mask, adjusted, gamma_base));

Output<Node> result = gamma_fp32;
if (output_type != element::f32) {
result = context.mark_node(std::make_shared<v1::ConvertLike>(result, input));
}
return {result};
};

OutputVector translate_normal_(const NodeContext& context) {
// aten::normal_(Tensor(a!) self, float mean=0., float std=1., *, Generator? generator=None) -> Tensor(a!)
num_inputs_check(context, 3, 4);
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ OP_CONVERTER(translate_split_with_sizes);
OP_CONVERTER(translate_square);
OP_CONVERTER(translate_squeeze);
OP_CONVERTER(translate_std);
OP_CONVERTER(translate_standard_gamma);
OP_CONVERTER(translate_std_mean);
OP_CONVERTER(translate_stft);
OP_CONVERTER(translate_sub);
Expand Down Expand Up @@ -385,6 +386,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::_pad_packed_sequence", op::translate_pad_packed_sequence},
{"aten::_set_item", op::translate_set_item},
{"aten::_shape_as_tensor", op::translate_shape_as_tensor},
{"aten::_standard_gamma", op::translate_standard_gamma},
{"aten::_unique2", op::translate_unique2},
{"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa},
{"aten::_upsample_bilinear2d_aa", op::translate_upsample_bilinear2d_aa},
Expand Down
107 changes: 107 additions & 0 deletions tests/layer_tests/pytorch_tests/test_standard_gamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch
import openvino as ov


class TestStandardGammaStatistics:
class AtenStandardGamma(torch.nn.Module):
def forward(self, alpha):
return torch._standard_gamma(alpha)

def _run_gamma_stat_test(
self,
alpha_value,
shape,
mean_rtol,
mean_atol,
var_rtol,
var_atol,
ie_device,
precision,
):
torch.manual_seed(0)
np.random.seed(0)

model = self.AtenStandardGamma()
alpha_np = np.full(shape, alpha_value, dtype=np.float32)
alpha_tensor = torch.from_numpy(alpha_np)

ov_model = ov.convert_model(input_model=model, example_input=(alpha_tensor,))
config = (
{"INFERENCE_PRECISION_HINT": "f32"}
if ie_device == "GPU" and precision == "FP32"
else {}
)
compiled_model = ov.Core().compile_model(ov_model, ie_device, config)

with torch.no_grad():
fw_samples = model(alpha_tensor).detach().cpu().numpy().reshape(-1)

infer_request = compiled_model.create_infer_request()
infer_request.infer({compiled_model.input(0): alpha_np})
ov_samples = infer_request.get_output_tensor(0).data.reshape(-1)

assert np.isfinite(fw_samples).all(), "PyTorch gamma samples contain non-finite values"
assert np.isfinite(ov_samples).all(), "OpenVINO gamma samples contain non-finite values"

expected_mean = alpha_value
expected_var = alpha_value

np.testing.assert_allclose(
fw_samples.mean(),
expected_mean,
rtol=mean_rtol,
atol=mean_atol,
)
np.testing.assert_allclose(
fw_samples.var(),
expected_var,
rtol=var_rtol,
atol=var_atol,
)
np.testing.assert_allclose(
ov_samples.mean(),
expected_mean,
rtol=mean_rtol,
atol=mean_atol,
)
np.testing.assert_allclose(
ov_samples.var(),
expected_var,
rtol=var_rtol,
atol=var_atol,
)

@pytest.mark.precommit
@pytest.mark.parametrize(
"alpha_value,shape,mean_rtol,mean_atol,var_rtol,var_atol",
[
(0.25, (10_000,), 2e-2, 2e-2, 2e-1, 2e-2),
(1.0, (10_000,), 2e-2, 2e-2, 2e-1, 2e-2),
],
)
def test_standard_gamma_statistics_precommit(
self, alpha_value, shape, mean_rtol, mean_atol, var_rtol, var_atol, ie_device, precision
):
self._run_gamma_stat_test(
alpha_value, shape, mean_rtol, mean_atol, var_rtol, var_atol, ie_device, precision
)

@pytest.mark.nightly
@pytest.mark.parametrize(
"alpha_value,shape,mean_rtol,mean_atol,var_rtol,var_atol",
[
(0.25, (200_000,), 5e-3, 5e-3, 1e-1, 2e-2),
(7.5, (50_000,), 1e-2, 1e-2, 1e-1, 2e-2),
],
)
def test_standard_gamma_statistics_nightly(
self, alpha_value, shape, mean_rtol, mean_atol, var_rtol, var_atol, ie_device, precision
):
self._run_gamma_stat_test(
alpha_value, shape, mean_rtol, mean_atol, var_rtol, var_atol, ie_device, precision
)