Skip to content

Commit d528827

Browse files
Chapamanpolvalente
andauthored
Basic Nx.block implementation (#1709)
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent 1a1717d commit d528827

29 files changed

Lines changed: 617 additions & 326 deletions

exla/lib/exla/backend.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,12 @@ defmodule EXLA.Backend do
318318
end
319319

320320
@impl true
321-
def optional(name, args, fun) do
321+
def block(struct, _output, args, fun) do
322322
# Here we take the leading tensor arguments and pass them as JIT arguments
323323
{tensors, rest} = Enum.split_while(args, &is_struct(&1, Nx.Tensor))
324324

325325
wrapper_fun = fn tensors ->
326-
Nx.Defn.Expr.optional(name, Tuple.to_list(tensors) ++ rest, fun)
326+
Nx.Defn.Expr.block(struct, nil, Tuple.to_list(tensors) ++ rest, fun)
327327
end
328328

329329
jit([], wrapper_fun, tensors, [List.to_tuple(tensors)])

exla/lib/exla/defn.ex

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -601,11 +601,12 @@ defmodule EXLA.Defn do
601601
end
602602

603603
defp cached_recur_operator(
604-
:optional,
604+
:block,
605605
%T{
606606
data: %Expr{
607607
args: [
608-
%{data: %{op: :qr, args: [tensor, _opts]}},
608+
%Nx.Block.QR{},
609+
[tensor],
609610
{%{type: {type_kind, _}} = q_expr, r_expr},
610611
_callback
611612
]
@@ -631,11 +632,12 @@ defmodule EXLA.Defn do
631632
end
632633

633634
defp cached_recur_operator(
634-
:optional,
635+
:block,
635636
%T{
636637
data: %Expr{
637638
args: [
638-
%{data: %{op: :eigh, args: [tensor, _opts]}},
639+
%Nx.Block.Eigh{},
640+
[tensor],
639641
{%{type: {evec_type_kind, _}} = eigenvals_expr,
640642
%{type: {eval_type_kind, _}} = eigenvecs_expr},
641643
_callback
@@ -672,16 +674,15 @@ defmodule EXLA.Defn do
672674
end
673675

674676
defp cached_recur_operator(
675-
:optional,
677+
:block,
676678
%T{
677679
data: %Expr{
678-
args: [%{data: %{op: :take, args: [tensor, indices, opts]}}, expr, _callback]
680+
args: [%Nx.Block.Take{axis: axis}, [tensor, indices], expr, _callback]
679681
}
680682
},
681683
state,
682684
cache
683685
) do
684-
axis = opts[:axis]
685686
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
686687
{indices, cache} = recur_operator(indices, state, cache) |> unwrap_single_tensor!()
687688

@@ -714,56 +715,80 @@ defmodule EXLA.Defn do
714715
end
715716

716717
defp cached_recur_operator(
717-
:optional,
718-
%T{data: %Expr{args: [%{data: %{op: :top_k, args: [tensor, opts]}}, expr, _callback]}},
718+
:block,
719+
%T{data: %Expr{args: [%Nx.Block.TopK{k: k}, [tensor], expr, _callback]}},
719720
state,
720721
cache
721722
) do
722723
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
723724
{values, idx} = expr
724725
typespecs = [expr_to_typespec(values), expr_to_typespec(idx)]
725-
results = Value.top_k(tensor, opts[:k], typespecs)
726+
results = Value.top_k(tensor, k, typespecs)
726727
{results, cache}
727728
end
728729

729730
defp cached_recur_operator(
730-
:optional,
731-
%T{data: %Expr{args: [%{data: %{op: :fft2, args: [tensor, opts]}}, expr, _callback]}},
731+
:block,
732+
%T{data: %Expr{args: [%Nx.Block.FFT2{} = fft2_struct, [tensor], expr, _callback]}},
732733
state,
733734
cache
734735
) do
735736
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
736737

738+
opts = [lengths: fft2_struct.lengths, axes: fft2_struct.axes]
739+
740+
opts =
741+
if eps = fft2_struct.eps do
742+
Keyword.put(opts, :eps, eps)
743+
else
744+
opts
745+
end
746+
737747
{fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr, state), cache}
738748
end
739749

740750
defp cached_recur_operator(
741-
:optional,
742-
%T{data: %Expr{args: [%{data: %{op: :ifft2, args: [tensor, opts]}}, expr, _callback]}},
751+
:block,
752+
%T{data: %Expr{args: [%Nx.Block.IFFT2{} = ifft2_struct, [tensor], expr, _callback]}},
743753
state,
744754
cache
745755
) do
746756
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
747757

758+
opts = [lengths: ifft2_struct.lengths, axes: ifft2_struct.axes]
759+
760+
opts =
761+
if eps = ifft2_struct.eps do
762+
Keyword.put(opts, :eps, eps)
763+
else
764+
opts
765+
end
766+
748767
{fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache}
749768
end
750769

751-
defp cached_recur_operator(:optional, %T{data: %Expr{args: args}}, state, cache) do
752-
[call, expr, _callback] = args
753-
%{data: %{args: in_args, op: op}} = call
754-
755-
{args, opts} = Enum.split_while(in_args, &(not is_list(&1)))
770+
defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do
771+
[struct, in_args, expr, _callback] = args
772+
%module{} = struct
756773

757-
{call_args, cache} = Enum.map_reduce(args, cache, &recur_operator(&1, state, &2))
758-
key = computation_key(op, call_args ++ opts)
774+
{call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2))
775+
key = computation_key(module, [struct | call_args])
759776

760777
{call_body, cache} =
761778
case cache do
762779
%{^key => computation} ->
763780
{computation, cache}
764781

765782
%{} ->
766-
{computation, cache} = optional_computation("optional", call_args, expr, state, cache)
783+
{computation, cache} =
784+
block_computation(
785+
block_subfunction_description(struct),
786+
call_args,
787+
expr,
788+
state,
789+
cache
790+
)
791+
767792
{computation, Map.put(cache, key, computation)}
768793
end
769794

@@ -1818,8 +1843,14 @@ defmodule EXLA.Defn do
18181843
{region, merge_outfeed(cache, comp_cache)}
18191844
end
18201845

1821-
defp optional_computation(name, args, expr, %{builder: %Function{}} = state, cache) do
1822-
%Function{module: module, name: name} = subbuilder(state.builder, name)
1846+
defp block_subfunction_description(%module{} = _) do
1847+
module
1848+
|> Atom.to_string()
1849+
|> String.replace(".", "_")
1850+
end
1851+
1852+
defp block_computation(description, args, expr, %{builder: %Function{}} = state, cache) do
1853+
%Function{module: module, name: name} = subbuilder(state.builder, description)
18231854

18241855
arg_typespecs = Enum.map(args, &Value.get_typespec/1)
18251856
out_typespecs = container_to_typespecs(expr)

exla/mix.lock

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
"makeup_erlang": {:hex, :makeup_erlang, "1.0.3", "4252d5d4098da7415c390e847c814bad3764c94a814a0b4245176215615e1035", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "953297c02582a33411ac6208f2c6e55f0e870df7f80da724ed613f10e6706afd"},
1212
"nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"},
1313
"nimble_pool": {:hex, :nimble_pool, "1.0.0", "5eb82705d138f4dd4423f69ceb19ac667b3b492ae570c9f5c900bb3d2f50a847", [:mix], [], "hexpm", "80be3b882d2d351882256087078e1b1952a28bf98d0a287be87e4a24a710b67a"},
14-
"nx": {:hex, :nx, "0.11.0", "d37723dbd6cfa274a5def6d6664f5680c32e2eb8a1ce25ec6d91751967fa0abf", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "36157b21239aeb251d6cbac23eb0eb3495a5e1e0cbc2e6df16afd2ede1575205"},
1514
"statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"},
1615
"telemetry": {:hex, :telemetry, "1.4.1", "ab6de178e2b29b58e8256b92b382ea3f590a47152ca3651ea857a6cae05ac423", [:rebar3], [], "hexpm", "2172e05a27531d3d31dd9782841065c50dd5c3c7699d95266b2edd54c2dafa1c"},
1716
"xla": {:hex, :xla, "0.10.0", "41121e9f011456242d3a79b9289910ce43419be0b0e7ebe67cc1292c6b3f232f", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f57d91aea6e661b52bf12239316c598679e9170628122bbd941235f040122bc6"},

exla/test/exla/nx_linalg_doctest_test.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ defmodule EXLA.NxLinAlgDoctestTest do
2525
least_squares: 3,
2626
determinant: 1,
2727
matrix_power: 2,
28-
lu: 2,
28+
lu: 1,
2929
qr: 2
3030
]
3131

0 commit comments

Comments
 (0)