@@ -1366,19 +1366,88 @@ struct ArithShrUIFW : OpRewritePattern<handshake::ShRUIOp> {
13661366 NameAnalysis &namer;
13671367};
13681368
1369+ // / Optimizes signed right-shifts with a constant as a forward pass.
1370+ struct ArithShrSIFW : OpRewritePattern<handshake::ShRSIOp> {
1371+ ArithShrSIFW (Pass::Statistic &bitwidthReduced, MLIRContext *ctx,
1372+ NameAnalysis &namer)
1373+ : OpRewritePattern(ctx), bitwidthReduced(bitwidthReduced), namer(namer) {}
1374+
1375+ LogicalResult matchAndRewrite (handshake::ShRSIOp op,
1376+ PatternRewriter &rewriter) const override {
1377+ auto [lhs, lhsExt] = getMinimalValueWithExtType (op.getLhs ());
1378+ unsigned inputBitwidth = lhs.getType ().getDataBitWidth ();
1379+ unsigned currentBitwidth = op.getType ().getDataBitWidth ();
1380+ if (inputBitwidth >= currentBitwidth)
1381+ return failure ();
1382+
1383+ assert (lhsExt != ExtType::NONE && " expected an extension" );
1384+ APInt value;
1385+ Value constantControl;
1386+ {
1387+ auto constantOp = op.getRhs ().getDefiningOp <handshake::ConstantOp>();
1388+ if (!constantOp)
1389+ return failure ();
1390+ value = cast<IntegerAttr>(constantOp.getValue ()).getValue ();
1391+ constantControl = constantOp.getCtrl ();
1392+ }
1393+
1394+ // Other pattern (such as canonicalization pattern) should fold this case
1395+ // to a useful constant instead
1396+ if (value.uge (currentBitwidth))
1397+ return failure ();
1398+
1399+ if (lhsExt == ExtType::ZEXT) {
1400+ // We use a generic canonicalization pattern that should fold this into
1401+ // an unsigned shift-right instead.
1402+ return failure ();
1403+ }
1404+
1405+ // SEXT case.
1406+ if (value.ult (inputBitwidth)) {
1407+ // c is less than the input bitwidth, meaning other bits from the input
1408+ // besides the sign-bit are preserved in the output.
1409+
1410+ modArithOp (op, {lhs, lhsExt}, {op.getRhs (), ExtType::NONE}, inputBitwidth,
1411+ ExtType::SEXT, rewriter, namer);
1412+ ++bitwidthReduced;
1413+ return success ();
1414+ }
1415+
1416+ // Our shift amount is larger than the input bitwidth but the input
1417+ // bitwidth is sign-extended. The only thing that remains from the input
1418+ // is the sign-bit.
1419+ Value inputBWM1 = rewriter.create <handshake::ConstantOp>(
1420+ op.getLoc (),
1421+ rewriter.getIntegerAttr (lhs.getType ().getDataType (), inputBitwidth - 1 ),
1422+ constantControl);
1423+ // Shift away all values of lhs other than the sign-bit.
1424+ ChannelVal signBit =
1425+ rewriter.create <handshake::ShRSIOp>(op.getLoc (), lhs, inputBWM1);
1426+ // Fill remaining sign-bit copies.
1427+ rewriter.replaceOpWithNewOp <handshake::ExtSIOp>(op, op.getType (), signBit);
1428+ ++bitwidthReduced;
1429+ return success ();
1430+ }
1431+
1432+ private:
1433+ Pass::Statistic &bitwidthReduced;
1434+ // / A reference to the pass's name analysis.
1435+ NameAnalysis &namer;
1436+ };
1437+
13691438// / Optimizes the bitwidth of shift-type operations. The first template
13701439// / parameter is meant to be either handshake::ShLIOp, handshake::ShRSIOp, or
13711440// / handshake::ShRUIOp. In both modes (forward and backward), the matched
13721441// / operation's bitwidth may only be reduced when the data operand is shifted by
13731442// / a known constant amount.
13741443template <typename Op>
1375- struct ArithShift : public OpRewritePattern <Op> {
1444+ struct ArithShiftBW : public OpRewritePattern <Op> {
13761445 using OpRewritePattern<Op>::OpRewritePattern;
13771446
1378- ArithShift (Pass::Statistic &bitwidthReduced, bool forward , MLIRContext *ctx,
1379- NameAnalysis &namer)
1447+ ArithShiftBW (Pass::Statistic &bitwidthReduced, MLIRContext *ctx,
1448+ NameAnalysis &namer)
13801449 : OpRewritePattern<Op>(ctx), bitwidthReduced(bitwidthReduced),
1381- namer (namer), forward(forward) {}
1450+ namer (namer) {}
13821451
13831452 LogicalResult matchAndRewrite (Op op,
13841453 PatternRewriter &rewriter) const override {
@@ -1396,56 +1465,25 @@ struct ArithShift : public OpRewritePattern<Op> {
13961465 if (Operation *defOp = minShiftBy.getDefiningOp ())
13971466 if (auto cstOp = dyn_cast<handshake::ConstantOp>(defOp)) {
13981467 cstVal = (unsigned )cast<IntegerAttr>(cstOp.getValue ()).getInt ();
1399- if (forward) {
1400- optWidth = minToShift.getType ().getDataBitWidth ();
1401- if (!isRightShift)
1402- optWidth += cstVal;
1403- } else {
1404- optWidth = getUsefulResultWidth (op.getResult ());
1405- if (isRightShift)
1406- optWidth += cstVal;
1407- }
1468+ optWidth = getUsefulResultWidth (op.getResult ());
1469+ if (isRightShift)
1470+ optWidth += cstVal;
14081471 }
14091472
14101473 if (optWidth >= resWidth)
14111474 return failure ();
14121475
1413- if (forward) {
1414- // Create a new operation as well as appropriate bitwidth modification
1415- // operations to keep the IR valid
1416- Value newToShift =
1417- modBitWidth ({minToShift, extToShift}, optWidth, rewriter);
1418- Value newShifyBy =
1419- modBitWidth ({minShiftBy, ExtType::ZEXT}, optWidth, rewriter);
1420- rewriter.setInsertionPoint (op);
1421- auto newOp = rewriter.create <Op>(op.getLoc (), newToShift.getType (),
1422- newToShift, newShifyBy);
1423- ChannelVal newRes = newOp.getResult ();
1424- if (isRightShift)
1425- // In the case of a right shift, we first truncate the result of the
1426- // newly inserted shift operation to discard high-significance bits that
1427- // we know are 0s, then extend the result back to satisfy the users of
1428- // the original operation's result
1429- newRes = modBitWidth ({newRes, extToShift}, optWidth - cstVal, rewriter);
1430- Value modRes = modBitWidth ({newRes, extToShift}, resWidth, rewriter);
1431- inheritBB (op, newOp);
1432-
1433- // Replace uses of the original operation's result with the result of the
1434- // optimized operation we just created
1435- rewriter.replaceOp (op, modRes);
1436- } else {
1437- ChannelVal modToShift = minToShift;
1438- if (!isRightShift) {
1439- // In the case of a left shift, we first truncate the shifted integer to
1440- // discard high-significance bits that were discarded in the result,
1441- // then extend back to satisfy the users of the original integer
1442- unsigned requiredToShiftWidth = optWidth - std::min (cstVal, optWidth);
1443- modToShift = modBitWidth ({minToShift, extToShift}, requiredToShiftWidth,
1444- rewriter);
1445- }
1446- modArithOp (op, {modToShift, extToShift}, {minShiftBy, ExtType::ZEXT},
1447- optWidth, extToShift, rewriter, namer);
1476+ ChannelVal modToShift = minToShift;
1477+ if (!isRightShift) {
1478+ // In the case of a left shift, we first truncate the shifted integer to
1479+ // discard high-significance bits that were discarded in the result,
1480+ // then extend back to satisfy the users of the original integer
1481+ unsigned requiredToShiftWidth = optWidth - std::min (cstVal, optWidth);
1482+ modToShift =
1483+ modBitWidth ({minToShift, extToShift}, requiredToShiftWidth, rewriter);
14481484 }
1485+ modArithOp (op, {modToShift, extToShift}, {minShiftBy, ExtType::ZEXT},
1486+ optWidth, extToShift, rewriter, namer);
14491487 ++bitwidthReduced;
14501488 return success ();
14511489 }
@@ -1872,19 +1910,20 @@ void HandshakeOptimizeBitwidthsPass::addArithPatterns(
18721910 // is dangerous if the shift is used as multiplication.
18731911 // Therefore, removing "ArithShift<handshake::ShLIOp>" from the patterns for
18741912 // now
1875- patterns.add <ArithShift<handshake::ShRSIOp>, ArithSelect>(
1876- bitwidthReduced, forward, ctx, getAnalysis<NameAnalysis>());
1877- if (! forward)
1878- patterns.add <ArithShift<handshake::ShRUIOp>> (bitwidthReduced, forward , ctx,
1879- getAnalysis<NameAnalysis>());
1913+ patterns.add <ArithSelect>(bitwidthReduced, forward, ctx,
1914+ getAnalysis<NameAnalysis>());
1915+ if (forward)
1916+ patterns.add <ArithShrUIFW, ArithShrSIFW> (bitwidthReduced, ctx,
1917+ getAnalysis<NameAnalysis>());
18801918 else
1881- patterns.add <ArithShrUIFW>(bitwidthReduced, ctx,
1882- getAnalysis<NameAnalysis>());
1919+ patterns.add <ArithShiftBW<handshake::ShRSIOp>,
1920+ ArithShiftBW<handshake::ShRUIOp>>(bitwidthReduced, ctx,
1921+ getAnalysis<NameAnalysis>());
18831922
18841923 patterns.add <ArithExtToTruncOpt>(bitwidthReduced, ctx,
18851924 getAnalysis<NameAnalysis>());
1886- handshake::ExtSIOp::getCanonicalizationPatterns (patterns, ctx);
1887- handshake::ExtUIOp:: getCanonicalizationPatterns (patterns, ctx );
1925+ ctx-> getLoadedDialect < handshake::HandshakeDialect>()
1926+ -> getCanonicalizationPatterns (patterns);
18881927}
18891928
18901929void HandshakeOptimizeBitwidthsPass::addHandshakeDataPatterns (
0 commit comments