Skip to content

Commit b043a4e

Browse files
authored
Set batch size 1 for binary AF matrix tests (#4604)
Fixes #4603
1 parent 0ea9bce commit b043a4e

6 files changed

Lines changed: 22 additions & 12 deletions

File tree

src/webgpu/shader/execution/expression/binary/af_matrix_addition.spec.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { Type } from '../../../../util/conversion.js';
88
import { onlyConstInputSource, run } from '../expression.js';
99

1010
import { d } from './af_matrix_addition.cache.js';
11-
import { abstractFloatBinary } from './binary.js';
11+
import { abstractFloatBinary, kAbstractFloatMatrixBinaryOpBatchSize } from './binary.js';
1212

1313
export const g = makeTestGroup(AllFeaturesMaxLimitsGPUTest);
1414

@@ -36,6 +36,7 @@ Accuracy: Correctly rounded
3636
[Type.mat(cols, rows, Type.abstractFloat), Type.mat(cols, rows, Type.abstractFloat)],
3737
Type.mat(cols, rows, Type.abstractFloat),
3838
t.params,
39-
cases
39+
cases,
40+
kAbstractFloatMatrixBinaryOpBatchSize
4041
);
4142
});

src/webgpu/shader/execution/expression/binary/af_matrix_matrix_multiplication.spec.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { Type } from '../../../../util/conversion.js';
88
import { onlyConstInputSource, run } from '../expression.js';
99

1010
import { d } from './af_matrix_matrix_multiplication.cache.js';
11-
import { abstractFloatBinary } from './binary.js';
11+
import { abstractFloatBinary, kAbstractFloatMatrixBinaryOpBatchSize } from './binary.js';
1212

1313
export const g = makeTestGroup(AllFeaturesMaxLimitsGPUTest);
1414

@@ -40,6 +40,7 @@ Accuracy: Correctly rounded
4040
[Type.mat(x_cols, x_rows, Type.abstractFloat), Type.mat(y_cols, y_rows, Type.abstractFloat)],
4141
Type.mat(y_cols, x_rows, Type.abstractFloat),
4242
t.params,
43-
cases
43+
cases,
44+
kAbstractFloatMatrixBinaryOpBatchSize
4445
);
4546
});

src/webgpu/shader/execution/expression/binary/af_matrix_scalar_multiplication.spec.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { Type } from '../../../../util/conversion.js';
88
import { onlyConstInputSource, run } from '../expression.js';
99

1010
import { d } from './af_matrix_scalar_multiplication.cache.js';
11-
import { abstractFloatBinary } from './binary.js';
11+
import { abstractFloatBinary, kAbstractFloatMatrixBinaryOpBatchSize } from './binary.js';
1212

1313
export const g = makeTestGroup(AllFeaturesMaxLimitsGPUTest);
1414

@@ -36,7 +36,8 @@ Accuracy: Correctly rounded
3636
[Type.mat(cols, rows, Type.abstractFloat), Type.abstractFloat],
3737
Type.mat(cols, rows, Type.abstractFloat),
3838
t.params,
39-
cases
39+
cases,
40+
kAbstractFloatMatrixBinaryOpBatchSize
4041
);
4142
});
4243

@@ -64,6 +65,7 @@ Accuracy: Correctly rounded
6465
[Type.abstractFloat, Type.mat(cols, rows, Type.abstractFloat)],
6566
Type.mat(cols, rows, Type.abstractFloat),
6667
t.params,
67-
cases
68+
cases,
69+
kAbstractFloatMatrixBinaryOpBatchSize
6870
);
6971
});

src/webgpu/shader/execution/expression/binary/af_matrix_subtraction.spec.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { Type } from '../../../../util/conversion.js';
88
import { onlyConstInputSource, run } from '../expression.js';
99

1010
import { d } from './af_matrix_subtraction.cache.js';
11-
import { abstractFloatBinary } from './binary.js';
11+
import { abstractFloatBinary, kAbstractFloatMatrixBinaryOpBatchSize } from './binary.js';
1212

1313
export const g = makeTestGroup(AllFeaturesMaxLimitsGPUTest);
1414

@@ -36,6 +36,7 @@ Accuracy: Correctly rounded
3636
[Type.mat(cols, rows, Type.abstractFloat), Type.mat(cols, rows, Type.abstractFloat)],
3737
Type.mat(cols, rows, Type.abstractFloat),
3838
t.params,
39-
cases
39+
cases,
40+
kAbstractFloatMatrixBinaryOpBatchSize
4041
);
4142
});

src/webgpu/shader/execution/expression/binary/af_matrix_vector_multiplication.spec.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { Type } from '../../../../util/conversion.js';
88
import { onlyConstInputSource, run } from '../expression.js';
99

1010
import { d } from './af_matrix_vector_multiplication.cache.js';
11-
import { abstractFloatBinary } from './binary.js';
11+
import { abstractFloatBinary, kAbstractFloatMatrixBinaryOpBatchSize } from './binary.js';
1212

1313
export const g = makeTestGroup(AllFeaturesMaxLimitsGPUTest);
1414

@@ -36,7 +36,8 @@ Accuracy: Correctly rounded
3636
[Type.mat(cols, rows, Type.abstractFloat), Type.vec(cols, Type.abstractFloat)],
3737
Type.vec(rows, Type.abstractFloat),
3838
t.params,
39-
cases
39+
cases,
40+
kAbstractFloatMatrixBinaryOpBatchSize
4041
);
4142
});
4243

@@ -64,6 +65,7 @@ Accuracy: Correctly rounded
6465
[Type.vec(rows, Type.abstractFloat), Type.mat(cols, rows, Type.abstractFloat)],
6566
Type.vec(cols, Type.abstractFloat),
6667
t.params,
67-
cases
68+
cases,
69+
kAbstractFloatMatrixBinaryOpBatchSize
6870
);
6971
});

src/webgpu/shader/execution/expression/binary/binary.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ export function abstractFloatBinary(op: string): ShaderBuilder {
2525
export function abstractIntBinary(op: string): ShaderBuilder {
2626
return abstractIntShaderBuilder(values => `(${values.map(v => `(${v})`).join(op)})`);
2727
}
28+
29+
// See issue #4603 for why using 1 instead of the default
30+
export const kAbstractFloatMatrixBinaryOpBatchSize = 1;

0 commit comments

Comments
 (0)