Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions exla/test/exla/defn/sharding_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -63,35 +63,37 @@ defmodule EXLA.Defn.ShardingTest do
# y[[100],[101],[102],[103]] broadcasts to [[100,100],[101,101],[102,102],[103,103]]
# x+y = [[100,100],[102,102],[104,104],[106,106]], device 0 gets col 0
assert_equal(result1_d0, Nx.tensor([[100], [102], [104], [106]]))
# Second output (y*2): y broadcasts to {4,2}, multiply by 2, result is {4,2}
# [[100],[101],[102],[103]] * 2 = [[200],[202],[204],[206]], broadcasts to [[200,200],[202,202],[204,204],[206,206]]
assert_equal(result2_d0, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]]))
# Second output (y*2): each partition has y={4,1}, so y*2 is {4,1} per partition.
# y_dev0=[[100],[101],[102],[103]] * 2 = [[200],[202],[204],[206]]
assert_equal(result2_d0, Nx.tensor([[200], [202], [204], [206]]))

# Device 1: x={4,1} rows[0-3]col[1], y={4,1} rows[0-3]
{result1_d1, result2_d1} = Enum.at(results, 1)
# First output: x[[10],[11],[12],[13]] + y[[100],[101],[102],[103]]
# broadcasts to [[10,10],[11,11],[12,12],[13,13]] + [[100,100],[101,101],[102,102],[103,103]]
# = [[110,110],[112,112],[114,114],[116,116]], device 1 gets col 1
assert_equal(result1_d1, Nx.tensor([[110], [112], [114], [116]]))
# Second output: same as device 0 (y is replicated across axis 1)
assert_equal(result2_d1, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]]))

# Second output: same as device 0 (y is replicated across axis 1, so d1 sees the same y_dev shard).
assert_equal(result2_d1, Nx.tensor([[200], [202], [204], [206]]))

# Device 2: x={4,1} rows[4-7]col[0], y={4,1} rows[4-7]
{result1_d2, result2_d2} = Enum.at(results, 2)
# First output: x[[4],[5],[6],[7]] + y[[104],[105],[106],[107]]
# = [[108,108],[110,110],[112,112],[114,114]], device 2 gets col 0
assert_equal(result1_d2, Nx.tensor([[108], [110], [112], [114]]))
# Second output: [[104],[105],[106],[107]] * 2, broadcasts to [[208,208]...]
assert_equal(result2_d2, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]]))
# Second output: y_dev2=[[104],[105],[106],[107]] * 2 = [[208],[210],[212],[214]]
assert_equal(result2_d2, Nx.tensor([[208], [210], [212], [214]]))

# Device 3: x={4,1} rows[4-7]col[1], y={4,1} rows[4-7]
{result1_d3, result2_d3} = Enum.at(results, 3)
# First output: x[[14],[15],[16],[17]] + y[[104],[105],[106],[107]]
# broadcasts to [[14,14],[15,15],[16,16],[17,17]] + [[104,104],[105,105],[106,106],[107,107]]
# = [[118,118],[120,120],[122,122],[124,124]], device 3 gets col 1
assert_equal(result1_d3, Nx.tensor([[118], [120], [122], [124]]))
# Second output: same as device 2 (y is replicated across axis 1)
assert_equal(result2_d3, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]]))

# Second output: same as device 2 (y is replicated across axis 1, so d3 sees the same y_dev shard).
assert_equal(result2_d3, Nx.tensor([[208], [210], [212], [214]]))
end

@moduletag :multi_device
Expand Down Expand Up @@ -176,13 +178,15 @@ defmodule EXLA.Defn.ShardingTest do
assert [result0, result1, result2, result3] =
EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args)

assert_equal(result0, Nx.tensor([[100, 100], [102, 102], [104, 104], [106, 106]]))
# Each partition's output is {4,1}: x_shard + y_shard = {4,1} + {4,1}
# (y input is {8,1}, sharded on axis 0 only → {4,1} per partition).
assert_equal(result0, Nx.tensor([[100], [102], [104], [106]]))
assert result0.data.buffer.device_id == 0
assert_equal(result1, Nx.tensor([[110, 110], [112, 112], [114, 114], [116, 116]]))
assert_equal(result1, Nx.tensor([[110], [112], [114], [116]]))
assert result1.data.buffer.device_id == 1
assert_equal(result2, Nx.tensor([[108, 108], [110, 110], [112, 112], [114, 114]]))
assert_equal(result2, Nx.tensor([[108], [110], [112], [114]]))
assert result2.data.buffer.device_id == 2
assert_equal(result3, Nx.tensor([[118, 118], [120, 120], [122, 122], [124, 124]]))
assert_equal(result3, Nx.tensor([[118], [120], [122], [124]]))
assert result3.data.buffer.device_id == 3
end

Expand Down
148 changes: 132 additions & 16 deletions nx/lib/nx/testing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ defmodule Nx.Testing do
Testing functions for Nx tensor assertions.

This module provides functions for asserting tensor equality and
approximate equality within specified tolerances.
approximate equality within specified tolerances. Both helpers handle
vectorized tensors and produce a numeric diagnostic (max absolute /
relative difference) on failure so that bit-level disagreements
hidden by truncated `inspect` output are still diagnosable.
"""

import ExUnit.Assertions
Expand All @@ -13,6 +16,8 @@ defmodule Nx.Testing do
Asserts that two tensors are exactly equal.

This handles NaN values correctly by considering NaN == NaN as true.
Works with vectorized tensors — two tensors must share the same
vectorized axes to be considered equal.
"""
def assert_equal(left, right) when not is_tensor(left) or not is_tensor(right) do
if not Nx.Defn.Composite.compatible?(left, right, &tensor_equal?/2) do
Expand All @@ -28,26 +33,67 @@ defmodule Nx.Testing do
if not tensor_equal?(left, right) do
flunk("""
Tensor assertion failed.
left: #{inspect(left)}
right: #{inspect(right)}

left:

#{inspect(left)}

right:

#{inspect(right)}

#{diagnose_difference(left, right)}
""")
end
end

defp tensor_equal?(left, right) do
both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))

left
|> Nx.equal(right)
|> Nx.logical_or(both_nan)
|> Nx.all()
|> Nx.to_flat_list()
|> Enum.all?(&(&1 == 1))
left = to_tensor(left)
right = to_tensor(right)

cond do
left.vectorized_axes != right.vectorized_axes ->
false

shapes_incompatible?(left, right) ->
false

true ->
both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))

left
|> Nx.equal(right)
|> Nx.logical_or(both_nan)
|> Nx.all()
|> Nx.to_flat_list()
|> Enum.all?(&(&1 == 1))
end
end

# Wrap raw scalars/lists in tensors so the struct-field accesses
# (`.vectorized_axes`) and `Nx.shape/1` below don't crash. Tensors
# pass through unchanged.
defp to_tensor(%Nx.Tensor{} = t), do: t
defp to_tensor(other), do: Nx.tensor(other)

# Genuine shape mismatches are rejected, but we still allow a scalar
# (shape `{}`) to compare against a tensor of any shape — that's the
# intentional "assert every element equals this scalar" pattern, and
# rejecting it would break a large number of existing tests that
# relied on `Nx.equal`'s broadcasting.
defp shapes_incompatible?(left, right) do
ls = Nx.shape(left)
rs = Nx.shape(right)
ls != rs and ls != {} and rs != {}
end

@doc """
Asserts that two tensors are approximately equal within the given tolerances.

Works with vectorized tensors — the comparison is per vectorized
instance, then aggregated. Two tensors must share the same vectorized
axes to be comparable.

See also:

* `Nx.all_close/2` - The underlying function that performs the comparison.
Expand All @@ -61,12 +107,17 @@ defmodule Nx.Testing do
atol = opts[:atol] || 1.0e-4
rtol = opts[:rtol] || 1.0e-4

left_t = to_tensor(left)
right_t = to_tensor(right)

equals =
left
|> Nx.all_close(right, atol: atol, rtol: rtol)
|> Nx.backend_transfer(Nx.BinaryBackend)
|> Nx.to_flat_list()
|> Enum.all?(&(&1 == 1))
left_t.vectorized_axes == right_t.vectorized_axes and
not shapes_incompatible?(left_t, right_t) and
left_t
|> Nx.all_close(right_t, atol: atol, rtol: rtol)
|> Nx.backend_transfer(Nx.BinaryBackend)
|> Nx.to_flat_list()
|> Enum.all?(&(&1 == 1))

if !equals do
flunk("""
Expand All @@ -77,7 +128,72 @@ defmodule Nx.Testing do
to be within tolerance of

#{inspect(right)}

(atol: #{atol}, rtol: #{rtol})

#{diagnose_difference(left, right)}
""")
end
end

# Produces a human-readable diagnostic describing how two tensors differ.
# If vectorized axes or shapes don't line up, returns a structural message.
# Otherwise computes max absolute and max relative difference across all
# elements (including vec axes) so bit-level disagreements hidden by
# truncated `inspect` output are still visible in the failure message.
defp diagnose_difference(left, right) do
left = to_tensor(left)
right = to_tensor(right)

cond do
left.vectorized_axes != right.vectorized_axes ->
"vectorized_axes differ: left #{inspect(left.vectorized_axes)}, " <>
"right #{inspect(right.vectorized_axes)}"

shapes_incompatible?(left, right) ->
"shapes differ: left #{inspect(Nx.shape(left))}, " <>
"right #{inspect(Nx.shape(right))}"

true ->
numeric_diagnostic(left, right)
end
rescue
_ -> ""
end

defp numeric_diagnostic(left, right) do
# Devectorize so reductions collapse across vec axes too, and so
# `Nx.to_number` on the final scalar doesn't hit a vectorized tensor.
left = if left.vectorized_axes == [], do: left, else: Nx.devectorize(left, keep_names: false)

right =
if right.vectorized_axes == [], do: right, else: Nx.devectorize(right, keep_names: false)

# Promote to a common numeric type so subtraction works for int/float mixes.
{left_f, right_f} =
case {Nx.type(left), Nx.type(right)} do
{{:f, _}, {:f, _}} -> {left, right}
{{:c, _}, _} -> {left, Nx.as_type(right, Nx.type(left))}
{_, {:c, _}} -> {Nx.as_type(left, Nx.type(right)), right}
_ -> {Nx.as_type(left, {:f, 32}), Nx.as_type(right, {:f, 32})}
end

diff = Nx.subtract(left_f, right_f) |> Nx.abs()
max_abs = diff |> Nx.reduce_max() |> Nx.to_number()

# Relative diff: |a - b| / max(|a|, |b|, tiny) to avoid divide by zero.
denom =
Nx.max(Nx.abs(left_f), Nx.abs(right_f))
|> Nx.max(Nx.tensor(1.0e-30))

max_rel = Nx.divide(diff, denom) |> Nx.reduce_max() |> Nx.to_number()

"max absolute difference: #{inspect(max_abs)}\n" <>
"max relative difference: #{inspect(max_rel)}"
rescue
# If the diff computation itself fails (mixed complex/real, NaN propagation,
# unusual types, etc.), fall back silently — the inspect output above is
# still shown.
_ -> ""
end
end
Loading
Loading