@@ -166,11 +166,6 @@ namespace torch_xla {
166166namespace tensor_methods {
167167namespace {
168168
169- struct MinMaxValues {
170- torch::lazy::Value min;
171- torch::lazy::Value max;
172- };
173-
174169torch::lazy::Value MaybeExpand (const torch::lazy::Value& input,
175170 const xla::Shape& target_shape) {
176171 if (GetXlaShape (input).dimensions () == target_shape.dimensions ()) {
@@ -180,22 +175,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
180175 input, torch::lazy::ToVector<int64_t >(target_shape.dimensions ()));
181176}
182177
183- MinMaxValues GetMinMaxValues (const XLATensorPtr& tensor,
184- const std::optional<at::Scalar>& min,
185- const std::optional<at::Scalar>& max) {
186- XLA_CHECK (min || max)
187- << " At least one of \' min\' or \' max\' must not be None" ;
188- xla::PrimitiveType raw_element_type = XlaTypeFromTorchType (tensor->dtype ());
189- XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues (raw_element_type);
190- auto shape = tensor->shape ();
191- return {XLAGraphExecutor::Get ()->GetIrValueForScalar (
192- min ? *min : min_max.min , shape.get ().element_type (),
193- tensor->GetDevice ()),
194- XLAGraphExecutor::Get ()->GetIrValueForScalar (
195- max ? *max : min_max.max , shape.get ().element_type (),
196- tensor->GetDevice ())};
197- }
198-
199178void CheckRank (const XLATensorPtr& t, int64_t expected_rank,
200179 const std::string& tag, const std::string& arg_name,
201180 int arg_number) {
@@ -581,6 +560,16 @@ absl::Status CheckStackAtLeastOneTensor(
581560 return absl::OkStatus ();
582561}
583562
563+ absl::Status CheckClampMinOrMax (const std::optional<at::Scalar>& min,
564+ const std::optional<at::Scalar>& max) {
565+ if (!min.has_value () && !max.has_value ()) {
566+ return XLA_ERROR_WITH_LOCATION (
567+ absl::InvalidArgumentError (" clamp(): expected at least one of `min` or "
568+ " `max` arguments to be specified." ));
569+ }
570+ return absl::OkStatus ();
571+ }
572+
584573} // namespace
585574
586575// ////////////////////////////////////////////////////////////////////////////
@@ -1432,12 +1421,23 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) {
14321421 input->SetInPlaceIrValue (Celu (input->GetIrValue (), alpha));
14331422}
14341423
1435- XLATensorPtr clamp (const XLATensorPtr& input,
1436- const std::optional<at::Scalar>& min,
1437- const std::optional<at::Scalar>& max) {
1438- MinMaxValues min_max = GetMinMaxValues (input, min, max);
1439- return input->CreateFrom (
1440- Clamp (input->GetIrValue (), min_max.min , min_max.max ));
1424+ absl::StatusOr<absl_nonnull XLATensorPtr> clamp (
1425+ const XLATensorPtr& input, const std::optional<at::Scalar>& min,
1426+ const std::optional<at::Scalar>& max) {
1427+ XLA_RETURN_IF_ERROR (CheckClampMinOrMax (min, max));
1428+
1429+ xla::Shape shape = input->shape ();
1430+ const torch::lazy::BackendDevice& device = input->GetDevice ();
1431+
1432+ xla::PrimitiveType raw_element_type = XlaTypeFromTorchType (input->dtype ());
1433+ XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues (raw_element_type);
1434+
1435+ torch::lazy::Value min_value = XLAGraphExecutor::Get ()->GetIrValueForScalar (
1436+ min.value_or (min_max.min ), shape.element_type (), device);
1437+ torch::lazy::Value max_value = XLAGraphExecutor::Get ()->GetIrValueForScalar (
1438+ max.value_or (min_max.max ), shape.element_type (), device);
1439+
1440+ return input->CreateFrom (Clamp (input->GetIrValue (), min_value, max_value));
14411441}
14421442
14431443XLATensorPtr clone (const XLATensorPtr& input) {
0 commit comments