Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ defmodule EXLA.Defn do
%T{
data: %Expr{
args: [
%Nx.Block.QR{},
%Nx.Block.LinAlg.QR{},
[tensor],
{%{type: {type_kind, _}} = q_expr, r_expr},
_callback
Expand Down Expand Up @@ -636,7 +636,7 @@ defmodule EXLA.Defn do
%T{
data: %Expr{
args: [
%Nx.Block.Eigh{},
%Nx.Block.LinAlg.Eigh{},
[tensor],
{%{type: {evec_type_kind, _}} = eigenvals_expr,
%{type: {eval_type_kind, _}} = eigenvecs_expr},
Expand Down
14 changes: 7 additions & 7 deletions nx/lib/nx/block.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@ defmodule Nx.Block.Phase do
defstruct []
end

defmodule Nx.Block.Cholesky do
defmodule Nx.Block.LinAlg.Cholesky do
defstruct []
end

defmodule Nx.Block.Solve do
defmodule Nx.Block.LinAlg.Solve do
defstruct []
end

defmodule Nx.Block.QR do
defmodule Nx.Block.LinAlg.QR do
defstruct eps: 1.0e-10, mode: :reduced
end

defmodule Nx.Block.Eigh do
defmodule Nx.Block.LinAlg.Eigh do
defstruct max_iter: 1000, eps: 1.0e-4
end

defmodule Nx.Block.SVD do
defmodule Nx.Block.LinAlg.SVD do
defstruct max_iter: 100, full_matrices?: true
end

defmodule Nx.Block.LU do
defmodule Nx.Block.LinAlg.LU do
defstruct eps: 1.0e-10
end

defmodule Nx.Block.Determinant do
defmodule Nx.Block.LinAlg.Determinant do
defstruct []
end

Expand Down
37 changes: 22 additions & 15 deletions nx/lib/nx/lin_alg.ex
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ defmodule Nx.LinAlg do

out = %{tensor | type: output_type, shape: output_shape, names: output_names}

Nx.block(%Nx.Block.Cholesky{}, [tensor], out, fn %Nx.Block.Cholesky{}, t ->
Nx.block(%Nx.Block.LinAlg.Cholesky{}, [tensor], out, fn %Nx.Block.LinAlg.Cholesky{}, t ->
Nx.LinAlg.Cholesky.cholesky(t)
end)
|> Nx.vectorize(vectorized_axes)
Expand Down Expand Up @@ -715,7 +715,7 @@ defmodule Nx.LinAlg do
output = Nx.template(output_shape, output_type)

result =
Nx.block(%Nx.Block.Solve{}, [a, b], output, fn %Nx.Block.Solve{}, a, b ->
Nx.block(%Nx.Block.LinAlg.Solve{}, [a, b], output, fn %Nx.Block.LinAlg.Solve{}, a, b ->
# Since we have triangular solve, which accepts upper
# triangular matrices with the `lower: false` option,
# we can solve a system as follows:
Expand Down Expand Up @@ -1155,7 +1155,8 @@ defmodule Nx.LinAlg do
names: List.duplicate(nil, tuple_size(r_shape))
}}

Nx.block(struct!(Nx.Block.QR, opts), [tensor], output, fn %Nx.Block.QR{} = s, t ->
Nx.block(struct!(Nx.Block.LinAlg.QR, opts), [tensor], output, fn %Nx.Block.LinAlg.QR{} = s,
t ->
opts = s |> Map.from_struct() |> Map.to_list()
Nx.LinAlg.QR.qr(t, opts)
end)
Expand Down Expand Up @@ -1402,7 +1403,8 @@ defmodule Nx.LinAlg do
{%{tensor | names: eigenvals_name, type: output_type, shape: eigenvals_shape},
%{tensor | names: eigenvecs_name, type: output_type, shape: eigenvecs_shape}}

Nx.block(struct!(Nx.Block.Eigh, opts), [tensor], output, fn %Nx.Block.Eigh{}, t ->
Nx.block(struct!(Nx.Block.LinAlg.Eigh, opts), [tensor], output, fn %Nx.Block.LinAlg.Eigh{},
t ->
Nx.LinAlg.BlockEigh.eigh(t, opts)
end)
|> Nx.vectorize(vectorized_axes)
Expand Down Expand Up @@ -1523,7 +1525,7 @@ defmodule Nx.LinAlg do
%{tensor | names: List.duplicate(nil, rank - 1), type: output_type, shape: s_shape},
%{tensor | names: List.duplicate(nil, rank), type: output_type, shape: v_shape}}

Nx.block(struct!(Nx.Block.SVD, opts), [tensor], output, fn %Nx.Block.SVD{}, t ->
Nx.block(struct!(Nx.Block.LinAlg.SVD, opts), [tensor], output, fn %Nx.Block.LinAlg.SVD{}, t ->
Nx.LinAlg.SVD.svd(t, opts)
end)
|> Nx.vectorize(vectorized_axes)
Expand Down Expand Up @@ -1748,7 +1750,7 @@ defmodule Nx.LinAlg do
%{tensor | type: output_type, shape: l_shape, names: names},
%{tensor | type: output_type, shape: u_shape, names: names}}

Nx.block(%Nx.Block.LU{}, [tensor], output, fn %Nx.Block.LU{}, t ->
Nx.block(%Nx.Block.LinAlg.LU{}, [tensor], output, fn %Nx.Block.LinAlg.LU{}, t ->
Nx.LinAlg.LU.lu(t)
end)
|> Nx.vectorize(vectorized_axes)
Expand Down Expand Up @@ -2001,18 +2003,23 @@ defmodule Nx.LinAlg do
"determinant/1 expects a square tensor, got tensor with shape: #{inspect(shape)}"
end

Nx.block(%Nx.Block.Determinant{}, [tensor], output, fn %Nx.Block.Determinant{}, t ->
case matrix_shape do
[2, 2] ->
determinant_2by2(t)
Nx.block(
%Nx.Block.LinAlg.Determinant{},
[tensor],
output,
fn %Nx.Block.LinAlg.Determinant{}, t ->
case matrix_shape do
[2, 2] ->
determinant_2by2(t)

[3, 3] ->
determinant_3by3(t)
[3, 3] ->
determinant_3by3(t)

[n, n] ->
determinant_NbyN(t, batch_shape_n: List.to_tuple(batch_shape ++ [n]))
[n, n] ->
determinant_NbyN(t, batch_shape_n: List.to_tuple(batch_shape ++ [n]))
end
end
end)
)
end)
end

Expand Down
12 changes: 6 additions & 6 deletions nx/test/nx/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ defmodule Nx.Defn.ExprTest do
f32[2][2]
\s\s
Nx.Defn.Expr
parameter a:0 s32[2][2]
b = block %Nx.Block.QR{eps: 1.0e-10, mode: :reduced}, a tuple2
c = elem b, 0 f32[2][2]
parameter a:0 s32[2][2]
b = block %Nx.Block.LinAlg.QR{eps: 1.0e-10, mode: :reduced}, a tuple2
c = elem b, 0 f32[2][2]
>\
"""

Expand All @@ -337,9 +337,9 @@ defmodule Nx.Defn.ExprTest do
f32[2][2]
\s\s
Nx.Defn.Expr
parameter a:0 s32[2][2]
b = block %Nx.Block.QR{eps: 1.0e-10, mode: :reduced}, a tuple2
c = elem b, 1 f32[2][2]
parameter a:0 s32[2][2]
b = block %Nx.Block.LinAlg.QR{eps: 1.0e-10, mode: :reduced}, a tuple2
c = elem b, 1 f32[2][2]
>\
"""
end
Expand Down
12 changes: 6 additions & 6 deletions nx/test/nx/optional_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ defmodule Nx.OptionalTest do
|> Nx.backend_transfer(__MODULE__)
end

def block(%Nx.Block.Solve{}, out, [a, b], _fun), do: solve(out, a, b)
def block(%Nx.Block.Determinant{}, out, [t], _fun), do: determinant(out, t)
def block(%Nx.Block.LinAlg.Solve{}, out, [a, b], _fun), do: solve(out, a, b)
def block(%Nx.Block.LinAlg.Determinant{}, out, [t], _fun), do: determinant(out, t)

def block(struct, _output, args, fun) do
apply(fun, [struct | args])
Expand Down Expand Up @@ -254,8 +254,8 @@ defmodule Nx.OptionalTest do
f32
\s\s
Nx.Defn.Expr
parameter a:0 s32[3][3]
b = block %Nx.Block.Determinant{}, a f32
parameter a:0 s32[3][3]
b = block %Nx.Block.LinAlg.Determinant{}, a f32
>
"""
end
Expand All @@ -270,8 +270,8 @@ defmodule Nx.OptionalTest do
f32
\s\s
Nx.Defn.Expr
parameter a:0 s32[3][3]
b = block %Nx.Block.Determinant{}, a f32
parameter a:0 s32[3][3]
b = block %Nx.Block.LinAlg.Determinant{}, a f32
>
"""
end
Expand Down
24 changes: 12 additions & 12 deletions torchx/lib/torchx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ defmodule Torchx.Backend do
# For MPS device, some linear algebra operations are not supported.
# Delegate to default implementation which will fall back to elementary Nx operations.
mps_unsupported = [
Nx.Block.LU,
Nx.Block.Eigh,
Nx.Block.Solve,
Nx.Block.Determinant,
Nx.Block.Cholesky
Nx.Block.LinAlg.LU,
Nx.Block.LinAlg.Eigh,
Nx.Block.LinAlg.Solve,
Nx.Block.LinAlg.Determinant,
Nx.Block.LinAlg.Cholesky
]

device =
Expand All @@ -70,25 +70,25 @@ defmodule Torchx.Backend do
apply(fun, [struct | args])
else
case {block_name, args} do
{Nx.Block.QR, [t]} ->
{Nx.Block.LinAlg.QR, [t]} ->
qr_impl(t, mode: struct.mode, eps: struct.eps)

{Nx.Block.LU, [t]} ->
{Nx.Block.LinAlg.LU, [t]} ->
lu_impl(t)

{Nx.Block.Eigh, [t]} ->
{Nx.Block.LinAlg.Eigh, [t]} ->
eigh_impl(t, max_iter: struct.max_iter, eps: struct.eps)

{Nx.Block.Solve, [a, b]} ->
{Nx.Block.LinAlg.Solve, [a, b]} ->
solve_impl(a, b)

{Nx.Block.Cholesky, [t]} ->
{Nx.Block.LinAlg.Cholesky, [t]} ->
cholesky_impl(t)

{Nx.Block.SVD, [t]} ->
{Nx.Block.LinAlg.SVD, [t]} ->
svd_impl(t, max_iter: struct.max_iter, full_matrices?: struct.full_matrices?)

{Nx.Block.Determinant, [t]} ->
{Nx.Block.LinAlg.Determinant, [t]} ->
determinant_impl(t)

{Nx.Block.TakeAlongAxis, [tensor, indices]} ->
Expand Down
Loading