diff --git a/packages/typegpu/src/data/dataTypes.ts b/packages/typegpu/src/data/dataTypes.ts index 66a8f4c19f..77537051d2 100644 --- a/packages/typegpu/src/data/dataTypes.ts +++ b/packages/typegpu/src/data/dataTypes.ts @@ -30,7 +30,6 @@ import type { Snippet } from './snippet.ts'; import type { PackedData } from './vertexFormatData.ts'; import * as wgsl from './wgslTypes.ts'; import type { WgslComparisonSampler, WgslSampler } from './sampler.ts'; -import type { ResolutionCtx } from '../types.ts'; import type { BaseData } from './wgslTypes.ts'; /** @@ -237,22 +236,6 @@ export type AnyConcreteData = Exclude< export const UnknownData = Symbol('UNKNOWN'); export type UnknownData = typeof UnknownData; -export class InfixDispatch { - readonly name: string; - readonly lhs: Snippet; - readonly operator: (ctx: ResolutionCtx, args: [lhs: Snippet, rhs: Snippet]) => Snippet; - - constructor( - name: string, - lhs: Snippet, - operator: (ctx: ResolutionCtx, args: [lhs: Snippet, rhs: Snippet]) => Snippet, - ) { - this.name = name; - this.lhs = lhs; - this.operator = operator; - } -} - export class MatrixColumnsAccess { readonly matrix: Snippet; diff --git a/packages/typegpu/src/data/index.ts b/packages/typegpu/src/data/index.ts index 5eebe77415..13aa6e1f5c 100644 --- a/packages/typegpu/src/data/index.ts +++ b/packages/typegpu/src/data/index.ts @@ -5,48 +5,39 @@ // NOTE: This is a barrel file, internal files should not import things from this file import { Operator } from 'tsover-runtime'; -import { type InfixOperator, infixOperators } from '../tgsl/accessProp.ts'; +import { type InfixOperatorName, infixOperators } from '../tgsl/accessProp.ts'; import { MatBase } from './matrix.ts'; import { VecBase } from './vectorImpl.ts'; +import { infixDispatch } from '../tgsl/infixDispatch.ts'; function assignInfixOperator( - object: T, - operator: InfixOperator, + base: T, + operator: InfixOperatorName, operatorSymbol: symbol, ) { - // oxlint-disable-next-line typescript/no-explicit-any -- anything is possible - const proto = object.prototype as any; - const opImpl = infixOperators[operator] as (lhs: unknown, rhs: unknown) => unknown; + const opImpl = infixOperators[operator]; - proto[operator] = function (this: unknown, other: unknown): unknown { - return opImpl(this, other); - }; + Object.defineProperty(base.prototype, operatorSymbol, { + value: opImpl, + }); - proto[operatorSymbol] = (lhs: unknown, rhs: unknown): unknown => { - return opImpl(lhs, rhs); - }; + Object.defineProperty(base.prototype, operator, { + get() { + return infixDispatch(this, opImpl); + }, + }); } assignInfixOperator(VecBase, 'add', Operator.plus); +assignInfixOperator(MatBase, 'add', Operator.plus); assignInfixOperator(VecBase, 'sub', Operator.minus); +assignInfixOperator(MatBase, 'sub', Operator.minus); assignInfixOperator(VecBase, 'mul', Operator.star); +assignInfixOperator(MatBase, 'mul', Operator.star); assignInfixOperator(VecBase, 'div', Operator.slash); assignInfixOperator(VecBase, 'mod', Operator.percent); -assignInfixOperator(MatBase, 'add', Operator.plus); -assignInfixOperator(MatBase, 'sub', Operator.minus); -assignInfixOperator(MatBase, 'mul', Operator.star); - -// bitShift does not yet have tsover operator symbol -{ - // oxlint-disable-next-line typescript/no-explicit-any -- anything is possible - const proto = VecBase.prototype as any; - proto.bitShiftLeft = function (this: unknown, other: unknown) { - return (infixOperators.bitShiftLeft as (a: unknown, b: unknown) => unknown)(this, other); - }; - proto.bitShiftRight = function (this: unknown, other: unknown) { - return (infixOperators.bitShiftRight as (a: unknown, b: unknown) => unknown)(this, other); - }; -} +assignInfixOperator(VecBase, 'bitShiftLeft', Symbol()); // bitShift does not yet have tsover operator symbol +assignInfixOperator(VecBase, 'bitShiftRight', Symbol()); // bitShift does not yet have tsover operator symbol export { bool, f16, f32, i32, u16, u32 } from './numeric.ts'; export { diff --git a/packages/typegpu/src/tgsl/accessProp.ts b/packages/typegpu/src/tgsl/accessProp.ts index 3410d3e383..8c88550c63 100644 --- a/packages/typegpu/src/tgsl/accessProp.ts +++ b/packages/typegpu/src/tgsl/accessProp.ts @@ -1,13 +1,7 @@ import { stitch } from '../core/resolve/stitch.ts'; import { AutoStruct } from '../data/autoStruct.ts'; import { EntryInputRouter } from '../core/function/entryInputRouter.ts'; -import { - InfixDispatch, - isUnstruct, - MatrixColumnsAccess, - undecorate, - UnknownData, -} from '../data/dataTypes.ts'; +import { isUnstruct, MatrixColumnsAccess, undecorate, UnknownData } from '../data/dataTypes.ts'; import { abstractInt, bool, f16, f32, i32, u32 } from '../data/numeric.ts'; import { derefSnippet } from '../data/ref.ts'; import { isEphemeralSnippet, isSnippet, snip, type Snippet } from '../data/snippet.ts'; @@ -37,10 +31,10 @@ import { isWgslArray, isWgslStruct, } from '../data/wgslTypes.ts'; -import { $gpuCallable } from '../shared/symbols.ts'; import { add, bitShiftLeft, bitShiftRight, div, mod, mul, sub } from '../std/operators.ts'; import { isKnownAtComptime } from '../types.ts'; import { coerceToSnippet } from './generationHelpers.ts'; +import { infixDispatch } from './infixDispatch.ts'; const infixKinds = [ 'vec2f', @@ -70,7 +64,8 @@ export const infixOperators = { bitShiftRight, } as const; -export type InfixOperator = keyof typeof infixOperators; +export type InfixOperatorName = keyof typeof infixOperators; +export type InfixOperator = (typeof infixOperators)[InfixOperatorName]; type SwizzleableType = 'f' | 'h' | 'i' | 'u' | 'b'; type SwizzleLength = 1 | 2 | 3 | 4; @@ -110,12 +105,8 @@ const swizzleLenToType: Record> export function accessProp(target: Snippet, propName: string): Snippet | undefined { if (infixKinds.includes((target.dataType as BaseData).type) && propName in infixOperators) { - const operator = infixOperators[propName as InfixOperator]; - return snip( - new InfixDispatch(propName, target, operator[$gpuCallable].call.bind(operator)), - UnknownData, - /* origin */ target.origin, - ); + const operator = infixOperators[propName as InfixOperatorName]; + return snip(infixDispatch(target, operator), UnknownData, /* origin */ target.origin); } if (isWgslArray(target.dataType) && propName === 'length') { @@ -191,15 +182,12 @@ export function accessProp(target: Snippet, propName: string): Snippet | undefin } const propLength = propName.length; - if (isVec(target.dataType) && propLength >= 1 && propLength <= 4) { - const isXYZW = /^[xyzw]+$/.test(propName); - const isRGBA = /^[rgba]+$/.test(propName); - - if (!isXYZW && !isRGBA) { - // Not a valid swizzle - return undefined; - } - + if ( + isVec(target.dataType) && + propLength >= 1 && + propLength <= 4 && + /^[xyzw]+$|^[rgba]+$/.test(propName) + ) { const swizzleTypeChar = target.dataType.type.includes('bool') ? 'b' : (target.dataType.type[4] as SwizzleableType); diff --git a/packages/typegpu/src/tgsl/infixDispatch.ts b/packages/typegpu/src/tgsl/infixDispatch.ts new file mode 100644 index 0000000000..ee846f23ee --- /dev/null +++ b/packages/typegpu/src/tgsl/infixDispatch.ts @@ -0,0 +1,52 @@ +import { isSnippet, type Snippet } from '../data/snippet.ts'; +import type { AnyMatInstance, AnyNumericVecInstance } from '../data/wgslTypes.ts'; +import { $internal, isMarkedInternal } from '../shared/symbols.ts'; +import type { InfixOperator } from './accessProp.ts'; + +type Numeric = number | AnyNumericVecInstance | AnyMatInstance; + +/** + * In wgslGenerator, the lhs may either be Numeric or Snippet, + * and InfixDispatch is recognized by the $internal symbol. + * InfixDispatch is not called in wgslGenerator. + * @example + * const dispatch = d.vec2u(1).mul; + * const fn = () => { + * 'use gpu'; + * dispatch(2); // lhs is Numeric + * d.vec2u(1).mul(2) // lhs is a snippet + * } + * + * In JS, the lhs is always numeric, and InfixDispatch is callable. + * @example + * const dispatch = d.vec2u(1).mul; + * dispatch(2); + */ +export interface InfixDispatch { + [$internal]: true; + type: 'infix-dispatch'; + readonly lhs: Snippet | Numeric; + readonly operator: InfixOperator; + (other: Numeric): Numeric; +} + +export function infixDispatch(lhs: Snippet | Numeric, operator: InfixOperator): InfixDispatch { + const callable = (other: Numeric | Snippet) => { + if (isSnippet(lhs)) { + throw new Error('Unexpected snippet lhs in JS infix operator.'); + } + // operator will perform all necessary type checks + return operator(lhs as never, other as never); + }; + const infix = Object.assign(callable, { + [$internal]: true as const, + type: 'infix-dispatch' as const, + lhs, + operator, + }); + return infix; +} + +export function isInfixDispatch(o: unknown): o is InfixDispatch { + return isMarkedInternal(o) && (o as InfixDispatch)?.type === 'infix-dispatch'; +} diff --git a/packages/typegpu/src/tgsl/wgslGenerator.ts b/packages/typegpu/src/tgsl/wgslGenerator.ts index 83cc68d0fe..bdd6fdab59 100644 --- a/packages/typegpu/src/tgsl/wgslGenerator.ts +++ b/packages/typegpu/src/tgsl/wgslGenerator.ts @@ -1,7 +1,7 @@ import * as tinyest from 'tinyest'; import { stitch } from '../core/resolve/stitch.ts'; import { arrayOf } from '../data/array.ts'; -import { type AnyData, InfixDispatch, isLooseData, UnknownData, unptr } from '../data/dataTypes.ts'; +import { type AnyData, isLooseData, UnknownData, unptr } from '../data/dataTypes.ts'; import { bool, i32, u32 } from '../data/numeric.ts'; import { vec2u, vec3u, vec4u } from '../data/vector.ts'; import { @@ -46,6 +46,7 @@ import * as forOfUtils from './forOfUtils.ts'; import { isTgpuRange } from '../std/range.ts'; import type { FunctionDefinitionOptions } from './shaderGenerator_members.ts'; import { getAttributesString } from '../data/attributes.ts'; +import { isInfixDispatch } from './infixDispatch.ts'; const { NodeTypeCatalog: NODE } = tinyest; @@ -645,15 +646,16 @@ ${this.ctx.pre}}`; ); } - if (callee.value instanceof InfixDispatch) { - // Infix operator dispatch. + if (isInfixDispatch(callee.value)) { if (!argNodes[0]) { throw new WgslTypeError( - `An infix operator '${callee.value.name}' was called without any arguments`, + `An infix operator '${getName(callee.value.operator)}' was called without any arguments`, ); } + const lhs = coerceToSnippet(callee.value.lhs); const rhs = this._expression(argNodes[0]); - return callee.value.operator(this.ctx, [callee.value.lhs, rhs]); + const callable = callee.value.operator[$gpuCallable]; + return callable.call(this.ctx, [lhs, rhs]); } if (isGPUCallable(callee.value)) { diff --git a/packages/typegpu/tests/swizzleMixedValidation.test.ts b/packages/typegpu/tests/swizzleMixedValidation.test.ts index 82e5069db0..599cde085f 100644 --- a/packages/typegpu/tests/swizzleMixedValidation.test.ts +++ b/packages/typegpu/tests/swizzleMixedValidation.test.ts @@ -71,7 +71,7 @@ describe('Mixed swizzle validation', () => { return mixed; }; - // The resolution should fail because accessProp returns undefined for mixed swizzles + // The resolution should fail because accessProp won't match any prop and will return undefined expect(() => tgpu.resolve([main])).toThrowErrorMatchingInlineSnapshot(` [Error: Resolution of the following tree failed: - diff --git a/packages/typegpu/tests/tgsl/infixOperators.test.ts b/packages/typegpu/tests/tgsl/infixOperators.test.ts index 2e18ed3331..e337024711 100644 --- a/packages/typegpu/tests/tgsl/infixOperators.test.ts +++ b/packages/typegpu/tests/tgsl/infixOperators.test.ts @@ -3,7 +3,7 @@ import tgpu, { d } from '../../src/index.js'; import { it } from 'typegpu-testing-utility'; describe('wgslGenerator', () => { - it('resolves add infix operator', () => { + it('resolves add infix operator in comptime', () => { const testFn = tgpu.fn([])(() => { const v1 = d.vec4f().add(1); const v2 = d.vec3f(2).add(d.vec3f(1, 2, 3)); @@ -23,7 +23,7 @@ describe('wgslGenerator', () => { `); }); - it('resolves sub infix operator', () => { + it('resolves sub infix operator in comptime', () => { const testFn = tgpu.fn([])(() => { const v1 = d.vec4f().sub(1); const v2 = d.vec3f().sub(d.vec3f(1, 2, 3)); @@ -43,7 +43,7 @@ describe('wgslGenerator', () => { `); }); - it('resolves mul infix operator', () => { + it('resolves mul infix operator in comptime', () => { const testFn = tgpu.fn([])(() => { const v1 = d.vec2f(2).mul(3); const v2 = d.vec3f(2).mul(d.vec3f(2, 3, 4)); @@ -70,27 +70,71 @@ describe('wgslGenerator', () => { `); }); + it('resolves div infix operator in comptime', () => { + const testFn = tgpu.fn([])(() => { + const v1 = d.vec4f(1).div(2); + const v2 = d.vec3f(6).div(d.vec3f(1, 2, 3)); + const v3 = d.vec2f(1).div(d.vec2f(2)).div(2); + }); + + expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` + "fn testFn() { + var v1 = vec4f(0.5); + var v2 = vec3f(6, 3, 2); + var v3 = vec2f(0.25); + }" + `); + }); + + it('resolves mod infix operator in comptime', () => { + const testFn = tgpu.fn([])(() => { + const v1 = d.vec4f(11).mod(2); + const v2 = d.vec3f(13.5).mod(d.vec3f(1, 2, 10)); + }); + + expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` + "fn testFn() { + var v1 = vec4f(1); + var v2 = vec3f(0.5, 1.5, 3.5); + }" + `); + }); + + it('resolves mul infix operator on a runtime variable', () => { + const testFn = () => { + 'use gpu'; + const v1 = d.vec2f(1, 2); + return v1.mul(2).mul(3); + }; + + expect(testFn()).toStrictEqual(d.vec2f(6, 12)); + expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` + "fn testFn() -> vec2f { + var v1 = vec2f(1, 2); + return ((v1 * 2f) * 3f); + }" + `); + }); + it('resolves mul infix operator on a function return value', () => { - const getVec = tgpu.fn( - [], - d.vec3f, - )(() => { + const getVec = () => { 'use gpu'; return d.vec3f(1, 2, 3); - }); + }; - const testFn = tgpu.fn([])(() => { + const testFn = () => { 'use gpu'; - const v1 = getVec().mul(getVec()); - }); + return getVec().mul(getVec()).mul(2); + }; + expect(testFn()).toStrictEqual(d.vec3f(2, 8, 18)); expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` "fn getVec() -> vec3f { return vec3f(1, 2, 3); } - fn testFn() { - var v1 = (getVec() * getVec()); + fn testFn() -> vec3f { + return ((getVec() * getVec()) * 2f); }" `); }); @@ -98,87 +142,119 @@ describe('wgslGenerator', () => { it('resolves mul infix operator on a struct property', () => { const Struct = d.struct({ vec: d.vec3f }); - const testFn = tgpu.fn([])(() => { + const testFn = () => { 'use gpu'; - const s = Struct({ vec: d.vec3f() }); - const v1 = s.vec.mul(s.vec); - }); + const s = Struct({ vec: d.vec3f(2) }); + return s.vec.mul(s.vec).mul(2); + }; + expect(testFn()).toStrictEqual(d.vec3f(8)); expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` "struct Struct { vec: vec3f, } - fn testFn() { - var s = Struct(vec3f()); - var v1 = (s.vec * s.vec); + fn testFn() -> vec3f { + var s = Struct(vec3f(2)); + return ((s.vec * s.vec) * 2f); }" `); }); - it('resolves div infix operator', () => { + it('resolves mul infix operator on uniform vector', ({ root }) => { + const fooUniform = root.createUniform(d.vec3f); + const testFn = tgpu.fn([])(() => { - const v1 = d.vec4f(1).div(2); - const v2 = d.vec3f(6).div(d.vec3f(1, 2, 3)); - const v3 = d.vec2f(1).div(d.vec2f(2)).div(2); + const v1 = fooUniform.$.mul(2); // lhs + const v2 = d.vec3f(1, 2, 3).mul(fooUniform.$); // rhs + const v3 = fooUniform.$.mul(fooUniform.$).mul(2); }); expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` - "fn testFn() { - var v1 = vec4f(0.5); - var v2 = vec3f(6, 3, 2); - var v3 = vec2f(0.25); + "@group(0) @binding(0) var fooUniform: vec3f; + + fn testFn() { + var v1 = (fooUniform * 2f); + var v2 = (vec3f(1, 2, 3) * fooUniform); + var v3 = ((fooUniform * fooUniform) * 2f); }" `); }); - it('resolves mod infix operator', () => { - const testFn = tgpu.fn([])(() => { - const v1 = d.vec4f(11).mod(2); - const v2 = d.vec3f(13.5).mod(d.vec3f(1, 2, 10)); - }); + it('resolves mul infix operator on external', () => { + const v = d.vec3f(2); + const testFn = () => { + 'use gpu'; + const v1 = v.mul(2); // lhs + const v2 = d.vec3f(3).mul(v); // rhs + const v3 = v.mul(v).mul(4); + return v1.add(v2).add(v3); + }; + + expect(testFn()).toStrictEqual(d.vec3f(26)); expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` - "fn testFn() { - var v1 = vec4f(1); - var v2 = vec3f(0.5, 1.5, 3.5); + "fn testFn() -> vec3f { + var v1 = vec3f(4); + var v2 = vec3f(6); + var v3 = vec3f(16); + return ((v1 + v2) + v3); }" `); }); - it('resolves add infix operator on uniform vector', ({ root }) => { - const fooUniform = root.createUniform(d.vec3f); - const barUniform = root.createUniform(d.vec3f); + it('resolves mul infix operator on accessors', () => { + const vAccess = tgpu.accessor(d.vec2i, d.vec2i(1, 2)); - const testFn = tgpu.fn([])(() => { - const v1 = fooUniform.$.add(2); // lhs - const v2 = d.vec3f(1, 2, 3).add(barUniform.$); // rhs - const v3 = fooUniform.$.add(barUniform.$); - }); + const main = () => { + 'use gpu'; + return vAccess.$.mul(2).mul(3); + }; - expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` - "@group(0) @binding(0) var fooUniform: vec3f; + // expect(main()).toMatchInlineSnapshot(d.vec2i(6, 12)); + expect(tgpu.resolve([main])).toMatchInlineSnapshot(` + "fn main() -> vec2i { + return vec2i(6, 12); + }" + `); + }); - @group(0) @binding(1) var barUniform: vec3f; + it('correctly casts types', () => { + const main = () => { + 'use gpu'; + const a = d.u32(1); + const b = d.vec3f(2); + return b.mul(a); + }; - fn testFn() { - var v1 = (fooUniform + 2f); - var v2 = (vec3f(1, 2, 3) + barUniform); - var v3 = (fooUniform + barUniform); + expect(tgpu.resolve([main])).toMatchInlineSnapshot(` + "fn main() -> vec3f { + const a = 1u; + var b = vec3f(2); + return (b * f32(a)); }" `); }); - it('precomputes adds on known values', () => { - const testFn = tgpu.fn([])(() => { - const v1 = d.vec3f(1, 2, 3).add(5); - const v2 = d.vec3f(1, 2, 3).add(d.vec3f(3, 2, 1)); - }); + it('works when application is deferred', ({ root }) => { + const v = d.vec2f(1, 2).mul; + const u = tgpu.comptime(() => d.vec2f(3, 4).mul); + const buf = tgpu.comptime(() => root.createUniform(d.vec2f, [5, 6]).$.add); - expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` - "fn testFn() { - var v1 = vec3f(6, 7, 8); - var v2 = vec3f(4); + const fn = () => { + 'use gpu'; + const a = v(7); + const b = u()(8); + const c = buf()(9); + }; + + expect(tgpu.resolve([fn])).toMatchInlineSnapshot(` + "@group(0) @binding(0) var item: vec2f; + + fn fn_1() { + var a = vec2f(7, 14); + var b = vec2f(24, 32); + var c = (item + 9f); }" `); });