diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 3f6ba7def1..602b4972c6 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -605,7 +605,7 @@ defmodule EXLA.Defn do %T{ data: %Expr{ args: [ - %Nx.Block.QR{}, + %Nx.Block.LinAlg.QR{}, [tensor], {%{type: {type_kind, _}} = q_expr, r_expr}, _callback @@ -636,7 +636,7 @@ defmodule EXLA.Defn do %T{ data: %Expr{ args: [ - %Nx.Block.Eigh{}, + %Nx.Block.LinAlg.Eigh{}, [tensor], {%{type: {evec_type_kind, _}} = eigenvals_expr, %{type: {eval_type_kind, _}} = eigenvecs_expr}, diff --git a/nx/lib/nx/block.ex b/nx/lib/nx/block.ex index 944955ca70..d7f1660cc3 100644 --- a/nx/lib/nx/block.ex +++ b/nx/lib/nx/block.ex @@ -6,31 +6,31 @@ defmodule Nx.Block.Phase do defstruct [] end -defmodule Nx.Block.Cholesky do +defmodule Nx.Block.LinAlg.Cholesky do defstruct [] end -defmodule Nx.Block.Solve do +defmodule Nx.Block.LinAlg.Solve do defstruct [] end -defmodule Nx.Block.QR do +defmodule Nx.Block.LinAlg.QR do defstruct eps: 1.0e-10, mode: :reduced end -defmodule Nx.Block.Eigh do +defmodule Nx.Block.LinAlg.Eigh do defstruct max_iter: 1000, eps: 1.0e-4 end -defmodule Nx.Block.SVD do +defmodule Nx.Block.LinAlg.SVD do defstruct max_iter: 100, full_matrices?: true end -defmodule Nx.Block.LU do +defmodule Nx.Block.LinAlg.LU do defstruct eps: 1.0e-10 end -defmodule Nx.Block.Determinant do +defmodule Nx.Block.LinAlg.Determinant do defstruct [] end diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 519d64d623..9b8bf61b89 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -149,7 +149,7 @@ defmodule Nx.LinAlg do out = %{tensor | type: output_type, shape: output_shape, names: output_names} - Nx.block(%Nx.Block.Cholesky{}, [tensor], out, fn %Nx.Block.Cholesky{}, t -> + Nx.block(%Nx.Block.LinAlg.Cholesky{}, [tensor], out, fn %Nx.Block.LinAlg.Cholesky{}, t -> Nx.LinAlg.Cholesky.cholesky(t) end) |> Nx.vectorize(vectorized_axes) @@ -715,7 +715,7 @@ defmodule Nx.LinAlg do output = Nx.template(output_shape, output_type) result = - Nx.block(%Nx.Block.Solve{}, [a, b], output, fn %Nx.Block.Solve{}, a, b -> + Nx.block(%Nx.Block.LinAlg.Solve{}, [a, b], output, fn %Nx.Block.LinAlg.Solve{}, a, b -> # Since we have triangular solve, which accepts upper # triangular matrices with the `lower: false` option, # we can solve a system as follows: @@ -1155,7 +1155,8 @@ defmodule Nx.LinAlg do names: List.duplicate(nil, tuple_size(r_shape)) }} - Nx.block(struct!(Nx.Block.QR, opts), [tensor], output, fn %Nx.Block.QR{} = s, t -> + Nx.block(struct!(Nx.Block.LinAlg.QR, opts), [tensor], output, fn %Nx.Block.LinAlg.QR{} = s, + t -> opts = s |> Map.from_struct() |> Map.to_list() Nx.LinAlg.QR.qr(t, opts) end) @@ -1402,7 +1403,8 @@ defmodule Nx.LinAlg do {%{tensor | names: eigenvals_name, type: output_type, shape: eigenvals_shape}, %{tensor | names: eigenvecs_name, type: output_type, shape: eigenvecs_shape}} - Nx.block(struct!(Nx.Block.Eigh, opts), [tensor], output, fn %Nx.Block.Eigh{}, t -> + Nx.block(struct!(Nx.Block.LinAlg.Eigh, opts), [tensor], output, fn %Nx.Block.LinAlg.Eigh{}, + t -> Nx.LinAlg.BlockEigh.eigh(t, opts) end) |> Nx.vectorize(vectorized_axes) @@ -1523,7 +1525,7 @@ defmodule Nx.LinAlg do %{tensor | names: List.duplicate(nil, rank - 1), type: output_type, shape: s_shape}, %{tensor | names: List.duplicate(nil, rank), type: output_type, shape: v_shape}} - Nx.block(struct!(Nx.Block.SVD, opts), [tensor], output, fn %Nx.Block.SVD{}, t -> + Nx.block(struct!(Nx.Block.LinAlg.SVD, opts), [tensor], output, fn %Nx.Block.LinAlg.SVD{}, t -> Nx.LinAlg.SVD.svd(t, opts) end) |> Nx.vectorize(vectorized_axes) @@ -1748,7 +1750,7 @@ defmodule Nx.LinAlg do %{tensor | type: output_type, shape: l_shape, names: names}, %{tensor | type: output_type, shape: u_shape, names: names}} - Nx.block(%Nx.Block.LU{}, [tensor], output, fn %Nx.Block.LU{}, t -> + Nx.block(%Nx.Block.LinAlg.LU{}, [tensor], output, fn %Nx.Block.LinAlg.LU{}, t -> Nx.LinAlg.LU.lu(t) end) |> Nx.vectorize(vectorized_axes) @@ -2001,18 +2003,23 @@ defmodule Nx.LinAlg do "determinant/1 expects a square tensor, got tensor with shape: #{inspect(shape)}" end - Nx.block(%Nx.Block.Determinant{}, [tensor], output, fn %Nx.Block.Determinant{}, t -> - case matrix_shape do - [2, 2] -> - determinant_2by2(t) + Nx.block( + %Nx.Block.LinAlg.Determinant{}, + [tensor], + output, + fn %Nx.Block.LinAlg.Determinant{}, t -> + case matrix_shape do + [2, 2] -> + determinant_2by2(t) - [3, 3] -> - determinant_3by3(t) + [3, 3] -> + determinant_3by3(t) - [n, n] -> - determinant_NbyN(t, batch_shape_n: List.to_tuple(batch_shape ++ [n])) + [n, n] -> + determinant_NbyN(t, batch_shape_n: List.to_tuple(batch_shape ++ [n])) + end end - end) + ) end) end diff --git a/nx/test/nx/defn/expr_test.exs b/nx/test/nx/defn/expr_test.exs index 0815dae25d..e767b19ced 100644 --- a/nx/test/nx/defn/expr_test.exs +++ b/nx/test/nx/defn/expr_test.exs @@ -326,9 +326,9 @@ defmodule Nx.Defn.ExprTest do f32[2][2] \s\s Nx.Defn.Expr - parameter a:0 s32[2][2] - b = block %Nx.Block.QR{eps: 1.0e-10, mode: :reduced}, a tuple2 - c = elem b, 0 f32[2][2] + parameter a:0 s32[2][2] + b = block %Nx.Block.LinAlg.QR{eps: 1.0e-10, mode: :reduced}, a tuple2 + c = elem b, 0 f32[2][2] >\ """ @@ -337,9 +337,9 @@ defmodule Nx.Defn.ExprTest do f32[2][2] \s\s Nx.Defn.Expr - parameter a:0 s32[2][2] - b = block %Nx.Block.QR{eps: 1.0e-10, mode: :reduced}, a tuple2 - c = elem b, 1 f32[2][2] + parameter a:0 s32[2][2] + b = block %Nx.Block.LinAlg.QR{eps: 1.0e-10, mode: :reduced}, a tuple2 + c = elem b, 1 f32[2][2] >\ """ end diff --git a/nx/test/nx/optional_test.exs b/nx/test/nx/optional_test.exs index 608ddfb00d..5cc1fec932 100644 --- a/nx/test/nx/optional_test.exs +++ b/nx/test/nx/optional_test.exs @@ -27,8 +27,8 @@ defmodule Nx.OptionalTest do |> Nx.backend_transfer(__MODULE__) end - def block(%Nx.Block.Solve{}, out, [a, b], _fun), do: solve(out, a, b) - def block(%Nx.Block.Determinant{}, out, [t], _fun), do: determinant(out, t) + def block(%Nx.Block.LinAlg.Solve{}, out, [a, b], _fun), do: solve(out, a, b) + def block(%Nx.Block.LinAlg.Determinant{}, out, [t], _fun), do: determinant(out, t) def block(struct, _output, args, fun) do apply(fun, [struct | args]) @@ -254,8 +254,8 @@ defmodule Nx.OptionalTest do f32 \s\s Nx.Defn.Expr - parameter a:0 s32[3][3] - b = block %Nx.Block.Determinant{}, a f32 + parameter a:0 s32[3][3] + b = block %Nx.Block.LinAlg.Determinant{}, a f32 > """ end @@ -270,8 +270,8 @@ defmodule Nx.OptionalTest do f32 \s\s Nx.Defn.Expr - parameter a:0 s32[3][3] - b = block %Nx.Block.Determinant{}, a f32 + parameter a:0 s32[3][3] + b = block %Nx.Block.LinAlg.Determinant{}, a f32 > """ end diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index ea9abb93cd..3bbc195b0e 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -53,11 +53,11 @@ defmodule Torchx.Backend do # For MPS device, some linear algebra operations are not supported. # Delegate to default implementation which will fall back to elementary Nx operations. mps_unsupported = [ - Nx.Block.LU, - Nx.Block.Eigh, - Nx.Block.Solve, - Nx.Block.Determinant, - Nx.Block.Cholesky + Nx.Block.LinAlg.LU, + Nx.Block.LinAlg.Eigh, + Nx.Block.LinAlg.Solve, + Nx.Block.LinAlg.Determinant, + Nx.Block.LinAlg.Cholesky ] device = @@ -70,25 +70,25 @@ defmodule Torchx.Backend do apply(fun, [struct | args]) else case {block_name, args} do - {Nx.Block.QR, [t]} -> + {Nx.Block.LinAlg.QR, [t]} -> qr_impl(t, mode: struct.mode, eps: struct.eps) - {Nx.Block.LU, [t]} -> + {Nx.Block.LinAlg.LU, [t]} -> lu_impl(t) - {Nx.Block.Eigh, [t]} -> + {Nx.Block.LinAlg.Eigh, [t]} -> eigh_impl(t, max_iter: struct.max_iter, eps: struct.eps) - {Nx.Block.Solve, [a, b]} -> + {Nx.Block.LinAlg.Solve, [a, b]} -> solve_impl(a, b) - {Nx.Block.Cholesky, [t]} -> + {Nx.Block.LinAlg.Cholesky, [t]} -> cholesky_impl(t) - {Nx.Block.SVD, [t]} -> + {Nx.Block.LinAlg.SVD, [t]} -> svd_impl(t, max_iter: struct.max_iter, full_matrices?: struct.full_matrices?) - {Nx.Block.Determinant, [t]} -> + {Nx.Block.LinAlg.Determinant, [t]} -> determinant_impl(t) {Nx.Block.TakeAlongAxis, [tensor, indices]} ->