diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index 28d01ae74d..64f20797d6 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -63,9 +63,9 @@ defmodule EXLA.Defn.ShardingTest do # y[[100],[101],[102],[103]] broadcasts to [[100,100],[101,101],[102,102],[103,103]] # x+y = [[100,100],[102,102],[104,104],[106,106]], device 0 gets col 0 assert_equal(result1_d0, Nx.tensor([[100], [102], [104], [106]])) - # Second output (y*2): y broadcasts to {4,2}, multiply by 2, result is {4,2} - # [[100],[101],[102],[103]] * 2 = [[200],[202],[204],[206]], broadcasts to [[200,200],[202,202],[204,204],[206,206]] - assert_equal(result2_d0, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]])) + # Second output (y*2): each partition has y={4,1}, so y*2 is {4,1} per partition. + # y_dev0=[[100],[101],[102],[103]] * 2 = [[200],[202],[204],[206]] + assert_equal(result2_d0, Nx.tensor([[200], [202], [204], [206]])) # Device 1: x={4,1} rows[0-3]col[1], y={4,1} rows[0-3] {result1_d1, result2_d1} = Enum.at(results, 1) @@ -73,16 +73,17 @@ defmodule EXLA.Defn.ShardingTest do # broadcasts to [[10,10],[11,11],[12,12],[13,13]] + [[100,100],[101,101],[102,102],[103,103]] # = [[110,110],[112,112],[114,114],[116,116]], device 1 gets col 1 assert_equal(result1_d1, Nx.tensor([[110], [112], [114], [116]])) - # Second output: same as device 0 (y is replicated across axis 1) - assert_equal(result2_d1, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]])) + + # Second output: same as device 0 (y is replicated across axis 1, so d1 sees the same y_dev shard). + assert_equal(result2_d1, Nx.tensor([[200], [202], [204], [206]])) # Device 2: x={4,1} rows[4-7]col[0], y={4,1} rows[4-7] {result1_d2, result2_d2} = Enum.at(results, 2) # First output: x[[4],[5],[6],[7]] + y[[104],[105],[106],[107]] # = [[108,108],[110,110],[112,112],[114,114]], device 2 gets col 0 assert_equal(result1_d2, Nx.tensor([[108], [110], [112], [114]])) - # Second output: [[104],[105],[106],[107]] * 2, broadcasts to [[208,208]...] - assert_equal(result2_d2, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]])) + # Second output: y_dev2=[[104],[105],[106],[107]] * 2 = [[208],[210],[212],[214]] + assert_equal(result2_d2, Nx.tensor([[208], [210], [212], [214]])) # Device 3: x={4,1} rows[4-7]col[1], y={4,1} rows[4-7] {result1_d3, result2_d3} = Enum.at(results, 3) @@ -90,8 +91,9 @@ defmodule EXLA.Defn.ShardingTest do # broadcasts to [[14,14],[15,15],[16,16],[17,17]] + [[104,104],[105,105],[106,106],[107,107]] # = [[118,118],[120,120],[122,122],[124,124]], device 3 gets col 1 assert_equal(result1_d3, Nx.tensor([[118], [120], [122], [124]])) - # Second output: same as device 2 (y is replicated across axis 1) - assert_equal(result2_d3, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]])) + + # Second output: same as device 2 (y is replicated across axis 1, so d3 sees the same y_dev shard). + assert_equal(result2_d3, Nx.tensor([[208], [210], [212], [214]])) end @moduletag :multi_device @@ -176,13 +178,15 @@ defmodule EXLA.Defn.ShardingTest do assert [result0, result1, result2, result3] = EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) - assert_equal(result0, Nx.tensor([[100, 100], [102, 102], [104, 104], [106, 106]])) + # Each partition's output is {4,1}: x_shard + y_shard = {4,1} + {4,1} + # (y input is {8,1}, sharded on axis 0 only → {4,1} per partition). + assert_equal(result0, Nx.tensor([[100], [102], [104], [106]])) assert result0.data.buffer.device_id == 0 - assert_equal(result1, Nx.tensor([[110, 110], [112, 112], [114, 114], [116, 116]])) + assert_equal(result1, Nx.tensor([[110], [112], [114], [116]])) assert result1.data.buffer.device_id == 1 - assert_equal(result2, Nx.tensor([[108, 108], [110, 110], [112, 112], [114, 114]])) + assert_equal(result2, Nx.tensor([[108], [110], [112], [114]])) assert result2.data.buffer.device_id == 2 - assert_equal(result3, Nx.tensor([[118, 118], [120, 120], [122, 122], [124, 124]])) + assert_equal(result3, Nx.tensor([[118], [120], [122], [124]])) assert result3.data.buffer.device_id == 3 end diff --git a/nx/lib/nx/testing.ex b/nx/lib/nx/testing.ex index cc4e7ebefa..5878eed0a8 100644 --- a/nx/lib/nx/testing.ex +++ b/nx/lib/nx/testing.ex @@ -3,7 +3,10 @@ defmodule Nx.Testing do Testing functions for Nx tensor assertions. This module provides functions for asserting tensor equality and - approximate equality within specified tolerances. + approximate equality within specified tolerances. Both helpers handle + vectorized tensors and produce a numeric diagnostic (max absolute / + relative difference) on failure so that bit-level disagreements + hidden by truncated `inspect` output are still diagnosable. """ import ExUnit.Assertions @@ -13,6 +16,8 @@ defmodule Nx.Testing do Asserts that two tensors are exactly equal. This handles NaN values correctly by considering NaN == NaN as true. + Works with vectorized tensors — two tensors must share the same + vectorized axes to be considered equal. """ def assert_equal(left, right) when not is_tensor(left) or not is_tensor(right) do if not Nx.Defn.Composite.compatible?(left, right, &tensor_equal?/2) do @@ -28,26 +33,67 @@ defmodule Nx.Testing do if not tensor_equal?(left, right) do flunk(""" Tensor assertion failed. - left: #{inspect(left)} - right: #{inspect(right)} + + left: + + #{inspect(left)} + + right: + + #{inspect(right)} + + #{diagnose_difference(left, right)} """) end end defp tensor_equal?(left, right) do - both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) - - left - |> Nx.equal(right) - |> Nx.logical_or(both_nan) - |> Nx.all() - |> Nx.to_flat_list() - |> Enum.all?(&(&1 == 1)) + left = to_tensor(left) + right = to_tensor(right) + + cond do + left.vectorized_axes != right.vectorized_axes -> + false + + shapes_incompatible?(left, right) -> + false + + true -> + both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right)) + + left + |> Nx.equal(right) + |> Nx.logical_or(both_nan) + |> Nx.all() + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) + end + end + + # Wrap raw scalars/lists in tensors so the struct-field accesses + # (`.vectorized_axes`) and `Nx.shape/1` below don't crash. Tensors + # pass through unchanged. + defp to_tensor(%Nx.Tensor{} = t), do: t + defp to_tensor(other), do: Nx.tensor(other) + + # Genuine shape mismatches are rejected, but we still allow a scalar + # (shape `{}`) to compare against a tensor of any shape — that's the + # intentional "assert every element equals this scalar" pattern, and + # rejecting it would break a large number of existing tests that + # relied on `Nx.equal`'s broadcasting. + defp shapes_incompatible?(left, right) do + ls = Nx.shape(left) + rs = Nx.shape(right) + ls != rs and ls != {} and rs != {} end @doc """ Asserts that two tensors are approximately equal within the given tolerances. + Works with vectorized tensors — the comparison is per vectorized + instance, then aggregated. Two tensors must share the same vectorized + axes to be comparable. + See also: * `Nx.all_close/2` - The underlying function that performs the comparison. @@ -61,12 +107,17 @@ defmodule Nx.Testing do atol = opts[:atol] || 1.0e-4 rtol = opts[:rtol] || 1.0e-4 + left_t = to_tensor(left) + right_t = to_tensor(right) + equals = - left - |> Nx.all_close(right, atol: atol, rtol: rtol) - |> Nx.backend_transfer(Nx.BinaryBackend) - |> Nx.to_flat_list() - |> Enum.all?(&(&1 == 1)) + left_t.vectorized_axes == right_t.vectorized_axes and + not shapes_incompatible?(left_t, right_t) and + left_t + |> Nx.all_close(right_t, atol: atol, rtol: rtol) + |> Nx.backend_transfer(Nx.BinaryBackend) + |> Nx.to_flat_list() + |> Enum.all?(&(&1 == 1)) if !equals do flunk(""" @@ -77,7 +128,72 @@ defmodule Nx.Testing do to be within tolerance of #{inspect(right)} + + (atol: #{atol}, rtol: #{rtol}) + + #{diagnose_difference(left, right)} """) end end + + # Produces a human-readable diagnostic describing how two tensors differ. + # If vectorized axes or shapes don't line up, returns a structural message. + # Otherwise computes max absolute and max relative difference across all + # elements (including vec axes) so bit-level disagreements hidden by + # truncated `inspect` output are still visible in the failure message. + defp diagnose_difference(left, right) do + left = to_tensor(left) + right = to_tensor(right) + + cond do + left.vectorized_axes != right.vectorized_axes -> + "vectorized_axes differ: left #{inspect(left.vectorized_axes)}, " <> + "right #{inspect(right.vectorized_axes)}" + + shapes_incompatible?(left, right) -> + "shapes differ: left #{inspect(Nx.shape(left))}, " <> + "right #{inspect(Nx.shape(right))}" + + true -> + numeric_diagnostic(left, right) + end + rescue + _ -> "" + end + + defp numeric_diagnostic(left, right) do + # Devectorize so reductions collapse across vec axes too, and so + # `Nx.to_number` on the final scalar doesn't hit a vectorized tensor. + left = if left.vectorized_axes == [], do: left, else: Nx.devectorize(left, keep_names: false) + + right = + if right.vectorized_axes == [], do: right, else: Nx.devectorize(right, keep_names: false) + + # Promote to a common numeric type so subtraction works for int/float mixes. + {left_f, right_f} = + case {Nx.type(left), Nx.type(right)} do + {{:f, _}, {:f, _}} -> {left, right} + {{:c, _}, _} -> {left, Nx.as_type(right, Nx.type(left))} + {_, {:c, _}} -> {Nx.as_type(left, Nx.type(right)), right} + _ -> {Nx.as_type(left, {:f, 32}), Nx.as_type(right, {:f, 32})} + end + + diff = Nx.subtract(left_f, right_f) |> Nx.abs() + max_abs = diff |> Nx.reduce_max() |> Nx.to_number() + + # Relative diff: |a - b| / max(|a|, |b|, tiny) to avoid divide by zero. + denom = + Nx.max(Nx.abs(left_f), Nx.abs(right_f)) + |> Nx.max(Nx.tensor(1.0e-30)) + + max_rel = Nx.divide(diff, denom) |> Nx.reduce_max() |> Nx.to_number() + + "max absolute difference: #{inspect(max_abs)}\n" <> + "max relative difference: #{inspect(max_rel)}" + rescue + # If the diff computation itself fails (mixed complex/real, NaN propagation, + # unusual types, etc.), fall back silently — the inspect output above is + # still shown. + _ -> "" + end end diff --git a/nx/test/nx/testing_test.exs b/nx/test/nx/testing_test.exs new file mode 100644 index 0000000000..c46d868dd5 --- /dev/null +++ b/nx/test/nx/testing_test.exs @@ -0,0 +1,172 @@ +defmodule Nx.TestingTest do + use ExUnit.Case, async: true + + import Nx.Testing + + describe "assert_all_close/3" do + test "passes on bit-identical tensors" do + a = Nx.tensor([1.0, 2.0, 3.0]) + assert_all_close(a, a) + end + + test "passes on tensors within tolerance" do + a = Nx.tensor([1.0, 2.0, 3.0]) + b = Nx.tensor([1.00001, 2.00001, 3.00001]) + assert_all_close(a, b, atol: 1.0e-4) + end + + test "passes on bit-identical vectorized tensors" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + assert_all_close(a, b) + end + + test "passes on vectorized tensors within tolerance" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.00001, 2.00001]) |> Nx.vectorize(:foo) + assert_all_close(a, b, atol: 1.0e-4) + end + + test "error message includes max absolute difference" do + a = Nx.tensor([1.0, 2.0, 3.0]) + b = Nx.tensor([1.0, 2.5, 3.0]) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_all_close(a, b, atol: 1.0e-4) + end + + assert error.message =~ "max absolute difference" + assert error.message =~ "0.5" + end + + test "error message includes max-diff diagnostic for vectorized tensors" do + # Exercises the case where both tensors have matching structure but + # differ in values by more than the tolerance — the diagnostic should + # include the max absolute difference so failures are still readable + # even when the inspect output truncates to look identical. + a = Nx.tensor([1.0, 2.0], type: :f32) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0001], type: :f32) |> Nx.vectorize(:foo) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_all_close(a, b, atol: 0.0, rtol: 0.0) + end + + assert error.message =~ "max absolute difference" + assert error.message =~ "max relative difference" + end + + test "error message shows a clear diagnostic for mismatched vec axes" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:bar) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_all_close(a, b) + end + + assert error.message =~ "vectorized_axes" + end + end + + describe "assert_equal/2" do + test "passes on bit-identical tensors" do + a = Nx.tensor([1, 2, 3]) + assert_equal(a, a) + end + + test "passes on bit-identical vectorized tensors" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + assert_equal(a, b) + end + + test "error message includes max absolute difference" do + a = Nx.tensor([1.0, 2.0, 3.0]) + b = Nx.tensor([1.0, 2.5, 3.0]) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_equal(a, b) + end + + assert error.message =~ "max absolute difference" + assert error.message =~ "0.5" + end + + test "error message shows a clear diagnostic for mismatched vec axes" do + a = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:foo) + b = Nx.tensor([1.0, 2.0]) |> Nx.vectorize(:bar) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_equal(a, b) + end + + assert error.message =~ "vectorized_axes" + end + + test "handles NaN equality correctly" do + a = Nx.tensor([:nan, 2.0, :nan]) + assert_equal(a, a) + end + + test "accepts a raw scalar against a scalar tensor" do + assert_equal(Nx.tensor(11), 11) + assert_equal(Nx.tensor(1.5), 1.5) + end + + test "accepts a raw scalar broadcast against a non-scalar tensor" do + # Matches the old Nx.equal broadcasting semantics that existing + # test suites (e.g. EXLA expr_test) rely on. `assert_equal(t, 1)` + # means "every element of t equals 1", not "t has shape {}". + assert_equal(Nx.tensor([1, 1, 1]), 1) + assert_equal(Nx.tensor([[1, 1], [1, 1]]), 1) + end + + test "rejects a scalar broadcast when values don't match" do + a = Nx.tensor([1, 1, 2]) + + assert_raise ExUnit.AssertionError, fn -> + assert_equal(a, 1) + end + end + + test "rejects a genuine shape mismatch between two non-scalar tensors" do + # The sharding test bug: result had shape {4,1} but expectation was + # {4,2}. Old code silently broadcast; we now reject. + a = Nx.tensor([[100], [102], [104], [106]]) + b = Nx.tensor([[100, 100], [102, 102], [104, 104], [106, 106]]) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_equal(a, b) + end + + assert error.message =~ "shapes differ" + end + end + + describe "assert_all_close/3 with scalar arguments" do + test "accepts a raw scalar against a scalar tensor" do + assert_all_close(Nx.tensor(1.5), 1.5) + end + + test "accepts a raw scalar broadcast against a non-scalar tensor" do + assert_all_close(Nx.tensor([1.0, 1.0001, 0.9999]), 1.0, atol: 1.0e-3) + end + + test "rejects a genuine shape mismatch" do + a = Nx.tensor([[1.0], [2.0]]) + b = Nx.tensor([[1.0, 1.0], [2.0, 2.0]]) + + error = + assert_raise ExUnit.AssertionError, fn -> + assert_all_close(a, b) + end + + assert error.message =~ "shapes differ" + end + end +end diff --git a/torchx/test/torchx/nx_test.exs b/torchx/test/torchx/nx_test.exs index 997dd8100e..2b551657d0 100644 --- a/torchx/test/torchx/nx_test.exs +++ b/torchx/test/torchx/nx_test.exs @@ -969,19 +969,21 @@ defmodule Torchx.NxTest do result, Nx.tensor([ [ - [15.0, 15.0], - [51.0, 51.0], - [87.0, 87.0] - ], - [ - [123.0, 123.0], - [159.0, 159.0], - [195.0, 195.0] - ], - [ - [231.0, 231.0], - [267.0, 267.0], - [303.0, 303.0] + [ + [15.0, 15.0], + [51.0, 51.0], + [87.0, 87.0] + ], + [ + [123.0, 123.0], + [159.0, 159.0], + [195.0, 195.0] + ], + [ + [231.0, 231.0], + [267.0, 267.0], + [303.0, 303.0] + ] ] ]) ) @@ -997,7 +999,7 @@ defmodule Torchx.NxTest do assert_equal( result, - Nx.tensor([[0, 0, -1, 1, 0, -2], [-3, 0, 4, -4, 0, 5]]) + Nx.tensor([[[[[0, 0, -1, 1, 0, -2], [-3, 0, 4, -4, 0, 5]]]]]) ) end end