diff --git a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp index 750912148218..d55c38ed467f 100644 --- a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp @@ -352,6 +352,140 @@ static APInt getMaxSignedValue(unsigned bitWidth) { return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt(); } +// NOLINTNEXTLINE(misc-no-recursion) +static Value coerceSource(PatternRewriter &rewriter, Location &loc, + FIRRTLBaseType targetType, FIRRTLBaseType sourceType, + Value source) { + if (sourceType == targetType) + return source; + + auto srcType = sourceType.getAnonymousType(); + auto tgtType = targetType.getAnonymousType(); + if (srcType == tgtType) + return source; + + auto srcBundleType = dyn_cast(srcType); + auto tgtBundleType = dyn_cast(tgtType); + if (srcBundleType && tgtBundleType) { + auto n = tgtBundleType.getNumElements(); + SmallVector elems; + elems.reserve(n); + for (unsigned i = 0; i < n; ++i) { + auto srcElemType = srcBundleType.getElementType(i); + auto tgtElemType = tgtBundleType.getElementType(i); + auto srcElem = rewriter.create(loc, source, i); + auto elem = + coerceSource(rewriter, loc, tgtElemType, srcElemType, srcElem); + elems.push_back(elem); + } + return rewriter.create(loc, tgtBundleType, elems); + } + + auto srcVectorType = dyn_cast(srcType); + auto tgtVectorType = dyn_cast(tgtType); + if (srcVectorType && tgtVectorType) { + auto srcElemType = srcVectorType.getElementType(); + auto tgtElemType = tgtVectorType.getElementType(); + auto n = tgtVectorType.getNumElements(); + SmallVector elems; + elems.reserve(n); + for (unsigned i = 0; i < n; ++i) { + auto srcElem = rewriter.create(loc, source, i); + auto elem = + coerceSource(rewriter, loc, tgtElemType, srcElemType, srcElem); + elems.push_back(elem); + } + return rewriter.create(loc, tgtVectorType, elems); + } + + auto srcIntType = dyn_cast(srcType); + auto tgtIntType = dyn_cast(tgtType); + if (srcIntType && tgtIntType) { + auto srcWidth = srcIntType.getBitWidthOrSentinel(); + auto tgtWidth = tgtIntType.getBitWidthOrSentinel(); + if (tgtWidth < srcWidth) { + auto delta = srcWidth - tgtWidth; + Value value = rewriter.create(loc, source, delta); + if (tgtIntType.isSigned()) + value = rewriter.create(loc, value); + return value; + } + + if (tgtWidth > srcWidth) + source = rewriter.create(loc, source, tgtWidth); + if (tgtIntType.isSigned() && !srcIntType.isSigned()) + return rewriter.create(loc, source); + if (!tgtIntType.isSigned() && srcIntType.isSigned()) + return rewriter.create(loc, source); + return source; + } + + return nullptr; +} + +/// Emit a coercion from a value to a target type. Returns nullptr if the +/// coercion is not possible. The resulting value is a non-aliasing source +/// value. As such, we can only emit coercions for passive types. +static Value coerceSource(PatternRewriter &rewriter, Location loc, + Type targetType, Value source) { + Type sourceType = source.getType(); + + // If the types are syntactically equal, no action is needed. + if (sourceType == targetType) + return source; + + // If either of the types are not FIRRTL base types, we cannot coerce. + auto sourceFType = type_cast(sourceType); + auto targetFType = type_cast(targetType); + if (!sourceFType || !targetFType) + return nullptr; + + // After type_cast resolves type-aliases, the underlying types may be the + // same. If they are, no action is needed. + if (sourceFType == targetFType) + return source; + + // One last shot at avoiding coercion: recursively unfold type-aliases and + // check again for syntactic equality. If they are, no action is needed. + if (sourceFType.getAnonymousType() == targetFType.getAnonymousType()) + return source; + + // OK, some coercion is necessary. Check if it's possible. + + // Give up if either side contains const. Eventually, const will be removed + // from the compiler. + if (sourceFType.containsConst() || targetFType.containsConst()) + return nullptr; + + // We can only coerce when all the involved widths are known. We can usually + // truncate or extend the source value to match the destination, but if either + // src or dst has an uninferred width, we don't know which way to go. + if (sourceFType.hasUninferredWidth() || targetFType.hasUninferredWidth()) + return nullptr; + + // Similar story for resets... + if (sourceFType.hasUninferredReset() || targetFType.hasUninferredReset()) + return nullptr; + + // Give up if the target is not passive. If we have to coerce the source + // value, the coercion ops will produce a nonaliasing source value, which + // prevents us from properly coercing to a correct non-passive value. + if (!targetFType.isPassive() || targetFType.containsAnalog()) + return nullptr; + + // After the earlier recursive checks, we can defer to equivalence checking. + if (!areTypesEquivalent(targetFType, sourceFType)) + return nullptr; + + auto result = coerceSource(rewriter, loc, targetFType, sourceFType, source); + + // Final sanity check: ensure the result will make matchingconnect happy. + if (result) + assert(areAnonymousTypesEquivalent(targetType, result.getType())); + + return result; +} + //===----------------------------------------------------------------------===// // Fold Hooks //===----------------------------------------------------------------------===// @@ -2269,14 +2403,15 @@ canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) { if (!isDefinedByOneConstantOp(reg.getResetSignal())) return failure(); - auto resetValue = reg.getResetValue(); - if (reg.getType(0) != resetValue.getType()) + auto value = + coerceSource(rewriter, reg.getLoc(), reg.getType(0), reg.getResetValue()); + if (!value) return failure(); // Ignore 'passthrough'. (void)dropWrite(rewriter, reg->getResult(0), {}); replaceOpWithNewOpAndCopyName( - rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(), + rewriter, reg, value, reg.getNameAttr(), reg.getNameKind(), reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable()); return success(); } diff --git a/test/Dialect/FIRRTL/canonicalization.mlir b/test/Dialect/FIRRTL/canonicalization.mlir index 78070aa96e5a..cee844471d40 100644 --- a/test/Dialect/FIRRTL/canonicalization.mlir +++ b/test/Dialect/FIRRTL/canonicalization.mlir @@ -2294,17 +2294,65 @@ firrtl.module @ForceableRegResetToNode(in %clock: !firrtl.clock, in %dummy : !fi } // https://github.com/llvm/circt/issues/8348 -// CHECK-LABEL: firrtl.module @RegResetInvalidResetValueType -// We cannot replace a regreset with its reset value, when the reset value's type does not match. -firrtl.module @RegResetInvalidResetValueType(in %c : !firrtl.clock, out %out : !firrtl.uint<2>) { - %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> - %c0_ui2 = firrtl.constant 0 : !firrtl.uint<2> +// CHECK-LABEL: firrtl.module @RegResetCoerceIntResetValue +// When we replace a regreset with its reset value, we must ensure the reset-value is the correct type. +firrtl.module @RegResetCoerceIntResetValue(in %c : !firrtl.clock, + out %out_si1 : !firrtl.sint<1>, out %out_si2 : !firrtl.sint<2>, + out %out_ui1 : !firrtl.uint<1>, out %out_ui2 : !firrtl.uint<2> +) { + %c1_asyncreset = firrtl.specialconstant 1 : !firrtl.asyncreset + + %c1_si1 = firrtl.constant 1 : !firrtl.sint<1> + %c1_si2 = firrtl.constant 1 : !firrtl.sint<2> + + %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> + %c1_ui2 = firrtl.constant 1 : !firrtl.uint<2> + + // SInt Extension. + // CHECK: firrtl.matchingconnect %out_si2, %c-1_si2 : !firrtl.sint<2> + %reg_si2 = firrtl.regreset %c, %c1_asyncreset, %c1_si1 : !firrtl.clock, !firrtl.asyncreset, !firrtl.sint<1>, !firrtl.sint<2> + firrtl.matchingconnect %out_si2, %reg_si2 : !firrtl.sint<2> + + // SInt Truncation. + // CHECK: firrtl.matchingconnect %out_si1, %c-1_si1 : !firrtl.sint<1> + %reg_si1 = firrtl.regreset %c, %c1_asyncreset, %c1_si2 : !firrtl.clock, !firrtl.asyncreset, !firrtl.sint<2>, !firrtl.sint<1> + firrtl.matchingconnect %out_si1, %reg_si1 : !firrtl.sint<1> + + // UInt Extension. + // CHECK: firrtl.matchingconnect %out_ui2, %c1_ui2 : !firrtl.uint<2> + %reg_ui2 = firrtl.regreset %c, %c1_asyncreset, %c1_ui1 : !firrtl.clock, !firrtl.asyncreset, !firrtl.uint<1>, !firrtl.uint<2> + firrtl.matchingconnect %out_ui2, %reg_ui2 : !firrtl.uint<2> + + // UInt Truncation. + // CHECK: firrtl.matchingconnect %out_ui1, %c1_ui1 : !firrtl.uint<1> + %reg_ui1 = firrtl.regreset %c, %c1_asyncreset, %c1_ui2 : !firrtl.clock, !firrtl.asyncreset, !firrtl.uint<2>, !firrtl.uint<1> + firrtl.matchingconnect %out_ui1, %reg_ui1 : !firrtl.uint<1> +} + +// CHECK-LABEL: firrtl.module @RegResetCoerceBundleResetValue +firrtl.module @RegResetCoerceBundleResetValue(in %c : !firrtl.clock, out %out : !firrtl.bundle, b: sint<1>>) { + // CHECK: %0 = firrtl.aggregateconstant [-1 : si2, -1 : si1] : !firrtl.bundle, b: sint<1>> + // CHECK: firrtl.matchingconnect %out, %0 : !firrtl.bundle, b: sint<1>> + %c1_asyncreset = firrtl.specialconstant 1 : !firrtl.asyncreset + %v = firrtl.aggregateconstant [-1 : si1, 1 : si2] : !firrtl.bundle, b: sint<2>> + %r = firrtl.regreset %c, %c1_asyncreset, %v : + !firrtl.clock, !firrtl.asyncreset, + !firrtl.bundle, b: sint<2>>, + !firrtl.bundle, b: sint<1>> + firrtl.matchingconnect %out, %r : !firrtl.bundle, b: sint<1>> +} + +// CHECK-LABEL: firrtl.module @RegResetCoerceVectorResetValue +firrtl.module @RegResetCoerceVectorResetValue(in %c : !firrtl.clock, out %out : !firrtl.vector, 1>) { + // CHECK: %0 = firrtl.aggregateconstant [-1 : si2] : !firrtl.vector, 1> + // CHECK: firrtl.matchingconnect %out, %0 : !firrtl.vector, 1> %c1_asyncreset = firrtl.specialconstant 1 : !firrtl.asyncreset - // CHECK: %reg = firrtl.regreset - %reg = firrtl.regreset %c, %c1_asyncreset, %c0_ui1 : !firrtl.clock, !firrtl.asyncreset, !firrtl.uint<1>, !firrtl.uint<2> - // CHECK: firrtl.matchingconnect %out, %reg : !firrtl.uint<2> - firrtl.matchingconnect %out, %reg : !firrtl.uint<2> - firrtl.matchingconnect %reg, %c0_ui2 : !firrtl.uint<2> + %v = firrtl.aggregateconstant [-1 : si1] : !firrtl.vector, 1> + %r = firrtl.regreset %c, %c1_asyncreset, %v : + !firrtl.clock, !firrtl.asyncreset, + !firrtl.vector, 1>, + !firrtl.vector, 1> + firrtl.matchingconnect %out, %r : !firrtl.vector, 1> } // https://github.com/llvm/circt/issues/929