Skip to content

Commit ac4657a

Browse files
authored
Upcast gradually when computing variance (#4283)
Going all the way to f64 is undesirable, especially for low-precision tensors in bf16 or f8 variants. Upcast only to the next type, e.g., bf16->f32 or f8->bf16. This is consistent with what Pytorch seems to be doing internally. Signed-off-by: Alex Zinenko <git@ozinenko.com>
1 parent e03f7c6 commit ac4657a

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9465,9 +9465,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
94659465
op, "support floating-point type input only");
94669466
}
94679467

9468-
// Upcasting the input tensor to `F64` dtype for higher precision during the
9469-
// computation of the result.
9470-
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
9468+
// Upcasting the input tensor to a double-bitwidth dtype for higher precision
9469+
// during the computation of the result.
9470+
unsigned bitwidth = inputTensorTy.getDtype().getIntOrFloatBitWidth();
9471+
if (bitwidth != 64) {
9472+
Type targetTy = rewriter.getF64Type();
9473+
if (bitwidth == 8)
9474+
targetTy = rewriter.getBF16Type();
9475+
else if (bitwidth == 16)
9476+
targetTy = rewriter.getF32Type();
94719477
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
94729478
inputTensorTy = cast<BaseTensorType>(self.getType());
94739479
}

0 commit comments

Comments
 (0)