Skip to content

Commit a119b15

Browse files
authored
[HandshakeOptimizeBitwidths] Fix bitwidth reduction of sign-extended or (#778)
The transfer function of `ori` currently always uses zero-extension. This is incorrect when bitwidth reduction occurs and the operands were originally sign-extended as the upper bits that might have previously been ones, would now be 0 through the zero-extension. This PR fixes that by adjusting the transfer function to sign-extend when one of its inputs is sign-extended. In the case that one operand is zero-extended, the computation additionally needs to be performed with an extra bit to not quasi sign-extend both operands. Alive 2 proofs: https://alive2.llvm.org/ce/z/Rze2eL https://alive2.llvm.org/ce/z/NJAWFy https://alive2.llvm.org/ce/z/amxV5R Fixes #765 Depends on #766
1 parent 0bf7f50 commit a119b15

2 files changed

Lines changed: 86 additions & 4 deletions

File tree

lib/Transforms/HandshakeOptimizeBitwidths.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,22 @@ struct ExtWidth {
366366
/// If needed, swaps the operands such that 'rhs' contains the larger extension
367367
/// type.
368368
/// This allows eliminating symmetrical cases in commutative operations.
369-
static void ignoreCommutativity(ExtWidth &lhs, ExtWidth &rhs) {
369+
/// Users are allowed to assume that 'rhs.extType >= lhs.extType'.
370+
/// The cases that need to be handled are then only:
371+
/// * NONE, NONE
372+
/// * NONE, LOGICAL
373+
/// * NONE, ARITHMETIC
374+
/// * LOGICAL, LOGICAL
375+
/// * LOGICAL, ARITHMETIC
376+
/// * ARITHMETIC, ARITHMETIC
377+
static void canonicalizeCommutativeExtensionType(ExtWidth &lhs, ExtWidth &rhs) {
370378
if (lhs.extType > rhs.extType)
371379
std::swap(lhs, rhs);
372380
}
373381

374382
/// Transfer function for add/sub operations or alike.
375383
static ExtWidth addWidth(ExtWidth lhs, ExtWidth rhs) {
376-
ignoreCommutativity(lhs, rhs);
384+
canonicalizeCommutativeExtensionType(lhs, rhs);
377385
if (rhs.extType <= ExtType::LOGICAL)
378386
return {ExtType::LOGICAL, std::max(lhs.bitWidth, rhs.bitWidth) + 1};
379387

@@ -393,7 +401,7 @@ static ExtWidth divWidth(ExtWidth lhs, ExtWidth _) {
393401

394402
/// Transfer function for and operations or alike.
395403
static ExtWidth andWidth(ExtWidth lhs, ExtWidth rhs) {
396-
ignoreCommutativity(lhs, rhs);
404+
canonicalizeCommutativeExtensionType(lhs, rhs);
397405
// Given two operands such as "a = 01, b = 101":
398406
// If both operands are zero-extended or not extended at all, then the
399407
// effective bitwidth is whichever is smaller since 1) any bits beyond
@@ -424,7 +432,25 @@ static ExtWidth andWidth(ExtWidth lhs, ExtWidth rhs) {
424432

425433
/// Transfer function for or/xor operations or alike.
426434
static ExtWidth orWidth(ExtWidth lhs, ExtWidth rhs) {
427-
return {ExtType::LOGICAL, std::max(lhs.bitWidth, rhs.bitWidth)};
435+
canonicalizeCommutativeExtensionType(lhs, rhs);
436+
if (rhs.extType <= ExtType::LOGICAL)
437+
return {ExtType::LOGICAL, std::max(lhs.bitWidth, rhs.bitWidth)};
438+
// rhs guaranteed to be at least arithmetic from here on.
439+
440+
// Since rhs was sign-extended the result to continue extending with 1s in the
441+
// case rhs extends with 1s.
442+
// However, if lhs is zero-extended and the larger bitwidth, we need to add
443+
// one extra bit to the result such the result isn't sign-extended using the
444+
// sign-bit of lhs.
445+
// Example: lhs = zext(101) to i32, rhs = sext(01) to i32.
446+
// We must perform the OR using 3 bits at the very least. Performing it
447+
// with 3 bits would be wrong however, since sext(OR 101, sext(01) to i3)
448+
// would extend with 1s, merely due to the bitwidth reduction.
449+
// The extra bit prevents this behavior.
450+
if (lhs.extType == ExtType::LOGICAL && lhs.bitWidth > rhs.bitWidth)
451+
return {ExtType::ARITHMETIC, 1 + lhs.bitWidth};
452+
453+
return {ExtType::ARITHMETIC, std::max(lhs.bitWidth, rhs.bitWidth)};
428454
}
429455

430456
//===----------------------------------------------------------------------===//
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
2+
// RUN: dynamatic-opt --handshake-optimize-bitwidths --remove-operation-names %s --split-input-file | FileCheck %s
3+
4+
// CHECK-LABEL: handshake.func @test_and_sext_sext(
5+
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i1>,
6+
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel<i16>,
7+
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> (!handshake.channel<i32>, !handshake.control<>) attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0", "end"]} {
8+
// CHECK: %[[VAL_3:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : <i1> to <i16>
9+
// CHECK: %[[VAL_4:.*]] = ori %[[VAL_3]], %[[VAL_1]] {handshake.bb = 0 : ui32} : <i16>
10+
// CHECK: %[[VAL_5:.*]] = extsi %[[VAL_4]] : <i16> to <i32>
11+
// CHECK: end {handshake.bb = 0 : ui32} %[[VAL_5]], %[[VAL_2]] : <i32>, <>
12+
// CHECK: }
13+
handshake.func @test_and_sext_sext(%arg0: !handshake.channel<i1>, %arg1: !handshake.channel<i16>, %arg2: !handshake.control<>, ...) -> (!handshake.channel<i32>, !handshake.control<>) attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0", "end"]} {
14+
%0 = extsi %arg0 {handshake.bb = 0 : ui32, handshake.name = "extsi0"} : <i1> to <i32>
15+
%1 = extsi %arg1 {handshake.bb = 0 : ui32, handshake.name = "extsi1"} : <i16> to <i32>
16+
%2 = ori %0, %1 {handshake.bb = 0 : ui32, handshake.name = "andi0"} : <i32>
17+
end {handshake.bb = 0 : ui32, handshake.name = "end0"} %2, %arg2 : <i32>, <>
18+
}
19+
20+
// -----
21+
22+
// CHECK-LABEL: handshake.func @test_and_sext_zext(
23+
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i1>,
24+
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel<i16>,
25+
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> (!handshake.channel<i32>, !handshake.control<>) attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0", "end"]} {
26+
// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : <i16> to <i17>
27+
// CHECK: %[[VAL_4:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : <i1> to <i17>
28+
// CHECK: %[[VAL_5:.*]] = ori %[[VAL_4]], %[[VAL_3]] {handshake.bb = 0 : ui32} : <i17>
29+
// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : <i17> to <i32>
30+
// CHECK: end {handshake.bb = 0 : ui32} %[[VAL_6]], %[[VAL_2]] : <i32>, <>
31+
// CHECK: }
32+
handshake.func @test_and_sext_zext(%arg0: !handshake.channel<i1>, %arg1: !handshake.channel<i16>, %arg2: !handshake.control<>, ...) -> (!handshake.channel<i32>, !handshake.control<>) attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0", "end"]} {
33+
%0 = extsi %arg0 {handshake.bb = 0 : ui32, handshake.name = "extsi0"} : <i1> to <i32>
34+
%1 = extui %arg1 {handshake.bb = 0 : ui32, handshake.name = "extsi1"} : <i16> to <i32>
35+
%2 = ori %0, %1 {handshake.bb = 0 : ui32, handshake.name = "andi0"} : <i32>
36+
end {handshake.bb = 0 : ui32, handshake.name = "end0"} %2, %arg2 : <i32>, <>
37+
}
38+
39+
40+
// -----
41+
42+
// CHECK-LABEL: handshake.func @test_and_zext_sext(
43+
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i1>,
44+
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel<i16>,
45+
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> (!handshake.channel<i32>, !handshake.control<>) attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0", "end"]} {
46+
// CHECK: %[[VAL_3:.*]] = extui %[[VAL_0]] {handshake.bb = 0 : ui32} : <i1> to <i16>
47+
// CHECK: %[[VAL_4:.*]] = ori %[[VAL_3]], %[[VAL_1]] {handshake.bb = 0 : ui32} : <i16>
48+
// CHECK: %[[VAL_5:.*]] = extsi %[[VAL_4]] : <i16> to <i32>
49+
// CHECK: end {handshake.bb = 0 : ui32} %[[VAL_5]], %[[VAL_2]] : <i32>, <>
50+
// CHECK: }
51+
handshake.func @test_and_zext_sext(%arg0: !handshake.channel<i1>, %arg1: !handshake.channel<i16>, %arg2: !handshake.control<>, ...) -> (!handshake.channel<i32>, !handshake.control<>) attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0", "end"]} {
52+
%0 = extui %arg0 {handshake.bb = 0 : ui32, handshake.name = "extsi0"} : <i1> to <i32>
53+
%1 = extsi %arg1 {handshake.bb = 0 : ui32, handshake.name = "extsi1"} : <i16> to <i32>
54+
%2 = ori %0, %1 {handshake.bb = 0 : ui32, handshake.name = "andi0"} : <i32>
55+
end {handshake.bb = 0 : ui32, handshake.name = "end0"} %2, %arg2 : <i32>, <>
56+
}

0 commit comments

Comments
 (0)