Skip to content

Commit 87e631a

Browse files
authored
clamp: improve error handling and error messages. (#9642)
This PR refactors the `clamp` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::clamp` return `StatusOr<XLATensorPtr>` - Improve error handling - Inline `GetMinMaxValues()` function - Move the check to a new `CheckClampMinOrMax()` function
1 parent 0fa6e31 commit 87e631a

File tree

4 files changed

+66
-44
lines changed

4 files changed

+66
-44
lines changed

test/test_ops_error_message.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,22 @@ def test():
235235
expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor)."""
236236
)
237237

238+
def test_clamp_scalar_raises_error_on_no_min_and_max(self):
239+
device = torch_xla.device()
240+
a = torch.rand(2, 5, device=device)
241+
242+
def test():
243+
# Dispatch to `clamp()` overload explicitly.
244+
# Otherwise, it's dispatched to `clamp.Tensor()`, which doesn't have
245+
# this check.
246+
return torch.ops.aten.clamp.default(a)
247+
248+
self.assertExpectedRaisesInline(
249+
exc_type=RuntimeError,
250+
callable=test,
251+
expect="""clamp(): expected at least one of `min` or `max` arguments to be specified."""
252+
)
253+
238254

239255
if __name__ == "__main__":
240256
unittest.main()

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,24 +1373,31 @@ at::Tensor XLANativeFunctions::clamp(const at::Tensor& self,
13731373
const std::optional<at::Scalar>& min,
13741374
const std::optional<at::Scalar>& max) {
13751375
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1376-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1377-
return bridge::AtenFromXlaTensor(tensor_methods::clamp(xla_self, min, max));
1376+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1377+
bridge::GetXlaTensor(self));
1378+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1379+
tensor_methods::clamp(xla_self, min, max));
1380+
return bridge::AtenFromXlaTensor(std::move(output));
13781381
}
13791382

13801383
at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self,
13811384
const at::Scalar& max) {
13821385
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1383-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1384-
return bridge::AtenFromXlaTensor(
1385-
tensor_methods::clamp(xla_self, std::nullopt, max));
1386+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1387+
bridge::GetXlaTensor(self));
1388+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1389+
tensor_methods::clamp(xla_self, std::nullopt, max));
1390+
return bridge::AtenFromXlaTensor(std::move(output));
13861391
}
13871392

13881393
at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
13891394
const at::Scalar& min) {
13901395
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1391-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1392-
return bridge::AtenFromXlaTensor(
1393-
tensor_methods::clamp(xla_self, min, std::nullopt));
1396+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1397+
bridge::GetXlaTensor(self));
1398+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1399+
tensor_methods::clamp(xla_self, min, std::nullopt));
1400+
return bridge::AtenFromXlaTensor(std::move(output));
13941401
}
13951402

13961403
at::Tensor XLANativeFunctions::clone(
@@ -1950,9 +1957,11 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
19501957
const at::Scalar& min_val,
19511958
const at::Scalar& max_val) {
19521959
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1953-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1954-
return bridge::AtenFromXlaTensor(
1955-
tensor_methods::clamp(xla_self, min_val, max_val));
1960+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1961+
bridge::GetXlaTensor(self));
1962+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1963+
tensor_methods::clamp(xla_self, min_val, max_val));
1964+
return bridge::AtenFromXlaTensor(std::move(output));
19561965
}
19571966

19581967
at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,6 @@ namespace torch_xla {
166166
namespace tensor_methods {
167167
namespace {
168168

169-
struct MinMaxValues {
170-
torch::lazy::Value min;
171-
torch::lazy::Value max;
172-
};
173-
174169
torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
175170
const xla::Shape& target_shape) {
176171
if (GetXlaShape(input).dimensions() == target_shape.dimensions()) {
@@ -180,22 +175,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
180175
input, torch::lazy::ToVector<int64_t>(target_shape.dimensions()));
181176
}
182177

183-
MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor,
184-
const std::optional<at::Scalar>& min,
185-
const std::optional<at::Scalar>& max) {
186-
XLA_CHECK(min || max)
187-
<< "At least one of \'min\' or \'max\' must not be None";
188-
xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(tensor->dtype());
189-
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type);
190-
auto shape = tensor->shape();
191-
return {XLAGraphExecutor::Get()->GetIrValueForScalar(
192-
min ? *min : min_max.min, shape.get().element_type(),
193-
tensor->GetDevice()),
194-
XLAGraphExecutor::Get()->GetIrValueForScalar(
195-
max ? *max : min_max.max, shape.get().element_type(),
196-
tensor->GetDevice())};
197-
}
198-
199178
void CheckRank(const XLATensorPtr& t, int64_t expected_rank,
200179
const std::string& tag, const std::string& arg_name,
201180
int arg_number) {
@@ -581,6 +560,16 @@ absl::Status CheckStackAtLeastOneTensor(
581560
return absl::OkStatus();
582561
}
583562

563+
absl::Status CheckClampMinOrMax(const std::optional<at::Scalar>& min,
564+
const std::optional<at::Scalar>& max) {
565+
if (!min.has_value() && !max.has_value()) {
566+
return XLA_ERROR_WITH_LOCATION(
567+
absl::InvalidArgumentError("clamp(): expected at least one of `min` or "
568+
"`max` arguments to be specified."));
569+
}
570+
return absl::OkStatus();
571+
}
572+
584573
} // namespace
585574

586575
//////////////////////////////////////////////////////////////////////////////
@@ -1432,12 +1421,23 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) {
14321421
input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha));
14331422
}
14341423

1435-
XLATensorPtr clamp(const XLATensorPtr& input,
1436-
const std::optional<at::Scalar>& min,
1437-
const std::optional<at::Scalar>& max) {
1438-
MinMaxValues min_max = GetMinMaxValues(input, min, max);
1439-
return input->CreateFrom(
1440-
Clamp(input->GetIrValue(), min_max.min, min_max.max));
1424+
absl::StatusOr<absl_nonnull XLATensorPtr> clamp(
1425+
const XLATensorPtr& input, const std::optional<at::Scalar>& min,
1426+
const std::optional<at::Scalar>& max) {
1427+
XLA_RETURN_IF_ERROR(CheckClampMinOrMax(min, max));
1428+
1429+
xla::Shape shape = input->shape();
1430+
const torch::lazy::BackendDevice& device = input->GetDevice();
1431+
1432+
xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(input->dtype());
1433+
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type);
1434+
1435+
torch::lazy::Value min_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
1436+
min.value_or(min_max.min), shape.element_type(), device);
1437+
torch::lazy::Value max_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
1438+
max.value_or(min_max.max), shape.element_type(), device);
1439+
1440+
return input->CreateFrom(Clamp(input->GetIrValue(), min_value, max_value));
14411441
}
14421442

14431443
XLATensorPtr clone(const XLATensorPtr& input) {

torch_xla/csrc/tensor_methods.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,9 @@ XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor);
316316
XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha);
317317
void celu_(XLATensorPtr& input, const at::Scalar& alpha);
318318

319-
XLATensorPtr clamp(const XLATensorPtr& input,
320-
const std::optional<at::Scalar>& min,
321-
const std::optional<at::Scalar>& max);
322-
XLATensorPtr clamp(const XLATensorPtr& input,
323-
const std::optional<at::Tensor>& min,
324-
const std::optional<at::Tensor>& max);
319+
absl::StatusOr<absl_nonnull XLATensorPtr> clamp(
320+
const XLATensorPtr& input, const std::optional<at::Scalar>& min,
321+
const std::optional<at::Scalar>& max);
325322

326323
XLATensorPtr clone(const XLATensorPtr& input);
327324

0 commit comments

Comments
 (0)