Skip to content

Commit a53bbae

Browse files
authored
chore: rename linalg blocks (#1738)
1 parent cccc4f7 commit a53bbae

6 files changed

Lines changed: 55 additions & 48 deletions

File tree

exla/lib/exla/defn.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ defmodule EXLA.Defn do
605605
%T{
606606
data: %Expr{
607607
args: [
608-
%Nx.Block.QR{},
608+
%Nx.Block.LinAlg.QR{},
609609
[tensor],
610610
{%{type: {type_kind, _}} = q_expr, r_expr},
611611
_callback
@@ -636,7 +636,7 @@ defmodule EXLA.Defn do
636636
%T{
637637
data: %Expr{
638638
args: [
639-
%Nx.Block.Eigh{},
639+
%Nx.Block.LinAlg.Eigh{},
640640
[tensor],
641641
{%{type: {evec_type_kind, _}} = eigenvals_expr,
642642
%{type: {eval_type_kind, _}} = eigenvecs_expr},

nx/lib/nx/block.ex

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,31 @@ defmodule Nx.Block.Phase do
66
defstruct []
77
end
88

9-
defmodule Nx.Block.Cholesky do
9+
defmodule Nx.Block.LinAlg.Cholesky do
1010
defstruct []
1111
end
1212

13-
defmodule Nx.Block.Solve do
13+
defmodule Nx.Block.LinAlg.Solve do
1414
defstruct []
1515
end
1616

17-
defmodule Nx.Block.QR do
17+
defmodule Nx.Block.LinAlg.QR do
1818
defstruct eps: 1.0e-10, mode: :reduced
1919
end
2020

21-
defmodule Nx.Block.Eigh do
21+
defmodule Nx.Block.LinAlg.Eigh do
2222
defstruct max_iter: 1000, eps: 1.0e-4
2323
end
2424

25-
defmodule Nx.Block.SVD do
25+
defmodule Nx.Block.LinAlg.SVD do
2626
defstruct max_iter: 100, full_matrices?: true
2727
end
2828

29-
defmodule Nx.Block.LU do
29+
defmodule Nx.Block.LinAlg.LU do
3030
defstruct eps: 1.0e-10
3131
end
3232

33-
defmodule Nx.Block.Determinant do
33+
defmodule Nx.Block.LinAlg.Determinant do
3434
defstruct []
3535
end
3636

nx/lib/nx/lin_alg.ex

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ defmodule Nx.LinAlg do
149149

150150
out = %{tensor | type: output_type, shape: output_shape, names: output_names}
151151

152-
Nx.block(%Nx.Block.Cholesky{}, [tensor], out, fn %Nx.Block.Cholesky{}, t ->
152+
Nx.block(%Nx.Block.LinAlg.Cholesky{}, [tensor], out, fn %Nx.Block.LinAlg.Cholesky{}, t ->
153153
Nx.LinAlg.Cholesky.cholesky(t)
154154
end)
155155
|> Nx.vectorize(vectorized_axes)
@@ -715,7 +715,7 @@ defmodule Nx.LinAlg do
715715
output = Nx.template(output_shape, output_type)
716716

717717
result =
718-
Nx.block(%Nx.Block.Solve{}, [a, b], output, fn %Nx.Block.Solve{}, a, b ->
718+
Nx.block(%Nx.Block.LinAlg.Solve{}, [a, b], output, fn %Nx.Block.LinAlg.Solve{}, a, b ->
719719
# Since we have triangular solve, which accepts upper
720720
# triangular matrices with the `lower: false` option,
721721
# we can solve a system as follows:
@@ -1155,7 +1155,8 @@ defmodule Nx.LinAlg do
11551155
names: List.duplicate(nil, tuple_size(r_shape))
11561156
}}
11571157

1158-
Nx.block(struct!(Nx.Block.QR, opts), [tensor], output, fn %Nx.Block.QR{} = s, t ->
1158+
Nx.block(struct!(Nx.Block.LinAlg.QR, opts), [tensor], output, fn %Nx.Block.LinAlg.QR{} = s,
1159+
t ->
11591160
opts = s |> Map.from_struct() |> Map.to_list()
11601161
Nx.LinAlg.QR.qr(t, opts)
11611162
end)
@@ -1402,7 +1403,8 @@ defmodule Nx.LinAlg do
14021403
{%{tensor | names: eigenvals_name, type: output_type, shape: eigenvals_shape},
14031404
%{tensor | names: eigenvecs_name, type: output_type, shape: eigenvecs_shape}}
14041405

1405-
Nx.block(struct!(Nx.Block.Eigh, opts), [tensor], output, fn %Nx.Block.Eigh{}, t ->
1406+
Nx.block(struct!(Nx.Block.LinAlg.Eigh, opts), [tensor], output, fn %Nx.Block.LinAlg.Eigh{},
1407+
t ->
14061408
Nx.LinAlg.BlockEigh.eigh(t, opts)
14071409
end)
14081410
|> Nx.vectorize(vectorized_axes)
@@ -1523,7 +1525,7 @@ defmodule Nx.LinAlg do
15231525
%{tensor | names: List.duplicate(nil, rank - 1), type: output_type, shape: s_shape},
15241526
%{tensor | names: List.duplicate(nil, rank), type: output_type, shape: v_shape}}
15251527

1526-
Nx.block(struct!(Nx.Block.SVD, opts), [tensor], output, fn %Nx.Block.SVD{}, t ->
1528+
Nx.block(struct!(Nx.Block.LinAlg.SVD, opts), [tensor], output, fn %Nx.Block.LinAlg.SVD{}, t ->
15271529
Nx.LinAlg.SVD.svd(t, opts)
15281530
end)
15291531
|> Nx.vectorize(vectorized_axes)
@@ -1748,7 +1750,7 @@ defmodule Nx.LinAlg do
17481750
%{tensor | type: output_type, shape: l_shape, names: names},
17491751
%{tensor | type: output_type, shape: u_shape, names: names}}
17501752

1751-
Nx.block(%Nx.Block.LU{}, [tensor], output, fn %Nx.Block.LU{}, t ->
1753+
Nx.block(%Nx.Block.LinAlg.LU{}, [tensor], output, fn %Nx.Block.LinAlg.LU{}, t ->
17521754
Nx.LinAlg.LU.lu(t)
17531755
end)
17541756
|> Nx.vectorize(vectorized_axes)
@@ -2001,18 +2003,23 @@ defmodule Nx.LinAlg do
20012003
"determinant/1 expects a square tensor, got tensor with shape: #{inspect(shape)}"
20022004
end
20032005

2004-
Nx.block(%Nx.Block.Determinant{}, [tensor], output, fn %Nx.Block.Determinant{}, t ->
2005-
case matrix_shape do
2006-
[2, 2] ->
2007-
determinant_2by2(t)
2006+
Nx.block(
2007+
%Nx.Block.LinAlg.Determinant{},
2008+
[tensor],
2009+
output,
2010+
fn %Nx.Block.LinAlg.Determinant{}, t ->
2011+
case matrix_shape do
2012+
[2, 2] ->
2013+
determinant_2by2(t)
20082014

2009-
[3, 3] ->
2010-
determinant_3by3(t)
2015+
[3, 3] ->
2016+
determinant_3by3(t)
20112017

2012-
[n, n] ->
2013-
determinant_NbyN(t, batch_shape_n: List.to_tuple(batch_shape ++ [n]))
2018+
[n, n] ->
2019+
determinant_NbyN(t, batch_shape_n: List.to_tuple(batch_shape ++ [n]))
2020+
end
20142021
end
2015-
end)
2022+
)
20162023
end)
20172024
end
20182025

nx/test/nx/defn/expr_test.exs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ defmodule Nx.Defn.ExprTest do
326326
f32[2][2]
327327
\s\s
328328
Nx.Defn.Expr
329-
parameter a:0 s32[2][2]
330-
b = block %Nx.Block.QR{eps: 1.0e-10, mode: :reduced}, a tuple2
331-
c = elem b, 0 f32[2][2]
329+
parameter a:0 s32[2][2]
330+
b = block %Nx.Block.LinAlg.QR{eps: 1.0e-10, mode: :reduced}, a tuple2
331+
c = elem b, 0 f32[2][2]
332332
>\
333333
"""
334334

@@ -337,9 +337,9 @@ defmodule Nx.Defn.ExprTest do
337337
f32[2][2]
338338
\s\s
339339
Nx.Defn.Expr
340-
parameter a:0 s32[2][2]
341-
b = block %Nx.Block.QR{eps: 1.0e-10, mode: :reduced}, a tuple2
342-
c = elem b, 1 f32[2][2]
340+
parameter a:0 s32[2][2]
341+
b = block %Nx.Block.LinAlg.QR{eps: 1.0e-10, mode: :reduced}, a tuple2
342+
c = elem b, 1 f32[2][2]
343343
>\
344344
"""
345345
end

nx/test/nx/optional_test.exs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ defmodule Nx.OptionalTest do
2727
|> Nx.backend_transfer(__MODULE__)
2828
end
2929

30-
def block(%Nx.Block.Solve{}, out, [a, b], _fun), do: solve(out, a, b)
31-
def block(%Nx.Block.Determinant{}, out, [t], _fun), do: determinant(out, t)
30+
def block(%Nx.Block.LinAlg.Solve{}, out, [a, b], _fun), do: solve(out, a, b)
31+
def block(%Nx.Block.LinAlg.Determinant{}, out, [t], _fun), do: determinant(out, t)
3232

3333
def block(struct, _output, args, fun) do
3434
apply(fun, [struct | args])
@@ -254,8 +254,8 @@ defmodule Nx.OptionalTest do
254254
f32
255255
\s\s
256256
Nx.Defn.Expr
257-
parameter a:0 s32[3][3]
258-
b = block %Nx.Block.Determinant{}, a f32
257+
parameter a:0 s32[3][3]
258+
b = block %Nx.Block.LinAlg.Determinant{}, a f32
259259
>
260260
"""
261261
end
@@ -270,8 +270,8 @@ defmodule Nx.OptionalTest do
270270
f32
271271
\s\s
272272
Nx.Defn.Expr
273-
parameter a:0 s32[3][3]
274-
b = block %Nx.Block.Determinant{}, a f32
273+
parameter a:0 s32[3][3]
274+
b = block %Nx.Block.LinAlg.Determinant{}, a f32
275275
>
276276
"""
277277
end

torchx/lib/torchx/backend.ex

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ defmodule Torchx.Backend do
5353
# For MPS device, some linear algebra operations are not supported.
5454
# Delegate to default implementation which will fall back to elementary Nx operations.
5555
mps_unsupported = [
56-
Nx.Block.LU,
57-
Nx.Block.Eigh,
58-
Nx.Block.Solve,
59-
Nx.Block.Determinant,
60-
Nx.Block.Cholesky
56+
Nx.Block.LinAlg.LU,
57+
Nx.Block.LinAlg.Eigh,
58+
Nx.Block.LinAlg.Solve,
59+
Nx.Block.LinAlg.Determinant,
60+
Nx.Block.LinAlg.Cholesky
6161
]
6262

6363
device =
@@ -70,25 +70,25 @@ defmodule Torchx.Backend do
7070
apply(fun, [struct | args])
7171
else
7272
case {block_name, args} do
73-
{Nx.Block.QR, [t]} ->
73+
{Nx.Block.LinAlg.QR, [t]} ->
7474
qr_impl(t, mode: struct.mode, eps: struct.eps)
7575

76-
{Nx.Block.LU, [t]} ->
76+
{Nx.Block.LinAlg.LU, [t]} ->
7777
lu_impl(t)
7878

79-
{Nx.Block.Eigh, [t]} ->
79+
{Nx.Block.LinAlg.Eigh, [t]} ->
8080
eigh_impl(t, max_iter: struct.max_iter, eps: struct.eps)
8181

82-
{Nx.Block.Solve, [a, b]} ->
82+
{Nx.Block.LinAlg.Solve, [a, b]} ->
8383
solve_impl(a, b)
8484

85-
{Nx.Block.Cholesky, [t]} ->
85+
{Nx.Block.LinAlg.Cholesky, [t]} ->
8686
cholesky_impl(t)
8787

88-
{Nx.Block.SVD, [t]} ->
88+
{Nx.Block.LinAlg.SVD, [t]} ->
8989
svd_impl(t, max_iter: struct.max_iter, full_matrices?: struct.full_matrices?)
9090

91-
{Nx.Block.Determinant, [t]} ->
91+
{Nx.Block.LinAlg.Determinant, [t]} ->
9292
determinant_impl(t)
9393

9494
{Nx.Block.TakeAlongAxis, [tensor, indices]} ->

0 commit comments

Comments
 (0)