From c1aba18e181015dcb22b018ccb1d8e04bc9120fa Mon Sep 17 00:00:00 2001 From: Bradley Lewis Fargo Date: Mon, 13 Apr 2026 22:23:34 -0500 Subject: [PATCH 1/6] Nx.Testing: vectorized tensor support + diff diagnostics on failure `Nx.Testing.assert_equal/2` and `assert_all_close/3` now handle vectorized tensors cleanly and produce a numeric diagnostic when an assertion fails, so bit-level disagreements hidden by truncated `inspect` output are still diagnosable. Previously, when two tensors differed by a tiny amount (e.g. 1 ULP) that both inspect to the same string, the failure message showed the same tensor text twice with no way to see what actually differed. This was particularly confusing when comparing computed gradients against hand-computed reference values in the vectorized grad test suite. Changes: * `tensor_equal?/2` now explicitly guards vec-axes and shape mismatches before running the element-wise comparison. This means two tensors with different vectorized axes or shapes always return false (instead of raising or relying on broadcast semantics). * `assert_all_close/3` gains the same vec-axes/shape guards at the top so mismatched inputs produce a structural error message instead of an opaque `Nx.all_close` failure. * Both helpers now call a new `diagnose_difference/2` in the flunk block. The diagnostic: - Describes structural mismatches directly (vec axes differ, shapes differ). - For tensors that DO share structure, computes the max absolute and max relative difference across all elements, including vec axes (via devectorize-before-reduce). Output looks like: max absolute difference: 0.5 max relative difference: 0.2 * `assert_all_close/3` failure messages now also include the atol and rtol values that were being checked against. * The diagnostic is defensive: if the diff computation raises (mixed complex/real types, unusual promotions, etc.) it falls back silently so the baseline inspect output is still shown. New test coverage in `nx/test/nx/testing_test.exs`: * Both helpers pass on bit-identical tensors and within-tolerance tensors (plain and vectorized). * Both helpers' error messages include the max absolute difference on a genuine disagreement. * Both helpers' error messages include a clear diagnostic when vec axes differ. * `assert_all_close` error message includes max-diff for vectorized tensors with close-but-not-equal values (the 1-ULP case). * `assert_equal` handles NaN-equality correctly. Co-Authored-By: Paulo Valente <16843419+polvalente@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 (1M context) --- nx/lib/nx/testing.ex | 123 +++++++++++++++++++++++++++++++----- nx/test/nx/testing_test.exs | 115 +++++++++++++++++++++++++++++++++ 2 files changed, 222 insertions(+), 16 deletions(-) create mode 100644 nx/test/nx/testing_test.exs diff --git a/nx/lib/nx/testing.ex b/nx/lib/nx/testing.ex index cc4e7ebefa..7fe3e24a44 100644 --- a/nx/lib/nx/testing.ex +++ b/nx/lib/nx/testing.ex @@ -3,7 +3,10 @@ defmodule Nx.Testing do Testing functions for Nx tensor assertions. This module provides functions for asserting tensor equality and - approximate equality within specified tolerances. + approximate equality within specified tolerances. Both helpers handle + vectorized tensors and produce a numeric diagnostic (max absolute / + relative difference) on failure so that bit-level disagreements + hidden by truncated `inspect` output are still diagnosable. """ import ExUnit.Assertions @@ -13,6 +16,8 @@ defmodule Nx.Testing do Asserts that two tensors are exactly equal. This handles NaN values correctly by considering NaN == NaN as true. + Works with vectorized tensors — two tensors must share the same + vectorized axes to be considered equal. """ def assert_equal(left, right) when not is_tensor(left) or not is_tensor(right) do if not Nx.Defn.Composite.compatible?(left, right, &tensor_equal?/2) do @@ -28,26 +33,50 @@ defmodule Nx.Testing do if not tensor_equal?(left, right) do flunk(""" Tensor assertion failed. - left: #{inspect(left)} - right: #{inspect(right)} + + left: + + #{inspect(left)} + + right: + + #{inspect(right)} + + #{diagnose_difference(left, right)} """) end end defp tensor_equal?(left, right) do - both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) - - left - |> Nx.equal(right) - |> Nx.logical_or(both_nan) - |> Nx.all() - |> Nx.to_flat_list() - |> Enum.all?(&(&1 == 1)) + cond do + not is_tensor(left) or not is_tensor(right) -> + false + + left.vectorized_axes != right.vectorized_axes -> + false + + Nx.shape(left) != Nx.shape(right) -> + false + + true -> + both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) + + left + |> Nx.equal(right) + |> Nx.logical_or(both_nan) + |> Nx.all() + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) + end end @doc """ Asserts that two tensors are approximately equal within the given tolerances. + Works with vectorized tensors — the comparison is per vectorized + instance, then aggregated. Two tensors must share the same vectorized + axes to be comparable. + See also: * `Nx.all_close/2` - The underlying function that performs the comparison. @@ -62,11 +91,13 @@ defmodule Nx.Testing do rtol = opts[:rtol] || 1.0e-4 equals = - left - |> Nx.all_close(right, atol: atol, rtol: rtol) - |> Nx.backend_transfer(Nx.BinaryBackend) - |> Nx.to_flat_list() - |> Enum.all?(&(&1 == 1)) + left.vectorized_axes == right.vectorized_axes and + Nx.shape(left) == Nx.shape(right) and + left + |> Nx.all_close(right, atol: atol, rtol: rtol) + |> Nx.backend_transfer(Nx.BinaryBackend) + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) if !equals do flunk(""" @@ -77,7 +108,67 @@ defmodule Nx.Testing do to be within tolerance of #{inspect(right)} + + (atol: #{atol}, rtol: #{rtol}) + + #{diagnose_difference(left, right)} """) end end + + # Produces a human-readable diagnostic describing how two tensors differ. + # If vectorized axes or shapes don't line up, returns a structural message. + # Otherwise computes max absolute and max relative difference across all + # elements (including vec axes) so bit-level disagreements hidden by + # truncated `inspect` output are still visible in the failure message. + defp diagnose_difference(left, right) when is_tensor(left) and is_tensor(right) do + cond do + left.vectorized_axes != right.vectorized_axes -> + "vectorized_axes differ: left #{inspect(left.vectorized_axes)}, " <> + "right #{inspect(right.vectorized_axes)}" + + Nx.shape(left) != Nx.shape(right) -> + "shapes differ: left #{inspect(Nx.shape(left))}, " <> + "right #{inspect(Nx.shape(right))}" + + true -> + numeric_diagnostic(left, right) + end + end + + defp diagnose_difference(_, _), do: "" + + defp numeric_diagnostic(left, right) do + # Devectorize so reductions collapse across vec axes too, and so + # `Nx.to_number` on the final scalar doesn't hit a vectorized tensor. + left = if left.vectorized_axes == [], do: left, else: Nx.devectorize(left, keep_names: false) + right = if right.vectorized_axes == [], do: right, else: Nx.devectorize(right, keep_names: false) + + # Promote to a common numeric type so subtraction works for int/float mixes. + {left_f, right_f} = + case {Nx.type(left), Nx.type(right)} do + {{:f, _}, {:f, _}} -> {left, right} + {{:c, _}, _} -> {left, Nx.as_type(right, Nx.type(left))} + {_, {:c, _}} -> {Nx.as_type(left, Nx.type(right)), right} + _ -> {Nx.as_type(left, {:f, 32}), Nx.as_type(right, {:f, 32})} + end + + diff = Nx.subtract(left_f, right_f) |> Nx.abs() + max_abs = diff |> Nx.reduce_max() |> Nx.to_number() + + # Relative diff: |a - b| / max(|a|, |b|, tiny) to avoid divide by zero. + denom = + Nx.max(Nx.abs(left_f), Nx.abs(right_f)) + |> Nx.max(Nx.tensor(1.0e-30)) + + max_rel = Nx.divide(diff, denom) |> Nx.reduce_max() |> Nx.to_number() + + "max absolute difference: #{inspect(max_abs)}\n" <> + "max relative difference: #{inspect(max_rel)}" + rescue + # If the diff computation itself fails (mixed complex/real, NaN propagation, + # unusual types, etc.), fall back silently — the inspect output above is + # still shown. + _ -> "" + end end diff --git a/nx/test/nx/testing_test.exs b/nx/test/nx/testing_test.exs new file mode 100644 index 0000000000..309b632649 --- /dev/null +++ b/nx/test/nx/testing_test.exs @@ -0,0 +1,115 @@ +defmodule Nx.TestingTest do + use ExUnit.Case, async: true + + import Nx.Testing + + describe "assert_all_close/3" do + test "passes on bit-identical tensors" do + a = Nx.tensor([1.0, 2.0, 3.0]) + assert_all_close(a, a) + end + + test "passes on tensors within tolerance" do + a = Nx.tensor([1.0, 2.0, 3.0]) + b = Nx.tensor([1.00001, 2.00001, 3.00001]) + assert_all_close(a, b, atol: 1.0e-4) + end + + test "passes on bit-identical vectorized tensors" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + assert_all_close(a, b) + end + + test "passes on vectorized tensors within tolerance" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.00001, 2.00001]) |> Nx.vectorize(:foo) + assert_all_close(a, b, atol: 1.0e-4) + end + + test "error message includes max absolute difference" do + a = Nx.tensor([1.0, 2.0, 3.0]) + b = Nx.tensor([1.0, 2.5, 3.0]) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_all_close(a, b, atol: 1.0e-4) + end + + assert error.message =~ "max absolute difference" + assert error.message =~ "0.5" + end + + test "error message includes max-diff diagnostic for vectorized tensors" do + # Exercises the case where both tensors have matching structure but + # differ in values by more than the tolerance — the diagnostic should + # include the max absolute difference so failures are still readable + # even when the inspect output truncates to look identical. + a = Nx.tensor([1.0, 2.0], type: :f32) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0001], type: :f32) |> Nx.vectorize(:foo) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_all_close(a, b, atol: 0.0, rtol: 0.0) + end + + assert error.message =~ "max absolute difference" + assert error.message =~ "max relative difference" + end + + test "error message shows a clear diagnostic for mismatched vec axes" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:bar) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_all_close(a, b) + end + + assert error.message =~ "vectorized_axes" + end + end + + describe "assert_equal/2" do + test "passes on bit-identical tensors" do + a = Nx.tensor([1, 2, 3]) + assert_equal(a, a) + end + + test "passes on bit-identical vectorized tensors" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + assert_equal(a, b) + end + + test "error message includes max absolute difference" do + a = Nx.tensor([1.0, 2.0, 3.0]) + b = Nx.tensor([1.0, 2.5, 3.0]) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_equal(a, b) + end + + assert error.message =~ "max absolute difference" + assert error.message =~ "0.5" + end + + test "error message shows a clear diagnostic for mismatched vec axes" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:bar) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_equal(a, b) + end + + assert error.message =~ "vectorized_axes" + end + + test "handles NaN equality correctly" do + a = Nx.tensor([:nan, 2.0, :nan]) + assert_equal(a, a) + end + end +end From c2716c1e7951c6dcc42016472bcd7d00088ba9da Mon Sep 17 00:00:00 2001 From: Bradley Lewis Fargo Date: Tue, 14 Apr 2026 01:38:31 -0500 Subject: [PATCH 2/6] Preserve scalar-broadcast semantics in assert_equal / assert_all_close MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The initial commit on this branch introduced strict shape guards that rejected any shape mismatch between the two compared values. That broke two legitimate patterns in existing tests: 1. Raw scalars passed as expected values: `assert_equal(a, 11)` — the old `Nx.equal`-based implementation auto-converted the scalar and broadcast it. The new code accessed `.vectorized_axes` directly on the raw integer, crashing with `KeyError: key :vectorized_axes not found in: 11`. 2. Scalar-broadcast comparisons: `assert_equal(Nx.tensor([1, 1, 1]), 1)` — the caller means "every element of the tensor equals 1", not "the tensor has shape {}". Strict shape matching rejects this. Fix: * Introduce `to_tensor/1` helper that wraps raw scalars/lists in tensors, leaving existing tensors untouched. Applied at the top of `tensor_equal?/2`, `assert_all_close/3`, and `diagnose_difference/2`. * Introduce `shapes_incompatible?/2` helper that rejects genuine shape mismatches but allows either side to be a scalar (shape `{}`). This preserves the old `Nx.equal` broadcasting semantics for the common "assert every element equals N" pattern while still catching real shape bugs like the `{4,1}` vs `{4,2}` case that exposed a latent error in EXLA's sharding_test (whose `result0` output has half the columns the test expects — caught by this PR's stricter comparison, to be fixed in a separate commit on the EXLA side). * Diagnose_difference is now a single clause guarded by a `rescue` fallback so any unusual promotion still falls back silently to the plain inspect output. New regression tests in `testing_test.exs`: * `assert_equal` accepts a raw scalar against a scalar tensor. * `assert_equal` accepts a raw scalar broadcast against a non-scalar tensor (and rejects when values don't match). * `assert_equal` rejects a genuine shape mismatch between two non-scalar tensors (the sharding_test case). * Parallel `assert_all_close` tests for each of the above. Verified locally: full nx suite passes (1353 doctests, 1246 tests, 0 failures, 1 skipped). --- nx/lib/nx/testing.ex | 51 ++++++++++++++++++++++++--------- nx/test/nx/testing_test.exs | 57 +++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 13 deletions(-) diff --git a/nx/lib/nx/testing.ex b/nx/lib/nx/testing.ex index 7fe3e24a44..5878eed0a8 100644 --- a/nx/lib/nx/testing.ex +++ b/nx/lib/nx/testing.ex @@ -48,14 +48,14 @@ defmodule Nx.Testing do end defp tensor_equal?(left, right) do - cond do - not is_tensor(left) or not is_tensor(right) -> - false + left = to_tensor(left) + right = to_tensor(right) + cond do left.vectorized_axes != right.vectorized_axes -> false - Nx.shape(left) != Nx.shape(right) -> + shapes_incompatible?(left, right) -> false true -> @@ -70,6 +70,23 @@ defmodule Nx.Testing do end end + # Wrap raw scalars/lists in tensors so the struct-field accesses + # (`.vectorized_axes`) and `Nx.shape/1` below don't crash. Tensors + # pass through unchanged. + defp to_tensor(%Nx.Tensor{} = t), do: t + defp to_tensor(other), do: Nx.tensor(other) + + # Genuine shape mismatches are rejected, but we still allow a scalar + # (shape `{}`) to compare against a tensor of any shape — that's the + # intentional "assert every element equals this scalar" pattern, and + # rejecting it would break a large number of existing tests that + # relied on `Nx.equal`'s broadcasting. + defp shapes_incompatible?(left, right) do + ls = Nx.shape(left) + rs = Nx.shape(right) + ls != rs and ls != {} and rs != {} + end + @doc """ Asserts that two tensors are approximately equal within the given tolerances. @@ -90,11 +107,14 @@ defmodule Nx.Testing do atol = opts[:atol] || 1.0e-4 rtol = opts[:rtol] || 1.0e-4 + left_t = to_tensor(left) + right_t = to_tensor(right) + equals = - left.vectorized_axes == right.vectorized_axes and - Nx.shape(left) == Nx.shape(right) and - left - |> Nx.all_close(right, atol: atol, rtol: rtol) + left_t.vectorized_axes == right_t.vectorized_axes and + not shapes_incompatible?(left_t, right_t) and + left_t + |> Nx.all_close(right_t, atol: atol, rtol: rtol) |> Nx.backend_transfer(Nx.BinaryBackend) |> Nx.to_flat_list() |> Enum.all?(&(&1 == 1)) @@ -121,28 +141,33 @@ defmodule Nx.Testing do # Otherwise computes max absolute and max relative difference across all # elements (including vec axes) so bit-level disagreements hidden by # truncated `inspect` output are still visible in the failure message. - defp diagnose_difference(left, right) when is_tensor(left) and is_tensor(right) do + defp diagnose_difference(left, right) do + left = to_tensor(left) + right = to_tensor(right) + cond do left.vectorized_axes != right.vectorized_axes -> "vectorized_axes differ: left #{inspect(left.vectorized_axes)}, " <> "right #{inspect(right.vectorized_axes)}" - Nx.shape(left) != Nx.shape(right) -> + shapes_incompatible?(left, right) -> "shapes differ: left #{inspect(Nx.shape(left))}, " <> "right #{inspect(Nx.shape(right))}" true -> numeric_diagnostic(left, right) end + rescue + _ -> "" end - defp diagnose_difference(_, _), do: "" - defp numeric_diagnostic(left, right) do # Devectorize so reductions collapse across vec axes too, and so # `Nx.to_number` on the final scalar doesn't hit a vectorized tensor. left = if left.vectorized_axes == [], do: left, else: Nx.devectorize(left, keep_names: false) - right = if right.vectorized_axes == [], do: right, else: Nx.devectorize(right, keep_names: false) + + right = + if right.vectorized_axes == [], do: right, else: Nx.devectorize(right, keep_names: false) # Promote to a common numeric type so subtraction works for int/float mixes. {left_f, right_f} = diff --git a/nx/test/nx/testing_test.exs b/nx/test/nx/testing_test.exs index 309b632649..c46d868dd5 100644 --- a/nx/test/nx/testing_test.exs +++ b/nx/test/nx/testing_test.exs @@ -111,5 +111,62 @@ defmodule Nx.TestingTest do a = Nx.tensor([:nan, 2.0, :nan]) assert_equal(a, a) end + + test "accepts a raw scalar against a scalar tensor" do + assert_equal(Nx.tensor(11), 11) + assert_equal(Nx.tensor(1.5), 1.5) + end + + test "accepts a raw scalar broadcast against a non-scalar tensor" do + # Matches the old Nx.equal broadcasting semantics that existing + # test suites (e.g. EXLA expr_test) rely on. `assert_equal(t, 1)` + # means "every element of t equals 1", not "t has shape {}". + assert_equal(Nx.tensor([1, 1, 1]), 1) + assert_equal(Nx.tensor([[1, 1], [1, 1]]), 1) + end + + test "rejects a scalar broadcast when values don't match" do + a = Nx.tensor([1, 1, 2]) + + assert_raise ExUnit.AssertionError, fn -> + assert_equal(a, 1) + end + end + + test "rejects a genuine shape mismatch between two non-scalar tensors" do + # The sharding test bug: result had shape {4,1} but expectation was + # {4,2}. Old code silently broadcast; we now reject. + a = Nx.tensor([[100], [102], [104], [106]]) + b = Nx.tensor([[100, 100], [102, 102], [104, 104], [106, 106]]) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_equal(a, b) + end + + assert error.message =~ "shapes differ" + end + end + + describe "assert_all_close/3 with scalar arguments" do + test "accepts a raw scalar against a scalar tensor" do + assert_all_close(Nx.tensor(1.5), 1.5) + end + + test "accepts a raw scalar broadcast against a non-scalar tensor" do + assert_all_close(Nx.tensor([1.0, 1.0001, 0.9999]), 1.0, atol: 1.0e-3) + end + + test "rejects a genuine shape mismatch" do + a = Nx.tensor([[1.0], [2.0]]) + b = Nx.tensor([[1.0, 1.0], [2.0, 2.0]]) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_all_close(a, b) + end + + assert error.message =~ "shapes differ" + end end end From 78db34f9adee503a6f71f7558e9759535e6cdf52 Mon Sep 17 00:00:00 2001 From: Bradley Lewis Fargo Date: Tue, 21 Apr 2026 01:08:43 -0500 Subject: [PATCH 3/6] Fix sharding_test expected shapes surfaced by this PR's strict checks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The two 2D-mesh tests in exla/test/exla/defn/sharding_test.exs had wrong expected-value literals: `y * 2` assertions in "output sharding with tuple outputs" and every assertion in "generates correct MLIR with simple 2D mesh and sharding" expected {4,2} tensors, but the actual per-partition outputs are {4,1} (the inputs themselves are {4,1} per partition — the test comments document this correctly; only the assertion literals were wrong). The old `Nx.Testing.assert_equal` masked the mismatch via `Nx.equal`'s broadcasting: comparing a {4,1} result against a {4,2} expectation broadcast the result up to {4,2} and matched exactly. The stricter shape guard added in this PR correctly rejects those comparisons, which exposed the bugs. Bundling per maintainer request (#1734 discussion). --- exla/test/exla/defn/sharding_test.exs | 28 ++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index 28d01ae74d..11506f0c35 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -63,9 +63,9 @@ defmodule EXLA.Defn.ShardingTest do # y[[100],[101],[102],[103]] broadcasts to [[100,100],[101,101],[102,102],[103,103]] # x+y = [[100,100],[102,102],[104,104],[106,106]], device 0 gets col 0 assert_equal(result1_d0, Nx.tensor([[100], [102], [104], [106]])) - # Second output (y*2): y broadcasts to {4,2}, multiply by 2, result is {4,2} - # [[100],[101],[102],[103]] * 2 = [[200],[202],[204],[206]], broadcasts to [[200,200],[202,202],[204,204],[206,206]] - assert_equal(result2_d0, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]])) + # Second output (y*2): each partition has y={4,1}, so y*2 is {4,1} per partition. + # y_dev0=[[100],[101],[102],[103]] * 2 = [[200],[202],[204],[206]] + assert_equal(result2_d0, Nx.tensor([[200], [202], [204], [206]])) # Device 1: x={4,1} rows[0-3]col[1], y={4,1} rows[0-3] {result1_d1, result2_d1} = Enum.at(results, 1) @@ -73,16 +73,16 @@ defmodule EXLA.Defn.ShardingTest do # broadcasts to [[10,10],[11,11],[12,12],[13,13]] + [[100,100],[101,101],[102,102],[103,103]] # = [[110,110],[112,112],[114,114],[116,116]], device 1 gets col 1 assert_equal(result1_d1, Nx.tensor([[110], [112], [114], [116]])) - # Second output: same as device 0 (y is replicated across axis 1) - assert_equal(result2_d1, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]])) + # Second output: same as device 0 (y is replicated across axis 1, so d1 sees the same y_dev shard). + assert_equal(result2_d1, Nx.tensor([[200], [202], [204], [206]])) # Device 2: x={4,1} rows[4-7]col[0], y={4,1} rows[4-7] {result1_d2, result2_d2} = Enum.at(results, 2) # First output: x[[4],[5],[6],[7]] + y[[104],[105],[106],[107]] # = [[108,108],[110,110],[112,112],[114,114]], device 2 gets col 0 assert_equal(result1_d2, Nx.tensor([[108], [110], [112], [114]])) - # Second output: [[104],[105],[106],[107]] * 2, broadcasts to [[208,208]...] - assert_equal(result2_d2, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]])) + # Second output: y_dev2=[[104],[105],[106],[107]] * 2 = [[208],[210],[212],[214]] + assert_equal(result2_d2, Nx.tensor([[208], [210], [212], [214]])) # Device 3: x={4,1} rows[4-7]col[1], y={4,1} rows[4-7] {result1_d3, result2_d3} = Enum.at(results, 3) @@ -90,8 +90,8 @@ defmodule EXLA.Defn.ShardingTest do # broadcasts to [[14,14],[15,15],[16,16],[17,17]] + [[104,104],[105,105],[106,106],[107,107]] # = [[118,118],[120,120],[122,122],[124,124]], device 3 gets col 1 assert_equal(result1_d3, Nx.tensor([[118], [120], [122], [124]])) - # Second output: same as device 2 (y is replicated across axis 1) - assert_equal(result2_d3, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]])) + # Second output: same as device 2 (y is replicated across axis 1, so d3 sees the same y_dev shard). + assert_equal(result2_d3, Nx.tensor([[208], [210], [212], [214]])) end @moduletag :multi_device @@ -176,13 +176,15 @@ defmodule EXLA.Defn.ShardingTest do assert [result0, result1, result2, result3] = EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) - assert_equal(result0, Nx.tensor([[100, 100], [102, 102], [104, 104], [106, 106]])) + # Each partition's output is {4,1}: x_shard + y_shard = {4,1} + {4,1} + # (y input is {8,1}, sharded on axis 0 only → {4,1} per partition). + assert_equal(result0, Nx.tensor([[100], [102], [104], [106]])) assert result0.data.buffer.device_id == 0 - assert_equal(result1, Nx.tensor([[110, 110], [112, 112], [114, 114], [116, 116]])) + assert_equal(result1, Nx.tensor([[110], [112], [114], [116]])) assert result1.data.buffer.device_id == 1 - assert_equal(result2, Nx.tensor([[108, 108], [110, 110], [112, 112], [114, 114]])) + assert_equal(result2, Nx.tensor([[108], [110], [112], [114]])) assert result2.data.buffer.device_id == 2 - assert_equal(result3, Nx.tensor([[118, 118], [120, 120], [122, 122], [124, 124]])) + assert_equal(result3, Nx.tensor([[118], [120], [122], [124]])) assert result3.data.buffer.device_id == 3 end From 1c51a6b780c116cfa5861e20fb44f473b3f40f4a Mon Sep 17 00:00:00 2001 From: Bradley Lewis Fargo Date: Tue, 21 Apr 2026 01:17:02 -0500 Subject: [PATCH 4/6] Fix torchx conv tests exposed by strict assertions; apply mix format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test/torchx/nx_test.exs:959 (output_permutation) and :990 (input_dilation) had expected-value literals of lower rank than the actual conv result; the old assert_all_close / assert_equal broadcast the lower-rank literal up to match. Fix: wrap the literal in the missing leading [...] layers to match the shape already asserted explicitly via \`result.shape == ...\` on the same test. Also applies mix format to the sharding_test.exs changes from the previous commit — formatter wants blank lines between assert_equal calls and the subsequent comments. Elixir 1.18 formatter rules caught this; 1.17 local run didn't. --- exla/test/exla/defn/sharding_test.exs | 2 ++ torchx/test/torchx/nx_test.exs | 30 ++++++++++++++------------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index 11506f0c35..64f20797d6 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -73,6 +73,7 @@ defmodule EXLA.Defn.ShardingTest do # broadcasts to [[10,10],[11,11],[12,12],[13,13]] + [[100,100],[101,101],[102,102],[103,103]] # = [[110,110],[112,112],[114,114],[116,116]], device 1 gets col 1 assert_equal(result1_d1, Nx.tensor([[110], [112], [114], [116]])) + # Second output: same as device 0 (y is replicated across axis 1, so d1 sees the same y_dev shard). assert_equal(result2_d1, Nx.tensor([[200], [202], [204], [206]])) @@ -90,6 +91,7 @@ defmodule EXLA.Defn.ShardingTest do # broadcasts to [[14,14],[15,15],[16,16],[17,17]] + [[104,104],[105,105],[106,106],[107,107]] # = [[118,118],[120,120],[122,122],[124,124]], device 3 gets col 1 assert_equal(result1_d3, Nx.tensor([[118], [120], [122], [124]])) + # Second output: same as device 2 (y is replicated across axis 1, so d3 sees the same y_dev shard). assert_equal(result2_d3, Nx.tensor([[208], [210], [212], [214]])) end diff --git a/torchx/test/torchx/nx_test.exs b/torchx/test/torchx/nx_test.exs index 997dd8100e..2b551657d0 100644 --- a/torchx/test/torchx/nx_test.exs +++ b/torchx/test/torchx/nx_test.exs @@ -969,19 +969,21 @@ defmodule Torchx.NxTest do result, Nx.tensor([ [ - [15.0, 15.0], - [51.0, 51.0], - [87.0, 87.0] - ], - [ - [123.0, 123.0], - [159.0, 159.0], - [195.0, 195.0] - ], - [ - [231.0, 231.0], - [267.0, 267.0], - [303.0, 303.0] + [ + [15.0, 15.0], + [51.0, 51.0], + [87.0, 87.0] + ], + [ + [123.0, 123.0], + [159.0, 159.0], + [195.0, 195.0] + ], + [ + [231.0, 231.0], + [267.0, 267.0], + [303.0, 303.0] + ] ] ]) ) @@ -997,7 +999,7 @@ defmodule Torchx.NxTest do assert_equal( result, - Nx.tensor([[0, 0, -1, 1, 0, -2], [-3, 0, 4, -4, 0, 5]]) + Nx.tensor([[[[[0, 0, -1, 1, 0, -2], [-3, 0, 4, -4, 0, 5]]]]]) ) end end From 8c3686ea7615bd284d1a6458f96d1e4b2ac5f696 Mon Sep 17 00:00:00 2001 From: Bradley Lewis Fargo Date: Tue, 21 Apr 2026 01:19:01 -0500 Subject: [PATCH 5/6] Revert torchx conv test edits (move to separate PR for clean scope) The previous commit (1c51a6b7) included edits to torchx/test/torchx/nx_test.exs. Those fixes address the same class of latent shape bug as the sharding_test fix in this PR, but Pol's ask on #1734 was scoped specifically to the sharding cases. The torchx tests are in a different sub-project entirely, and the right default when expanding past a maintainer's explicit scope is to ask rather than assume. Moving those fixes to a separate PR. The mix-format changes to sharding_test.exs from 1c51a6b7 are retained (Elixir 1.18 formatter needs blank lines between assert_equal calls and trailing comments). --- torchx/test/torchx/nx_test.exs | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/torchx/test/torchx/nx_test.exs b/torchx/test/torchx/nx_test.exs index 2b551657d0..997dd8100e 100644 --- a/torchx/test/torchx/nx_test.exs +++ b/torchx/test/torchx/nx_test.exs @@ -969,21 +969,19 @@ defmodule Torchx.NxTest do result, Nx.tensor([ [ - [ - [15.0, 15.0], - [51.0, 51.0], - [87.0, 87.0] - ], - [ - [123.0, 123.0], - [159.0, 159.0], - [195.0, 195.0] - ], - [ - [231.0, 231.0], - [267.0, 267.0], - [303.0, 303.0] - ] + [15.0, 15.0], + [51.0, 51.0], + [87.0, 87.0] + ], + [ + [123.0, 123.0], + [159.0, 159.0], + [195.0, 195.0] + ], + [ + [231.0, 231.0], + [267.0, 267.0], + [303.0, 303.0] ] ]) ) @@ -999,7 +997,7 @@ defmodule Torchx.NxTest do assert_equal( result, - Nx.tensor([[[[[0, 0, -1, 1, 0, -2], [-3, 0, 4, -4, 0, 5]]]]]) + Nx.tensor([[0, 0, -1, 1, 0, -2], [-3, 0, 4, -4, 0, 5]]) ) end end From 163f7da0d248326c05ccf2e9328f868aa5bd8466 Mon Sep 17 00:00:00 2001 From: Bradley Lewis Fargo Date: Tue, 21 Apr 2026 01:31:39 -0500 Subject: [PATCH 6/6] Fix torchx conv tests exposed by strict assertions (re-applied) Restores the torchx/test/torchx/nx_test.exs edits that were reverted in 8c3686ea after a brief scope question. These are the same class of latent shape bug the sharding_test fix addresses: test author explicitly asserted \`result.shape == {1, 3, 3, 2}\` (and \`{1, 1, 1, 2, 6}\` for input_dilation) on one line, then passed a value literal of lower rank to assert_all_close / assert_equal on the next. The old broadcast-tolerant assertion hid the contradiction; the new strict shape check in this PR correctly rejects it. Fix: wrap each literal in the missing leading [...] brackets to match the shape the test is already explicitly asserting. --- torchx/test/torchx/nx_test.exs | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/torchx/test/torchx/nx_test.exs b/torchx/test/torchx/nx_test.exs index 997dd8100e..2b551657d0 100644 --- a/torchx/test/torchx/nx_test.exs +++ b/torchx/test/torchx/nx_test.exs @@ -969,19 +969,21 @@ defmodule Torchx.NxTest do result, Nx.tensor([ [ - [15.0, 15.0], - [51.0, 51.0], - [87.0, 87.0] - ], - [ - [123.0, 123.0], - [159.0, 159.0], - [195.0, 195.0] - ], - [ - [231.0, 231.0], - [267.0, 267.0], - [303.0, 303.0] + [ + [15.0, 15.0], + [51.0, 51.0], + [87.0, 87.0] + ], + [ + [123.0, 123.0], + [159.0, 159.0], + [195.0, 195.0] + ], + [ + [231.0, 231.0], + [267.0, 267.0], + [303.0, 303.0] + ] ] ]) ) @@ -997,7 +999,7 @@ defmodule Torchx.NxTest do assert_equal( result, - Nx.tensor([[0, 0, -1, 1, 0, -2], [-3, 0, 4, -4, 0, 5]]) + Nx.tensor([[[[[0, 0, -1, 1, 0, -2], [-3, 0, 4, -4, 0, 5]]]]]) ) end end