Skip to content

Commit e8a51c0

Browse files
committed
[HandshakeOptimizeBitwidths] Fix forward logic for shrsi
The previous logic for `shrsi` for the forward pass often crashed in edge cases such as the shift amount was larger than the bitwidth. This PR rewrites the forward logic for `shrsi` into a dedicated pattern. In the case of the input of `shrsi` being zero-extended we just optimize it to a `shrui` and reuse the existing optimization logic there. Fixes #792
1 parent a91b53a commit e8a51c0

5 files changed

Lines changed: 111 additions & 75 deletions

File tree

include/dynamatic/Dialect/Handshake/Handshake.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def Handshake_Dialect : Dialect {
4141
let useDefaultTypePrinterParser = 1;
4242
let useDefaultAttributePrinterParser = 1;
4343
let usePropertiesForAttributes = 0;
44+
let hasCanonicalizer = 1;
4445
}
4546

4647
include "dynamatic/Dialect/Handshake/HandshakeAttributes.td"

include/dynamatic/Dialect/Handshake/HandshakeArithOps.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,10 @@ def Handshake_MinUIOp : Handshake_Arith_IntBinaryOp<"minui"> {
310310

311311
def Handshake_ExtSIOp : Handshake_Arith_IToICastOp<"extsi"> {
312312
let summary = "Integer unsigned width extension.";
313-
let hasCanonicalizer = 1;
314313
}
315314

316315
def Handshake_ExtUIOp : Handshake_Arith_IToICastOp<"extui"> {
317316
let summary = "Integer signed width extension.";
318-
let hasCanonicalizer = 1;
319317
}
320318

321319
def Handshake_MaximumFOp : Handshake_Arith_FloatBinaryOp<"maximumf", [
@@ -410,7 +408,6 @@ def Handshake_SubIOp : Handshake_Arith_IntBinaryOp<"subi"> {
410408

411409
def Handshake_TruncIOp : Handshake_Arith_IToICastOp<"trunci"> {
412410
let summary = "Integer truncation.";
413-
let hasCanonicalizer = 1;
414411
}
415412

416413
def Handshake_TruncFOp : Handshake_Arith_FToFCastOp<"truncf", [

lib/Dialect/Handshake/HandshakeCanonicalization.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,13 @@ def TruncIExtUIToExtUI : Pat<
6060
[(ValueWiderThan $ext, $tr), (ValueWiderThan $tr, $x)]
6161
>;
6262

63+
//===----------------------------------------------------------------------===//
64+
// ShrSIOp
65+
//===----------------------------------------------------------------------===//
66+
67+
def ShrSIOpOfExtUI : Pat<
68+
(Handshake_ShRSIOp (Handshake_ExtUIOp:$ext $_), $shift),
69+
(Handshake_ShRUIOp $ext, $shift)
70+
>;
71+
6372
#endif // DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_CANONICALIZATION_TD

lib/Dialect/Handshake/HandshakeOps.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ namespace {
206206
#include "lib/Dialect/Handshake/HandshakeCanonicalization.inc"
207207
} // namespace
208208

209+
void handshake::HandshakeDialect::getCanonicalizationPatterns(
210+
RewritePatternSet &set) const {
211+
populateWithGenerated(set);
212+
}
213+
209214
//===----------------------------------------------------------------------===//
210215
// MergeOp
211216
//===----------------------------------------------------------------------===//
@@ -1896,11 +1901,6 @@ static OpFoldResult foldExtOp(Op op) {
18961901
return nullptr;
18971902
}
18981903

1899-
void ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &results,
1900-
MLIRContext *context) {
1901-
results.add<ExtSIOfExtUI, ExtSIOfConst>(context);
1902-
}
1903-
19041904
OpFoldResult ExtSIOp::fold(FoldAdaptor adaptor) { return foldExtOp(*this); }
19051905

19061906
/// Extension operations can only extend to a channel with a wider data type and
@@ -1927,22 +1927,12 @@ LogicalResult ExtSIOp::verify() { return verifyExtOp(*this); }
19271927

19281928
OpFoldResult ExtUIOp::fold(FoldAdaptor adaptor) { return foldExtOp(*this); }
19291929

1930-
void ExtUIOp::getCanonicalizationPatterns(RewritePatternSet &results,
1931-
MLIRContext *context) {
1932-
results.add<ExtUIOfConst>(context);
1933-
}
1934-
19351930
LogicalResult ExtUIOp::verify() { return verifyExtOp(*this); }
19361931

19371932
//===----------------------------------------------------------------------===//
19381933
// TruncIOp
19391934
//===----------------------------------------------------------------------===//
19401935

1941-
void TruncIOp::getCanonicalizationPatterns(RewritePatternSet &results,
1942-
MLIRContext *context) {
1943-
results.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI>(context);
1944-
}
1945-
19461936
OpFoldResult TruncIOp::fold(FoldAdaptor adaptor) {
19471937
if (auto defTruncOp = getIn().getDefiningOp<TruncIOp>()) {
19481938
// Bypass the preceeding truncation operation

lib/Transforms/HandshakeOptimizeBitwidths.cpp

Lines changed: 96 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,19 +1366,88 @@ struct ArithShrUIFW : OpRewritePattern<handshake::ShRUIOp> {
13661366
NameAnalysis &namer;
13671367
};
13681368

1369+
/// Optimizes signed right-shifts with a constant as a forward pass.
1370+
struct ArithShrSIFW : OpRewritePattern<handshake::ShRSIOp> {
1371+
ArithShrSIFW(Pass::Statistic &bitwidthReduced, MLIRContext *ctx,
1372+
NameAnalysis &namer)
1373+
: OpRewritePattern(ctx), bitwidthReduced(bitwidthReduced), namer(namer) {}
1374+
1375+
LogicalResult matchAndRewrite(handshake::ShRSIOp op,
1376+
PatternRewriter &rewriter) const override {
1377+
auto [lhs, lhsExt] = getMinimalValueWithExtType(op.getLhs());
1378+
unsigned inputBitwidth = lhs.getType().getDataBitWidth();
1379+
unsigned currentBitwidth = op.getType().getDataBitWidth();
1380+
if (inputBitwidth >= currentBitwidth)
1381+
return failure();
1382+
1383+
assert(lhsExt != ExtType::NONE && "expected an extension");
1384+
APInt value;
1385+
Value constantControl;
1386+
{
1387+
auto constantOp = op.getRhs().getDefiningOp<handshake::ConstantOp>();
1388+
if (!constantOp)
1389+
return failure();
1390+
value = cast<IntegerAttr>(constantOp.getValue()).getValue();
1391+
constantControl = constantOp.getCtrl();
1392+
}
1393+
1394+
// Other pattern (such as canonicalization pattern) should fold this case
1395+
// to a useful constant instead
1396+
if (value.uge(currentBitwidth))
1397+
return failure();
1398+
1399+
if (lhsExt == ExtType::ZEXT) {
1400+
// We use a generic canonicalization pattern that should fold this into
1401+
// an unsigned shift-right instead.
1402+
return failure();
1403+
}
1404+
1405+
// SEXT case.
1406+
if (value.ult(inputBitwidth)) {
1407+
// c is less than the input bitwidth, meaning other bits from the input
1408+
// besides the sign-bit are preserved in the output.
1409+
1410+
modArithOp(op, {lhs, lhsExt}, {op.getRhs(), ExtType::NONE}, inputBitwidth,
1411+
ExtType::SEXT, rewriter, namer);
1412+
++bitwidthReduced;
1413+
return success();
1414+
}
1415+
1416+
// Our shift amount is larger than the input bitwidth but the input
1417+
// bitwidth is sign-extended. The only thing that remains from the input
1418+
// is the sign-bit.
1419+
Value inputBWM1 = rewriter.create<handshake::ConstantOp>(
1420+
op.getLoc(),
1421+
rewriter.getIntegerAttr(lhs.getType().getDataType(), inputBitwidth - 1),
1422+
constantControl);
1423+
// Shift away all values of lhs other than the sign-bit.
1424+
ChannelVal signBit =
1425+
rewriter.create<handshake::ShRSIOp>(op.getLoc(), lhs, inputBWM1);
1426+
// Fill remaining sign-bit copies.
1427+
rewriter.replaceOpWithNewOp<handshake::ExtSIOp>(op, op.getType(), signBit);
1428+
++bitwidthReduced;
1429+
return success();
1430+
}
1431+
1432+
private:
1433+
Pass::Statistic &bitwidthReduced;
1434+
/// A reference to the pass's name analysis.
1435+
NameAnalysis &namer;
1436+
};
1437+
13691438
/// Optimizes the bitwidth of shift-type operations. The first template
13701439
/// parameter is meant to be either handshake::ShLIOp, handshake::ShRSIOp, or
13711440
/// handshake::ShRUIOp. In both modes (forward and backward), the matched
13721441
/// operation's bitwidth may only be reduced when the data operand is shifted by
13731442
/// a known constant amount.
13741443
template <typename Op>
1375-
struct ArithShift : public OpRewritePattern<Op> {
1444+
struct ArithShiftBW : public OpRewritePattern<Op> {
13761445
using OpRewritePattern<Op>::OpRewritePattern;
13771446

1378-
ArithShift(Pass::Statistic &bitwidthReduced, bool forward, MLIRContext *ctx,
1379-
NameAnalysis &namer)
1447+
ArithShiftBW(Pass::Statistic &bitwidthReduced, MLIRContext *ctx,
1448+
NameAnalysis &namer)
13801449
: OpRewritePattern<Op>(ctx), bitwidthReduced(bitwidthReduced),
1381-
namer(namer), forward(forward) {}
1450+
namer(namer) {}
13821451

13831452
LogicalResult matchAndRewrite(Op op,
13841453
PatternRewriter &rewriter) const override {
@@ -1396,56 +1465,25 @@ struct ArithShift : public OpRewritePattern<Op> {
13961465
if (Operation *defOp = minShiftBy.getDefiningOp())
13971466
if (auto cstOp = dyn_cast<handshake::ConstantOp>(defOp)) {
13981467
cstVal = (unsigned)cast<IntegerAttr>(cstOp.getValue()).getInt();
1399-
if (forward) {
1400-
optWidth = minToShift.getType().getDataBitWidth();
1401-
if (!isRightShift)
1402-
optWidth += cstVal;
1403-
} else {
1404-
optWidth = getUsefulResultWidth(op.getResult());
1405-
if (isRightShift)
1406-
optWidth += cstVal;
1407-
}
1468+
optWidth = getUsefulResultWidth(op.getResult());
1469+
if (isRightShift)
1470+
optWidth += cstVal;
14081471
}
14091472

14101473
if (optWidth >= resWidth)
14111474
return failure();
14121475

1413-
if (forward) {
1414-
// Create a new operation as well as appropriate bitwidth modification
1415-
// operations to keep the IR valid
1416-
Value newToShift =
1417-
modBitWidth({minToShift, extToShift}, optWidth, rewriter);
1418-
Value newShifyBy =
1419-
modBitWidth({minShiftBy, ExtType::ZEXT}, optWidth, rewriter);
1420-
rewriter.setInsertionPoint(op);
1421-
auto newOp = rewriter.create<Op>(op.getLoc(), newToShift.getType(),
1422-
newToShift, newShifyBy);
1423-
ChannelVal newRes = newOp.getResult();
1424-
if (isRightShift)
1425-
// In the case of a right shift, we first truncate the result of the
1426-
// newly inserted shift operation to discard high-significance bits that
1427-
// we know are 0s, then extend the result back to satisfy the users of
1428-
// the original operation's result
1429-
newRes = modBitWidth({newRes, extToShift}, optWidth - cstVal, rewriter);
1430-
Value modRes = modBitWidth({newRes, extToShift}, resWidth, rewriter);
1431-
inheritBB(op, newOp);
1432-
1433-
// Replace uses of the original operation's result with the result of the
1434-
// optimized operation we just created
1435-
rewriter.replaceOp(op, modRes);
1436-
} else {
1437-
ChannelVal modToShift = minToShift;
1438-
if (!isRightShift) {
1439-
// In the case of a left shift, we first truncate the shifted integer to
1440-
// discard high-significance bits that were discarded in the result,
1441-
// then extend back to satisfy the users of the original integer
1442-
unsigned requiredToShiftWidth = optWidth - std::min(cstVal, optWidth);
1443-
modToShift = modBitWidth({minToShift, extToShift}, requiredToShiftWidth,
1444-
rewriter);
1445-
}
1446-
modArithOp(op, {modToShift, extToShift}, {minShiftBy, ExtType::ZEXT},
1447-
optWidth, extToShift, rewriter, namer);
1476+
ChannelVal modToShift = minToShift;
1477+
if (!isRightShift) {
1478+
// In the case of a left shift, we first truncate the shifted integer to
1479+
// discard high-significance bits that were discarded in the result,
1480+
// then extend back to satisfy the users of the original integer
1481+
unsigned requiredToShiftWidth = optWidth - std::min(cstVal, optWidth);
1482+
modToShift =
1483+
modBitWidth({minToShift, extToShift}, requiredToShiftWidth, rewriter);
14481484
}
1485+
modArithOp(op, {modToShift, extToShift}, {minShiftBy, ExtType::ZEXT},
1486+
optWidth, extToShift, rewriter, namer);
14491487
++bitwidthReduced;
14501488
return success();
14511489
}
@@ -1872,19 +1910,20 @@ void HandshakeOptimizeBitwidthsPass::addArithPatterns(
18721910
// is dangerous if the shift is used as multiplication.
18731911
// Therefore, removing "ArithShift<handshake::ShLIOp>" from the patterns for
18741912
// now
1875-
patterns.add<ArithShift<handshake::ShRSIOp>, ArithSelect>(
1876-
bitwidthReduced, forward, ctx, getAnalysis<NameAnalysis>());
1877-
if (!forward)
1878-
patterns.add<ArithShift<handshake::ShRUIOp>>(bitwidthReduced, forward, ctx,
1879-
getAnalysis<NameAnalysis>());
1913+
patterns.add<ArithSelect>(bitwidthReduced, forward, ctx,
1914+
getAnalysis<NameAnalysis>());
1915+
if (forward)
1916+
patterns.add<ArithShrUIFW, ArithShrSIFW>(bitwidthReduced, ctx,
1917+
getAnalysis<NameAnalysis>());
18801918
else
1881-
patterns.add<ArithShrUIFW>(bitwidthReduced, ctx,
1882-
getAnalysis<NameAnalysis>());
1919+
patterns.add<ArithShiftBW<handshake::ShRSIOp>,
1920+
ArithShiftBW<handshake::ShRUIOp>>(bitwidthReduced, ctx,
1921+
getAnalysis<NameAnalysis>());
18831922

18841923
patterns.add<ArithExtToTruncOpt>(bitwidthReduced, ctx,
18851924
getAnalysis<NameAnalysis>());
1886-
handshake::ExtSIOp::getCanonicalizationPatterns(patterns, ctx);
1887-
handshake::ExtUIOp::getCanonicalizationPatterns(patterns, ctx);
1925+
ctx->getLoadedDialect<handshake::HandshakeDialect>()
1926+
->getCanonicalizationPatterns(patterns);
18881927
}
18891928

18901929
void HandshakeOptimizeBitwidthsPass::addHandshakeDataPatterns(

0 commit comments

Comments
 (0)