diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index 2921ec21a04d..53e5372bb444 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -5730,25 +5730,43 @@ LogicalResult HWStructCastOp::verify() { LogicalResult BitCastOp::verify() { auto inTypeBits = getBitWidth(getInput().getType(), /*ignoreFlip=*/true); auto resTypeBits = getBitWidth(getType()); - if (inTypeBits.has_value() && resTypeBits.has_value()) { - // Bitwidths must match for valid bit - if (*inTypeBits == *resTypeBits) { - // non-'const' cannot be casted to 'const' - if (containsConst(getType()) && !isConst(getOperand().getType())) - return emitError("cannot cast non-'const' input type ") - << getOperand().getType() << " to 'const' result type " - << getType(); - return success(); - } + if (!inTypeBits.has_value()) + return emitError("bitwidth cannot be determined for input operand type ") + << getInput().getType(); + if (!resTypeBits.has_value()) + return emitError("bitwidth cannot be determined for result type ") + << getType(); + // Bitwidths must match for valid bit + if (*inTypeBits != *resTypeBits) return emitError("the bitwidth of input (") << *inTypeBits << ") and result (" << *resTypeBits << ") don't match"; + + // non-'const' cannot be casted to 'const' + if (containsConst(getType()) && !isConst(getOperand().getType())) + return emitError("cannot cast non-'const' input type ") + << getOperand().getType() << " to 'const' result type " << getType(); + + if (auto bundleType = dyn_cast(getInput().getType())) { + auto elts = bundleType.getElements(); + if (any_of(elts, [&](BundleType::BundleElement elt) { + return elt.isFlip != elts.front().isFlip; + })) + return emitError("cannot cast input bundle type with elements in " + "different directions ") + << getInput().getType(); } - if (!inTypeBits.has_value()) - return emitError("bitwidth cannot be determined for input operand type ") - << getInput().getType(); - return emitError("bitwidth cannot be determined for result type ") - << getType(); + if (auto openBundleType = dyn_cast(getInput().getType())) { + auto elts = openBundleType.getElements(); + if (any_of(elts, [&](OpenBundleType::BundleElement elt) { + return elt.isFlip != elts.front().isFlip; + })) + return emitError("cannot cast input open bundle type with elements in " + "different directions ") + << getInput().getType(); + } + + return success(); } //===----------------------------------------------------------------------===// diff --git a/test/Dialect/FIRRTL/bitcast-bundle-different-diretions.mlir b/test/Dialect/FIRRTL/bitcast-bundle-different-diretions.mlir new file mode 100644 index 000000000000..d9cc36c9ef08 --- /dev/null +++ b/test/Dialect/FIRRTL/bitcast-bundle-different-diretions.mlir @@ -0,0 +1,17 @@ +// RUN: circt-opt %s --lower-firrtl-to-hw --verify-diagnostics + +firrtl.circuit "BitcastBundle" { + firrtl.module @BitcastBundle(in %input_0: !firrtl.uint<8>, out %output_0: !firrtl.uint<8>) { + %io = firrtl.wire : !firrtl.bundle, output_0: uint<8>> + %0 = firrtl.subfield %io[input_0] : !firrtl.bundle, output_0: uint<8>> + firrtl.connect %0, %input_0 : !firrtl.uint<8> + %1 = firrtl.subfield %io[output_0] : !firrtl.bundle, output_0: uint<8>> + firrtl.connect %output_0, %1 : !firrtl.uint<8> + %2 = firrtl.subfield %io[output_0] : !firrtl.bundle, output_0: uint<8>> + // expected-error @below {{cannot cast input bundle type with elements in different directions '!firrtl.bundle, output_0: uint<8>>'}} + %3 = firrtl.bitcast %io : (!firrtl.bundle, output_0: uint<8>>) -> !firrtl.uint<16> + %4 = firrtl.bits %3 7 to 0 : (!firrtl.uint<16>) -> !firrtl.uint<8> + %_GEN_0 = firrtl.node interesting_name %4 : !firrtl.uint<8> + firrtl.connect %2, %_GEN_0 : !firrtl.uint<8> + } +}