diff --git a/include/dynamatic/Dialect/Handshake/HandshakeOps.h b/include/dynamatic/Dialect/Handshake/HandshakeOps.h index 849b63a89b..a862f80c38 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeOps.h +++ b/include/dynamatic/Dialect/Handshake/HandshakeOps.h @@ -625,4 +625,17 @@ class HasParentInterface { #define GET_OP_CLASSES #include "dynamatic/Dialect/Handshake/Handshake.h.inc" +template <> +struct llvm::PointerLikeTypeTraits { + static void *getAsVoidPointer(dynamatic::handshake::ConstantOp val) { + return const_cast(val.getAsOpaquePointer()); + } + + static dynamatic::handshake::ConstantOp getFromVoidPointer(void *p) { + return dynamatic::handshake::ConstantOp::getFromOpaquePointer(p); + } + static constexpr int NumLowBitsAvailable = + PointerLikeTypeTraits::NumLowBitsAvailable; +}; + #endif // DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_OPS_H diff --git a/include/dynamatic/Transforms/HandshakeMinimizeCstWidth.h b/include/dynamatic/Transforms/HandshakeMinimizeCstWidth.h deleted file mode 100644 index 8f7611e1b4..0000000000 --- a/include/dynamatic/Transforms/HandshakeMinimizeCstWidth.h +++ /dev/null @@ -1,29 +0,0 @@ -//===- HandshakeMinimizeCstWidth.h - Min. constants bitwidth ----*- C++ -*-===// -// -// Dynamatic is under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the --handshake-minimize-cst-width pass. -// -//===----------------------------------------------------------------------===// - -#ifndef DYNAMATIC_TRANSFORMS_HANDSHAKEMINIMIZECSTWIDTH_H -#define DYNAMATIC_TRANSFORMS_HANDSHAKEMINIMIZECSTWIDTH_H - -#include "dynamatic/Dialect/Handshake/HandshakeOps.h" -#include "dynamatic/Support/DynamaticPass.h" -#include "dynamatic/Support/LLVM.h" - -namespace dynamatic { -/// Attempts to find an equivalent constant to the one provided in the circuit, -/// that is a constant with the same control and same value attribute. Returns -/// such an equivalent constant if it finds one, nullptr otherwise. -handshake::ConstantOp findEquivalentCst(handshake::ConstantOp cstOp); -/// Computes the minimum required bitwidth needed to store the provided integer. -unsigned computeRequiredBitwidth(APInt val); -} // namespace dynamatic - -#endif // DYNAMATIC_TRANSFORMS_HANDSHAKEMINIMIZECSTWIDTH_H diff --git a/include/dynamatic/Transforms/Passes.td b/include/dynamatic/Transforms/Passes.td index c70840c65a..8bae4ad5a5 100644 --- a/include/dynamatic/Transforms/Passes.td +++ b/include/dynamatic/Transforms/Passes.td @@ -188,21 +188,6 @@ def HandshakeOptimizeBitwidths : DynamaticPass< "handshake-optimize-bitwidths", ]; } -def HandshakeMinimizeCstWidth : DynamaticPass<"handshake-minimize-cst-width", - ["mlir::arith::ArithDialect"]> { - let summary = "Minimizes the bitwidth of all Handshake constants."; - let description = [{ - Rewrites constant operations with the minimum required bitwidth to support - the constants' values. The pass inserts extension operations as needed to - ensure consistency with users of constant operations. The pass also pays - attention to not create duplicated constants indirectly due to the - minimization process. - }]; - let options = - [Option<"optNegatives", "opt-negatives", "bool", "false", - "If true, allows bitwidth optimization of negative values.">]; -} - def HandshakeReplaceMemoryInterfaces : DynamaticPass< "handshake-replace-memory-interfaces" > { diff --git a/integration-test/if_convert/buffer.json b/integration-test/if_convert/buffer.json index 32eba852a8..2071359a7c 100644 --- a/integration-test/if_convert/buffer.json +++ b/integration-test/if_convert/buffer.json @@ -28,7 +28,7 @@ "comment": "To achieve better II" }, { - "pred": "constant11", + "pred": "constant1", "outid": 0, "slots": 4, "type": "fifo_break_none", diff --git a/integration-test/nested_loop/buffer.json b/integration-test/nested_loop/buffer.json index ee608b8038..303f0ae1f1 100644 --- a/integration-test/nested_loop/buffer.json +++ b/integration-test/nested_loop/buffer.json @@ -35,7 +35,7 @@ "comment": "Buffer non-spec token to prevent II=2 locking" }, { - "pred": "constant18", + "pred": "constant1", "outid": 0, "slots": 4, "type": "fifo_break_none", diff --git a/integration-test/single_loop/buffer.json b/integration-test/single_loop/buffer.json index 1e183b8cd1..9f1cafa1a4 100644 --- a/integration-test/single_loop/buffer.json +++ b/integration-test/single_loop/buffer.json @@ -14,7 +14,7 @@ "comment": "To absorb latency for spec_commit4 (data)" }, { - "pred": "constant9", + "pred": "constant2", "outid": 0, "slots": 5, "type": "fifo_break_none", diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index 929d84786d..73aa3d4740 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -9,7 +9,6 @@ add_dynamatic_library(DynamaticTransforms HandshakeCanonicalize.cpp HandshakeHoistExtInstances.cpp HandshakeMaterialize.cpp - HandshakeMinimizeCstWidth.cpp HandshakeOptimizeBitwidths.cpp HandshakeInferBasicBlocks.cpp HandshakeReplaceMemoryInterfaces.cpp diff --git a/lib/Transforms/HandshakeMinimizeCstWidth.cpp b/lib/Transforms/HandshakeMinimizeCstWidth.cpp deleted file mode 100644 index 91eabe4182..0000000000 --- a/lib/Transforms/HandshakeMinimizeCstWidth.cpp +++ /dev/null @@ -1,208 +0,0 @@ -//===- HandshakeMinimizeCstWidth.cpp - Min. constants bitwidth --*- C++ -*-===// -// -// Dynamatic is under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Implements the handshake-minimize-cst-width pass, which minimizes the -// bitwidth of all constants. The pass matches on all Handshake constants in the -// IR, determines the minimum bitwidth necessary to hold their value, and -// updates their result/attribute type match to this minimal value. -// -//===----------------------------------------------------------------------===// - -#include "dynamatic/Transforms/HandshakeMinimizeCstWidth.h" -#include "dynamatic/Dialect/Handshake/HandshakeOps.h" -#include "dynamatic/Dialect/Handshake/HandshakeTypes.h" -#include "dynamatic/Support/CFG.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "BITWIDTH" - -// [START Boilerplate code for the MLIR pass] -#include "dynamatic/Transforms/Passes.h" // IWYU pragma: keep -namespace dynamatic { -#define GEN_PASS_DEF_HANDSHAKEMINIMIZECSTWIDTH -#include "dynamatic/Transforms/Passes.h.inc" -} // namespace dynamatic -// [END Boilerplate code for the MLIR pass] - -STATISTIC(savedBits, "Number of saved bits"); - -using namespace mlir; -using namespace dynamatic; - -/// Determines whether the control value of two constants can be considered -/// equivalent. -static bool areCstCtrlEquivalent(Value ctrl, Value otherCtrl) { - if (ctrl == otherCtrl) - return true; - - // Both controls are equivalent if they originate from sources in the same - // block - Operation *defOp = ctrl.getDefiningOp(); - if (!defOp || !isa(defOp)) - return false; - Operation *otherDefOp = otherCtrl.getDefiningOp(); - if (!otherDefOp || !isa(otherDefOp)) - return false; - std::optional block = getLogicBB(defOp); - std::optional otherBlock = getLogicBB(otherDefOp); - return block.has_value() && otherBlock.has_value() && - block.value() == otherBlock.value(); -} - -handshake::ConstantOp static findEquivalentCst(TypedAttr valueAttr, - Value ctrl) { - auto funcOp = cast(ctrl.getParentBlock()->getParentOp()); - for (auto cstOp : funcOp.getOps()) { - // The constant operation needs to have the same value attribute and the - // same control - auto cstAttr = cstOp.getValue(); - if (cstAttr == valueAttr && areCstCtrlEquivalent(ctrl, cstOp.getCtrl())) - return cstOp; - } - return nullptr; -} - -handshake::ConstantOp -dynamatic::findEquivalentCst(handshake::ConstantOp cstOp) { - auto cstAttr = cstOp.getValue(); - auto funcOp = cstOp->getParentOfType(); - assert(funcOp && "constant should have parent function"); - - for (auto otherCstOp : funcOp.getOps()) { - // Don't match ourself - if (cstOp == otherCstOp) - continue; - - // The constant operation needs to have the same value attribute and the - // same control - auto otherCstAttr = otherCstOp.getValue(); - if (otherCstAttr == cstAttr && - areCstCtrlEquivalent(cstOp.getCtrl(), otherCstOp.getCtrl())) - return otherCstOp; - } - - return nullptr; -} - -/// Inserts an extension op after the constant op that extends the constant's -/// integer result to a provided destination type. The function assumes that it -/// makes sense to extend the former type into the latter type. -static handshake::ExtSIOp insertExtOp(handshake::ConstantOp toExtend, - handshake::ConstantOp toReplace, - PatternRewriter &rewriter) { - rewriter.setInsertionPointAfter(toExtend); - auto extOp = rewriter.create( - toExtend.getLoc(), toReplace.getResult().getType(), toExtend.getResult()); - inheritBB(toExtend, extOp); - return extOp; -} - -unsigned dynamatic::computeRequiredBitwidth(APInt val) { - bool isNegative = false; - if (val.isNegative()) { - isNegative = true; - int64_t negVal = val.getSExtValue(); - if (negVal - 1 == 0) - // The value is the minimum number representable on 64 bits - return APInt::APINT_BITS_PER_WORD; - - // Flip the sign to make it positive - val = APInt(APInt::APINT_BITS_PER_WORD, -negVal); - } - - unsigned log = val.logBase2(); - return val.isPowerOf2() && isNegative ? log + 1 : log + 2; -} - -namespace { - -/// Minimizes the bitwidth used by Handshake constants as much as possible. If -/// the bitwidth is reduced, inserts an extension operation after the constant -/// so that users of the constant result can keep using the same value type. -struct MinimizeConstantBitwidth - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - MinimizeConstantBitwidth(bool optNegatives, MLIRContext *ctx) - : OpRewritePattern(ctx), - optNegatives(optNegatives) {}; - - LogicalResult matchAndRewrite(handshake::ConstantOp cstOp, - PatternRewriter &rewriter) const override { - // Only consider integer attributes - auto intAttr = dyn_cast(cstOp.getValue()); - if (!intAttr) - return failure(); - handshake::ChannelType channelType = cstOp.getResult().getType(); - IntegerType dataType = cast(channelType.getDataType()); - - // We only support reducing signless values that fit on 64 bits or less - APInt val = intAttr.getValue(); - if (dataType.getSignedness() != - IntegerType::SignednessSemantics::Signless || - !val.isSingleWord()) - return failure(); - - // Do not optimize negative values - if (val.isNegative() && !optNegatives) - return failure(); - - // Check if we can reduce the bitwidth - unsigned newWidth = computeRequiredBitwidth(val); - if (newWidth >= dataType.getWidth()) - return failure(); - - // Create the new constant value - IntegerAttr newAttr = IntegerAttr::get( - IntegerType::get(getContext(), newWidth, dataType.getSignedness()), - val.trunc(newWidth)); - - if (auto otherCstOp = findEquivalentCst(newAttr, cstOp.getCtrl())) { - // Use the other constant's result and simply erase the matched constant - rewriter.replaceOp(cstOp, insertExtOp(otherCstOp, cstOp, rewriter)); - return success(); - } - - // Create a new constant to replace the matched one with - auto newCstOp = rewriter.create( - cstOp->getLoc(), newAttr, cstOp.getCtrl()); - rewriter.replaceOp(cstOp, insertExtOp(newCstOp, cstOp, rewriter)); - return success(); - } - -private: - /// Whether to allow optimization of negative values. - bool optNegatives; -}; - -/// Driver for the constant bitwidth reduction pass. A greedy pattern rewriter -/// matches on all Handshake constants and minimizes their bitwidth. -struct HandshakeMinimizeCstWidthPass - : public dynamatic::impl::HandshakeMinimizeCstWidthBase< - HandshakeMinimizeCstWidthPass> { - - using HandshakeMinimizeCstWidthBase::HandshakeMinimizeCstWidthBase; - - void runDynamaticPass() override { - auto *ctx = &getContext(); - mlir::ModuleOp mod = getOperation(); - - mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; - RewritePatternSet patterns{ctx}; - patterns.add(optNegatives, ctx); - if (failed(applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) - return signalPassFailure(); - - LLVM_DEBUG(llvm::dbgs() << "Number of saved bits is " << savedBits << "\n"); - }; -}; - -} // namespace diff --git a/lib/Transforms/HandshakeOptimizeBitwidths.cpp b/lib/Transforms/HandshakeOptimizeBitwidths.cpp index 93a114d6b1..0a5f02c297 100644 --- a/lib/Transforms/HandshakeOptimizeBitwidths.cpp +++ b/lib/Transforms/HandshakeOptimizeBitwidths.cpp @@ -31,13 +31,12 @@ #include "dynamatic/Dialect/Handshake/HandshakeOps.h" #include "dynamatic/Dialect/Handshake/HandshakeTypes.h" #include "dynamatic/Support/CFG.h" -#include "dynamatic/Transforms/HandshakeMinimizeCstWidth.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/ADT/TypeSwitch.h" #include #include @@ -52,6 +51,16 @@ namespace dynamatic { } // namespace dynamatic // [END Boilerplate code for the MLIR pass] +static unsigned computeRequiredBitwidth(const APInt &value) { + unsigned bitWidth; + if (value.isNegative()) { + bitWidth = value.getSignificantBits(); + } else { + bitWidth = value.getActiveBits(); + } + return std::max(bitWidth, 1u); +} + namespace { /// Extension type. When backtracking through extension operations, serves to /// remember the type of any extension we may have encountered along the @@ -67,8 +76,104 @@ enum class ExtType { NONE, ZEXT, SEXT }; /// A channel-typed value. using ChannelVal = TypedValue; -/// Shortcut for a value accompanied by its corresponding extension type. -using ExtValue = std::pair; +/// Class representing a minimal value. +/// For a given value, a minimal value is the same value with possibly fewer +/// bits. +/// If it has fewer bits, an extension type specifies how to extend the minimal +/// value to the original value. +class MinimalValue { +public: + MinimalValue() = default; + + /// Constructs a minimal value from a value that is the bitwidth reduced form + /// and the used extension type. + MinimalValue(ChannelVal value, ExtType extType) + : extType(extType), minimalType(value.getType()), repr(value) {} + + /// Constructs a minimal value from a constant integer op. + /// The number of bits and extension required is based on the integer value. + explicit MinimalValue(handshake::ConstantOp op) + : extType(ExtType::NONE), minimalType(op.getType()), repr(op) { + APInt value = cast(op.getValue()).getValue(); + unsigned bitWidth = computeRequiredBitwidth(value); + if (bitWidth >= value.getBitWidth()) + return; + + if (value.isNegative()) { + extType = ExtType::SEXT; + } else { + extType = ExtType::ZEXT; + } + minimalType = + minimalType.withDataType(IntegerType::get(op.getContext(), bitWidth)); + } + + /// Materializes and returns the minimal value as an MLIR value. + ChannelVal materializeValue(RewriterBase &rewriter) const { + return TypeSwitch(repr) + .Case([](ChannelVal value) { return value; }) + .Case([&](handshake::ConstantOp op) { + auto constantOp = rewriter.create( + op.getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(getDataBitWidth()), + *getConstantOrNone()), + op.getCtrl()); + + inheritBB(op, constantOp); + return constantOp; + }); + } + + /// Returns the data bitwidth of the minimal value. + unsigned getDataBitWidth() const { return minimalType.getDataBitWidth(); } + + /// Returns the datatype the minimal value. + Type getDataType() const { return minimalType.getDataType(); } + + /// Returns the extension required to go from the minimal value to the + /// original value. + ExtType getExtType() const { return extType; } + + /// Returns the corresponding channel value for the minimal value or null + /// if this minimal value was created from a constant and therefore no + /// materialized channel value exists. + ChannelVal getChannelValOrNull() const { return dyn_cast(repr); } + + /// Returns the constant op if this minimal value was initialized from a + /// constant op or null otherwise. + handshake::ConstantOp getConstantOpOrNull() const { + return dyn_cast(repr); + } + + /// Returns the constant value if this minimal value was initialized from a + /// constant op or null otherwise. + std::optional getConstantOrNone() const { + handshake::ConstantOp op = getConstantOpOrNull(); + if (!op) + return std::nullopt; + return cast(op.getValue()).getValue().trunc(getDataBitWidth()); + } + + /// Returns the op that defines this minimal value. + /// For constant ops, this is the constant op itself. + Operation *getDefiningOp() const { + return TypeSwitch(repr) + .Case([](ChannelVal value) { return value.getDefiningOp(); }) + .Case([&](handshake::ConstantOp op) { return op; }); + } + + friend bool operator==(const MinimalValue &lhs, const MinimalValue &rhs) { + return std::tie(lhs.extType, lhs.minimalType, lhs.repr) == + std::tie(rhs.extType, rhs.minimalType, rhs.repr); + } + +private: + ExtType extType; + handshake::ChannelType minimalType; + using Union = PointerUnion; + Union repr; +}; /// Holds a set of operations that were already visisted during backtracking. using VisitedOps = SmallPtrSet; @@ -97,7 +202,7 @@ static ChannelVal asTypedIfLegal(Value val) { /// original value can be safely discarded. If an extension type is provided and /// the function is able to backtrack through any extension operation, updates /// the extension type with respect to the latter. -static ExtValue getMinimalValueWithExtType(ChannelVal val) { +static MinimalValue getMinimalValue(ChannelVal val) { // Ignore values whose type isn't optimizable if (!asTypedIfLegal(val)) return {val, ExtType::NONE}; @@ -108,6 +213,9 @@ static ExtValue getMinimalValueWithExtType(ChannelVal val) { if (auto op = val.getDefiningOp()) return {op.getIn(), ExtType::ZEXT}; + if (auto op = val.getDefiningOp()) + return MinimalValue(op); + return {val, ExtType::NONE}; } @@ -138,8 +246,8 @@ static ChannelVal backtrack(ChannelVal val) { return val; } -static ExtValue backtrackToMinimalValue(ChannelVal val) { - return getMinimalValueWithExtType(backtrack(val)); +static MinimalValue backtrackToMinimalValue(ChannelVal val) { + return getMinimalValue(backtrack(val)); } /// Returns the maximum number of bits that are used by any of the value's @@ -167,18 +275,17 @@ static unsigned getUsefulResultWidth(ChannelVal val) { /// which type of extension operation is inserted. If the extension type is /// unknown, the value's signedness determines whether the extension should be /// logical or arithmetic. -static ChannelVal modBitWidth(ExtValue extVal, unsigned targetWidth, +static ChannelVal modBitWidth(const MinimalValue &extVal, unsigned targetWidth, PatternRewriter &rewriter) { - auto &[val, ext] = extVal; - // Return the original value when it already has the target width - unsigned width = val.getType().getDataBitWidth(); + unsigned width = extVal.getDataBitWidth(); if (width == targetWidth) - return val; + return extVal.materializeValue(rewriter); // Otherwise, insert a bitwidth modification operation to create a value of // the target width Operation *newOp = nullptr; + ChannelVal val = extVal.materializeValue(rewriter); Location loc = val.getLoc(); Type newDataType = rewriter.getIntegerType(targetWidth); Type dstChannelType = val.getType().withDataType(newDataType); @@ -188,8 +295,8 @@ static ChannelVal modBitWidth(ExtValue extVal, unsigned targetWidth, // signedness information in the integer type. // All code should be migrated to pass LOGICAL or ARITHMETIC if it performs // bitwidth changes. - if (ext == ExtType::ZEXT || - (ext == ExtType::NONE && + if (extVal.getExtType() == ExtType::ZEXT || + (extVal.getExtType() == ExtType::NONE && val.getType().getDataType().isUnsignedInteger())) { newOp = rewriter.create(loc, dstChannelType, val); } else { @@ -286,9 +393,9 @@ static bool isOperandInCycle(Value val, Value res, /// bitwidth. Extension and truncation operations are inserted as necessary to /// satisfy the IR and bitwidth constraints. template -static void modArithOp(Op op, ExtValue lhs, ExtValue rhs, unsigned optWidth, - ExtType extRes, PatternRewriter &rewriter, - NameAnalysis &namer) { +static void modArithOp(Op op, MinimalValue lhs, MinimalValue rhs, + unsigned optWidth, ExtType extRes, + PatternRewriter &rewriter, NameAnalysis &namer) { ChannelVal channelVal = asTypedIfLegal(op->getResult(0)); assert(channelVal && "result must have valid type"); unsigned resWidth = channelVal.getType().getDataBitWidth(); @@ -445,14 +552,12 @@ class OptDataConfig { /// list of minimal data operands of the original operation. The vector given /// as last argument is filled with the new operands. virtual void getNewOperands(unsigned optWidth, - ArrayRef minDataOperands, + ArrayRef minDataOperands, PatternRewriter &rewriter, SmallVector &newOperands) { - llvm::transform(minDataOperands, std::back_inserter(newOperands), - [&](auto &&pair) { - auto &&[val, ext] = pair; - return modBitWidth({val, ext}, optWidth, rewriter); - }); + llvm::transform( + minDataOperands, std::back_inserter(newOperands), + [&](auto &&pair) { return modBitWidth(pair, optWidth, rewriter); }); } /// Determines the list of result types that will be given to the builder of @@ -529,7 +634,7 @@ class MuxDataConfig : public OptDataConfig { SmallVector getDataOperands() override { return op.getDataOperands(); } - void getNewOperands(unsigned optWidth, ArrayRef minDataOperands, + void getNewOperands(unsigned optWidth, ArrayRef minDataOperands, PatternRewriter &rewriter, SmallVector &newOperands) override { newOperands.push_back(op.getSelectOperand()); @@ -548,7 +653,7 @@ class CBranchDataConfig : public OptDataConfig { return SmallVector{op.getDataOperand()}; } - void getNewOperands(unsigned optWidth, ArrayRef minDataOperands, + void getNewOperands(unsigned optWidth, ArrayRef minDataOperands, PatternRewriter &rewriter, SmallVector &newOperands) override { newOperands.push_back(op.getConditionOperand()); @@ -588,10 +693,12 @@ class BufferDataConfig : public OptDataConfig { /// type when reducing the bitwidth to 'optWidth'. /// This operation may increase 'optWidth' if it is impossible to preserve /// semantics under the given bitwidth. -static ExtType computeDataForwardResult(ArrayRef operands, +static ExtType computeDataForwardResult(ArrayRef operands, unsigned &optWidth) { assert(!operands.empty() && "expected non empty operands"); - auto exts = llvm::make_second_range(operands); + auto exts = llvm::map_range(operands, [](const MinimalValue &extValue) { + return extValue.getExtType(); + }); // If all operands have the same extension, then we can simply use the same // extension for the result. if (llvm::all_equal(exts)) @@ -600,14 +707,14 @@ static ExtType computeDataForwardResult(ArrayRef operands, // In all other cases, we must sign-extend the output such that if an operand // was originally sign-extended, it remains fully sign-extended after // forwarding. - for (auto [operand, extType] : operands) { + for (const MinimalValue &extValue : operands) { // Special case: If the operand with the largest bitwidth uses // zero-extension, we must increase the bitwidth by one. // Otherwise, when the zero-extended operand is forwarded it could get // sign-extended accidentally. Increasing the bitwidth by one ensures that // the top bit remains unset for the zero-extended operand. - if (extType == ExtType::ZEXT) - if (operand.getType().getDataBitWidth() == optWidth) { + if (extValue.getExtType() == ExtType::ZEXT) + if (extValue.getDataBitWidth() == optWidth) { optWidth++; break; } @@ -618,11 +725,14 @@ static ExtType computeDataForwardResult(ArrayRef operands, /// For simple data forwarding operations that forward one of 'operands' to its /// result, computes the resulting bitwidth and extension type. -static ExtWidth computeDataForwardResult(ArrayRef operands) { +static ExtWidth computeDataForwardResult(ArrayRef operands) { assert(!operands.empty() && "expected non empty operands"); unsigned optWidth = 0; - for (ChannelVal oprd : llvm::make_first_range(operands)) - optWidth = std::max(optWidth, oprd.getType().getDataBitWidth()); + for (unsigned dataBitWidth : + llvm::map_range(operands, [](const MinimalValue &extValue) { + return extValue.getDataBitWidth(); + })) + optWidth = std::max(optWidth, dataBitWidth); ExtType type = computeDataForwardResult(operands, optWidth); return {type, optWidth}; @@ -671,11 +781,10 @@ struct HandshakeOptData : public OpRewritePattern { return failure(); // Get the operation's data operands actual widths - SmallVector minDataOperands; - llvm::transform(dataOperands, std::back_inserter(minDataOperands), - [&](Value val) { - return getMinimalValueWithExtType(cast(val)); - }); + SmallVector minDataOperands; + llvm::transform( + dataOperands, std::back_inserter(minDataOperands), + [&](Value val) { return getMinimalValue(cast(val)); }); // Check whether we can reduce the bitwidth of the operation ExtWidth resultWidth = {ExtType::ZEXT, 0}; @@ -847,9 +956,8 @@ struct MemInterfaceAddrOpt // Optimizes the bitwidth of the address channel currently being pointed to // by inputIdx, and increment inputIdx before returning the optimized value auto getOptAddrInput = [&](unsigned inputIdx) { - return modBitWidth( - getMinimalValueWithExtType(cast(operands[inputIdx])), - optWidth, rewriter); + return modBitWidth(getMinimalValue(cast(operands[inputIdx])), + optWidth, rewriter); }; // Replace new operands and result types with the narrrower address type by @@ -948,9 +1056,8 @@ struct MemPortAddrOpt return failure(); // Derive new operands and result types with the narrrower address type - Value newAddr = - modBitWidth(getMinimalValueWithExtType(portOp.getAddressInput()), - optWidth, rewriter); + Value newAddr = modBitWidth(getMinimalValue(portOp.getAddressInput()), + optWidth, rewriter); Value dataIn = portOp.getDataInput(); SmallVector newOperands{newAddr, dataIn}; SmallVector newResultTypes{newAddr.getType(), dataIn.getType()}; @@ -1038,10 +1145,9 @@ struct ForwardCycleOpt : public OpRewritePattern { })); // Get the minimal value of all data operands - SmallVector minDataOperands; + SmallVector minDataOperands; for (Value oprd : dataOperands) - minDataOperands.push_back( - getMinimalValueWithExtType(cast(oprd))); + minDataOperands.push_back(getMinimalValue(cast(oprd))); // Check whether we managed to optimize anything unsigned dataWidth = channelVal.getType().getDataBitWidth(); @@ -1122,13 +1228,12 @@ struct ArithSingleType : public OpRewritePattern { return failure(); // Check whether we can reduce the bitwidth of the operation - ExtValue minLhs = getMinimalValueWithExtType(op.getLhs()); - ExtValue minRhs = getMinimalValueWithExtType(op.getRhs()); + MinimalValue minLhs = getMinimalValue(op.getLhs()); + MinimalValue minRhs = getMinimalValue(op.getRhs()); ExtWidth optWidth; if (forward) - optWidth = - fTransfer({minLhs.second, minLhs.first.getType().getDataBitWidth()}, - {minRhs.second, minRhs.first.getType().getDataBitWidth()}); + optWidth = fTransfer({minLhs.getExtType(), minLhs.getDataBitWidth()}, + {minRhs.getExtType(), minRhs.getDataBitWidth()}); else { // It does not matter whether we use sign- or zero-extension in this case // since the bits added by the extension are unused by definition. @@ -1176,14 +1281,12 @@ struct ArithSelect : public OpRewritePattern { return failure(); // Check whether we can reduce the bitwidth of the operation - ExtValue lhsExtValue = getMinimalValueWithExtType(selectOp.getTrueValue()); - ExtValue rhsExtValue = getMinimalValueWithExtType(selectOp.getFalseValue()); - auto [minLhs, extLhs] = lhsExtValue; - auto [minRhs, extRhs] = rhsExtValue; + MinimalValue lhsExtValue = getMinimalValue(selectOp.getTrueValue()); + MinimalValue rhsExtValue = getMinimalValue(selectOp.getFalseValue()); unsigned optWidth; if (forward) - optWidth = std::max(minLhs.getType().getDataBitWidth(), - minRhs.getType().getDataBitWidth()); + optWidth = std::max(lhsExtValue.getDataBitWidth(), + rhsExtValue.getDataBitWidth()); else optWidth = getUsefulResultWidth(selectOp.getResult()); unsigned resWidth = channelVal.getType().getDataBitWidth(); @@ -1192,8 +1295,10 @@ struct ArithSelect : public OpRewritePattern { // Different operand extension types mean that we don't know how to extend // the operation's result, so it cannot be optimized - if ((extLhs == ExtType::ZEXT && extRhs == ExtType::SEXT) || - (extLhs == ExtType::SEXT && extRhs == ExtType::ZEXT)) + if ((lhsExtValue.getExtType() == ExtType::ZEXT && + rhsExtValue.getExtType() == ExtType::SEXT) || + (lhsExtValue.getExtType() == ExtType::SEXT && + rhsExtValue.getExtType() == ExtType::ZEXT)) return failure(); // Create a new operation as well as appropriate bitwidth modification @@ -1203,7 +1308,8 @@ struct ArithSelect : public OpRewritePattern { rewriter.setInsertionPoint(selectOp); auto newOp = rewriter.create( selectOp.getLoc(), selectOp.getCondition(), newLhs, newRhs); - Value newRes = modBitWidth({newOp.getResult(), extLhs}, resWidth, rewriter); + Value newRes = modBitWidth({newOp.getResult(), lhsExtValue.getExtType()}, + resWidth, rewriter); inheritBB(selectOp, newOp); namer.replaceOp(selectOp, newOp); @@ -1232,13 +1338,13 @@ struct ArithShrUIFW : OpRewritePattern { LogicalResult matchAndRewrite(handshake::ShRUIOp op, PatternRewriter &rewriter) const override { - auto [lhs, lhsExt] = getMinimalValueWithExtType(op.getLhs()); - unsigned inputBitwidth = lhs.getType().getDataBitWidth(); + MinimalValue lhs = getMinimalValue(op.getLhs()); + unsigned inputBitwidth = lhs.getDataBitWidth(); unsigned currentBitwidth = op.getType().getDataBitWidth(); if (inputBitwidth >= currentBitwidth) return failure(); - assert(lhsExt != ExtType::NONE && "expected an extension"); + assert(lhs.getExtType() != ExtType::NONE && "expected an extension"); APInt numOfShiftPositions; Value constantControl; { @@ -1260,7 +1366,7 @@ struct ArithShrUIFW : OpRewritePattern { // with 'c' many 0s leading 0s and copies of the sign bit otherwise. // * Otherwise: We can perform the shift at the (lower) input bitwidth // enabling other ops to be optimized in the forward pass. - if (lhsExt == ExtType::ZEXT) { + if (lhs.getExtType() == ExtType::ZEXT) { if (numOfShiftPositions.uge(inputBitwidth)) { // The entire input is shifted away and only 0 bits from the extension // remain. @@ -1273,7 +1379,7 @@ struct ArithShrUIFW : OpRewritePattern { // Otherwise, we can perform the entire shift at the input bitwidth and // zero-extend back to the original bitwidth. - modArithOp(op, {lhs, lhsExt}, {op.getRhs(), ExtType::NONE}, inputBitwidth, + modArithOp(op, lhs, {op.getRhs(), ExtType::NONE}, inputBitwidth, ExtType::ZEXT, rewriter, namer); ++bitwidthReduced; return success(); @@ -1306,7 +1412,8 @@ struct ArithShrUIFW : OpRewritePattern { // Perform the shift on the bitwidth of lhs. Value newRhs = modBitWidth({op.getRhs(), ExtType::NONE}, inputBitwidth, rewriter); - result = rewriter.create(op.getLoc(), lhs, newRhs); + result = rewriter.create( + op.getLoc(), lhs.materializeValue(rewriter), newRhs); // Now truncate the result to make the sign-bit after shifting the // top-bit again. @@ -1324,12 +1431,11 @@ struct ArithShrUIFW : OpRewritePattern { // - c]. Value inputBWM1 = rewriter.create( op.getLoc(), - rewriter.getIntegerAttr(lhs.getType().getDataType(), - inputBitwidth - 1), + rewriter.getIntegerAttr(lhs.getDataType(), inputBitwidth - 1), constantControl); // Shift away all values of lhs other than the sign-bit. - ChannelVal signBit = - rewriter.create(op.getLoc(), lhs, inputBWM1); + ChannelVal signBit = rewriter.create( + op.getLoc(), lhs.materializeValue(rewriter), inputBWM1); // Truncate down to just the sign-bit. // Note that this even works when 'c' is greater than the difference // between the input and current bit width. @@ -1368,13 +1474,13 @@ struct ArithShrSIFW : OpRewritePattern { LogicalResult matchAndRewrite(handshake::ShRSIOp op, PatternRewriter &rewriter) const override { - auto [lhs, lhsExt] = getMinimalValueWithExtType(op.getLhs()); - unsigned inputBitwidth = lhs.getType().getDataBitWidth(); + MinimalValue lhs = getMinimalValue(op.getLhs()); + unsigned inputBitwidth = lhs.getDataBitWidth(); unsigned currentBitwidth = op.getType().getDataBitWidth(); if (inputBitwidth >= currentBitwidth) return failure(); - assert(lhsExt != ExtType::NONE && "expected an extension"); + assert(lhs.getExtType() != ExtType::NONE && "expected an extension"); APInt numOfShiftPositions; Value constantControl; { @@ -1390,7 +1496,7 @@ struct ArithShrSIFW : OpRewritePattern { if (numOfShiftPositions.uge(currentBitwidth)) return failure(); - if (lhsExt == ExtType::ZEXT) { + if (lhs.getExtType() == ExtType::ZEXT) { // We use a generic canonicalization pattern that should fold this into // an unsigned shift-right instead. return failure(); @@ -1405,7 +1511,7 @@ struct ArithShrSIFW : OpRewritePattern { // c is less than the input bitwidth, meaning other bits from the input // besides the sign-bit are preserved in the output. - modArithOp(op, {lhs, lhsExt}, {op.getRhs(), ExtType::NONE}, inputBitwidth, + modArithOp(op, lhs, {op.getRhs(), ExtType::NONE}, inputBitwidth, ExtType::SEXT, rewriter, namer); ++bitwidthReduced; return success(); @@ -1416,11 +1522,11 @@ struct ArithShrSIFW : OpRewritePattern { // is the sign-bit. Value inputBWM1 = rewriter.create( op.getLoc(), - rewriter.getIntegerAttr(lhs.getType().getDataType(), inputBitwidth - 1), + rewriter.getIntegerAttr(lhs.getDataType(), inputBitwidth - 1), constantControl); // Shift away all values of lhs other than the sign-bit. - ChannelVal signBit = - rewriter.create(op.getLoc(), lhs, inputBWM1); + ChannelVal signBit = rewriter.create( + op.getLoc(), lhs.materializeValue(rewriter), inputBWM1); // Fill remaining sign-bit copies. rewriter.replaceOpWithNewOp(op, op.getType(), signBit); ++bitwidthReduced; @@ -1449,39 +1555,36 @@ struct ArithShiftBW : public OpRewritePattern { LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - ChannelVal toShift = op.getLhs(); - ChannelVal shiftBy = op.getRhs(); - auto [minToShift, extToShift] = getMinimalValueWithExtType(toShift); - auto [minShiftBy, minShiftByExt] = backtrackToMinimalValue(shiftBy); + MinimalValue toShift = getMinimalValue(op.getLhs()); + MinimalValue shiftBy = backtrackToMinimalValue(op.getRhs()); bool isRightShift = isa((Operation *)op); // Check whether we can reduce the bitwidth of the operation unsigned resWidth = op.getResult().getType().getDataBitWidth(); unsigned optWidth = resWidth; - unsigned cstVal = 0; - if (Operation *defOp = minShiftBy.getDefiningOp()) - if (auto cstOp = dyn_cast(defOp)) { - cstVal = (unsigned)cast(cstOp.getValue()).getInt(); - optWidth = getUsefulResultWidth(op.getResult()); - if (isRightShift) - optWidth += cstVal; - } + APInt cstVal; + if (std::optional maybeConstant = shiftBy.getConstantOrNone()) { + cstVal = std::move(*maybeConstant); + optWidth = getUsefulResultWidth(op.getResult()); + if (isRightShift) + optWidth += cstVal.getZExtValue(); + } if (optWidth >= resWidth) return failure(); - ChannelVal modToShift = minToShift; + ChannelVal modToShift = toShift.materializeValue(rewriter); if (!isRightShift) { // In the case of a left shift, we first truncate the shifted integer to // discard high-significance bits that were discarded in the result, // then extend back to satisfy the users of the original integer - unsigned requiredToShiftWidth = optWidth - std::min(cstVal, optWidth); - modToShift = - modBitWidth({minToShift, extToShift}, requiredToShiftWidth, rewriter); + unsigned requiredToShiftWidth = + optWidth - std::min(cstVal.getZExtValue(), optWidth); + modToShift = modBitWidth(toShift, requiredToShiftWidth, rewriter); } - modArithOp(op, {modToShift, extToShift}, {minShiftBy, ExtType::ZEXT}, - optWidth, extToShift, rewriter, namer); + modArithOp(op, {modToShift, toShift.getExtType()}, shiftBy, optWidth, + toShift.getExtType(), rewriter, namer); ++bitwidthReduced; return success(); } @@ -1509,12 +1612,10 @@ struct ArithCmpFW : public OpRewritePattern { LogicalResult matchAndRewrite(handshake::CmpIOp cmpOp, PatternRewriter &rewriter) const override { // Check whether we can reduce the bitwidth of the operation - ExtValue lhsExtValue = getMinimalValueWithExtType(cmpOp.getLhs()); - ExtValue rhsExtValue = getMinimalValueWithExtType(cmpOp.getRhs()); - auto [minLhs, extLhs] = lhsExtValue; - auto [minRhs, extRhs] = rhsExtValue; - unsigned optWidth = std::max(minLhs.getType().getDataBitWidth(), - minRhs.getType().getDataBitWidth()); + MinimalValue lhsExtValue = getMinimalValue(cmpOp.getLhs()); + MinimalValue rhsExtValue = getMinimalValue(cmpOp.getRhs()); + unsigned optWidth = + std::max(lhsExtValue.getDataBitWidth(), rhsExtValue.getDataBitWidth()); unsigned actualWidth = cmpOp.getLhs().getType().getDataBitWidth(); // An extra bit is required to account for bits added by sign-extension. @@ -1528,12 +1629,12 @@ struct ArithCmpFW : public OpRewritePattern { // sign-extension of a negative number will insert a 1-bit upfront which // changes the result. // Example: cmpi uge zext(110), sext(10) must be done using 4, not 3 bits. - if ((extLhs == ExtType::ZEXT && extRhs == ExtType::SEXT && - minLhs.getType().getDataBitWidth() >= - minRhs.getType().getDataBitWidth()) || - (extRhs == ExtType::ZEXT && extLhs == ExtType::SEXT && - minRhs.getType().getDataBitWidth() >= - minLhs.getType().getDataBitWidth())) + if ((lhsExtValue.getExtType() == ExtType::ZEXT && + rhsExtValue.getExtType() == ExtType::SEXT && + lhsExtValue.getDataBitWidth() >= rhsExtValue.getDataBitWidth()) || + (rhsExtValue.getExtType() == ExtType::ZEXT && + lhsExtValue.getExtType() == ExtType::SEXT && + rhsExtValue.getDataBitWidth() >= lhsExtValue.getDataBitWidth())) optWidth += 1; if (optWidth >= actualWidth) @@ -1541,8 +1642,8 @@ struct ArithCmpFW : public OpRewritePattern { // Create a new operation as well as appropriate bitwidth modification // operations to keep the IR valid - Value newLhs = modBitWidth({minLhs, extLhs}, optWidth, rewriter); - Value newRhs = modBitWidth({minRhs, extRhs}, optWidth, rewriter); + Value newLhs = modBitWidth(lhsExtValue, optWidth, rewriter); + Value newRhs = modBitWidth(rhsExtValue, optWidth, rewriter); rewriter.setInsertionPoint(cmpOp); auto newOp = rewriter.create( cmpOp.getLoc(), cmpOp.getPredicate(), newLhs, newRhs); @@ -1575,12 +1676,12 @@ struct ArithExtToTruncOpt : public OpRewritePattern { LogicalResult matchAndRewrite(handshake::TruncIOp truncOp, PatternRewriter &rewriter) const override { // Operand must be produced by an extension operation - ExtValue minVal = getMinimalValueWithExtType(truncOp.getIn()); - if (minVal.second == ExtType::NONE) + MinimalValue minVal = getMinimalValue(truncOp.getIn()); + if (minVal.getExtType() == ExtType::NONE) return failure(); unsigned finalWidth = truncOp.getResult().getType().getDataBitWidth(); - if (finalWidth == minVal.first.getType().getDataBitWidth()) + if (finalWidth == minVal.getDataBitWidth()) return failure(); // Bypass all extensions and truncation operation and replace it with a @@ -1617,7 +1718,7 @@ struct ArithBoundOpt : public OpRewritePattern { ChannelVal channelVal = asTypedIfLegal(condOp.getDataOperand()); if (!channelVal) return failure(); - ExtValue dataOperand = backtrackToMinimalValue(channelVal); + MinimalValue dataOperand = backtrackToMinimalValue(channelVal); // Find all comparison operations whose result is used in a logical and to // determine the condition operand and which have the data operand as one of @@ -1627,8 +1728,8 @@ struct ArithBoundOpt : public OpRewritePattern { falseRes = cast(condOp.getFalseResult()); std::optional> trueBranch, falseBranch; for (handshake::CmpIOp cmpOp : getCmpOps(condOp.getConditionOperand())) { - ExtValue minLhs = backtrackToMinimalValue(cmpOp.getLhs()); - ExtValue minRhs = backtrackToMinimalValue(cmpOp.getRhs()); + MinimalValue minLhs = backtrackToMinimalValue(cmpOp.getLhs()); + MinimalValue minRhs = backtrackToMinimalValue(cmpOp.getRhs()); // One of the two comparison operands must be the data input unsigned width; @@ -1638,13 +1739,13 @@ struct ArithBoundOpt : public OpRewritePattern { // Used in the case of equality-comparisons. ExtType boundExt; if (dataOperand == minLhs) { - width = minRhs.first.getType().getDataBitWidth(); + width = minRhs.getDataBitWidth(); isDataLhs = true; - boundExt = minRhs.second; + boundExt = minRhs.getExtType(); } else if (dataOperand == minRhs) { - width = minLhs.first.getType().getDataBitWidth(); + width = minLhs.getDataBitWidth(); isDataLhs = false; - boundExt = minLhs.second; + boundExt = minLhs.getExtType(); } else continue; @@ -1653,7 +1754,8 @@ struct ArithBoundOpt : public OpRewritePattern { getBranchToOptimize(condOp, cmpOp, isDataLhs); if (!branchOutputToOptimize) continue; - if (isBoundTight(isDataLhs ? minRhs.first : minLhs.first)) + if (isBoundTight(isDataLhs ? minRhs.getConstantOpOrNull() + : minLhs.getConstantOpOrNull())) width = getRealOptWidth(cmpOp, width, isDataLhs); // Perform result extension based on the comparison operator. @@ -1729,7 +1831,7 @@ struct ArithBoundOpt : public OpRewritePattern { /// Determines whether the bound that the data operand is compared with is /// tight, i.e. whether being strictly closer to 0 than it means we can /// represent the number using one less bit than the bound itself. - bool isBoundTight(Value bound) const; + bool isBoundTight(handshake::ConstantOp maybeConstant) const; /// Determines which branch may be optimized based on the nature of the /// comparison and the side of the data operand to the conditional branch. @@ -1755,10 +1857,10 @@ struct ArithBoundOpt : public OpRewritePattern { SmallVector ArithBoundOpt::getCmpOps(ChannelVal condVal) const { - ExtValue minVal = backtrackToMinimalValue(condVal); + MinimalValue minVal = backtrackToMinimalValue(condVal); // Stop when reaching function arguments - Operation *defOp = minVal.first.getDefiningOp(); + Operation *defOp = minVal.getDefiningOp(); if (!defOp) return {}; @@ -1779,15 +1881,13 @@ ArithBoundOpt::getCmpOps(ChannelVal condVal) const { return {}; } -bool ArithBoundOpt::isBoundTight(Value bound) const { +bool ArithBoundOpt::isBoundTight(handshake::ConstantOp maybeConstant) const { // Bound must be a constant - auto cstOp = - dyn_cast_if_present(bound.getDefiningOp()); - if (!cstOp) + if (!maybeConstant) return false; // Constant must have an integer attribute - auto intAttr = cast(cstOp.getValue()); + auto intAttr = cast(maybeConstant.getValue()); if (!intAttr) return false; diff --git a/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir b/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir index 1aba14b621..86116dab4f 100644 --- a/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir +++ b/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir @@ -149,15 +149,15 @@ handshake.func @shliFW(%arg0: !handshake.channel, %start: !handshake.contro // ----- -// CHECK-LABEL: handshake.func @shrsiFW( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, -// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { +// CHECK-LABEL: handshake.func @shrsi_small( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { // CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i16} : <>, // CHECK: %[[VAL_3:.*]] = shrsi %[[VAL_0]], %[[VAL_2]] : // CHECK: %[[VAL_4:.*]] = extsi %[[VAL_3]] : to // CHECK: end %[[VAL_4]] : // CHECK: } -handshake.func @shrsiFW(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { +handshake.func @shrsi_small(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { %cst = handshake.constant %start {value = 4 : i4} : <>, %ext0 = extsi %arg0 : to %extCst = extsi %cst : to @@ -167,6 +167,60 @@ handshake.func @shrsiFW(%arg0: !handshake.channel, %start: !handshake.contr // ----- +// CHECK-LABEL: handshake.func @shrsi_oob( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_0]] : to +// CHECK: %[[VAL_3:.*]] = constant %[[VAL_1]] {value = 32 : i32} : <>, +// CHECK: %[[VAL_4:.*]] = shrsi %[[VAL_2]], %[[VAL_3]] : +// CHECK: end %[[VAL_4]] : +// CHECK: } +handshake.func @shrsi_oob(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %cst = handshake.constant %start {value = 32 : i8} : <>, + %ext0 = extsi %arg0 : to + %extCst = extsi %cst : to + %res = shrsi %ext0, %extCst : + end %res : +} + +// ----- + +// CHECK-LABEL: handshake.func @shrsi_zext( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i16} : <>, +// CHECK: %[[VAL_3:.*]] = shrui %[[VAL_0]], %[[VAL_2]] : +// CHECK: %[[VAL_4:.*]] = extui %[[VAL_3]] : to +// CHECK: end %[[VAL_4]] : +// CHECK: } +handshake.func @shrsi_zext(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %cst = handshake.constant %start {value = 4 : i4} : <>, + %ext0 = extui %arg0 : to + %extCst = extsi %cst : to + %res = shrsi %ext0, %extCst : + end %res : +} + +// ----- + +// CHECK-LABEL: handshake.func @shrsi_large( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 15 : i16} : <>, +// CHECK: %[[VAL_3:.*]] = shrsi %[[VAL_0]], %[[VAL_2]] : +// CHECK: %[[VAL_4:.*]] = extsi %[[VAL_3]] : to +// CHECK: end %[[VAL_4]] : +// CHECK: } +handshake.func @shrsi_large(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %cst = handshake.constant %start {value = 18 : i8} : <>, + %ext0 = extsi %arg0 : to + %extCst = extsi %cst : to + %res = shrsi %ext0, %extCst : + end %res : +} + +// ----- + // CHECK-LABEL: handshake.func @shruiFW( // CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, // CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { diff --git a/test/Transforms/HandshakeOptimizeBitwidths/bound-opti.mlir b/test/Transforms/HandshakeOptimizeBitwidths/bound-opti.mlir index 7ba5e9d417..f031c3b284 100644 --- a/test/Transforms/HandshakeOptimizeBitwidths/bound-opti.mlir +++ b/test/Transforms/HandshakeOptimizeBitwidths/bound-opti.mlir @@ -1,10 +1,3 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// The script is designed to make adding checks to -// a test case fast, it is *not* designed to be authoritative -// about what constitutes a good test! The CHECK should be -// minimized and named to reflect the test intent. - // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // RUN: dynamatic-opt --handshake-optimize-bitwidths --remove-operation-names %s --split-input-file | FileCheck %s @@ -51,11 +44,11 @@ handshake.func @boundUleCst(%arg0: !handshake.channel, %start: !handshake.c // CHECK-LABEL: handshake.func @boundUleCstFlip( // CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, // CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { -// CHECK: %[[VAL_2:.*]] = trunci %[[VAL_0]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_2:.*]] = trunci %[[VAL_0]] {handshake.bb = 0 : ui32} : to // CHECK: %[[VAL_3:.*]] = constant %[[VAL_1]] {value = 16 : i32} : <>, // CHECK: %[[VAL_4:.*]] = cmpi ule, %[[VAL_3]], %[[VAL_0]] : -// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = cond_br %[[VAL_4]], %[[VAL_2]] : , -// CHECK: %[[VAL_7:.*]] = extui %[[VAL_6]] : to +// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = cond_br %[[VAL_4]], %[[VAL_2]] : , +// CHECK: %[[VAL_7:.*]] = extui %[[VAL_6]] : to // CHECK: end %[[VAL_7]] : // CHECK: } handshake.func @boundUleCstFlip(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { @@ -109,19 +102,21 @@ handshake.func @argUleArg(%arg0: !handshake.channel, %bound: !handshake.cha // CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, // CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel, // CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "bound", "start"], resNames = ["out0"]} { -// CHECK: %[[VAL_3:.*]] = constant %[[VAL_2]] {value = 0 : i32} : <>, -// CHECK: %[[VAL_4:.*]] = constant %[[VAL_2]] {value = 50 : i32} : <>, -// CHECK: %[[VAL_5:.*]] = constant %[[VAL_2]] {value = 100 : i32} : <>, -// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_1]] : to -// CHECK: %[[VAL_7:.*]] = cmpi uge, %[[VAL_0]], %[[VAL_3]] : -// CHECK: %[[VAL_8:.*]] = cmpi ult, %[[VAL_0]], %[[VAL_5]] : -// CHECK: %[[VAL_9:.*]] = cmpi ne, %[[VAL_0]], %[[VAL_4]] : -// CHECK: %[[VAL_10:.*]] = cmpi ult, %[[VAL_0]], %[[VAL_6]] : -// CHECK: %[[VAL_11:.*]] = andi %[[VAL_7]], %[[VAL_8]] : -// CHECK: %[[VAL_12:.*]] = andi %[[VAL_9]], %[[VAL_10]] : -// CHECK: %[[VAL_13:.*]] = andi %[[VAL_11]], %[[VAL_12]] : -// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]] = cond_br %[[VAL_13]], %[[VAL_0]] : , -// CHECK: end %[[VAL_14]] : +// CHECK: %[[VAL_3:.*]] = trunci %[[VAL_0]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_4:.*]] = constant %[[VAL_2]] {value = 0 : i32} : <>, +// CHECK: %[[VAL_5:.*]] = constant %[[VAL_2]] {value = 50 : i32} : <>, +// CHECK: %[[VAL_6:.*]] = constant %[[VAL_2]] {value = 100 : i32} : <>, +// CHECK: %[[VAL_7:.*]] = extsi %[[VAL_1]] : to +// CHECK: %[[VAL_8:.*]] = cmpi uge, %[[VAL_0]], %[[VAL_4]] : +// CHECK: %[[VAL_9:.*]] = cmpi ult, %[[VAL_0]], %[[VAL_6]] : +// CHECK: %[[VAL_10:.*]] = cmpi ne, %[[VAL_0]], %[[VAL_5]] : +// CHECK: %[[VAL_11:.*]] = cmpi ult, %[[VAL_0]], %[[VAL_7]] : +// CHECK: %[[VAL_12:.*]] = andi %[[VAL_8]], %[[VAL_9]] : +// CHECK: %[[VAL_13:.*]] = andi %[[VAL_10]], %[[VAL_11]] : +// CHECK: %[[VAL_14:.*]] = andi %[[VAL_12]], %[[VAL_13]] : +// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = cond_br %[[VAL_14]], %[[VAL_3]] : , +// CHECK: %[[VAL_17:.*]] = extui %[[VAL_15]] : to +// CHECK: end %[[VAL_17]] : // CHECK: } handshake.func @mulCmps(%arg0: !handshake.channel, %bound: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { %0 = constant %start {value = 0 : i1} : <>, @@ -147,16 +142,17 @@ handshake.func @mulCmps(%arg0: !handshake.channel, %bound: !handshake.chann // CHECK-LABEL: handshake.func @simpleLoop( // CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { // CHECK: %[[VAL_1:.*]] = source : <> -// CHECK: %[[VAL_2:.*]] = constant %[[VAL_0]] {value = 0 : i32} : <>, -// CHECK: %[[VAL_3:.*]] = constant %[[VAL_1]] {value = 16 : i32} : <>, -// CHECK: %[[VAL_4:.*]] = constant %[[VAL_1]] {value = 1 : i32} : <>, -// CHECK: %[[VAL_5:.*]] = merge %[[VAL_2]], %[[VAL_6:.*]] : -// CHECK: %[[VAL_7:.*]] = addi %[[VAL_5]], %[[VAL_4]] : -// CHECK: %[[VAL_8:.*]] = cmpi ult, %[[VAL_7]], %[[VAL_3]] : -// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = cond_br %[[VAL_8]], %[[VAL_7]] : , -// CHECK: %[[VAL_11:.*]] = trunci %[[VAL_9]] : to -// CHECK: %[[VAL_6]] = extui %[[VAL_11]] : to -// CHECK: end %[[VAL_10]] : +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_0]] {value = 0 : i4} : <>, +// CHECK: %[[VAL_3:.*]] = merge %[[VAL_2]], %[[VAL_4:.*]] : +// CHECK: %[[VAL_5:.*]] = extui %[[VAL_3]] : to +// CHECK: %[[VAL_6:.*]] = constant %[[VAL_1]] {value = 1 : i5} : <>, +// CHECK: %[[VAL_7:.*]] = addi %[[VAL_5]], %[[VAL_6]] : +// CHECK: %[[VAL_8:.*]] = constant %[[VAL_1]] {value = -16 : i5} : <>, +// CHECK: %[[VAL_9:.*]] = cmpi ult, %[[VAL_7]], %[[VAL_8]] : +// CHECK: %[[VAL_10:.*]], %[[VAL_11:.*]] = cond_br %[[VAL_9]], %[[VAL_7]] : , +// CHECK: %[[VAL_12:.*]] = extui %[[VAL_11]] : to +// CHECK: %[[VAL_4]] = trunci %[[VAL_10]] : to +// CHECK: end %[[VAL_12]] : // CHECK: } handshake.func @simpleLoop(%start: !handshake.control<>) -> !handshake.channel { %source = source : <> @@ -179,42 +175,43 @@ handshake.func @simpleLoop(%start: !handshake.control<>) -> !handshake.channel, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { // CHECK: %[[VAL_1:.*]] = source {handshake.bb = 0 : ui32} : <> -// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 16 : i32} : <>, -// CHECK: %[[VAL_3:.*]] = constant %[[VAL_1]] {value = 0 : i32} : <>, -// CHECK: %[[VAL_4:.*]] = constant %[[VAL_1]] {value = 1 : i32} : <>, -// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = control_merge {{\[}}%[[VAL_0]], %[[VAL_7:.*]]] : [<>, <>] to <>, -// CHECK: %[[VAL_8:.*]] = mux %[[VAL_6]] {{\[}}%[[VAL_3]], %[[VAL_9:.*]]] : , [, ] to -// CHECK: %[[VAL_10:.*]] = mux %[[VAL_6]] {{\[}}%[[VAL_3]], %[[VAL_11:.*]]] : , [, ] to -// CHECK: %[[VAL_12:.*]] = addi %[[VAL_10]], %[[VAL_4]] : -// CHECK: %[[VAL_13:.*]] = trunci %[[VAL_12]] : to -// CHECK: %[[VAL_14:.*]] = cmpi ult, %[[VAL_12]], %[[VAL_2]] : -// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = cond_br %[[VAL_14]], %[[VAL_13]] : , -// CHECK: %[[VAL_11]] = extui %[[VAL_15]] : to -// CHECK: %[[VAL_17:.*]], %[[VAL_18:.*]] = cond_br %[[VAL_14]], %[[VAL_8]] : , -// CHECK: %[[VAL_19:.*]], %[[VAL_20:.*]] = cond_br %[[VAL_14]], %[[VAL_5]] : , <> -// CHECK: %[[VAL_21:.*]] = source : <> -// CHECK: %[[VAL_22:.*]] = constant %[[VAL_21]] {value = 32 : i32} : <>, -// CHECK: %[[VAL_23:.*]] = constant %[[VAL_21]] {value = 0 : i32} : <>, -// CHECK: %[[VAL_24:.*]] = constant %[[VAL_21]] {value = 1 : i32} : <>, -// CHECK: %[[VAL_25:.*]], %[[VAL_26:.*]] = control_merge {{\[}}%[[VAL_19]], %[[VAL_27:.*]]] : [<>, <>] to <>, -// CHECK: %[[VAL_28:.*]] = mux %[[VAL_26]] {{\[}}%[[VAL_17]], %[[VAL_29:.*]]] : , [, ] to -// CHECK: %[[VAL_30:.*]] = mux %[[VAL_26]] {{\[}}%[[VAL_23]], %[[VAL_31:.*]]] : , [, ] to -// CHECK: %[[VAL_32:.*]] = addi %[[VAL_30]], %[[VAL_24]] : -// CHECK: %[[VAL_33:.*]] = trunci %[[VAL_32]] : to -// CHECK: %[[VAL_34:.*]] = cmpi ult, %[[VAL_32]], %[[VAL_22]] : -// CHECK: %[[VAL_35:.*]], %[[VAL_36:.*]] = cond_br %[[VAL_34]], %[[VAL_33]] : , -// CHECK: %[[VAL_31]] = extui %[[VAL_35]] : to -// CHECK: %[[VAL_37:.*]], %[[VAL_9]] = cond_br %[[VAL_34]], %[[VAL_28]] : , -// CHECK: %[[VAL_38:.*]], %[[VAL_7]] = cond_br %[[VAL_34]], %[[VAL_25]] : , <> -// CHECK: %[[VAL_39:.*]] = source : <> -// CHECK: %[[VAL_40:.*]] = constant %[[VAL_39]] {value = 10 : i32} : <>, -// CHECK: %[[VAL_41:.*]], %[[VAL_42:.*]] = control_merge {{\[}}%[[VAL_38]]] : [<>] to <>, -// CHECK: %[[VAL_43:.*]] = merge %[[VAL_37]] : -// CHECK: %[[VAL_44:.*]] = addi %[[VAL_43]], %[[VAL_40]] : -// CHECK: %[[VAL_29]] = br %[[VAL_44]] : -// CHECK: %[[VAL_27]] = br %[[VAL_41]] : <> -// CHECK: %[[VAL_45:.*]] = merge %[[VAL_18]] : -// CHECK: end %[[VAL_45]] : +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 0 : i32} : <>, +// CHECK: %[[VAL_3:.*]], %[[VAL_4:.*]] = control_merge {{\[}}%[[VAL_0]], %[[VAL_5:.*]]] : [<>, <>] to <>, +// CHECK: %[[VAL_6:.*]] = mux %[[VAL_4]] {{\[}}%[[VAL_2]], %[[VAL_7:.*]]] : , [, ] to +// CHECK: %[[VAL_8:.*]] = constant %[[VAL_1]] {value = 0 : i4} : <>, +// CHECK: %[[VAL_9:.*]] = mux %[[VAL_4]] {{\[}}%[[VAL_8]], %[[VAL_10:.*]]] : , [, ] to +// CHECK: %[[VAL_11:.*]] = extui %[[VAL_9]] : to +// CHECK: %[[VAL_12:.*]] = constant %[[VAL_1]] {value = 1 : i5} : <>, +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_12]] : +// CHECK: %[[VAL_14:.*]] = trunci %[[VAL_13]] : to +// CHECK: %[[VAL_15:.*]] = constant %[[VAL_1]] {value = -16 : i5} : <>, +// CHECK: %[[VAL_16:.*]] = cmpi ult, %[[VAL_13]], %[[VAL_15]] : +// CHECK: %[[VAL_10]], %[[VAL_17:.*]] = cond_br %[[VAL_16]], %[[VAL_14]] : , +// CHECK: %[[VAL_18:.*]], %[[VAL_19:.*]] = cond_br %[[VAL_16]], %[[VAL_6]] : , +// CHECK: %[[VAL_20:.*]], %[[VAL_21:.*]] = cond_br %[[VAL_16]], %[[VAL_3]] : , <> +// CHECK: %[[VAL_22:.*]] = source : <> +// CHECK: %[[VAL_23:.*]], %[[VAL_24:.*]] = control_merge {{\[}}%[[VAL_20]], %[[VAL_25:.*]]] : [<>, <>] to <>, +// CHECK: %[[VAL_26:.*]] = mux %[[VAL_24]] {{\[}}%[[VAL_18]], %[[VAL_27:.*]]] : , [, ] to +// CHECK: %[[VAL_28:.*]] = constant %[[VAL_22]] {value = 0 : i5} : <>, +// CHECK: %[[VAL_29:.*]] = mux %[[VAL_24]] {{\[}}%[[VAL_28]], %[[VAL_30:.*]]] : , [, ] to +// CHECK: %[[VAL_31:.*]] = extui %[[VAL_29]] : to +// CHECK: %[[VAL_32:.*]] = constant %[[VAL_22]] {value = 1 : i6} : <>, +// CHECK: %[[VAL_33:.*]] = addi %[[VAL_31]], %[[VAL_32]] : +// CHECK: %[[VAL_34:.*]] = trunci %[[VAL_33]] : to +// CHECK: %[[VAL_35:.*]] = constant %[[VAL_22]] {value = -32 : i6} : <>, +// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[VAL_33]], %[[VAL_35]] : +// CHECK: %[[VAL_30]], %[[VAL_37:.*]] = cond_br %[[VAL_36]], %[[VAL_34]] : , +// CHECK: %[[VAL_38:.*]], %[[VAL_7]] = cond_br %[[VAL_36]], %[[VAL_26]] : , +// CHECK: %[[VAL_39:.*]], %[[VAL_5]] = cond_br %[[VAL_36]], %[[VAL_23]] : , <> +// CHECK: %[[VAL_40:.*]] = source : <> +// CHECK: %[[VAL_41:.*]] = constant %[[VAL_40]] {value = 10 : i32} : <>, +// CHECK: %[[VAL_42:.*]], %[[VAL_43:.*]] = control_merge {{\[}}%[[VAL_39]]] : [<>] to <>, +// CHECK: %[[VAL_44:.*]] = merge %[[VAL_38]] : +// CHECK: %[[VAL_45:.*]] = addi %[[VAL_44]], %[[VAL_41]] : +// CHECK: %[[VAL_27]] = br %[[VAL_45]] : +// CHECK: %[[VAL_25]] = br %[[VAL_42]] : <> +// CHECK: %[[VAL_46:.*]] = merge %[[VAL_19]] : +// CHECK: end %[[VAL_46]] : // CHECK: } handshake.func @nestedLoop(%start: !handshake.control<>) -> !handshake.channel { // ^^entry outer loop: diff --git a/test/Transforms/HandshakeOptimizeBitwidths/gh764.mlir b/test/Transforms/HandshakeOptimizeBitwidths/gh764.mlir index 431b75be7f..540706e85b 100644 --- a/test/Transforms/HandshakeOptimizeBitwidths/gh764.mlir +++ b/test/Transforms/HandshakeOptimizeBitwidths/gh764.mlir @@ -36,13 +36,13 @@ handshake.func @test_and_sext_zext(%arg0: !handshake.channel, %arg1: !handsh // ----- // CHECK-LABEL: handshake.func @test_and_sext_sext( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, -// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> (!handshake.channel, !handshake.control<>) attributes {argNames = ["var2", "start"], resNames = ["out0", "end"]} { -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : to -// CHECK: %[[VAL_3:.*]] = source {handshake.bb = 0 : ui32} : <> -// CHECK: %[[VAL_4:.*]] = constant %[[VAL_3]] {handshake.bb = 0 : ui32, value = 71 : i8} : <>, -// CHECK: %[[VAL_5:.*]] = andi %[[VAL_2]], %[[VAL_4]] {handshake.bb = 0 : ui32} : -// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : to +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> (!handshake.channel, !handshake.control<>) attributes {argNames = ["var2", "start"], resNames = ["out0", "end"]} { +// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_3:.*]] = constant %[[VAL_4:.*]] {handshake.bb = 0 : ui32, value = -57 : i7} : <>, +// CHECK: %[[VAL_4]] = source {handshake.bb = 0 : ui32} : <> +// CHECK: %[[VAL_5:.*]] = andi %[[VAL_2]], %[[VAL_3]] {handshake.bb = 0 : ui32} : +// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : to // CHECK: end {handshake.bb = 0 : ui32} %[[VAL_6]], %[[VAL_1]] : , <> // CHECK: } handshake.func @test_and_sext_sext(%arg0: !handshake.channel, %arg2: !handshake.control<>, ...) -> (!handshake.channel, !handshake.control<>) attributes {argNames = ["var2", "start"], resNames = ["out0", "end"]} { diff --git a/test/Transforms/HandshakeOptimizeBitwidths/handshake-special.mlir b/test/Transforms/HandshakeOptimizeBitwidths/handshake-special.mlir index 7208c28cdb..a580a4ea47 100644 --- a/test/Transforms/HandshakeOptimizeBitwidths/handshake-special.mlir +++ b/test/Transforms/HandshakeOptimizeBitwidths/handshake-special.mlir @@ -43,13 +43,13 @@ handshake.func @cmergeToMuxIndexOpt(%arg0: !handshake.channel, %arg1: !hand // CHECK: %[[VAL_3:.*]], %[[VAL_4:.*]] = mem_controller{{\[}}%[[VAL_0]] : memref<1000xi32>] %[[VAL_1]] (%[[VAL_5:.*]], %[[VAL_6:.*]], %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]], %[[VAL_10:.*]]) %[[VAL_2]] {connectedBlocks = [0 : i32]} : (!handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel) -> !handshake.channel // CHECK: %[[VAL_11:.*]] = constant %[[VAL_2]] {value = 0 : i10} : <>, // CHECK: %[[VAL_12:.*]] = constant %[[VAL_2]] {value = 500 : i10} : <>, -// CHECK: %[[VAL_13:.*]] = constant %[[VAL_2]] {value = -25 : i10} : <>, -// CHECK: %[[VAL_14:.*]] = constant %[[VAL_2]] {value = 42 : i32} : <>, +// CHECK: %[[VAL_13:.*]] = constant %[[VAL_2]] {value = 42 : i32} : <>, // CHECK: %[[VAL_5]] = constant %[[VAL_2]] {handshake.bb = 0 : ui32, value = 2 : i32} : <>, -// CHECK: %[[VAL_6]], %[[VAL_15:.*]] = load{{\[}}%[[VAL_11]]] %[[VAL_3]] {handshake.bb = 0 : ui32} : , , , -// CHECK: %[[VAL_7]], %[[VAL_8]] = store{{\[}}%[[VAL_12]]] %[[VAL_14]] {handshake.bb = 0 : ui32} : , , , -// CHECK: %[[VAL_9]], %[[VAL_10]] = store{{\[}}%[[VAL_13]]] %[[VAL_14]] {handshake.bb = 0 : ui32} : , , , -// CHECK: end %[[VAL_15]], %[[VAL_4]] : , <> +// CHECK: %[[VAL_6]], %[[VAL_14:.*]] = load{{\[}}%[[VAL_11]]] %[[VAL_3]] {handshake.bb = 0 : ui32} : , , , +// CHECK: %[[VAL_7]], %[[VAL_8]] = store{{\[}}%[[VAL_12]]] %[[VAL_13]] {handshake.bb = 0 : ui32} : , , , +// CHECK: %[[VAL_15:.*]] = constant %[[VAL_2]] {value = -25 : i10} : <>, +// CHECK: %[[VAL_9]], %[[VAL_10]] = store{{\[}}%[[VAL_15]]] %[[VAL_13]] {handshake.bb = 0 : ui32} : , , , +// CHECK: end %[[VAL_14]], %[[VAL_4]] : , <> // CHECK: } handshake.func @memAddrOpt(%mem: memref<1000xi32>, %mem_start: !handshake.control<>, %start: !handshake.control<>) -> (!handshake.channel, !handshake.control<>) { %ldData1, %done = mem_controller[%mem : memref<1000xi32>] %mem_start (%ctrl1, %ldAddr1, %stAddr1, %stData1, %stAddr2, %stData2) %start {connectedBlocks = [0 : i32]} : (!handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel) -> !handshake.channel @@ -76,13 +76,13 @@ handshake.func @memAddrOpt(%mem: memref<1000xi32>, %mem_start: !handshake.contro // CHECK: %[[VAL_8]]:4 = lsq[MC] (%[[VAL_2]], %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]], %[[VAL_3]]) {groupSizes = [2 : i32]} : (!handshake.control<>, !handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel) -> (!handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel) // CHECK: %[[VAL_12:.*]] = constant %[[VAL_2]] {value = 0 : i10} : <>, // CHECK: %[[VAL_13:.*]] = constant %[[VAL_2]] {value = 500 : i10} : <>, -// CHECK: %[[VAL_14:.*]] = constant %[[VAL_2]] {value = -25 : i10} : <>, -// CHECK: %[[VAL_15:.*]] = constant %[[VAL_2]] {value = 42 : i32} : <>, +// CHECK: %[[VAL_14:.*]] = constant %[[VAL_2]] {value = 42 : i32} : <>, // CHECK: %[[VAL_5]] = constant %[[VAL_2]] {handshake.bb = 0 : ui32, value = 2 : i32} : <>, -// CHECK: %[[VAL_9]], %[[VAL_16:.*]] = load{{\[}}%[[VAL_12]]] %[[VAL_8]]#0 {handshake.bb = 0 : ui32} : , , , -// CHECK: %[[VAL_10]], %[[VAL_11]] = store{{\[}}%[[VAL_13]]] %[[VAL_15]] {handshake.bb = 0 : ui32} : , , , -// CHECK: %[[VAL_6]], %[[VAL_7]] = store{{\[}}%[[VAL_14]]] %[[VAL_15]] {handshake.bb = 0 : ui32} : , , , -// CHECK: end %[[VAL_16]], %[[VAL_4]] : , <> +// CHECK: %[[VAL_9]], %[[VAL_15:.*]] = load{{\[}}%[[VAL_12]]] %[[VAL_8]]#0 {handshake.bb = 0 : ui32} : , , , +// CHECK: %[[VAL_10]], %[[VAL_11]] = store{{\[}}%[[VAL_13]]] %[[VAL_14]] {handshake.bb = 0 : ui32} : , , , +// CHECK: %[[VAL_16:.*]] = constant %[[VAL_2]] {value = -25 : i10} : <>, +// CHECK: %[[VAL_6]], %[[VAL_7]] = store{{\[}}%[[VAL_16]]] %[[VAL_14]] {handshake.bb = 0 : ui32} : , , , +// CHECK: end %[[VAL_15]], %[[VAL_4]] : , <> // CHECK: } handshake.func @memAddrOptMasterSlave(%mem: memref<1000xi32>, %mem_start: !handshake.control<>, %start: !handshake.control<>) -> (!handshake.channel, !handshake.control<>) { %ldDataToLSQ, %done = mem_controller[%mem : memref<1000xi32>] %mem_start (%ctrl1, %stAddr2, %stData2, %ldAddrToMC, %stAddrToMC, %stdDataToMC) %start {connectedBlocks = [0 : i32]} : (!handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel, !handshake.channel) -> !handshake.channel diff --git a/test/Transforms/minimize-cst-width.mlir b/test/Transforms/minimize-cst-width.mlir deleted file mode 100644 index f6838035a7..0000000000 --- a/test/Transforms/minimize-cst-width.mlir +++ /dev/null @@ -1,222 +0,0 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// RUN: dynamatic-opt --handshake-minimize-cst-width="opt-negatives" --remove-operation-names %s --split-input-file | FileCheck %s - -// CHECK-LABEL: handshake.func @doNothing( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 31 : i6} : <>, -// CHECK: end %[[VAL_1]] : -// CHECK: } -handshake.func @doNothing(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = 31 : i6} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @zeroCst( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = false} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: end %[[VAL_2]] : -// CHECK: } -handshake.func @zeroCst(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = 0 : i32} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @oneCst( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 1 : i2} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: end %[[VAL_2]] : -// CHECK: } -handshake.func @oneCst(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = 1 : i32} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @powerOfTwoMinusOne( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 31 : i6} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: end %[[VAL_2]] : -// CHECK: } -handshake.func @powerOfTwoMinusOne(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = 31 : i32} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @powerOfTwo( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 32 : i7} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: end %[[VAL_2]] : -// CHECK: } -handshake.func @powerOfTwo(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = 32 : i32} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @maxPosVal( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 9223372036854775807 : i64} : <>, -// CHECK: end %[[VAL_1]] : -// CHECK: } -handshake.func @maxPosVal(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = 9223372036854775807 : i64} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @negPowerOfMinusOne( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = -33 : i7} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: end %[[VAL_2]] : -// CHECK: } -handshake.func @negPowerOfMinusOne(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = -33 : i32} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @negPowerOfTwo( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = -32 : i6} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: end %[[VAL_2]] : -// CHECK: } -handshake.func @negPowerOfTwo(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = -32 : i32} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @minNegVal( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = -9223372036854775808 : i64} : <>, -// CHECK: end %[[VAL_1]] : -// CHECK: } -handshake.func @minNegVal(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = -9223372036854775808 : i64} : <>, - end %cst : -} - - -// ----- - -// CHECK-LABEL: handshake.func @multipleUsers( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 32 : i7} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: %[[VAL_3:.*]] = addi %[[VAL_2]], %[[VAL_2]] : -// CHECK: %[[VAL_4:.*]] = addi %[[VAL_3]], %[[VAL_2]] : -// CHECK: end %[[VAL_4]] : -// CHECK: } -handshake.func @multipleUsers(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = 32 : i32} : <>, - %add = addi %cst, %cst : - %add2 = addi %add, %cst : - end %add2 : -} - -// ----- - -// CHECK-LABEL: handshake.func @inheritBB( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 32 : i7} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: end %[[VAL_2]] : -// CHECK: } -handshake.func @inheritBB(%start: !handshake.control<>) -> !handshake.channel { - %cst = constant %start {value = 32 : i32, handshake.bb = 0 : ui32} : <>, - end %cst : -} - -// ----- - -// CHECK-LABEL: handshake.func @duplicateDoNothingDiff( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = merge %[[VAL_0]] : <> -// CHECK: %[[VAL_2:.*]] = constant %[[VAL_0]] {value = 3 : i3} : <>, -// CHECK: %[[VAL_3:.*]] = extsi %[[VAL_2]] : to -// CHECK: %[[VAL_4:.*]] = constant %[[VAL_1]] {value = 3 : i3} : <>, -// CHECK: %[[VAL_5:.*]] = extsi %[[VAL_4]] : to -// CHECK: %[[VAL_6:.*]] = constant %[[VAL_0]] {value = 2 : i3} : <>, -// CHECK: %[[VAL_7:.*]] = extsi %[[VAL_6]] : to -// CHECK: %[[VAL_8:.*]] = addi %[[VAL_3]], %[[VAL_5]] : -// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_7]] : -// CHECK: end %[[VAL_9]] : -// CHECK: } -handshake.func @duplicateDoNothingDiff(%start: !handshake.control<>) -> !handshake.channel { - %mergeStart = merge %start : <> - %cst1 = constant %start {value = 3 : i32} : <>, - %cst2 = constant %mergeStart {value = 3 : i32} : <>, - %cst3 = constant %start {value = 2 : i32} : <>, - %add1 = addi %cst1, %cst2 : - %add2 = addi %add1, %cst3 : - end %add2 : -} - -// ----- - -// CHECK-LABEL: handshake.func @duplicateDoNothingPrevious( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 32 : i7} : <>, -// CHECK: %[[VAL_2:.*]] = constant %[[VAL_0]] {value = 32 : i7} : <>, -// CHECK: %[[VAL_3:.*]] = addi %[[VAL_1]], %[[VAL_2]] : -// CHECK: end %[[VAL_3]] : -// CHECK: } -handshake.func @duplicateDoNothingPrevious(%start: !handshake.control<>) -> !handshake.channel { - %cst1 = constant %start {value = 32 : i7} : <>, - %cst2 = constant %start {value = 32 : i7} : <>, - %add = addi %cst1, %cst2 : - end %add : -} - -// ----- - -// CHECK-LABEL: handshake.func @deleteDuplicate( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 32 : i7} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: %[[VAL_3:.*]] = extsi %[[VAL_1]] : to -// CHECK: %[[VAL_4:.*]] = addi %[[VAL_3]], %[[VAL_2]] : -// CHECK: end %[[VAL_4]] : -// CHECK: } -handshake.func @deleteDuplicate(%start: !handshake.control<>) -> !handshake.channel { - %cst1 = constant %start {value = 32 : i32} : <>, - %cst2 = constant %start {value = 32 : i32} : <>, - %add = addi %cst1, %cst2 : - end %add : -} - - -// ----- - -// CHECK-LABEL: handshake.func @deleteDuplicateMatchExists( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["start"], resNames = ["out0"]} { -// CHECK: %[[VAL_1:.*]] = constant %[[VAL_0]] {value = 32 : i7} : <>, -// CHECK: %[[VAL_2:.*]] = extsi %[[VAL_1]] : to -// CHECK: %[[VAL_3:.*]] = extsi %[[VAL_1]] : to -// CHECK: %[[VAL_4:.*]] = addi %[[VAL_3]], %[[VAL_2]] : -// CHECK: end %[[VAL_4]] : -// CHECK: } -handshake.func @deleteDuplicateMatchExists(%start: !handshake.control<>) -> !handshake.channel { - %cst1 = constant %start {value = 32 : i7} : <>, - %cst2 = constant %start {value = 32 : i32} : <>, - %cst1ext = extsi %cst1 : to - %add = addi %cst1ext, %cst2 : - end %add : -} diff --git a/tools/dynamatic/scripts/compile.sh b/tools/dynamatic/scripts/compile.sh index 52e8520579..99075b354a 100755 --- a/tools/dynamatic/scripts/compile.sh +++ b/tools/dynamatic/scripts/compile.sh @@ -267,7 +267,7 @@ if [[ $STRAIGHT_TO_QUEUE -ne 0 ]]; then # handshake transformations "$DYNAMATIC_OPT_BIN" "$F_HANDSHAKE" \ --handshake-remove-unused-memrefs \ - --handshake-minimize-cst-width --handshake-optimize-bitwidths \ + --handshake-optimize-bitwidths \ --handshake-materialize="replicate-constant=true" --handshake-infer-basic-blocks \ > "$F_HANDSHAKE_TRANSFORMED" exit_on_fail "Failed to apply transformations to handshake" \ @@ -279,7 +279,7 @@ else "$DYNAMATIC_OPT_BIN" "$F_HANDSHAKE" \ --handshake-analyze-lsq-usage --handshake-replace-memory-interfaces \ --handshake-remove-unused-memrefs \ - --handshake-minimize-cst-width --handshake-optimize-bitwidths \ + --handshake-optimize-bitwidths \ --handshake-materialize --handshake-infer-basic-blocks \ > "$F_HANDSHAKE_TRANSFORMED" exit_on_fail "Failed to apply transformations to handshake" \ diff --git a/tools/integration/util.cpp b/tools/integration/util.cpp index 6dfbc4ce4e..07943e52fd 100644 --- a/tools/integration/util.cpp +++ b/tools/integration/util.cpp @@ -170,13 +170,13 @@ bool runSpecIntegrationTest(const std::string &name, int &outSimTime) { } fs::path handshakeTransformed = compOutDir / "handshakeTransformed.mlir"; - if (!runSubprocess( - {DYNAMATIC_OPT_BIN, handshake.string(), - "--handshake-analyze-lsq-usage", - "--handshake-replace-memory-interfaces", - "--handshake-minimize-cst-width", "--handshake-optimize-bitwidths", - "--handshake-materialize", "--handshake-infer-basic-blocks"}, - handshakeTransformed)) { + if (!runSubprocess({DYNAMATIC_OPT_BIN, handshake.string(), + "--handshake-analyze-lsq-usage", + "--handshake-replace-memory-interfaces", + "--handshake-optimize-bitwidths", + "--handshake-materialize", + "--handshake-infer-basic-blocks"}, + handshakeTransformed)) { std::cerr << "Failed to apply transformations to handshake\n"; return false; }