Skip to content

Commit 6b27a46

Browse files
authored
[HandshakeOptimizeBitwidths] Redefine ExtType to remove CONFLICT (#784)
This PR redefines the `ExtType` enum in the bitwidth optimization pass to 1) remove the pessimization that is the `CONFLICT` value and 2) rename `UNKNOWN` to `NONE`. The `CONFLICT` type previously served as a pessimization that simply stated that some minimum value was extended with both sign and zero extension. For the purpose of optimizations it is rather useless as the precise extension kind is basically always required. The new logic no longer looks through a chain of extensions but rather only the immediate extension and performs optimizations accordingly. Repeated pattern applications plus canonicalization patterns that fold successive extensions lead to a better optimized output instead. The `UNKNOWN` enum value was simply renamed to `NONE`. My rationale for this is that `UNKNOWN` subjectively sounds like a pessimistic value when it reality, it precisely states that the operand was not extended. The new name reflects those semantics. New tests were also added to make sure successive extensions are handled properly
1 parent a119b15 commit 6b27a46

2 files changed

Lines changed: 118 additions & 93 deletions

File tree

lib/Transforms/HandshakeOptimizeBitwidths.cpp

Lines changed: 65 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,16 @@ namespace dynamatic {
5353
// [END Boilerplate code for the MLIR pass]
5454

5555
namespace {
56-
/// Extension type. When backtracing through extension operations, serves to
56+
/// Extension type. When backtracking through extension operations, serves to
5757
/// remember the type of any extension we may have encountered along the
58-
/// way.getMinimalValue Then, when modifying a value's bitwidth, serves to guide
58+
/// way. Then, when modifying a value's bitwidth, serves to guide
5959
/// the determination of which extension operation to use.
60-
/// - UNKNOWN when no extension has been encountered / when a value's signedness
61-
/// should determine its extension type.
62-
/// - LOGICAL when only logical extensions have been encountered / when a value
63-
/// should be logically extended.
64-
/// - ARITHMETIC when only arithmaric extensions have been encountered / when a
65-
/// value should be arithmetically extended.
66-
/// - CONFLICT when both logical and arithmetic extensions have been encountered
67-
/// when it's not possible to accurately determine what type of extension to
68-
/// use for a value.
69-
enum class ExtType { UNKNOWN, LOGICAL, ARITHMETIC, CONFLICT };
60+
/// - NONE when no extension has been encountered.
61+
/// - ZEXT when a zero extension has been encountered / when a value
62+
/// should be zero extended.
63+
/// - SEXT when a sign extension has been encountered / when a
64+
/// value should be sign extended.
65+
enum class ExtType { NONE, ZEXT, SEXT };
7066

7167
/// A channel-typed value.
7268
using ChannelVal = TypedValue<handshake::ChannelType>;
@@ -106,35 +102,23 @@ static ChannelVal getMinimalValue(ChannelVal val, ExtType *ext = nullptr) {
106102
if (!asTypedIfLegal(val))
107103
return val;
108104

109-
// Only backtrack through values that were produced by extension operations
110-
while (Operation *defOp = val.getDefiningOp()) {
111-
if (!isa<handshake::ExtSIOp, handshake::ExtUIOp>(defOp))
112-
return val;
105+
if (auto op = val.getDefiningOp<handshake::ExtSIOp>()) {
106+
if (ext)
107+
*ext = ExtType::SEXT;
113108

114-
// Update the extension type using the nature of the current extension
115-
// operation and the current type
116-
if (ext) {
117-
switch (*ext) {
118-
case ExtType::UNKNOWN:
119-
*ext = isa<handshake::ExtSIOp>(defOp) ? ExtType::ARITHMETIC
120-
: ExtType::LOGICAL;
121-
break;
122-
case ExtType::LOGICAL:
123-
if (isa<handshake::ExtSIOp>(defOp))
124-
*ext = ExtType::CONFLICT;
125-
break;
126-
case ExtType::ARITHMETIC:
127-
if (isa<handshake::ExtUIOp>(defOp))
128-
*ext = ExtType::CONFLICT;
129-
break;
130-
default:
131-
break;
132-
}
133-
}
134-
// Backtrack through the extension operation
135-
val = cast<ChannelVal>(defOp->getOperand(0));
109+
return op.getIn();
110+
}
111+
112+
if (auto op = val.getDefiningOp<handshake::ExtUIOp>()) {
113+
if (ext)
114+
*ext = ExtType::ZEXT;
115+
116+
return op.getIn();
136117
}
137118

119+
if (ext)
120+
*ext = ExtType::NONE;
121+
138122
return val;
139123
}
140124

@@ -215,23 +199,12 @@ static ChannelVal modBitWidth(ExtValue extVal, unsigned targetWidth,
215199
Type dstChannelType = val.getType().withDataType(newDataType);
216200
rewriter.setInsertionPointAfterValue(val);
217201
if (width < targetWidth) {
218-
if (ext == ExtType::CONFLICT) {
219-
// If the extension type is conflicting, just emit a warning and hope for
220-
// the best
221-
Operation *defOp = val.getDefiningOp();
222-
std::string origin;
223-
if (defOp)
224-
origin = "operation result";
225-
else {
226-
defOp = val.getParentBlock()->getParentOp();
227-
origin = "function argument";
228-
}
229-
defOp->emitWarning()
230-
<< "Conflicting extension type given for " << origin
231-
<< ", optimization result may change circuit semantics.";
232-
}
233-
if (ext == ExtType::LOGICAL ||
234-
(ext == ExtType::UNKNOWN &&
202+
// TODO: The code here is wrong for NONE as our IR no-longer carries
203+
// signedness information in the integer type.
204+
// All code should be migrated to pass LOGICAL or ARITHMETIC if it performs
205+
// bitwidth changes.
206+
if (ext == ExtType::ZEXT ||
207+
(ext == ExtType::NONE &&
235208
val.getType().getDataType().isUnsignedInteger())) {
236209
newOp = rewriter.create<handshake::ExtUIOp>(loc, dstChannelType, val);
237210
} else {
@@ -382,21 +355,21 @@ static void canonicalizeCommutativeExtensionType(ExtWidth &lhs, ExtWidth &rhs) {
382355
/// Transfer function for add/sub operations or alike.
383356
static ExtWidth addWidth(ExtWidth lhs, ExtWidth rhs) {
384357
canonicalizeCommutativeExtensionType(lhs, rhs);
385-
if (rhs.extType <= ExtType::LOGICAL)
386-
return {ExtType::LOGICAL, std::max(lhs.bitWidth, rhs.bitWidth) + 1};
358+
if (rhs.extType <= ExtType::ZEXT)
359+
return {ExtType::ZEXT, std::max(lhs.bitWidth, rhs.bitWidth) + 1};
387360

388-
return {ExtType::ARITHMETIC, std::max(lhs.bitWidth, rhs.bitWidth) + 1};
361+
return {ExtType::SEXT, std::max(lhs.bitWidth, rhs.bitWidth) + 1};
389362
}
390363

391364
/// Transfer function for mul operations or alike.
392365
static ExtWidth mulWidth(ExtWidth lhs, ExtWidth rhs) {
393-
return {ExtType::UNKNOWN, lhs.bitWidth + rhs.bitWidth};
366+
return {ExtType::NONE, lhs.bitWidth + rhs.bitWidth};
394367
}
395368

396369
/// Transfer function for div/rem operations or alike.
397370
template <bool zeroExtend>
398371
static ExtWidth divWidth(ExtWidth lhs, ExtWidth _) {
399-
return {zeroExtend ? ExtType::LOGICAL : ExtType::UNKNOWN, lhs.bitWidth + 1};
372+
return {zeroExtend ? ExtType::ZEXT : ExtType::NONE, lhs.bitWidth + 1};
400373
}
401374

402375
/// Transfer function for and operations or alike.
@@ -411,8 +384,8 @@ static ExtWidth andWidth(ExtWidth lhs, ExtWidth rhs) {
411384
// From our example:
412385
// Extending 'a' to 00001 and 'b' to 00101 yields the same result as if ANDing
413386
// "a = 01, b = 01" and zero-extending the result.
414-
if (rhs.extType <= ExtType::LOGICAL)
415-
return {ExtType::LOGICAL, std::min(lhs.bitWidth, rhs.bitWidth)};
387+
if (rhs.extType <= ExtType::ZEXT)
388+
return {ExtType::ZEXT, std::min(lhs.bitWidth, rhs.bitWidth)};
416389

417390
// Sign-extension might fill with 1-bits, meaning all bits of the larger
418391
// operand are part of the effective result bitwidth.
@@ -424,17 +397,14 @@ static ExtWidth andWidth(ExtWidth lhs, ExtWidth rhs) {
424397
// For bits the bits inbetween |a| and |b|, sign-extension of the smaller
425398
// operand is still required as the corresponding result bits are dependent
426399
// on the sign of the smaller operand.
427-
//
428-
// TODO: CONFLICT might be able to be optimized better but needs further
429-
// investigation. This case is conservatively correct.
430-
return {ExtType::ARITHMETIC, std::max(lhs.bitWidth, rhs.bitWidth)};
400+
return {ExtType::SEXT, std::max(lhs.bitWidth, rhs.bitWidth)};
431401
}
432402

433403
/// Transfer function for or/xor operations or alike.
434404
static ExtWidth orWidth(ExtWidth lhs, ExtWidth rhs) {
435405
canonicalizeCommutativeExtensionType(lhs, rhs);
436-
if (rhs.extType <= ExtType::LOGICAL)
437-
return {ExtType::LOGICAL, std::max(lhs.bitWidth, rhs.bitWidth)};
406+
if (rhs.extType <= ExtType::ZEXT)
407+
return {ExtType::ZEXT, std::max(lhs.bitWidth, rhs.bitWidth)};
438408
// rhs guaranteed to be at least arithmetic from here on.
439409

440410
// Since rhs was sign-extended the result to continue extending with 1s in the
@@ -447,10 +417,10 @@ static ExtWidth orWidth(ExtWidth lhs, ExtWidth rhs) {
447417
// with 3 bits would be wrong however, since sext(OR 101, sext(01) to i3)
448418
// would extend with 1s, merely due to the bitwidth reduction.
449419
// The extra bit prevents this behavior.
450-
if (lhs.extType == ExtType::LOGICAL && lhs.bitWidth > rhs.bitWidth)
451-
return {ExtType::ARITHMETIC, 1 + lhs.bitWidth};
420+
if (lhs.extType == ExtType::ZEXT && lhs.bitWidth > rhs.bitWidth)
421+
return {ExtType::SEXT, 1 + lhs.bitWidth};
452422

453-
return {ExtType::ARITHMETIC, std::max(lhs.bitWidth, rhs.bitWidth)};
423+
return {ExtType::SEXT, std::max(lhs.bitWidth, rhs.bitWidth)};
454424
}
455425

456426
//===----------------------------------------------------------------------===//
@@ -681,7 +651,7 @@ struct HandshakeOptData : public OpRewritePattern<Op> {
681651

682652
// Get the operation's data operands actual widths
683653
SmallVector<ChannelVal> minDataOperands;
684-
ExtType ext = ExtType::UNKNOWN;
654+
ExtType ext = ExtType::NONE;
685655
llvm::transform(dataOperands, std::back_inserter(minDataOperands),
686656
[&](Value val) {
687657
return getMinimalValue(cast<ChannelVal>(val), &ext);
@@ -755,7 +725,7 @@ struct HandshakeMuxSelect : public OpRewritePattern<handshake::MuxOp> {
755725
// Create a new mux whose select operand is optimized
756726
SmallVector<Value, 3> newOperands;
757727
newOperands.push_back(
758-
modBitWidth({selectOperand, ExtType::LOGICAL}, optWidth, rewriter));
728+
modBitWidth({selectOperand, ExtType::ZEXT}, optWidth, rewriter));
759729
auto dataOprds = muxOp.getDataOperands();
760730
newOperands.append(dataOprds.begin(), dataOprds.end());
761731
auto newMuxOp = rewriter.create<handshake::MuxOp>(
@@ -804,7 +774,7 @@ struct HandshakeCMergeIndex
804774
cmergeOp.getLoc(), newResultTypes, cmergeOp.getDataOperands(),
805775
cmergeOp->getAttrs());
806776
namer.replaceOp(cmergeOp, newCmergeOp);
807-
Value modIndex = modBitWidth({newCmergeOp.getIndex(), ExtType::LOGICAL},
777+
Value modIndex = modBitWidth({newCmergeOp.getIndex(), ExtType::ZEXT},
808778
indexWidth, rewriter);
809779
rewriter.replaceOp(cmergeOp, {newCmergeOp.getResult(), modIndex});
810780
return success();
@@ -847,7 +817,7 @@ struct MemInterfaceAddrOpt
847817
// by inputIdx, and increment inputIdx before returning the optimized value
848818
auto getOptAddrInput = [&](unsigned inputIdx) {
849819
return modBitWidth({getMinimalValue(cast<ChannelVal>(operands[inputIdx])),
850-
ExtType::LOGICAL},
820+
ExtType::ZEXT},
851821
optWidth, rewriter);
852822
};
853823

@@ -908,7 +878,7 @@ struct MemInterfaceAddrOpt
908878
SmallVector<Value> replacementValues(newMemOp->getResults());
909879
for (unsigned resIdx : addrResultIndices) {
910880
replacementValues[resIdx] = modBitWidth(
911-
{cast<ChannelVal>(replacementValues[resIdx]), ExtType::LOGICAL},
881+
{cast<ChannelVal>(replacementValues[resIdx]), ExtType::ZEXT},
912882
ports.addrWidth, rewriter);
913883
}
914884
inheritBB(memOp, newMemOp);
@@ -944,9 +914,9 @@ struct MemPortAddrOpt
944914
return failure();
945915

946916
// Derive new operands and result types with the narrrower address type
947-
Value newAddr = modBitWidth(
948-
{getMinimalValue(portOp.getAddressInput()), ExtType::LOGICAL}, optWidth,
949-
rewriter);
917+
Value newAddr =
918+
modBitWidth({getMinimalValue(portOp.getAddressInput()), ExtType::ZEXT},
919+
optWidth, rewriter);
950920
Value dataIn = portOp.getDataInput();
951921
SmallVector<Value, 2> newOperands{newAddr, dataIn};
952922
SmallVector<Type, 2> newResultTypes{newAddr.getType(), dataIn.getType()};
@@ -960,7 +930,7 @@ struct MemPortAddrOpt
960930
namer.replaceOp(portOp, newPortOp);
961931
inheritBB(portOp, newPortOp);
962932
Value newAddrRes = modBitWidth(
963-
{newPortOp.getAddressOutput(), ExtType::LOGICAL}, addrWidth, rewriter);
933+
{newPortOp.getAddressOutput(), ExtType::ZEXT}, addrWidth, rewriter);
964934
rewriter.replaceOp(portOp, {newAddrRes, newPortOp.getDataOutput()});
965935
return success();
966936
}
@@ -1025,7 +995,7 @@ struct ForwardCycleOpt : public OpRewritePattern<Op> {
1025995

1026996
// Determine the achievable optimized width for operands inside the cycle
1027997
unsigned optWidth = 0;
1028-
ExtType ext = ExtType::UNKNOWN;
998+
ExtType ext = ExtType::NONE;
1029999
for (ChannelVal mergedVal : allMergedValues) {
10301000
optWidth = std::max(
10311001
optWidth,
@@ -1113,7 +1083,7 @@ struct ArithSingleType : public OpRewritePattern<Op> {
11131083
return failure();
11141084

11151085
// Check whether we can reduce the bitwidth of the operation
1116-
ExtType extLhs = ExtType::UNKNOWN, extRhs = ExtType::UNKNOWN;
1086+
ExtType extLhs = ExtType::NONE, extRhs = ExtType::NONE;
11171087
ChannelVal minLhs = getMinimalValue(op.getLhs(), &extLhs);
11181088
ChannelVal minRhs = getMinimalValue(op.getRhs(), &extRhs);
11191089
ExtWidth optWidth;
@@ -1124,7 +1094,7 @@ struct ArithSingleType : public OpRewritePattern<Op> {
11241094
// It does not matter whether we use sign- or zero-extension in this case
11251095
// since the bits added by the extension are unused by definition.
11261096
// We use zero-extension as it is cheaper and easier to optimize.
1127-
optWidth = {ExtType::LOGICAL, getUsefulResultWidth(op.getResult())};
1097+
optWidth = {ExtType::ZEXT, getUsefulResultWidth(op.getResult())};
11281098
}
11291099
unsigned resWidth = channelVal.getType().getDataBitWidth();
11301100
if (optWidth.bitWidth >= resWidth)
@@ -1164,7 +1134,7 @@ struct ArithSelect : public OpRewritePattern<handshake::SelectOp> {
11641134
return failure();
11651135

11661136
// Check whether we can reduce the bitwidth of the operation
1167-
ExtType extLhs = ExtType::UNKNOWN, extRhs = ExtType::UNKNOWN;
1137+
ExtType extLhs = ExtType::NONE, extRhs = ExtType::NONE;
11681138
ChannelVal minLhs = getMinimalValue(selectOp.getTrueValue(), &extLhs);
11691139
ChannelVal minRhs = getMinimalValue(selectOp.getFalseValue(), &extRhs);
11701140
unsigned optWidth;
@@ -1179,8 +1149,8 @@ struct ArithSelect : public OpRewritePattern<handshake::SelectOp> {
11791149

11801150
// Different operand extension types mean that we don't know how to extend
11811151
// the operation's result, so it cannot be optimized
1182-
if ((extLhs == ExtType::LOGICAL && extRhs == ExtType::ARITHMETIC) ||
1183-
(extLhs == ExtType::ARITHMETIC && extRhs == ExtType::LOGICAL))
1152+
if ((extLhs == ExtType::ZEXT && extRhs == ExtType::SEXT) ||
1153+
(extLhs == ExtType::SEXT && extRhs == ExtType::ZEXT))
11841154
return failure();
11851155

11861156
// Create a new operation as well as appropriate bitwidth modification
@@ -1225,7 +1195,7 @@ struct ArithShift : public OpRewritePattern<Op> {
12251195
PatternRewriter &rewriter) const override {
12261196
ChannelVal toShift = op.getLhs();
12271197
ChannelVal shiftBy = op.getRhs();
1228-
ExtType extToShift = ExtType::UNKNOWN;
1198+
ExtType extToShift = ExtType::NONE;
12291199
ChannelVal minToShift = getMinimalValue(toShift, &extToShift);
12301200
ChannelVal minShiftBy = backtrackToMinimalValue(shiftBy);
12311201
bool isRightShift =
@@ -1258,7 +1228,7 @@ struct ArithShift : public OpRewritePattern<Op> {
12581228
Value newToShift =
12591229
modBitWidth({minToShift, extToShift}, optWidth, rewriter);
12601230
Value newShifyBy =
1261-
modBitWidth({minShiftBy, ExtType::LOGICAL}, optWidth, rewriter);
1231+
modBitWidth({minShiftBy, ExtType::ZEXT}, optWidth, rewriter);
12621232
rewriter.setInsertionPoint(op);
12631233
auto newOp = rewriter.create<Op>(op.getLoc(), newToShift.getType(),
12641234
newToShift, newShifyBy);
@@ -1285,7 +1255,7 @@ struct ArithShift : public OpRewritePattern<Op> {
12851255
modToShift = modBitWidth({minToShift, extToShift}, requiredToShiftWidth,
12861256
rewriter);
12871257
}
1288-
modArithOp(op, {modToShift, extToShift}, {minShiftBy, ExtType::LOGICAL},
1258+
modArithOp(op, {modToShift, extToShift}, {minShiftBy, ExtType::ZEXT},
12891259
optWidth, extToShift, rewriter, namer);
12901260
}
12911261
return success();
@@ -1311,7 +1281,7 @@ struct ArithCmpFW : public OpRewritePattern<handshake::CmpIOp> {
13111281
LogicalResult matchAndRewrite(handshake::CmpIOp cmpOp,
13121282
PatternRewriter &rewriter) const override {
13131283
// Check whether we can reduce the bitwidth of the operation
1314-
ExtType extLhs = ExtType::UNKNOWN, extRhs = ExtType::UNKNOWN;
1284+
ExtType extLhs = ExtType::NONE, extRhs = ExtType::NONE;
13151285
ChannelVal minLhs = getMinimalValue(cmpOp.getLhs(), &extLhs);
13161286
ChannelVal minRhs = getMinimalValue(cmpOp.getRhs(), &extRhs);
13171287
unsigned optWidth = std::max(minLhs.getType().getDataBitWidth(),
@@ -1352,9 +1322,9 @@ struct ArithExtToTruncOpt : public OpRewritePattern<handshake::TruncIOp> {
13521322
LogicalResult matchAndRewrite(handshake::TruncIOp truncOp,
13531323
PatternRewriter &rewriter) const override {
13541324
// Operand must be produced by an extension operation
1355-
ExtType extType = ExtType::UNKNOWN;
1325+
ExtType extType = ExtType::NONE;
13561326
ChannelVal minVal = getMinimalValue(truncOp.getIn(), &extType);
1357-
if (extType == ExtType::UNKNOWN || extType == ExtType::CONFLICT)
1327+
if (extType == ExtType::NONE)
13581328
return failure();
13591329

13601330
unsigned finalWidth = truncOp.getResult().getType().getDataBitWidth();
@@ -1403,7 +1373,7 @@ struct ArithBoundOpt : public OpRewritePattern<handshake::ConditionalBranchOp> {
14031373
falseRes = cast<ChannelVal>(condOp.getFalseResult());
14041374
std::optional<std::pair<unsigned, ExtType>> trueBranch, falseBranch;
14051375
for (handshake::CmpIOp cmpOp : getCmpOps(condOp.getConditionOperand())) {
1406-
ExtType extLhs = ExtType::UNKNOWN, extRhs = ExtType::UNKNOWN;
1376+
ExtType extLhs = ExtType::NONE, extRhs = ExtType::NONE;
14071377
ChannelVal minLhs = backtrackToMinimalValue(cmpOp.getLhs(), &extLhs);
14081378
ChannelVal minRhs = backtrackToMinimalValue(cmpOp.getRhs(), &extRhs);
14091379

@@ -1704,6 +1674,8 @@ void HandshakeOptimizeBitwidthsPass::addArithPatterns(
17041674
ArithSelect>(forward, ctx, getAnalysis<NameAnalysis>());
17051675

17061676
patterns.add<ArithExtToTruncOpt>(ctx, getAnalysis<NameAnalysis>());
1677+
handshake::ExtSIOp::getCanonicalizationPatterns(patterns, ctx);
1678+
handshake::ExtUIOp::getCanonicalizationPatterns(patterns, ctx);
17071679
}
17081680

17091681
void HandshakeOptimizeBitwidthsPass::addHandshakeDataPatterns(

0 commit comments

Comments
 (0)