@@ -53,20 +53,16 @@ namespace dynamatic {
5353// [END Boilerplate code for the MLIR pass]
5454
5555namespace {
56- // / Extension type. When backtracing through extension operations, serves to
56+ // / Extension type. When backtracking through extension operations, serves to
5757// / remember the type of any extension we may have encountered along the
58- // / way.getMinimalValue Then, when modifying a value's bitwidth, serves to guide
58+ // / way. Then, when modifying a value's bitwidth, serves to guide
5959// / the determination of which extension operation to use.
60- // / - UNKNOWN when no extension has been encountered / when a value's signedness
61- // / should determine its extension type.
62- // / - LOGICAL when only logical extensions have been encountered / when a value
63- // / should be logically extended.
64- // / - ARITHMETIC when only arithmaric extensions have been encountered / when a
65- // / value should be arithmetically extended.
66- // / - CONFLICT when both logical and arithmetic extensions have been encountered
67- // / when it's not possible to accurately determine what type of extension to
68- // / use for a value.
69- enum class ExtType { UNKNOWN, LOGICAL, ARITHMETIC, CONFLICT };
60+ // / - NONE when no extension has been encountered.
61+ // / - ZEXT when a zero extension has been encountered / when a value
62+ // / should be zero extended.
63+ // / - SEXT when a sign extension has been encountered / when a
64+ // / value should be sign extended.
65+ enum class ExtType { NONE, ZEXT, SEXT };
7066
7167// / A channel-typed value.
7268using ChannelVal = TypedValue<handshake::ChannelType>;
@@ -106,35 +102,23 @@ static ChannelVal getMinimalValue(ChannelVal val, ExtType *ext = nullptr) {
106102 if (!asTypedIfLegal (val))
107103 return val;
108104
109- // Only backtrack through values that were produced by extension operations
110- while (Operation *defOp = val.getDefiningOp ()) {
111- if (!isa<handshake::ExtSIOp, handshake::ExtUIOp>(defOp))
112- return val;
105+ if (auto op = val.getDefiningOp <handshake::ExtSIOp>()) {
106+ if (ext)
107+ *ext = ExtType::SEXT;
113108
114- // Update the extension type using the nature of the current extension
115- // operation and the current type
116- if (ext) {
117- switch (*ext) {
118- case ExtType::UNKNOWN:
119- *ext = isa<handshake::ExtSIOp>(defOp) ? ExtType::ARITHMETIC
120- : ExtType::LOGICAL;
121- break ;
122- case ExtType::LOGICAL:
123- if (isa<handshake::ExtSIOp>(defOp))
124- *ext = ExtType::CONFLICT;
125- break ;
126- case ExtType::ARITHMETIC:
127- if (isa<handshake::ExtUIOp>(defOp))
128- *ext = ExtType::CONFLICT;
129- break ;
130- default :
131- break ;
132- }
133- }
134- // Backtrack through the extension operation
135- val = cast<ChannelVal>(defOp->getOperand (0 ));
109+ return op.getIn ();
110+ }
111+
112+ if (auto op = val.getDefiningOp <handshake::ExtUIOp>()) {
113+ if (ext)
114+ *ext = ExtType::ZEXT;
115+
116+ return op.getIn ();
136117 }
137118
119+ if (ext)
120+ *ext = ExtType::NONE;
121+
138122 return val;
139123}
140124
@@ -215,23 +199,12 @@ static ChannelVal modBitWidth(ExtValue extVal, unsigned targetWidth,
215199 Type dstChannelType = val.getType ().withDataType (newDataType);
216200 rewriter.setInsertionPointAfterValue (val);
217201 if (width < targetWidth) {
218- if (ext == ExtType::CONFLICT) {
219- // If the extension type is conflicting, just emit a warning and hope for
220- // the best
221- Operation *defOp = val.getDefiningOp ();
222- std::string origin;
223- if (defOp)
224- origin = " operation result" ;
225- else {
226- defOp = val.getParentBlock ()->getParentOp ();
227- origin = " function argument" ;
228- }
229- defOp->emitWarning ()
230- << " Conflicting extension type given for " << origin
231- << " , optimization result may change circuit semantics." ;
232- }
233- if (ext == ExtType::LOGICAL ||
234- (ext == ExtType::UNKNOWN &&
202+ // TODO: The code here is wrong for NONE as our IR no-longer carries
203+ // signedness information in the integer type.
204+ // All code should be migrated to pass LOGICAL or ARITHMETIC if it performs
205+ // bitwidth changes.
206+ if (ext == ExtType::ZEXT ||
207+ (ext == ExtType::NONE &&
235208 val.getType ().getDataType ().isUnsignedInteger ())) {
236209 newOp = rewriter.create <handshake::ExtUIOp>(loc, dstChannelType, val);
237210 } else {
@@ -382,21 +355,21 @@ static void canonicalizeCommutativeExtensionType(ExtWidth &lhs, ExtWidth &rhs) {
382355// / Transfer function for add/sub operations or alike.
383356static ExtWidth addWidth (ExtWidth lhs, ExtWidth rhs) {
384357 canonicalizeCommutativeExtensionType (lhs, rhs);
385- if (rhs.extType <= ExtType::LOGICAL )
386- return {ExtType::LOGICAL , std::max (lhs.bitWidth , rhs.bitWidth ) + 1 };
358+ if (rhs.extType <= ExtType::ZEXT )
359+ return {ExtType::ZEXT , std::max (lhs.bitWidth , rhs.bitWidth ) + 1 };
387360
388- return {ExtType::ARITHMETIC , std::max (lhs.bitWidth , rhs.bitWidth ) + 1 };
361+ return {ExtType::SEXT , std::max (lhs.bitWidth , rhs.bitWidth ) + 1 };
389362}
390363
391364// / Transfer function for mul operations or alike.
392365static ExtWidth mulWidth (ExtWidth lhs, ExtWidth rhs) {
393- return {ExtType::UNKNOWN , lhs.bitWidth + rhs.bitWidth };
366+ return {ExtType::NONE , lhs.bitWidth + rhs.bitWidth };
394367}
395368
396369// / Transfer function for div/rem operations or alike.
397370template <bool zeroExtend>
398371static ExtWidth divWidth (ExtWidth lhs, ExtWidth _) {
399- return {zeroExtend ? ExtType::LOGICAL : ExtType::UNKNOWN , lhs.bitWidth + 1 };
372+ return {zeroExtend ? ExtType::ZEXT : ExtType::NONE , lhs.bitWidth + 1 };
400373}
401374
402375// / Transfer function for and operations or alike.
@@ -411,8 +384,8 @@ static ExtWidth andWidth(ExtWidth lhs, ExtWidth rhs) {
411384 // From our example:
412385 // Extending 'a' to 00001 and 'b' to 00101 yields the same result as if ANDing
413386 // "a = 01, b = 01" and zero-extending the result.
414- if (rhs.extType <= ExtType::LOGICAL )
415- return {ExtType::LOGICAL , std::min (lhs.bitWidth , rhs.bitWidth )};
387+ if (rhs.extType <= ExtType::ZEXT )
388+ return {ExtType::ZEXT , std::min (lhs.bitWidth , rhs.bitWidth )};
416389
417390 // Sign-extension might fill with 1-bits, meaning all bits of the larger
418391 // operand are part of the effective result bitwidth.
@@ -424,17 +397,14 @@ static ExtWidth andWidth(ExtWidth lhs, ExtWidth rhs) {
424397 // For bits the bits inbetween |a| and |b|, sign-extension of the smaller
425398 // operand is still required as the corresponding result bits are dependent
426399 // on the sign of the smaller operand.
427- //
428- // TODO: CONFLICT might be able to be optimized better but needs further
429- // investigation. This case is conservatively correct.
430- return {ExtType::ARITHMETIC, std::max (lhs.bitWidth , rhs.bitWidth )};
400+ return {ExtType::SEXT, std::max (lhs.bitWidth , rhs.bitWidth )};
431401}
432402
433403// / Transfer function for or/xor operations or alike.
434404static ExtWidth orWidth (ExtWidth lhs, ExtWidth rhs) {
435405 canonicalizeCommutativeExtensionType (lhs, rhs);
436- if (rhs.extType <= ExtType::LOGICAL )
437- return {ExtType::LOGICAL , std::max (lhs.bitWidth , rhs.bitWidth )};
406+ if (rhs.extType <= ExtType::ZEXT )
407+ return {ExtType::ZEXT , std::max (lhs.bitWidth , rhs.bitWidth )};
438408 // rhs guaranteed to be at least arithmetic from here on.
439409
440410 // Since rhs was sign-extended the result to continue extending with 1s in the
@@ -447,10 +417,10 @@ static ExtWidth orWidth(ExtWidth lhs, ExtWidth rhs) {
447417 // with 3 bits would be wrong however, since sext(OR 101, sext(01) to i3)
448418 // would extend with 1s, merely due to the bitwidth reduction.
449419 // The extra bit prevents this behavior.
450- if (lhs.extType == ExtType::LOGICAL && lhs.bitWidth > rhs.bitWidth )
451- return {ExtType::ARITHMETIC , 1 + lhs.bitWidth };
420+ if (lhs.extType == ExtType::ZEXT && lhs.bitWidth > rhs.bitWidth )
421+ return {ExtType::SEXT , 1 + lhs.bitWidth };
452422
453- return {ExtType::ARITHMETIC , std::max (lhs.bitWidth , rhs.bitWidth )};
423+ return {ExtType::SEXT , std::max (lhs.bitWidth , rhs.bitWidth )};
454424}
455425
456426// ===----------------------------------------------------------------------===//
@@ -681,7 +651,7 @@ struct HandshakeOptData : public OpRewritePattern<Op> {
681651
682652 // Get the operation's data operands actual widths
683653 SmallVector<ChannelVal> minDataOperands;
684- ExtType ext = ExtType::UNKNOWN ;
654+ ExtType ext = ExtType::NONE ;
685655 llvm::transform (dataOperands, std::back_inserter (minDataOperands),
686656 [&](Value val) {
687657 return getMinimalValue (cast<ChannelVal>(val), &ext);
@@ -755,7 +725,7 @@ struct HandshakeMuxSelect : public OpRewritePattern<handshake::MuxOp> {
755725 // Create a new mux whose select operand is optimized
756726 SmallVector<Value, 3 > newOperands;
757727 newOperands.push_back (
758- modBitWidth ({selectOperand, ExtType::LOGICAL }, optWidth, rewriter));
728+ modBitWidth ({selectOperand, ExtType::ZEXT }, optWidth, rewriter));
759729 auto dataOprds = muxOp.getDataOperands ();
760730 newOperands.append (dataOprds.begin (), dataOprds.end ());
761731 auto newMuxOp = rewriter.create <handshake::MuxOp>(
@@ -804,7 +774,7 @@ struct HandshakeCMergeIndex
804774 cmergeOp.getLoc (), newResultTypes, cmergeOp.getDataOperands (),
805775 cmergeOp->getAttrs ());
806776 namer.replaceOp (cmergeOp, newCmergeOp);
807- Value modIndex = modBitWidth ({newCmergeOp.getIndex (), ExtType::LOGICAL },
777+ Value modIndex = modBitWidth ({newCmergeOp.getIndex (), ExtType::ZEXT },
808778 indexWidth, rewriter);
809779 rewriter.replaceOp (cmergeOp, {newCmergeOp.getResult (), modIndex});
810780 return success ();
@@ -847,7 +817,7 @@ struct MemInterfaceAddrOpt
847817 // by inputIdx, and increment inputIdx before returning the optimized value
848818 auto getOptAddrInput = [&](unsigned inputIdx) {
849819 return modBitWidth ({getMinimalValue (cast<ChannelVal>(operands[inputIdx])),
850- ExtType::LOGICAL },
820+ ExtType::ZEXT },
851821 optWidth, rewriter);
852822 };
853823
@@ -908,7 +878,7 @@ struct MemInterfaceAddrOpt
908878 SmallVector<Value> replacementValues (newMemOp->getResults ());
909879 for (unsigned resIdx : addrResultIndices) {
910880 replacementValues[resIdx] = modBitWidth (
911- {cast<ChannelVal>(replacementValues[resIdx]), ExtType::LOGICAL },
881+ {cast<ChannelVal>(replacementValues[resIdx]), ExtType::ZEXT },
912882 ports.addrWidth , rewriter);
913883 }
914884 inheritBB (memOp, newMemOp);
@@ -944,9 +914,9 @@ struct MemPortAddrOpt
944914 return failure ();
945915
946916 // Derive new operands and result types with the narrrower address type
947- Value newAddr = modBitWidth (
948- {getMinimalValue (portOp.getAddressInput ()), ExtType::LOGICAL}, optWidth ,
949- rewriter);
917+ Value newAddr =
918+ modBitWidth ( {getMinimalValue (portOp.getAddressInput ()), ExtType::ZEXT} ,
919+ optWidth, rewriter);
950920 Value dataIn = portOp.getDataInput ();
951921 SmallVector<Value, 2 > newOperands{newAddr, dataIn};
952922 SmallVector<Type, 2 > newResultTypes{newAddr.getType (), dataIn.getType ()};
@@ -960,7 +930,7 @@ struct MemPortAddrOpt
960930 namer.replaceOp (portOp, newPortOp);
961931 inheritBB (portOp, newPortOp);
962932 Value newAddrRes = modBitWidth (
963- {newPortOp.getAddressOutput (), ExtType::LOGICAL }, addrWidth, rewriter);
933+ {newPortOp.getAddressOutput (), ExtType::ZEXT }, addrWidth, rewriter);
964934 rewriter.replaceOp (portOp, {newAddrRes, newPortOp.getDataOutput ()});
965935 return success ();
966936 }
@@ -1025,7 +995,7 @@ struct ForwardCycleOpt : public OpRewritePattern<Op> {
1025995
1026996 // Determine the achievable optimized width for operands inside the cycle
1027997 unsigned optWidth = 0 ;
1028- ExtType ext = ExtType::UNKNOWN ;
998+ ExtType ext = ExtType::NONE ;
1029999 for (ChannelVal mergedVal : allMergedValues) {
10301000 optWidth = std::max (
10311001 optWidth,
@@ -1113,7 +1083,7 @@ struct ArithSingleType : public OpRewritePattern<Op> {
11131083 return failure ();
11141084
11151085 // Check whether we can reduce the bitwidth of the operation
1116- ExtType extLhs = ExtType::UNKNOWN , extRhs = ExtType::UNKNOWN ;
1086+ ExtType extLhs = ExtType::NONE , extRhs = ExtType::NONE ;
11171087 ChannelVal minLhs = getMinimalValue (op.getLhs (), &extLhs);
11181088 ChannelVal minRhs = getMinimalValue (op.getRhs (), &extRhs);
11191089 ExtWidth optWidth;
@@ -1124,7 +1094,7 @@ struct ArithSingleType : public OpRewritePattern<Op> {
11241094 // It does not matter whether we use sign- or zero-extension in this case
11251095 // since the bits added by the extension are unused by definition.
11261096 // We use zero-extension as it is cheaper and easier to optimize.
1127- optWidth = {ExtType::LOGICAL , getUsefulResultWidth (op.getResult ())};
1097+ optWidth = {ExtType::ZEXT , getUsefulResultWidth (op.getResult ())};
11281098 }
11291099 unsigned resWidth = channelVal.getType ().getDataBitWidth ();
11301100 if (optWidth.bitWidth >= resWidth)
@@ -1164,7 +1134,7 @@ struct ArithSelect : public OpRewritePattern<handshake::SelectOp> {
11641134 return failure ();
11651135
11661136 // Check whether we can reduce the bitwidth of the operation
1167- ExtType extLhs = ExtType::UNKNOWN , extRhs = ExtType::UNKNOWN ;
1137+ ExtType extLhs = ExtType::NONE , extRhs = ExtType::NONE ;
11681138 ChannelVal minLhs = getMinimalValue (selectOp.getTrueValue (), &extLhs);
11691139 ChannelVal minRhs = getMinimalValue (selectOp.getFalseValue (), &extRhs);
11701140 unsigned optWidth;
@@ -1179,8 +1149,8 @@ struct ArithSelect : public OpRewritePattern<handshake::SelectOp> {
11791149
11801150 // Different operand extension types mean that we don't know how to extend
11811151 // the operation's result, so it cannot be optimized
1182- if ((extLhs == ExtType::LOGICAL && extRhs == ExtType::ARITHMETIC ) ||
1183- (extLhs == ExtType::ARITHMETIC && extRhs == ExtType::LOGICAL ))
1152+ if ((extLhs == ExtType::ZEXT && extRhs == ExtType::SEXT ) ||
1153+ (extLhs == ExtType::SEXT && extRhs == ExtType::ZEXT ))
11841154 return failure ();
11851155
11861156 // Create a new operation as well as appropriate bitwidth modification
@@ -1225,7 +1195,7 @@ struct ArithShift : public OpRewritePattern<Op> {
12251195 PatternRewriter &rewriter) const override {
12261196 ChannelVal toShift = op.getLhs ();
12271197 ChannelVal shiftBy = op.getRhs ();
1228- ExtType extToShift = ExtType::UNKNOWN ;
1198+ ExtType extToShift = ExtType::NONE ;
12291199 ChannelVal minToShift = getMinimalValue (toShift, &extToShift);
12301200 ChannelVal minShiftBy = backtrackToMinimalValue (shiftBy);
12311201 bool isRightShift =
@@ -1258,7 +1228,7 @@ struct ArithShift : public OpRewritePattern<Op> {
12581228 Value newToShift =
12591229 modBitWidth ({minToShift, extToShift}, optWidth, rewriter);
12601230 Value newShifyBy =
1261- modBitWidth ({minShiftBy, ExtType::LOGICAL }, optWidth, rewriter);
1231+ modBitWidth ({minShiftBy, ExtType::ZEXT }, optWidth, rewriter);
12621232 rewriter.setInsertionPoint (op);
12631233 auto newOp = rewriter.create <Op>(op.getLoc (), newToShift.getType (),
12641234 newToShift, newShifyBy);
@@ -1285,7 +1255,7 @@ struct ArithShift : public OpRewritePattern<Op> {
12851255 modToShift = modBitWidth ({minToShift, extToShift}, requiredToShiftWidth,
12861256 rewriter);
12871257 }
1288- modArithOp (op, {modToShift, extToShift}, {minShiftBy, ExtType::LOGICAL },
1258+ modArithOp (op, {modToShift, extToShift}, {minShiftBy, ExtType::ZEXT },
12891259 optWidth, extToShift, rewriter, namer);
12901260 }
12911261 return success ();
@@ -1311,7 +1281,7 @@ struct ArithCmpFW : public OpRewritePattern<handshake::CmpIOp> {
13111281 LogicalResult matchAndRewrite (handshake::CmpIOp cmpOp,
13121282 PatternRewriter &rewriter) const override {
13131283 // Check whether we can reduce the bitwidth of the operation
1314- ExtType extLhs = ExtType::UNKNOWN , extRhs = ExtType::UNKNOWN ;
1284+ ExtType extLhs = ExtType::NONE , extRhs = ExtType::NONE ;
13151285 ChannelVal minLhs = getMinimalValue (cmpOp.getLhs (), &extLhs);
13161286 ChannelVal minRhs = getMinimalValue (cmpOp.getRhs (), &extRhs);
13171287 unsigned optWidth = std::max (minLhs.getType ().getDataBitWidth (),
@@ -1352,9 +1322,9 @@ struct ArithExtToTruncOpt : public OpRewritePattern<handshake::TruncIOp> {
13521322 LogicalResult matchAndRewrite (handshake::TruncIOp truncOp,
13531323 PatternRewriter &rewriter) const override {
13541324 // Operand must be produced by an extension operation
1355- ExtType extType = ExtType::UNKNOWN ;
1325+ ExtType extType = ExtType::NONE ;
13561326 ChannelVal minVal = getMinimalValue (truncOp.getIn (), &extType);
1357- if (extType == ExtType::UNKNOWN || extType == ExtType::CONFLICT )
1327+ if (extType == ExtType::NONE )
13581328 return failure ();
13591329
13601330 unsigned finalWidth = truncOp.getResult ().getType ().getDataBitWidth ();
@@ -1403,7 +1373,7 @@ struct ArithBoundOpt : public OpRewritePattern<handshake::ConditionalBranchOp> {
14031373 falseRes = cast<ChannelVal>(condOp.getFalseResult ());
14041374 std::optional<std::pair<unsigned , ExtType>> trueBranch, falseBranch;
14051375 for (handshake::CmpIOp cmpOp : getCmpOps (condOp.getConditionOperand ())) {
1406- ExtType extLhs = ExtType::UNKNOWN , extRhs = ExtType::UNKNOWN ;
1376+ ExtType extLhs = ExtType::NONE , extRhs = ExtType::NONE ;
14071377 ChannelVal minLhs = backtrackToMinimalValue (cmpOp.getLhs (), &extLhs);
14081378 ChannelVal minRhs = backtrackToMinimalValue (cmpOp.getRhs (), &extRhs);
14091379
@@ -1704,6 +1674,8 @@ void HandshakeOptimizeBitwidthsPass::addArithPatterns(
17041674 ArithSelect>(forward, ctx, getAnalysis<NameAnalysis>());
17051675
17061676 patterns.add <ArithExtToTruncOpt>(ctx, getAnalysis<NameAnalysis>());
1677+ handshake::ExtSIOp::getCanonicalizationPatterns (patterns, ctx);
1678+ handshake::ExtUIOp::getCanonicalizationPatterns (patterns, ctx);
17071679}
17081680
17091681void HandshakeOptimizeBitwidthsPass::addHandshakeDataPatterns (
0 commit comments