From 014eb649addc5eeb1173fabeaa2f4977d6accb34 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 20 Apr 2026 19:53:24 -0300 Subject: [PATCH 01/11] Refactor EXLA block lowering through EXLA.CustomCall protocol --- exla/lib/exla/custom_call.ex | 193 ++++++++++++++++++++ exla/lib/exla/defn.ex | 329 +++++++---------------------------- 2 files changed, 255 insertions(+), 267 deletions(-) create mode 100644 exla/lib/exla/custom_call.ex diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex new file mode 100644 index 0000000000..997ad4d47b --- /dev/null +++ b/exla/lib/exla/custom_call.ex @@ -0,0 +1,193 @@ +defprotocol EXLA.CustomCall do + @moduledoc """ + Protocol used by `EXLA.Defn` to lower specific `Nx.block/4` tags natively + instead of compiling the fallback callback. + + Implementations receive the block tag struct, the output template (`out`), + the already-recursed MLIR `EXLA.MLIR.Value` arguments and the active + `EXLA.Client`. + """ + + @fallback_to_any true + + @doc """ + Returns `true` when EXLA should lower the block natively via `call/4`. + + When it returns `false`, `EXLA.Defn` falls back to compiling the block's + default callback implementation. + """ + def apply?(struct, out, args, client) + + @fallback_to_any true + + @doc """ + Lowers the block natively. + + Must return the list of `EXLA.MLIR.Value`s (or a single value) that + represents the block result, matching the shape of `out`. + """ + def call(struct, out, args, client) +end + +defimpl EXLA.CustomCall, for: Any do + alias EXLA.MLIR.Value + alias EXLA.Defn + + # --- apply?/4 --- + + def apply?( + %Nx.Block.LinAlg.QR{}, + {%{type: {q_type_kind, _}}, _r}, + _args, + client + ) do + q_type_kind != :c and client.platform == :host + end + + def apply?( + %Nx.Block.LinAlg.Eigh{}, + {%{type: {eval_type_kind, _}}, %{type: {evec_type_kind, _}}}, + _args, + client + ) do + eval_type_kind != :c and evec_type_kind != :c and client.platform == :host + end + + def apply?(%Nx.Block.Take{}, _out, _args, _client), do: true + def apply?(%Nx.Block.TopK{}, _out, _args, _client), do: true + def apply?(%Nx.Block.FFT2{}, _out, _args, _client), do: true + def apply?(%Nx.Block.IFFT2{}, _out, _args, _client), do: true + def apply?(%Nx.Block.RFFT{}, _out, _args, _client), do: true + def apply?(%Nx.Block.IRFFT{}, _out, _args, _client), do: true + + def apply?(_, _, _, _), do: false + + # --- call/4 --- + + def call(%Nx.Block.LinAlg.QR{}, {q_expr, r_expr}, [tensor], _client) do + tensor = + if Defn.op_type(tensor) != q_expr.type do + Defn.to_type(tensor, q_expr.type) + else + tensor + end + + {q, r} = Value.qr(tensor, Defn.expr_to_typespec(q_expr), Defn.expr_to_typespec(r_expr)) + [q, r] + end + + def call( + %Nx.Block.LinAlg.Eigh{}, + {eigenvals_expr, eigenvecs_expr}, + [tensor], + _client + ) do + # Eigen only supports f32/f64, so promote to the smallest floating type + # wide enough to represent the requested output. + out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) + + tensor = + if Defn.op_type(tensor) != out_type do + Defn.to_type(tensor, out_type) + else + tensor + end + + {eigenvals, eigenvecs} = + Value.eigh( + tensor, + Defn.expr_to_typespec(%{eigenvals_expr | type: out_type}), + Defn.expr_to_typespec(%{eigenvecs_expr | type: out_type}) + ) + + [ + Defn.to_type(eigenvals, eigenvals_expr.type), + Defn.to_type(eigenvecs, eigenvecs_expr.type) + ] + end + + def call(%Nx.Block.Take{axis: axis}, expr, [tensor, indices], _client) do + tensor_shape = Defn.op_shape(tensor) + tensor_rank = tuple_size(tensor_shape) + indices_rank = indices |> Defn.op_shape() |> tuple_size() + result_rank = tensor_rank - 1 + indices_rank + + index_vector_dim = indices_rank + slice_sizes = tensor_shape |> put_elem(axis, 1) |> Tuple.to_list() + + {left, right} = result_rank |> Defn.axes_for_rank() |> Enum.split(axis) + offset_dims = left ++ Enum.drop(right, indices_rank) + + collapsed_slice_dims = [axis] + start_index_map = [axis] + + Value.gather( + tensor, + indices, + index_vector_dim, + slice_sizes, + offset_dims, + collapsed_slice_dims, + start_index_map, + Defn.expr_to_typespec(expr) + ) + end + + def call(%Nx.Block.TopK{k: k}, {values, idx}, [tensor], _client) do + typespecs = [Defn.expr_to_typespec(values), Defn.expr_to_typespec(idx)] + Value.top_k(tensor, k, typespecs) + end + + def call(%Nx.Block.FFT2{} = struct, expr, [tensor], _client) do + Defn.fft2(&Value.fft(&1, :fft, &2, &3), [tensor, fft2_opts(struct)], expr) + end + + def call(%Nx.Block.IFFT2{} = struct, expr, [tensor], _client) do + Defn.fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, fft2_opts(struct)], expr) + end + + def call(%Nx.Block.RFFT{} = struct, expr, [tensor], _client) do + # expr.type is complex; input tensor is real. + input_type = Nx.Type.to_real(expr.type) + + Defn.fft( + &Value.fft(&1, :rfft, &2, &3), + input_type, + expr.type, + [tensor, fft_opts(struct)], + expr + ) + end + + def call(%Nx.Block.IRFFT{} = struct, expr, [tensor], _client) do + # expr.type is real; input tensor is complex. The expected input length is + # div(n, 2) + 1 (pad_n) while the output length is n (fft_n). + n = struct.length + input_type = Nx.Type.to_complex(expr.type) + + Defn.fft( + &Value.fft(&1, :irfft, &2, &3), + input_type, + expr.type, + div(n, 2) + 1, + [tensor, fft_opts(struct)], + expr + ) + end + + def call(struct, _out, _args, _client) do + raise ArgumentError, + "EXLA.CustomCall.call/4 is not implemented for #{inspect(struct)}. " <> + "Did you forget to guard with EXLA.CustomCall.apply?/4?" + end + + defp fft_opts(%{length: length, axis: axis, eps: eps}) do + opts = [length: length, axis: axis] + if eps, do: Keyword.put(opts, :eps, eps), else: opts + end + + defp fft2_opts(%{lengths: lengths, axes: axes, eps: eps}) do + opts = [lengths: lengths, axes: axes] + if eps, do: Keyword.put(opts, :eps, eps), else: opts + end +end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 602b4972c6..aafbca1cd4 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -602,261 +602,16 @@ defmodule EXLA.Defn do defp cached_recur_operator( :block, - %T{ - data: %Expr{ - args: [ - %Nx.Block.LinAlg.QR{}, - [tensor], - {%{type: {type_kind, _}} = q_expr, r_expr}, - _callback - ] - } - }, - %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, - cache - ) - when type_kind != :c do - # We match only on platform: :host for MLIR, as we want to support - # QR-on-cpu as a custom call only in this case - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - tensor = - if op_type(tensor) != q_expr.type do - to_type(tensor, q_expr.type) - else - tensor - end - - {q, r} = Value.qr(tensor, expr_to_typespec(q_expr), expr_to_typespec(r_expr)) - {[q, r], cache} - end - - defp cached_recur_operator( - :block, - %T{ - data: %Expr{ - args: [ - %Nx.Block.LinAlg.Eigh{}, - [tensor], - {%{type: {evec_type_kind, _}} = eigenvals_expr, - %{type: {eval_type_kind, _}} = eigenvecs_expr}, - _callback - ] - } - }, - %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, - cache - ) - when evec_type_kind != :c and eval_type_kind != :c do - # We match only on platform: :host for MLIR, as we want to support - # eigh-on-cpu as a custom call only in this case - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - # convert to float and ensure that we're either using f32 or f64, because Eigen - # only supports f32 and f64 easily. - out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) - - tensor = - if op_type(tensor) != out_type do - to_type(tensor, out_type) - else - tensor - end - - {eigenvals, eigenvecs} = - Value.eigh( - tensor, - expr_to_typespec(%{eigenvals_expr | type: out_type}), - expr_to_typespec(%{eigenvecs_expr | type: out_type}) - ) - - {[to_type(eigenvals, eigenvals_expr.type), to_type(eigenvecs, eigenvecs_expr.type)], cache} - end - - defp cached_recur_operator( - :block, - %T{ - data: %Expr{ - args: [%Nx.Block.Take{axis: axis}, [tensor, indices], expr, _callback] - } - }, - state, - cache - ) do - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - {indices, cache} = recur_operator(indices, state, cache) |> unwrap_single_tensor!() - - tensor_rank = tensor |> op_shape() |> tuple_size() - indices_rank = indices |> op_shape() |> tuple_size() - result_rank = tensor_rank - 1 + indices_rank - - index_vector_dim = indices_rank - slice_sizes = tensor |> op_shape() |> put_elem(axis, 1) |> Tuple.to_list() - - {left, right} = result_rank |> axes_for_rank() |> Enum.split(axis) - offset_dims = left ++ Enum.drop(right, indices_rank) - - collapsed_slice_dims = [axis] - start_index_map = [axis] - - result = - Value.gather( - tensor, - indices, - index_vector_dim, - slice_sizes, - offset_dims, - collapsed_slice_dims, - start_index_map, - expr_to_typespec(expr) - ) - - {result, cache} - end - - defp cached_recur_operator( - :block, - %T{data: %Expr{args: [%Nx.Block.TopK{k: k}, [tensor], expr, _callback]}}, - state, - cache - ) do - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - {values, idx} = expr - typespecs = [expr_to_typespec(values), expr_to_typespec(idx)] - results = Value.top_k(tensor, k, typespecs) - {results, cache} - end - - defp cached_recur_operator( - :block, - %T{data: %Expr{args: [%Nx.Block.FFT2{} = fft2_struct, [tensor], expr, _callback]}}, - state, - cache - ) do - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - opts = [lengths: fft2_struct.lengths, axes: fft2_struct.axes] - - opts = - if eps = fft2_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - {fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr, state), cache} - end - - defp cached_recur_operator( - :block, - %T{data: %Expr{args: [%Nx.Block.IFFT2{} = ifft2_struct, [tensor], expr, _callback]}}, - state, - cache - ) do - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - opts = [lengths: ifft2_struct.lengths, axes: ifft2_struct.axes] - - opts = - if eps = ifft2_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache} - end - - defp cached_recur_operator( - :block, - %T{data: %Expr{args: [%Nx.Block.RFFT{} = rfft_struct, [tensor], expr, _callback]}}, - state, + %T{data: %Expr{args: [struct, in_args, out, _callback]}}, + %{client: client, builder: %Function{}} = state, cache ) do - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - opts = [length: rfft_struct.length, axis: rfft_struct.axis] - - opts = - if eps = rfft_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - # expr.type is complex; input tensor is real - input_type = Nx.Type.to_real(expr.type) - - {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state), - cache} - end - - defp cached_recur_operator( - :block, - %T{data: %Expr{args: [%Nx.Block.IRFFT{} = irfft_struct, [tensor], expr, _callback]}}, - state, - cache - ) do - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - opts = [length: irfft_struct.length, axis: irfft_struct.axis] - - opts = - if eps = irfft_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - # expr.type is real; input tensor is complex. - # pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length). - n = irfft_struct.length - input_type = Nx.Type.to_complex(expr.type) - - {fft( - &Value.fft(&1, :irfft, &2, &3), - input_type, - expr.type, - div(n, 2) + 1, - [tensor, opts], - expr, - state - ), cache} - end - - defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do - [struct, in_args, expr, _callback] = args - %module{} = struct - {call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2)) - key = computation_key(module, [struct | call_args]) - - {call_body, cache} = - case cache do - %{^key => computation} -> - {computation, cache} - - %{} -> - {computation, cache} = - block_computation( - block_subfunction_description(struct), - call_args, - expr, - state, - cache - ) - {computation, Map.put(cache, key, computation)} - end - - if token = get_token(cache) do - typespecs = [Typespec.token() | container_to_typespecs(expr)] - [token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs) - {wrap_tuple_result(result, expr), update_token(cache, token)} + if EXLA.CustomCall.apply?(struct, out, call_args, client) do + {EXLA.CustomCall.call(struct, out, call_args, client), cache} else - typespecs = container_to_typespecs(expr) - result = Value.call(state.builder, call_args, call_body, typespecs) - {wrap_tuple_result(result, expr), cache} + default_block_implementation(struct, call_args, out, state, cache) end end @@ -998,6 +753,39 @@ defmodule EXLA.Defn do {to_operator(op, args, expr, state), cache} end + defp default_block_implementation(struct, call_args, expr, state, cache) do + %module{} = struct + key = computation_key(module, [struct | call_args]) + + {call_body, cache} = + case cache do + %{^key => computation} -> + {computation, cache} + + %{} -> + {computation, cache} = + block_computation( + block_subfunction_description(struct), + call_args, + expr, + state, + cache + ) + + {computation, Map.put(cache, key, computation)} + end + + if token = get_token(cache) do + typespecs = [Typespec.token() | container_to_typespecs(expr)] + [token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs) + {wrap_tuple_result(result, expr), update_token(cache, token)} + else + typespecs = container_to_typespecs(expr) + result = Value.call(state.builder, call_args, call_body, typespecs) + {wrap_tuple_result(result, expr), cache} + end + end + ## to_operator creation defp to_operator(:constant, [constant], ans, state) do @@ -1289,11 +1077,11 @@ defmodule EXLA.Defn do apply(Value, op, [to_type(arg, type), expr_to_typespec(ans)]) end - defp to_operator(:fft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out, state) + defp to_operator(:fft, [%Value{} | _] = args, out, _state), + do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out) - defp to_operator(:ifft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out, state) + defp to_operator(:ifft, [%Value{} | _] = args, out, _state), + do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out) defp to_operator(:is_nan, [%Value{} = arg], out, _state), do: Value.is_nan(arg, expr_to_typespec(out)) @@ -1618,7 +1406,8 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end - defp fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans, state) do + @doc false + def fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans) do fft_n = opts[:length] pad_n = pad_n || fft_n axis = opts[:axis] @@ -1627,7 +1416,7 @@ defmodule EXLA.Defn do shape = op_shape(tensor) m = elem(shape, axis) - tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type, state) + tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type) last_axis = tuple_size(shape) - 1 @@ -1662,7 +1451,8 @@ defmodule EXLA.Defn do end end - defp fft2(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do + @doc false + def fft2(exla_op, [%Value{} = tensor, opts], %{type: type} = ans) do [l1, l2] = lengths = opts[:lengths] [ax1, ax2] = axes = opts[:axes] output_type = Nx.Type.to_complex(type) @@ -1672,8 +1462,8 @@ defmodule EXLA.Defn do m1 = elem(shape, ax1) m2 = elem(shape, ax2) - tensor = fft_pad_or_slice(tensor, m1, l1, ax1, shape, output_type, state) - tensor = fft_pad_or_slice(tensor, m2, l2, ax2, op_shape(tensor), output_type, state) + tensor = fft_pad_or_slice(tensor, m1, l1, ax1, shape, output_type) + tensor = fft_pad_or_slice(tensor, m2, l2, ax2, op_shape(tensor), output_type) last_axis = tuple_size(shape) - 1 penultimate_axis = last_axis - 1 @@ -1701,7 +1491,7 @@ defmodule EXLA.Defn do end end - defp fft_pad_or_slice(tensor, m, n, axis, shape, output_type, state) do + defp fft_pad_or_slice(%Value{function: builder} = tensor, m, n, axis, shape, output_type) do cond do m == n -> tensor @@ -1726,7 +1516,7 @@ defmodule EXLA.Defn do zero_value = if Nx.Type.complex?(output_type), do: Complex.new(0), else: 0 zero = - Value.constant(state.builder, [zero_value], Typespec.tensor(output_type, {})) + Value.constant(builder, [zero_value], Typespec.tensor(output_type, {})) padding_config = {0, 0, 0} @@ -2251,19 +2041,23 @@ defmodule EXLA.Defn do defp count_up(0, _n), do: [] defp count_up(i, n), do: [n | count_up(i - 1, n + 1)] - defp axes_for_rank(0), do: [] + @doc false + def axes_for_rank(0), do: [] - defp axes_for_rank(rank) do + def axes_for_rank(rank) do Enum.to_list(0..(rank - 1)) end ## Op Helpers - defp op_type(%Value{} = op), do: Value.get_typespec(op).type + @doc false + def op_type(%Value{} = op), do: Value.get_typespec(op).type - defp op_shape(%Value{} = op), do: Value.get_typespec(op).shape + @doc false + def op_shape(%Value{} = op), do: Value.get_typespec(op).shape - defp to_type(%Value{} = op, type) do + @doc false + def to_type(%Value{} = op, type) do typespec = Value.get_typespec(op) if typespec.type == type do @@ -2357,7 +2151,8 @@ defmodule EXLA.Defn do |> Enum.reduce(shape, &Tuple.delete_at(&2, &1)) end - defp expr_to_typespec(expr) do + @doc false + def expr_to_typespec(expr) do Typespec.tensor(expr.type, expr.shape) end From 5db679f9efdc273d1062dc1d49cd2ac6b5ee40a3 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Fri, 24 Apr 2026 18:05:55 -0300 Subject: [PATCH 02/11] update EXLA.CustomCall to handle C-backed Nx.block tags (QR, Eigh) --- exla/lib/exla/custom_call.ex | 107 ++++++----------------------- exla/lib/exla/defn.ex | 126 +++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 88 deletions(-) diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex index 997ad4d47b..3bca4285fa 100644 --- a/exla/lib/exla/custom_call.ex +++ b/exla/lib/exla/custom_call.ex @@ -1,11 +1,21 @@ defprotocol EXLA.CustomCall do @moduledoc """ - Protocol used by `EXLA.Defn` to lower specific `Nx.block/4` tags natively - instead of compiling the fallback callback. + Protocol used by `EXLA.Defn` to lower specific `Nx.block/4` tags that are + implemented as **XLA/StableHLO custom calls into native (C/C++) code** — + the same pipeline as `EXLA.MLIR.Value` helpers such as `qr/3` and `eigh/3`. + + Other blocks (for example gather-based take or plain StableHLO FFT) stay + inlined in `EXLA.Defn` so this protocol stays focused on what those paths + share: `stablehlo.custom_call` plus registration of the callee. Implementations receive the block tag struct, the output template (`out`), the already-recursed MLIR `EXLA.MLIR.Value` arguments and the active `EXLA.Client`. + + Built-in lowerings for those tags live in a single `defimpl ..., for: Any` + module (see comment there). Applications and libraries can still supply a + **more specific** `defimpl EXLA.CustomCall, for: TheirStruct` — Elixir will + use that instead of the `Any` fallback when the block tag matches. """ @fallback_to_any true @@ -29,6 +39,13 @@ defprotocol EXLA.CustomCall do def call(struct, out, args, client) end +# Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live +# in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the +# protocol, applications and libraries can define their own +# `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that +# implementation instead of this fallback when the block tag matches (you can +# also target a built-in struct such as `Nx.Block...` from your app if needed). +# defimpl EXLA.CustomCall, for: Any do alias EXLA.MLIR.Value alias EXLA.Defn @@ -53,13 +70,6 @@ defimpl EXLA.CustomCall, for: Any do eval_type_kind != :c and evec_type_kind != :c and client.platform == :host end - def apply?(%Nx.Block.Take{}, _out, _args, _client), do: true - def apply?(%Nx.Block.TopK{}, _out, _args, _client), do: true - def apply?(%Nx.Block.FFT2{}, _out, _args, _client), do: true - def apply?(%Nx.Block.IFFT2{}, _out, _args, _client), do: true - def apply?(%Nx.Block.RFFT{}, _out, _args, _client), do: true - def apply?(%Nx.Block.IRFFT{}, _out, _args, _client), do: true - def apply?(_, _, _, _), do: false # --- call/4 --- @@ -106,88 +116,9 @@ defimpl EXLA.CustomCall, for: Any do ] end - def call(%Nx.Block.Take{axis: axis}, expr, [tensor, indices], _client) do - tensor_shape = Defn.op_shape(tensor) - tensor_rank = tuple_size(tensor_shape) - indices_rank = indices |> Defn.op_shape() |> tuple_size() - result_rank = tensor_rank - 1 + indices_rank - - index_vector_dim = indices_rank - slice_sizes = tensor_shape |> put_elem(axis, 1) |> Tuple.to_list() - - {left, right} = result_rank |> Defn.axes_for_rank() |> Enum.split(axis) - offset_dims = left ++ Enum.drop(right, indices_rank) - - collapsed_slice_dims = [axis] - start_index_map = [axis] - - Value.gather( - tensor, - indices, - index_vector_dim, - slice_sizes, - offset_dims, - collapsed_slice_dims, - start_index_map, - Defn.expr_to_typespec(expr) - ) - end - - def call(%Nx.Block.TopK{k: k}, {values, idx}, [tensor], _client) do - typespecs = [Defn.expr_to_typespec(values), Defn.expr_to_typespec(idx)] - Value.top_k(tensor, k, typespecs) - end - - def call(%Nx.Block.FFT2{} = struct, expr, [tensor], _client) do - Defn.fft2(&Value.fft(&1, :fft, &2, &3), [tensor, fft2_opts(struct)], expr) - end - - def call(%Nx.Block.IFFT2{} = struct, expr, [tensor], _client) do - Defn.fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, fft2_opts(struct)], expr) - end - - def call(%Nx.Block.RFFT{} = struct, expr, [tensor], _client) do - # expr.type is complex; input tensor is real. - input_type = Nx.Type.to_real(expr.type) - - Defn.fft( - &Value.fft(&1, :rfft, &2, &3), - input_type, - expr.type, - [tensor, fft_opts(struct)], - expr - ) - end - - def call(%Nx.Block.IRFFT{} = struct, expr, [tensor], _client) do - # expr.type is real; input tensor is complex. The expected input length is - # div(n, 2) + 1 (pad_n) while the output length is n (fft_n). - n = struct.length - input_type = Nx.Type.to_complex(expr.type) - - Defn.fft( - &Value.fft(&1, :irfft, &2, &3), - input_type, - expr.type, - div(n, 2) + 1, - [tensor, fft_opts(struct)], - expr - ) - end - def call(struct, _out, _args, _client) do raise ArgumentError, "EXLA.CustomCall.call/4 is not implemented for #{inspect(struct)}. " <> "Did you forget to guard with EXLA.CustomCall.apply?/4?" end - - defp fft_opts(%{length: length, axis: axis, eps: eps}) do - opts = [length: length, axis: axis] - if eps, do: Keyword.put(opts, :eps, eps), else: opts - end - - defp fft2_opts(%{lengths: lengths, axes: axes, eps: eps}) do - opts = [lengths: lengths, axes: axes] - if eps, do: Keyword.put(opts, :eps, eps), else: opts - end end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index aafbca1cd4..05d7568675 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -600,6 +600,132 @@ defmodule EXLA.Defn do {fun_computation(args, expr, type, state), cache} end + # StableHLO-style lowering (gather, top_k, fft): not the C custom_call path; + # see `EXLA.CustomCall` for blocks that delegate to native CPU kernels. + + defp cached_recur_operator( + :block, + %T{ + data: %Expr{ + args: [%Nx.Block.Take{axis: axis}, [tensor, indices], expr, _callback] + } + }, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + {indices, cache} = recur_operator(indices, state, cache) |> unwrap_single_tensor!() + + tensor_rank = tensor |> op_shape() |> tuple_size() + indices_rank = indices |> op_shape() |> tuple_size() + result_rank = tensor_rank - 1 + indices_rank + + index_vector_dim = indices_rank + slice_sizes = tensor |> op_shape() |> put_elem(axis, 1) |> Tuple.to_list() + + {left, right} = result_rank |> axes_for_rank() |> Enum.split(axis) + offset_dims = left ++ Enum.drop(right, indices_rank) + + collapsed_slice_dims = [axis] + start_index_map = [axis] + + result = + Value.gather( + tensor, + indices, + index_vector_dim, + slice_sizes, + offset_dims, + collapsed_slice_dims, + start_index_map, + expr_to_typespec(expr) + ) + + {result, cache} + end + + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.TopK{k: k}, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + {values, idx} = expr + typespecs = [expr_to_typespec(values), expr_to_typespec(idx)] + results = Value.top_k(tensor, k, typespecs) + {results, cache} + end + + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.FFT2{} = fft2_struct, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + opts = [lengths: fft2_struct.lengths, axes: fft2_struct.axes] + opts = if eps = fft2_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts + + {fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr), cache} + end + + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.IFFT2{} = ifft2_struct, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + opts = [lengths: ifft2_struct.lengths, axes: ifft2_struct.axes] + opts = if eps = ifft2_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts + + {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr), cache} + end + + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.RFFT{} = rfft_struct, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + opts = [length: rfft_struct.length, axis: rfft_struct.axis] + opts = if eps = rfft_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts + + input_type = Nx.Type.to_real(expr.type) + + {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr), cache} + end + + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.IRFFT{} = irfft_struct, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + opts = [length: irfft_struct.length, axis: irfft_struct.axis] + opts = if eps = irfft_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts + + n = irfft_struct.length + input_type = Nx.Type.to_complex(expr.type) + + {fft( + &Value.fft(&1, :irfft, &2, &3), + input_type, + expr.type, + div(n, 2) + 1, + [tensor, opts], + expr + ), cache} + end + + # C-backed custom_call blocks (QR, Eigh, …): `EXLA.CustomCall`; else compile default callback. defp cached_recur_operator( :block, %T{data: %Expr{args: [struct, in_args, out, _callback]}}, From 4011521c06322999903e19cc5b7253c347aadfd2 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Fri, 24 Apr 2026 23:30:00 -0300 Subject: [PATCH 03/11] test(exla): add QR FFI alias plugin + dlopen NIF and MLIR/JIT coverage for custom_call targets --- exla/Makefile | 15 +++ exla/c_src/exla/exla.cc | 16 +++ .../exla_test_plugin/qr_alias_registration.cc | 16 +++ exla/lib/exla/mlir/value.ex | 17 ++++ exla/lib/exla/nif.ex | 1 + exla/mix.exs | 3 +- exla/test/exla/custom_call_alias_test.exs | 98 +++++++++++++++++++ exla/test/support/exla_test_qr_alias_block.ex | 37 +++++++ 8 files changed, 202 insertions(+), 1 deletion(-) create mode 100644 exla/c_src/exla_test_plugin/qr_alias_registration.cc create mode 100644 exla/test/exla/custom_call_alias_test.exs create mode 100644 exla/test/support/exla_test_qr_alias_block.ex diff --git a/exla/Makefile b/exla/Makefile index b875ef59f7..f6f43ed58e 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -86,6 +86,21 @@ else LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib' endif +# Optional test plugin: registers qr_cpu_custom_call_f32_exla_alias -> same +# handler as qr_cpu_custom_call_f32 (Mix sets BUILD_EXLA_TEST_PLUGIN=1 in :test). +TEST_PLUGIN_CC = c_src/exla_test_plugin/qr_alias_registration.cc +TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias_plugin.so + +ifneq ($(BUILD_EXLA_TEST_PLUGIN),) +ifneq ($(BUILD_EXLA_TEST_PLUGIN),0) +$(EXLA_SO): $(TEST_PLUGIN_SO) +endif +endif + +$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) + @ mkdir -p $(dir $@) + $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) + $(EXLA_SO): $(EXLA_CACHE_SO) @ mkdir -p $(PRIV_DIR) @ mkdir -p $(PRIV_DIR)/xla_extension diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index c4e9085833..ce68245aa8 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -535,6 +537,20 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type, FINE_NIF(load_pjrt_plugin, 0); +// Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run. +// Used from tests (e.g. alias custom-call registration plugin). +fine::Ok<> dlopen_test_plugin(ErlNifEnv *env, std::string path) { + void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + const char *err = dlerror(); + throw std::invalid_argument(err ? err : "dlopen failed"); + } + (void)handle; + return fine::Ok(); +} + +FINE_NIF(dlopen_test_plugin, 0); + int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr client) { return client->client()->device_count(); } diff --git a/exla/c_src/exla_test_plugin/qr_alias_registration.cc b/exla/c_src/exla_test_plugin/qr_alias_registration.cc new file mode 100644 index 0000000000..1e80045e54 --- /dev/null +++ b/exla/c_src/exla_test_plugin/qr_alias_registration.cc @@ -0,0 +1,16 @@ +// Test-only shared library: registers an alias FFI name that reuses the +// existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so. +// +// Built when BUILD_EXLA_TEST_PLUGIN=1 (Mix test). Load with RTLD_GLOBAL via +// EXLA.NIF.dlopen_test_plugin/1 before compiling or running graphs that emit +// the alias call_target_name. + +#include "xla/ffi/api/api.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; + +extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias", + "Host", qr_cpu_custom_call_f32); diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 9d028ff6dd..a15d9f6bc7 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -785,6 +785,23 @@ defmodule EXLA.MLIR.Value do {q, r} end + @doc false + def qr_with_call_target(%Value{function: func} = value, q_typespec, r_typespec, call_target_name) + when is_binary(call_target_name) do + operands = [value] + result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) + + attributes = [ + call_target_name: attr_string(call_target_name), + api_version: attr_i32(4) + ] + + [q, r] = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + + {q, r} + end + def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do %{type: op_type} = get_typespec(value) diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 2a0a99f1ef..df0af731cf 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -79,6 +79,7 @@ defmodule EXLA.NIF do def get_tpu_client(), do: err!() def get_c_api_client(_device_type), do: err!() def load_pjrt_plugin(_device_type, _library_path), do: err!() + def dlopen_test_plugin(_path), do: err!() def get_device_count(_client), do: err!() def get_supported_platforms, do: err!() def run_cpu(_executable, _arguments, _device_id, _callback_server_pid), do: err!() diff --git a/exla/mix.exs b/exla/mix.exs index 22c8ee9460..404488aeef 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -33,7 +33,8 @@ defmodule EXLA.MixProject do "FINE_INCLUDE_DIR" => Fine.include_dir(), "MIX_BUILD_EMBEDDED" => "#{Mix.Project.config()[:build_embedded]}", "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv, - "EXLA_VERSION" => "#{@version}" + "EXLA_VERSION" => "#{@version}", + "BUILD_EXLA_TEST_PLUGIN" => if(Mix.env() == :test, do: "1", else: "0") } end, make_args: make_args diff --git a/exla/test/exla/custom_call_alias_test.exs b/exla/test/exla/custom_call_alias_test.exs new file mode 100644 index 0000000000..fdfe052084 --- /dev/null +++ b/exla/test/exla/custom_call_alias_test.exs @@ -0,0 +1,98 @@ +defmodule EXLA.CustomCallAliasTest do + use EXLA.Case, async: false + + import Nx.Defn + + alias EXLA.Test.QRAliasBlock + + defmodule BuiltinFun do + import Nx.Defn + + defn qr(t), do: Nx.LinAlg.qr(t) + end + + defmodule Fun do + import Nx.Defn + + alias EXLA.Test.QRAliasBlock + + defn qr_alias_fn(t) do + q_out = Nx.template({3, 3}, {:f, 32}) + r_out = Nx.template({3, 4}, {:f, 32}) + + Nx.block(%QRAliasBlock{}, [t], {q_out, r_out}, fn _, t2 -> + Nx.LinAlg.qr(t2, mode: :reduced) + end) + end + end + + @plugin_relative ~c"test/exla_qr_alias_plugin.so" + + defp plugin_path do + :filename.join(:code.priv_dir(:exla), @plugin_relative) + end + + defp mlir_via_jit_apply!(fun, args) when is_function(fun) and is_list(args) do + try do + Nx.Defn.jit_apply(fun, args, + compiler: EXLA, + module_compilation: :to_mlir + ) + catch + :throw, {:mlir_module, ref, used_inputs, output_container} -> + %{ + mlir_module: EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref}), + used_inputs: used_inputs, + output_container: output_container + } + end + end + + defp load_plugin! do + path = List.to_string(plugin_path()) + + unless File.exists?(path) do + flunk(""" + Missing #{path}. Build EXLA with MIX_ENV=test so the alias plugin is compiled \ + (see Makefile target exla_qr_alias_plugin.so). + """) + end + + case EXLA.NIF.dlopen_test_plugin(path) do + :ok -> + :ok + + other -> + flunk("dlopen_test_plugin(#{path}) expected :ok, got: #{inspect(other)}") + end + end + + test "builtin QR lowering includes qr_cpu_custom_call_f32 in MLIR" do + arg = Nx.iota({3, 4}, type: {:f, 32}) + assert %{mlir_module: mlir} = mlir_via_jit_apply!(&BuiltinFun.qr/1, [arg]) + + assert mlir =~ "@qr_cpu_custom_call_f32(" + refute mlir =~ "qr_cpu_custom_call_f32_exla_alias" + end + + test "QR alias plugin: MLIR uses alias name and not the builtin target string" do + load_plugin!() + + arg = Nx.iota({3, 4}, type: {:f, 32}) + assert %{mlir_module: mlir} = mlir_via_jit_apply!(&Fun.qr_alias_fn/1, [arg]) + + assert mlir =~ "qr_cpu_custom_call_f32_exla_alias" + refute mlir =~ "@qr_cpu_custom_call_f32(" + end + + test "QR alias plugin: JIT result matches builtin QR" do + load_plugin!() + + t = Nx.iota({3, 4}, type: {:f, 32}) + exp = EXLA.jit(fn t -> Nx.LinAlg.qr(t) end).(t) + act = EXLA.jit(&Fun.qr_alias_fn/1).(t) + + assert Nx.all_close(elem(exp, 0), elem(act, 0), atol: 1.0e-4, rtol: 1.0e-4) + assert Nx.all_close(elem(exp, 1), elem(act, 1), atol: 1.0e-4, rtol: 1.0e-4) + end +end diff --git a/exla/test/support/exla_test_qr_alias_block.ex b/exla/test/support/exla_test_qr_alias_block.ex new file mode 100644 index 0000000000..c7b82a6aaf --- /dev/null +++ b/exla/test/support/exla_test_qr_alias_block.ex @@ -0,0 +1,37 @@ +# Test-only block tag + `EXLA.CustomCall` impl used to emit a StableHLO custom_call +# with `call_target_name` `qr_cpu_custom_call_f32_exla_alias` (registered by +# `priv/test/exla_qr_alias_plugin.so` when built with `MIX_ENV=test`). +defmodule EXLA.Test.QRAliasBlock do + @moduledoc false + defstruct [] +end + +defimpl EXLA.CustomCall, for: EXLA.Test.QRAliasBlock do + alias EXLA.MLIR.Value + alias EXLA.Defn + + def apply?(_, {%{type: {q_kind, _}}, _r}, _args, client) do + q_kind != :c and client.platform == :host + end + + def apply?(_, _, _, _), do: false + + def call(_, {q_expr, r_expr}, [tensor], _client) do + tensor = + if Defn.op_type(tensor) != q_expr.type do + Defn.to_type(tensor, q_expr.type) + else + tensor + end + + {q, r} = + Value.qr_with_call_target( + tensor, + Defn.expr_to_typespec(q_expr), + Defn.expr_to_typespec(r_expr), + "qr_cpu_custom_call_f32_exla_alias" + ) + + [q, r] + end +end From 31520034c10468c8919e1ccac6c81a5036925a1f Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 27 Apr 2026 17:22:39 -0300 Subject: [PATCH 04/11] fix formatting --- exla/lib/exla/mlir/value.ex | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index a15d9f6bc7..77afb4fa23 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -786,7 +786,12 @@ defmodule EXLA.MLIR.Value do end @doc false - def qr_with_call_target(%Value{function: func} = value, q_typespec, r_typespec, call_target_name) + def qr_with_call_target( + %Value{function: func} = value, + q_typespec, + r_typespec, + call_target_name + ) when is_binary(call_target_name) do operands = [value] result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) From f2aa55876a736b4ae7a3d16bcc20b1fb0af32c34 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 27 Apr 2026 18:15:35 -0300 Subject: [PATCH 05/11] CustomCall now has only one callback + add documentation --- exla/lib/exla.ex | 7 + exla/lib/exla/custom_call.ex | 158 +++++++++++------- exla/lib/exla/defn.ex | 10 +- exla/test/support/exla_test_qr_alias_block.ex | 11 +- 4 files changed, 116 insertions(+), 70 deletions(-) diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index b558b0fd83..9be5dc13db 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -75,6 +75,13 @@ defmodule EXLA do * `:highest` - Slowest but most accurate. Performs computations in float32 or float64 as applicable + ## Native custom calls (`EXLA.CustomCall`) + + Some `Nx.block/4` tags can be lowered to XLA **custom calls** (StableHLO plus + a registered native handler). Implement the `EXLA.CustomCall` protocol for + your block tag struct; see `EXLA.CustomCall` for the `call/4` contract, + including returning `:skip` to fall back to the block's default Elixir callback. + ## Clients The `EXLA` library uses a client for compiling and executing code. diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex index 3bca4285fa..c92ba400e5 100644 --- a/exla/lib/exla/custom_call.ex +++ b/exla/lib/exla/custom_call.ex @@ -1,40 +1,101 @@ defprotocol EXLA.CustomCall do @moduledoc """ - Protocol used by `EXLA.Defn` to lower specific `Nx.block/4` tags that are - implemented as **XLA/StableHLO custom calls into native (C/C++) code** — - the same pipeline as `EXLA.MLIR.Value` helpers such as `qr/3` and `eigh/3`. - - Other blocks (for example gather-based take or plain StableHLO FFT) stay - inlined in `EXLA.Defn` so this protocol stays focused on what those paths - share: `stablehlo.custom_call` plus registration of the callee. - - Implementations receive the block tag struct, the output template (`out`), - the already-recursed MLIR `EXLA.MLIR.Value` arguments and the active - `EXLA.Client`. - - Built-in lowerings for those tags live in a single `defimpl ..., for: Any` - module (see comment there). Applications and libraries can still supply a - **more specific** `defimpl EXLA.CustomCall, for: TheirStruct` — Elixir will - use that instead of the `Any` fallback when the block tag matches. - """ + Extension point for lowering selected `Nx.block/4` tags to **XLA custom calls** + (`stablehlo.custom_call` in MLIR), the same style as helpers on + `EXLA.MLIR.Value` such as `qr/3` and `eigh/3`. - @fallback_to_any true + Other blocks (for example gather-based `take` or FFT) are lowered inline in + `EXLA.Defn` and do not use this protocol. - @doc """ - Returns `true` when EXLA should lower the block natively via `call/4`. + ## When `EXLA.Defn` calls it + + During compilation with `compiler: EXLA`, when the builder is an MLIR + `EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is + passed here: `EXLA.Defn` invokes `call(tag, outputs_template, lowered_inputs, client)`. + + If `call/4` returns `:skip`, EXLA compiles the block's **default callback** + (the anonymous function body) instead of emitting a custom call. + + ## `call/4` arguments + + * `struct` — the **tag** passed as the first argument to `Nx.block/4` + (your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`). + + * `out` — the **output template** tuple passed to `Nx.block/4` (expression + metadata for shapes and types, not runtime tensors). + + * `args` — list of already-lowered **operands** as `EXLA.MLIR.Value`s, in + the same order as `inputs` in `Nx.block/4`. + + * `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate + host-only lowerings). + + ## Return value + + * **Success** — return a list of `EXLA.MLIR.Value` (or a single value) that + matches the block result shape implied by `out`. + + * **`:skip`** — this implementation does not apply (unsupported type, + non-host platform, wrong arity, etc.). The default block implementation is + used instead. + + ## Dispatch + + The protocol uses `@fallback_to_any true`. Built-in lowerings for known tags + live in `defimpl EXLA.CustomCall, for: Any`. Your application or dependency can + add `defimpl EXLA.CustomCall, for: YourStruct`; that implementation is chosen + whenever the block tag is `%YourStruct{}`, instead of the `Any` fallback. + + ## Native handlers + + Emitting a custom call in MLIR is only half of the story: the **target name** + must be registered with XLA on the relevant platform (typically via a native + library loaded into the process). That registration is **not** configured + through `config :exla, ...`; you load or link the native code by the same + means you would for any other NIF-backed extension. + + ## Example + + defmodule MyApp.CustomQrTag do + defstruct [] + end + + defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do + alias EXLA.Defn + alias EXLA.MLIR.Value - When it returns `false`, `EXLA.Defn` falls back to compiling the block's - default callback implementation. + def call(_tag, {q_expr, r_expr}, [tensor], %{platform: :host}) do + tensor = + if Defn.op_type(tensor) != q_expr.type do + Defn.to_type(tensor, q_expr.type) + else + tensor + end + + {q, r} = + Value.qr(tensor, Defn.expr_to_typespec(q_expr), Defn.expr_to_typespec(r_expr)) + + [q, r] + end + + def call(_, _, _, _), do: :skip + end + + Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with + `compiler: EXLA`. """ - def apply?(struct, out, args, client) @fallback_to_any true @doc """ - Lowers the block natively. + Attempts to lower the block natively. - Must return the list of `EXLA.MLIR.Value`s (or a single value) that - represents the block result, matching the shape of `out`. + Returns a list of `EXLA.MLIR.Value`s (or a single value) that represents the + block result, matching the shape of `out`. + + Returns `:skip` when this implementation does not apply (wrong types, + platform, arity, etc.). `EXLA.Defn` then compiles the block's default callback + instead. """ def call(struct, out, args, client) end @@ -47,34 +108,13 @@ end # also target a built-in struct such as `Nx.Block...` from your app if needed). # defimpl EXLA.CustomCall, for: Any do + @moduledoc false + alias EXLA.MLIR.Value alias EXLA.Defn - # --- apply?/4 --- - - def apply?( - %Nx.Block.LinAlg.QR{}, - {%{type: {q_type_kind, _}}, _r}, - _args, - client - ) do - q_type_kind != :c and client.platform == :host - end - - def apply?( - %Nx.Block.LinAlg.Eigh{}, - {%{type: {eval_type_kind, _}}, %{type: {evec_type_kind, _}}}, - _args, - client - ) do - eval_type_kind != :c and evec_type_kind != :c and client.platform == :host - end - - def apply?(_, _, _, _), do: false - - # --- call/4 --- - - def call(%Nx.Block.LinAlg.QR{}, {q_expr, r_expr}, [tensor], _client) do + def call(%Nx.Block.LinAlg.QR{}, {%{type: {q_type_kind, _}} = q_expr, r_expr}, [tensor], client) + when q_type_kind != :c and client.platform == :host do tensor = if Defn.op_type(tensor) != q_expr.type do Defn.to_type(tensor, q_expr.type) @@ -88,10 +128,12 @@ defimpl EXLA.CustomCall, for: Any do def call( %Nx.Block.LinAlg.Eigh{}, - {eigenvals_expr, eigenvecs_expr}, + {%{type: {eval_type_kind, _}} = eigenvals_expr, + %{type: {evec_type_kind, _}} = eigenvecs_expr}, [tensor], - _client - ) do + client + ) + when eval_type_kind != :c and evec_type_kind != :c and client.platform == :host do # Eigen only supports f32/f64, so promote to the smallest floating type # wide enough to represent the requested output. out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) @@ -116,9 +158,7 @@ defimpl EXLA.CustomCall, for: Any do ] end - def call(struct, _out, _args, _client) do - raise ArgumentError, - "EXLA.CustomCall.call/4 is not implemented for #{inspect(struct)}. " <> - "Did you forget to guard with EXLA.CustomCall.apply?/4?" - end + def call(%Nx.Block.LinAlg.QR{}, _, _, _), do: :skip + def call(%Nx.Block.LinAlg.Eigh{}, _, _, _), do: :skip + def call(_, _, _, _), do: :skip end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 05d7568675..6dce31e7e9 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -734,10 +734,12 @@ defmodule EXLA.Defn do ) do {call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2)) - if EXLA.CustomCall.apply?(struct, out, call_args, client) do - {EXLA.CustomCall.call(struct, out, call_args, client), cache} - else - default_block_implementation(struct, call_args, out, state, cache) + case EXLA.CustomCall.call(struct, out, call_args, client) do + :skip -> + default_block_implementation(struct, call_args, out, state, cache) + + lowered -> + {lowered, cache} end end diff --git a/exla/test/support/exla_test_qr_alias_block.ex b/exla/test/support/exla_test_qr_alias_block.ex index c7b82a6aaf..b371adcfe5 100644 --- a/exla/test/support/exla_test_qr_alias_block.ex +++ b/exla/test/support/exla_test_qr_alias_block.ex @@ -10,13 +10,8 @@ defimpl EXLA.CustomCall, for: EXLA.Test.QRAliasBlock do alias EXLA.MLIR.Value alias EXLA.Defn - def apply?(_, {%{type: {q_kind, _}}, _r}, _args, client) do - q_kind != :c and client.platform == :host - end - - def apply?(_, _, _, _), do: false - - def call(_, {q_expr, r_expr}, [tensor], _client) do + def call(_, {%{type: {q_kind, _}} = q_expr, r_expr}, [tensor], client) + when q_kind != :c and client.platform == :host do tensor = if Defn.op_type(tensor) != q_expr.type do Defn.to_type(tensor, q_expr.type) @@ -34,4 +29,6 @@ defimpl EXLA.CustomCall, for: EXLA.Test.QRAliasBlock do [q, r] end + + def call(_, _, _, _), do: :skip end From 2d8e9c233cd924c590785a98f5f61bcf75fc3506 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:11:45 -0300 Subject: [PATCH 06/11] update based on polvalente comments --- exla/Makefile | 23 +-- exla/c_src/exla/exla.cc | 5 +- .../custom_calls.cc} | 7 +- exla/lib/exla/custom_call.ex | 140 ++++++++---------- exla/lib/exla/defn.ex | 28 +++- exla/lib/exla/mlir/value.ex | 63 ++++++-- exla/lib/exla/nif.ex | 2 +- exla/test/exla/custom_call_alias_test.exs | 10 +- exla/test/support/exla_test_qr_alias_block.ex | 30 +--- 9 files changed, 169 insertions(+), 139 deletions(-) rename exla/c_src/{exla_test_plugin/qr_alias_registration.cc => exla_test/custom_calls.cc} (70%) diff --git a/exla/Makefile b/exla/Makefile index f6f43ed58e..0183a3ea0e 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -21,6 +21,8 @@ EXLA_LIB_DIR = $(PRIV_DIR)/xla_extension/lib XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LIB) EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO) +.DEFAULT_GOAL := $(EXLA_SO) + # Build flags # # Note that XLA requires c++17, Fine as well @@ -86,22 +88,21 @@ else LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib' endif -# Optional test plugin: registers qr_cpu_custom_call_f32_exla_alias -> same -# handler as qr_cpu_custom_call_f32 (Mix sets BUILD_EXLA_TEST_PLUGIN=1 in :test). -TEST_PLUGIN_CC = c_src/exla_test_plugin/qr_alias_registration.cc -TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias_plugin.so - -ifneq ($(BUILD_EXLA_TEST_PLUGIN),) -ifneq ($(BUILD_EXLA_TEST_PLUGIN),0) -$(EXLA_SO): $(TEST_PLUGIN_SO) -endif -endif +# Optional test dylib: registers qr_cpu_custom_call_f32_exla_alias -> same +# handler as qr_cpu_custom_call_f32. Built only when MIX_ENV=test. +TEST_PLUGIN_CC = c_src/exla_test/custom_calls.cc +TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias.so $(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) @ mkdir -p $(dir $@) $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) -$(EXLA_SO): $(EXLA_CACHE_SO) +EXLA_SO_DEPS = $(EXLA_CACHE_SO) +ifeq ($(MIX_ENV),test) +EXLA_SO_DEPS += $(TEST_PLUGIN_SO) +endif + +$(EXLA_SO): $(EXLA_SO_DEPS) @ mkdir -p $(PRIV_DIR) @ mkdir -p $(PRIV_DIR)/xla_extension @ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \ diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index ce68245aa8..34816b3b7e 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -538,8 +538,7 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type, FINE_NIF(load_pjrt_plugin, 0); // Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run. -// Used from tests (e.g. alias custom-call registration plugin). -fine::Ok<> dlopen_test_plugin(ErlNifEnv *env, std::string path) { +fine::Ok<> load_dylib(ErlNifEnv *env, std::string path) { void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL); if (handle == nullptr) { const char *err = dlerror(); @@ -549,7 +548,7 @@ fine::Ok<> dlopen_test_plugin(ErlNifEnv *env, std::string path) { return fine::Ok(); } -FINE_NIF(dlopen_test_plugin, 0); +FINE_NIF(load_dylib, 0); int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr client) { return client->client()->device_count(); diff --git a/exla/c_src/exla_test_plugin/qr_alias_registration.cc b/exla/c_src/exla_test/custom_calls.cc similarity index 70% rename from exla/c_src/exla_test_plugin/qr_alias_registration.cc rename to exla/c_src/exla_test/custom_calls.cc index 1e80045e54..e54b095bf1 100644 --- a/exla/c_src/exla_test_plugin/qr_alias_registration.cc +++ b/exla/c_src/exla_test/custom_calls.cc @@ -1,9 +1,6 @@ // Test-only shared library: registers an alias FFI name that reuses the // existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so. -// -// Built when BUILD_EXLA_TEST_PLUGIN=1 (Mix test). Load with RTLD_GLOBAL via -// EXLA.NIF.dlopen_test_plugin/1 before compiling or running graphs that emit -// the alias call_target_name. +#ifndef EXLA_PROD #include "xla/ffi/api/api.h" #include "xla/ffi/ffi_api.h" @@ -14,3 +11,5 @@ extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame); XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias", "Host", qr_cpu_custom_call_f32); + +#endif diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex index c92ba400e5..cb9e8eb671 100644 --- a/exla/lib/exla/custom_call.ex +++ b/exla/lib/exla/custom_call.ex @@ -11,12 +11,15 @@ defprotocol EXLA.CustomCall do During compilation with `compiler: EXLA`, when the builder is an MLIR `EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is - passed here: `EXLA.Defn` invokes `call(tag, outputs_template, lowered_inputs, client)`. + passed here. `EXLA.Defn` invokes: - If `call/4` returns `:skip`, EXLA compiles the block's **default callback** - (the anonymous function body) instead of emitting a custom call. + * `function_name(tag, outputs_template, input_templates, client)` + * `config(tag, outputs_template, input_templates, client)` - ## `call/4` arguments + If `function_name/4` returns `:skip`, EXLA compiles the block's **default + callback** (the anonymous function body) instead of emitting a custom call. + + ## `function_name/4` and `config/4` arguments * `struct` — the **tag** passed as the first argument to `Nx.block/4` (your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`). @@ -24,20 +27,23 @@ defprotocol EXLA.CustomCall do * `out` — the **output template** tuple passed to `Nx.block/4` (expression metadata for shapes and types, not runtime tensors). - * `args` — list of already-lowered **operands** as `EXLA.MLIR.Value`s, in - the same order as `inputs` in `Nx.block/4`. + * `args` — list of **input templates**, in the same order as `inputs` in + `Nx.block/4`. * `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate host-only lowerings). - ## Return value + ## Return values - * **Success** — return a list of `EXLA.MLIR.Value` (or a single value) that - matches the block result shape implied by `out`. + * `function_name/4`: + * **Success** — return the native custom-call target name. + * **`:skip`** — this implementation does not apply (unsupported type, + non-host platform, wrong arity, etc.). The default block implementation + is used instead. - * **`:skip`** — this implementation does not apply (unsupported type, - non-host platform, wrong arity, etc.). The default block implementation is - used instead. + * `config/4`: + * Return a `map()` to be encoded as `backend_config`. + * Return `nil` to omit `backend_config`. ## Dispatch @@ -61,24 +67,14 @@ defprotocol EXLA.CustomCall do end defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do - alias EXLA.Defn - alias EXLA.MLIR.Value - - def call(_tag, {q_expr, r_expr}, [tensor], %{platform: :host}) do - tensor = - if Defn.op_type(tensor) != q_expr.type do - Defn.to_type(tensor, q_expr.type) - else - tensor - end - - {q, r} = - Value.qr(tensor, Defn.expr_to_typespec(q_expr), Defn.expr_to_typespec(r_expr)) - - [q, r] + def function_name(_tag, {%{type: {kind, size}}, _r_expr}, [_input], %{platform: :host}) + when kind != :c and kind in [:f, :bf] and size in [16, 32, 64] do + "my_custom_qr_target" end - def call(_, _, _, _), do: :skip + def function_name(_, _, _, _), do: :skip + + def config(_, _, _, _), do: nil end Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with @@ -88,16 +84,14 @@ defprotocol EXLA.CustomCall do @fallback_to_any true @doc """ - Attempts to lower the block natively. - - Returns a list of `EXLA.MLIR.Value`s (or a single value) that represents the - block result, matching the shape of `out`. + Returns the custom-call target name or `:skip`. + """ + def function_name(struct, out, args, client) - Returns `:skip` when this implementation does not apply (wrong types, - platform, arity, etc.). `EXLA.Defn` then compiles the block's default callback - instead. + @doc """ + Returns a map encoded into `backend_config`, or `nil`. """ - def call(struct, out, args, client) + def config(struct, out, args, client) end # Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live @@ -110,55 +104,43 @@ end defimpl EXLA.CustomCall, for: Any do @moduledoc false - alias EXLA.MLIR.Value - alias EXLA.Defn - - def call(%Nx.Block.LinAlg.QR{}, {%{type: {q_type_kind, _}} = q_expr, r_expr}, [tensor], client) - when q_type_kind != :c and client.platform == :host do - tensor = - if Defn.op_type(tensor) != q_expr.type do - Defn.to_type(tensor, q_expr.type) - else - tensor - end - - {q, r} = Value.qr(tensor, Defn.expr_to_typespec(q_expr), Defn.expr_to_typespec(r_expr)) - [q, r] + def function_name( + %Nx.Block.LinAlg.QR{}, + {%{type: {q_type_kind, q_size}}, _r_expr}, + [_tensor], + %{platform: :host} + ) + when q_type_kind != :c do + case {q_type_kind, q_size} do + {:f, 32} -> "qr_cpu_custom_call_f32" + {:f, 64} -> "qr_cpu_custom_call_f64" + {:f, 16} -> "qr_cpu_custom_call_f16" + {:bf, 16} -> "qr_cpu_custom_call_bf16" + _ -> :skip + end end - def call( + def function_name( %Nx.Block.LinAlg.Eigh{}, - {%{type: {eval_type_kind, _}} = eigenvals_expr, - %{type: {evec_type_kind, _}} = eigenvecs_expr}, - [tensor], - client + {%{type: {eval_type_kind, _}}, %{type: {evec_type_kind, evec_type_size}}}, + [_tensor], + %{platform: :host} ) - when eval_type_kind != :c and evec_type_kind != :c and client.platform == :host do - # Eigen only supports f32/f64, so promote to the smallest floating type - # wide enough to represent the requested output. - out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) - - tensor = - if Defn.op_type(tensor) != out_type do - Defn.to_type(tensor, out_type) - else - tensor - end - - {eigenvals, eigenvecs} = - Value.eigh( - tensor, - Defn.expr_to_typespec(%{eigenvals_expr | type: out_type}), - Defn.expr_to_typespec(%{eigenvecs_expr | type: out_type}) + when eval_type_kind != :c and evec_type_kind != :c do + out_type = + Nx.Type.merge( + Nx.Type.to_floating({evec_type_kind, evec_type_size}), + {:f, 32} ) - [ - Defn.to_type(eigenvals, eigenvals_expr.type), - Defn.to_type(eigenvecs, eigenvecs_expr.type) - ] + case out_type do + {:f, 32} -> "eigh_cpu_custom_call_f32" + {:f, 64} -> "eigh_cpu_custom_call_f64" + _ -> :skip + end end - def call(%Nx.Block.LinAlg.QR{}, _, _, _), do: :skip - def call(%Nx.Block.LinAlg.Eigh{}, _, _, _), do: :skip - def call(_, _, _, _), do: :skip + def function_name(_, _, _, _), do: :skip + + def config(_, _, _, _), do: nil end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 6dce31e7e9..5697dd2bb1 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -734,11 +734,35 @@ defmodule EXLA.Defn do ) do {call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2)) - case EXLA.CustomCall.call(struct, out, call_args, client) do + case EXLA.CustomCall.function_name(struct, out, in_args, client) do :skip -> default_block_implementation(struct, call_args, out, state, cache) - lowered -> + function_name -> + config = EXLA.CustomCall.config(struct, out, in_args, client) + + backend_config = + case config do + nil -> + nil + + %{} = map -> + map + + other -> + raise ArgumentError, + "EXLA.CustomCall.config/4 must return map() | nil, got: #{inspect(other)}" + end + + out_typespecs = + [out] + |> Composite.flatten_list() + |> Enum.map(&expr_to_typespec/1) + + lowered = + Value.custom_call(call_args, out_typespecs, function_name, backend_config) + |> wrap_tuple_result(out) + {lowered, cache} end end diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 77afb4fa23..a897844892 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -786,25 +786,28 @@ defmodule EXLA.MLIR.Value do end @doc false - def qr_with_call_target( - %Value{function: func} = value, - q_typespec, - r_typespec, - call_target_name + def custom_call( + [%Value{function: func} | _] = operands, + typespecs, + call_target_name, + backend_config \\ nil ) - when is_binary(call_target_name) do - operands = [value] - result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) + when is_binary(call_target_name) and is_list(typespecs) do + result_types = typespecs_to_mlir_types(typespecs) attributes = [ call_target_name: attr_string(call_target_name), api_version: attr_i32(4) ] - [q, r] = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + attributes = + if is_map(backend_config) do + Keyword.put(attributes, :backend_config, backend_config_to_attr(backend_config)) + else + attributes + end - {q, r} + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) end def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do @@ -1110,6 +1113,44 @@ defmodule EXLA.MLIR.Value do "{#{content}}" end + defp backend_config_to_attr(map) when is_map(map) do + map + |> Enum.map(fn {k, v} -> {attr_dict_key(k), backend_config_value_to_attr(v)} end) + |> attr_dict() + end + + defp backend_config_value_to_attr(v) when is_boolean(v), do: attr_boolean(v) + defp backend_config_value_to_attr(v) when is_integer(v), do: attr_i64(v) + defp backend_config_value_to_attr(v) when is_float(v), do: "#{v} : f64" + defp backend_config_value_to_attr(v) when is_binary(v), do: attr_string(v) + + defp backend_config_value_to_attr(v) when is_list(v) do + "[" <> Enum.map_join(v, ", ", &backend_config_value_to_attr/1) <> "]" + end + + defp backend_config_value_to_attr(v) when is_map(v), do: backend_config_to_attr(v) + + defp backend_config_value_to_attr(v) do + raise ArgumentError, + "custom_call backend_config value is not encodable to MLIR DictionaryAttr: #{inspect(v)}" + end + + defp attr_dict_key(key) when is_atom(key), do: Atom.to_string(key) + + defp attr_dict_key(key) when is_binary(key) do + if Regex.match?(~r/^[A-Za-z_][A-Za-z0-9_]*$/, key) do + key + else + raise ArgumentError, + "custom_call backend_config key must match [A-Za-z_][A-Za-z0-9_]*, got: #{inspect(key)}" + end + end + + defp attr_dict_key(key) do + raise ArgumentError, + "custom_call backend_config key must be an atom or string, got: #{inspect(key)}" + end + defp join_list(list) do "[" <> Enum.join(list, ", ") <> "]" end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index df0af731cf..9fab99366c 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -79,7 +79,7 @@ defmodule EXLA.NIF do def get_tpu_client(), do: err!() def get_c_api_client(_device_type), do: err!() def load_pjrt_plugin(_device_type, _library_path), do: err!() - def dlopen_test_plugin(_path), do: err!() + def load_dylib(_path), do: err!() def get_device_count(_client), do: err!() def get_supported_platforms, do: err!() def run_cpu(_executable, _arguments, _device_id, _callback_server_pid), do: err!() diff --git a/exla/test/exla/custom_call_alias_test.exs b/exla/test/exla/custom_call_alias_test.exs index fdfe052084..fc4fad82bc 100644 --- a/exla/test/exla/custom_call_alias_test.exs +++ b/exla/test/exla/custom_call_alias_test.exs @@ -26,7 +26,7 @@ defmodule EXLA.CustomCallAliasTest do end end - @plugin_relative ~c"test/exla_qr_alias_plugin.so" + @plugin_relative ~c"test/exla_qr_alias.so" defp plugin_path do :filename.join(:code.priv_dir(:exla), @plugin_relative) @@ -53,17 +53,17 @@ defmodule EXLA.CustomCallAliasTest do unless File.exists?(path) do flunk(""" - Missing #{path}. Build EXLA with MIX_ENV=test so the alias plugin is compiled \ - (see Makefile target exla_qr_alias_plugin.so). + Missing #{path}. Build EXLA with MIX_ENV=test so the alias dylib is compiled \ + (see Makefile target exla_qr_alias.so). """) end - case EXLA.NIF.dlopen_test_plugin(path) do + case EXLA.NIF.load_dylib(path) do :ok -> :ok other -> - flunk("dlopen_test_plugin(#{path}) expected :ok, got: #{inspect(other)}") + flunk("load_dylib(#{path}) expected :ok, got: #{inspect(other)}") end end diff --git a/exla/test/support/exla_test_qr_alias_block.ex b/exla/test/support/exla_test_qr_alias_block.ex index b371adcfe5..909f72b58f 100644 --- a/exla/test/support/exla_test_qr_alias_block.ex +++ b/exla/test/support/exla_test_qr_alias_block.ex @@ -1,34 +1,18 @@ # Test-only block tag + `EXLA.CustomCall` impl used to emit a StableHLO custom_call # with `call_target_name` `qr_cpu_custom_call_f32_exla_alias` (registered by -# `priv/test/exla_qr_alias_plugin.so` when built with `MIX_ENV=test`). +# `priv/test/exla_qr_alias.so` when built with `MIX_ENV=test`). defmodule EXLA.Test.QRAliasBlock do @moduledoc false defstruct [] end defimpl EXLA.CustomCall, for: EXLA.Test.QRAliasBlock do - alias EXLA.MLIR.Value - alias EXLA.Defn - - def call(_, {%{type: {q_kind, _}} = q_expr, r_expr}, [tensor], client) - when q_kind != :c and client.platform == :host do - tensor = - if Defn.op_type(tensor) != q_expr.type do - Defn.to_type(tensor, q_expr.type) - else - tensor - end - - {q, r} = - Value.qr_with_call_target( - tensor, - Defn.expr_to_typespec(q_expr), - Defn.expr_to_typespec(r_expr), - "qr_cpu_custom_call_f32_exla_alias" - ) - - [q, r] + def function_name(_, {%{type: {q_kind, q_size}}, _r_expr}, [_tensor], client) + when q_kind != :c and q_size == 32 and client.platform == :host do + "qr_cpu_custom_call_f32_exla_alias" end - def call(_, _, _, _), do: :skip + def function_name(_, _, _, _), do: :skip + + def config(_, _, _, _), do: nil end From adf236963a29adeca3710921f94933b8548f9e54 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:45:12 -0300 Subject: [PATCH 07/11] upcast integers to float in C --- exla/c_src/exla/exla.cc | 147 +++++++++++++++++++++++++++++++++++ exla/lib/exla/custom_call.ex | 86 ++++++++++++++------ exla/lib/exla/mlir/value.ex | 44 +++++------ 3 files changed, 228 insertions(+), 49 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 34816b3b7e..206cdb46e1 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -31,6 +31,12 @@ #include "xla/tsl/platform/statusor.h" #include "llvm/Support/ThreadPool.h" +#include + +#include "xla/extension/custom_calls/eigh.h" +#include "xla/extension/custom_calls/qr.h" +#include "xla/ffi/ffi_api.h" + namespace exla { using callback_bridge::Pending; @@ -730,4 +736,145 @@ FINE_NIF(write_to_pointer, 0); } // namespace exla +// Host QR custom calls: integer operands with f32 Q/R (see Nx.Type.to_floating/1 +// for integer matrices). Handlers live in libexla alongside the NIFs. +namespace { + +namespace ffi = xla::ffi; + +template +ffi::Error QrCpuCustomCallIntegerOperandF32ResultsImpl( + ffi::Buffer operand, ffi::ResultBuffer q, + ffi::ResultBuffer r) { + using IntT = ffi::NativeType; + auto operand_dims = operand.dimensions(); + auto q_dims = q->dimensions(); + auto r_dims = r->dimensions(); + + uint64_t m = q_dims[q_dims.size() - 2]; + uint64_t k = q_dims[q_dims.size() - 1]; + uint64_t n = r_dims[r_dims.size() - 1]; + uint64_t l = r_dims[r_dims.size() - 2]; + + bool complete = l == m; + + uint64_t batch_items = 1; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= static_cast(*it); + } + + uint64_t q_stride = m * k; + uint64_t r_stride = n * l; + uint64_t inner_stride = m * n; + + std::vector tmp(inner_stride); + const IntT *in_base = operand.typed_data(); + float *q_base = reinterpret_cast(q->untyped_data()); + float *r_base = reinterpret_cast(r->untyped_data()); + + for (uint64_t b = 0; b < batch_items; b++) { + const IntT *in = in_base + b * inner_stride; + for (uint64_t j = 0; j < inner_stride; j++) { + tmp[j] = static_cast(in[j]); + } + single_matrix_qr_cpu_custom_call( + q_base + b * q_stride, r_base + b * r_stride, tmp.data(), m, k, n, + complete); + } + + return ffi::Error::Success(); +} + +#define EXLA_REGISTER_QR_INT_F32(DTYPE, NAME) \ + static ffi::Error NAME##_impl(ffi::Buffer operand, \ + ffi::ResultBuffer q, \ + ffi::ResultBuffer r) { \ + return QrCpuCustomCallIntegerOperandF32ResultsImpl(operand, \ + q, r); \ + } \ + XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \ + ffi::Ffi::Bind() \ + .Arg>() \ + .Ret>() \ + .Ret>()); \ + XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME); + +EXLA_REGISTER_QR_INT_F32(S8, qr_cpu_custom_call_s8) +EXLA_REGISTER_QR_INT_F32(S16, qr_cpu_custom_call_s16) +EXLA_REGISTER_QR_INT_F32(S32, qr_cpu_custom_call_s32) +EXLA_REGISTER_QR_INT_F32(S64, qr_cpu_custom_call_s64) +EXLA_REGISTER_QR_INT_F32(U8, qr_cpu_custom_call_u8) +EXLA_REGISTER_QR_INT_F32(U16, qr_cpu_custom_call_u16) +EXLA_REGISTER_QR_INT_F32(U32, qr_cpu_custom_call_u32) +EXLA_REGISTER_QR_INT_F32(U64, qr_cpu_custom_call_u64) + +#undef EXLA_REGISTER_QR_INT_F32 + +template +ffi::Error EighCpuCustomCallIntegerOperandF32ResultsImpl( + ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + using IntT = ffi::NativeType; + auto operand_dims = operand.dimensions(); + auto eigenvalues_dims = eigenvalues->dimensions(); + auto eigenvectors_dims = eigenvectors->dimensions(); + + uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; + uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; + + uint64_t batch_items = 1; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= static_cast(*it); + } + + uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1]; + uint64_t eigenvectors_stride = m * n; + uint64_t inner_stride = m * n; + + std::vector tmp(inner_stride); + const IntT *in_base = operand.typed_data(); + float *eval_base = reinterpret_cast(eigenvalues->untyped_data()); + float *evec_base = reinterpret_cast(eigenvectors->untyped_data()); + + for (uint64_t b = 0; b < batch_items; b++) { + const IntT *in = in_base + b * inner_stride; + for (uint64_t j = 0; j < inner_stride; j++) { + tmp[j] = static_cast(in[j]); + } + single_matrix_eigh_cpu_custom_call( + eval_base + b * eigenvalues_stride, evec_base + b * eigenvectors_stride, + tmp.data(), m, n); + } + + return ffi::Error::Success(); +} + +#define EXLA_REGISTER_EIGH_INT_F32(DTYPE, NAME) \ + static ffi::Error NAME##_impl(ffi::Buffer operand, \ + ffi::ResultBuffer eigenvalues, \ + ffi::ResultBuffer eigenvectors) { \ + return EighCpuCustomCallIntegerOperandF32ResultsImpl( \ + operand, eigenvalues, eigenvectors); \ + } \ + XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \ + ffi::Ffi::Bind() \ + .Arg>() \ + .Ret>() \ + .Ret>()); \ + XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME); + +EXLA_REGISTER_EIGH_INT_F32(S8, eigh_cpu_custom_call_s8) +EXLA_REGISTER_EIGH_INT_F32(S16, eigh_cpu_custom_call_s16) +EXLA_REGISTER_EIGH_INT_F32(S32, eigh_cpu_custom_call_s32) +EXLA_REGISTER_EIGH_INT_F32(S64, eigh_cpu_custom_call_s64) +EXLA_REGISTER_EIGH_INT_F32(U8, eigh_cpu_custom_call_u8) +EXLA_REGISTER_EIGH_INT_F32(U16, eigh_cpu_custom_call_u16) +EXLA_REGISTER_EIGH_INT_F32(U32, eigh_cpu_custom_call_u32) +EXLA_REGISTER_EIGH_INT_F32(U64, eigh_cpu_custom_call_u64) + +#undef EXLA_REGISTER_EIGH_INT_F32 + +} // namespace + FINE_INIT("Elixir.EXLA.NIF"); diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex index cb9e8eb671..79f217a2c2 100644 --- a/exla/lib/exla/custom_call.ex +++ b/exla/lib/exla/custom_call.ex @@ -94,6 +94,58 @@ defprotocol EXLA.CustomCall do def config(struct, out, args, client) end +defmodule EXLA.CustomCall.Builtins do + @moduledoc false + + @doc """ + Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.qr/2`, or `:skip`. + + `operand_type` is the input matrix element type; `q_output_type` is the + element type of the `Q` factor from the block output template. + """ + def qr_cpu_target(operand_type, q_output_type) do + case {operand_type, q_output_type} do + {{:f, 32}, {:f, 32}} -> "qr_cpu_custom_call_f32" + {{:f, 64}, {:f, 64}} -> "qr_cpu_custom_call_f64" + {{:f, 16}, {:f, 16}} -> "qr_cpu_custom_call_f16" + {{:bf, 16}, {:bf, 16}} -> "qr_cpu_custom_call_bf16" + {{:s, 8}, {:f, 32}} -> "qr_cpu_custom_call_s8" + {{:s, 16}, {:f, 32}} -> "qr_cpu_custom_call_s16" + {{:s, 32}, {:f, 32}} -> "qr_cpu_custom_call_s32" + {{:s, 64}, {:f, 32}} -> "qr_cpu_custom_call_s64" + {{:u, 8}, {:f, 32}} -> "qr_cpu_custom_call_u8" + {{:u, 16}, {:f, 32}} -> "qr_cpu_custom_call_u16" + {{:u, 32}, {:f, 32}} -> "qr_cpu_custom_call_u32" + {{:u, 64}, {:f, 32}} -> "qr_cpu_custom_call_u64" + _ -> :skip + end + end + + @doc """ + Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.eigh/2`, or `:skip`. + + `operand_type` is the input matrix element type; `computation_type` is the + floating type used for eigenvalues and eigenvectors (same rule as + `Nx.Type.merge(Nx.Type.to_floating(evec_type), {:f, 32})` in the protocol). + Integer operands are promoted to that float inside the native handler. + """ + def eigh_cpu_target(operand_type, computation_type) do + case {operand_type, computation_type} do + {{:f, 32}, {:f, 32}} -> "eigh_cpu_custom_call_f32" + {{:f, 64}, {:f, 64}} -> "eigh_cpu_custom_call_f64" + {{:s, 8}, {:f, 32}} -> "eigh_cpu_custom_call_s8" + {{:s, 16}, {:f, 32}} -> "eigh_cpu_custom_call_s16" + {{:s, 32}, {:f, 32}} -> "eigh_cpu_custom_call_s32" + {{:s, 64}, {:f, 32}} -> "eigh_cpu_custom_call_s64" + {{:u, 8}, {:f, 32}} -> "eigh_cpu_custom_call_u8" + {{:u, 16}, {:f, 32}} -> "eigh_cpu_custom_call_u16" + {{:u, 32}, {:f, 32}} -> "eigh_cpu_custom_call_u32" + {{:u, 64}, {:f, 32}} -> "eigh_cpu_custom_call_u64" + _ -> :skip + end + end +end + # Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live # in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the # protocol, applications and libraries can define their own @@ -106,38 +158,26 @@ defimpl EXLA.CustomCall, for: Any do def function_name( %Nx.Block.LinAlg.QR{}, - {%{type: {q_type_kind, q_size}}, _r_expr}, - [_tensor], + {%{type: q_type}, _r_expr}, + [%{type: in_type} | _], %{platform: :host} ) - when q_type_kind != :c do - case {q_type_kind, q_size} do - {:f, 32} -> "qr_cpu_custom_call_f32" - {:f, 64} -> "qr_cpu_custom_call_f64" - {:f, 16} -> "qr_cpu_custom_call_f16" - {:bf, 16} -> "qr_cpu_custom_call_bf16" - _ -> :skip - end + when elem(q_type, 0) != :c and elem(in_type, 0) != :c do + EXLA.CustomCall.Builtins.qr_cpu_target(in_type, q_type) end def function_name( %Nx.Block.LinAlg.Eigh{}, - {%{type: {eval_type_kind, _}}, %{type: {evec_type_kind, evec_type_size}}}, - [_tensor], + {%{type: eval_type}, %{type: evec_type}}, + [%{type: in_type} | _], %{platform: :host} ) - when eval_type_kind != :c and evec_type_kind != :c do - out_type = - Nx.Type.merge( - Nx.Type.to_floating({evec_type_kind, evec_type_size}), - {:f, 32} - ) + when elem(eval_type, 0) != :c and elem(evec_type, 0) != :c and + elem(in_type, 0) != :c do + computation_type = + Nx.Type.merge(Nx.Type.to_floating(evec_type), {:f, 32}) - case out_type do - {:f, 32} -> "eigh_cpu_custom_call_f32" - {:f, 64} -> "eigh_cpu_custom_call_f64" - _ -> :skip - end + EXLA.CustomCall.Builtins.eigh_cpu_target(in_type, computation_type) end def function_name(_, _, _, _), do: :skip diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index a897844892..3fb63b52cc 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -719,23 +719,25 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.return", values, []) end - def eigh(%Value{function: func} = value, eigenvals_typespec, eigenvecs_typespec) do + def eigh( + %Value{function: func} = value, + %{type: eval_type} = eigenvals_typespec, + %{type: evec_type} = eigenvecs_typespec + ) do %{type: op_type} = get_typespec(value) operands = [value] result_types = typespecs_to_mlir_types([eigenvals_typespec, eigenvecs_typespec]) - call_target_name = - case op_type do - {:f, 32} -> - "eigh_cpu_custom_call_f32" + computation_type = Nx.Type.merge(Nx.Type.to_floating(evec_type), {:f, 32}) - {:f, 64} -> - "eigh_cpu_custom_call_f64" + call_target_name = + case EXLA.CustomCall.Builtins.eigh_cpu_target(op_type, computation_type) do + :skip -> + raise "Eigh decomposition not supported on :host device for operand type #{inspect(op_type)}, eigenvalue type #{inspect(eval_type)}, eigenvector type #{inspect(evec_type)}" - type -> - # Due to matching on EXLA.Defn, we are sure that the device here is always :host - raise "Eigh decomposition not supported on :host device for type #{inspect(type)}" + name when is_binary(name) -> + name end attributes = [ @@ -749,29 +751,19 @@ defmodule EXLA.MLIR.Value do {eigenvals, eigenvecs} end - def qr(%Value{function: func} = value, q_typespec, r_typespec) do + def qr(%Value{function: func} = value, %{type: q_type} = q_typespec, r_typespec) do %{type: op_type} = get_typespec(value) operands = [value] result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) call_target_name = - case op_type do - {:f, 32} -> - "qr_cpu_custom_call_f32" - - {:f, 64} -> - "qr_cpu_custom_call_f64" - - {:f, 16} -> - "qr_cpu_custom_call_f16" + case EXLA.CustomCall.Builtins.qr_cpu_target(op_type, q_type) do + :skip -> + raise "QR decomposition not supported on :host device for operand type #{inspect(op_type)} and Q type #{inspect(q_type)}" - {:bf, 16} -> - "qr_cpu_custom_call_bf16" - - type -> - # Due to matching on EXLA.Defn, we are sure that the device here is always :host - raise "QR decomposition not supported on :host device for type #{inspect(type)}" + name when is_binary(name) -> + name end attributes = [ From c058794652191df2221299d08bbf611dc189fc69 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:07:42 -0300 Subject: [PATCH 08/11] remove name from qr and eigh _cpu_target in custom_call.ex --- exla/lib/exla/custom_call.ex | 67 ++++++++++++++++-------------------- exla/lib/exla/mlir/value.ex | 6 ++-- exla/mix.exs | 1 - 3 files changed, 32 insertions(+), 42 deletions(-) diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex index 79f217a2c2..70587f90a0 100644 --- a/exla/lib/exla/custom_call.ex +++ b/exla/lib/exla/custom_call.ex @@ -100,23 +100,22 @@ defmodule EXLA.CustomCall.Builtins do @doc """ Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.qr/2`, or `:skip`. - `operand_type` is the input matrix element type; `q_output_type` is the - element type of the `Q` factor from the block output template. + `operand_type` is the input matrix element type. """ - def qr_cpu_target(operand_type, q_output_type) do - case {operand_type, q_output_type} do - {{:f, 32}, {:f, 32}} -> "qr_cpu_custom_call_f32" - {{:f, 64}, {:f, 64}} -> "qr_cpu_custom_call_f64" - {{:f, 16}, {:f, 16}} -> "qr_cpu_custom_call_f16" - {{:bf, 16}, {:bf, 16}} -> "qr_cpu_custom_call_bf16" - {{:s, 8}, {:f, 32}} -> "qr_cpu_custom_call_s8" - {{:s, 16}, {:f, 32}} -> "qr_cpu_custom_call_s16" - {{:s, 32}, {:f, 32}} -> "qr_cpu_custom_call_s32" - {{:s, 64}, {:f, 32}} -> "qr_cpu_custom_call_s64" - {{:u, 8}, {:f, 32}} -> "qr_cpu_custom_call_u8" - {{:u, 16}, {:f, 32}} -> "qr_cpu_custom_call_u16" - {{:u, 32}, {:f, 32}} -> "qr_cpu_custom_call_u32" - {{:u, 64}, {:f, 32}} -> "qr_cpu_custom_call_u64" + def qr_cpu_target(operand_type) do + case operand_type do + {:f, 32} -> "qr_cpu_custom_call_f32" + {:f, 64} -> "qr_cpu_custom_call_f64" + {:f, 16} -> "qr_cpu_custom_call_f16" + {:bf, 16} -> "qr_cpu_custom_call_bf16" + {:s, 8} -> "qr_cpu_custom_call_s8" + {:s, 16} -> "qr_cpu_custom_call_s16" + {:s, 32} -> "qr_cpu_custom_call_s32" + {:s, 64} -> "qr_cpu_custom_call_s64" + {:u, 8} -> "qr_cpu_custom_call_u8" + {:u, 16} -> "qr_cpu_custom_call_u16" + {:u, 32} -> "qr_cpu_custom_call_u32" + {:u, 64} -> "qr_cpu_custom_call_u64" _ -> :skip end end @@ -124,23 +123,20 @@ defmodule EXLA.CustomCall.Builtins do @doc """ Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.eigh/2`, or `:skip`. - `operand_type` is the input matrix element type; `computation_type` is the - floating type used for eigenvalues and eigenvectors (same rule as - `Nx.Type.merge(Nx.Type.to_floating(evec_type), {:f, 32})` in the protocol). - Integer operands are promoted to that float inside the native handler. + `operand_type` is the input matrix element type. """ - def eigh_cpu_target(operand_type, computation_type) do - case {operand_type, computation_type} do - {{:f, 32}, {:f, 32}} -> "eigh_cpu_custom_call_f32" - {{:f, 64}, {:f, 64}} -> "eigh_cpu_custom_call_f64" - {{:s, 8}, {:f, 32}} -> "eigh_cpu_custom_call_s8" - {{:s, 16}, {:f, 32}} -> "eigh_cpu_custom_call_s16" - {{:s, 32}, {:f, 32}} -> "eigh_cpu_custom_call_s32" - {{:s, 64}, {:f, 32}} -> "eigh_cpu_custom_call_s64" - {{:u, 8}, {:f, 32}} -> "eigh_cpu_custom_call_u8" - {{:u, 16}, {:f, 32}} -> "eigh_cpu_custom_call_u16" - {{:u, 32}, {:f, 32}} -> "eigh_cpu_custom_call_u32" - {{:u, 64}, {:f, 32}} -> "eigh_cpu_custom_call_u64" + def eigh_cpu_target(operand_type) do + case operand_type do + {:f, 32} -> "eigh_cpu_custom_call_f32" + {:f, 64} -> "eigh_cpu_custom_call_f64" + {:s, 8} -> "eigh_cpu_custom_call_s8" + {:s, 16} -> "eigh_cpu_custom_call_s16" + {:s, 32} -> "eigh_cpu_custom_call_s32" + {:s, 64} -> "eigh_cpu_custom_call_s64" + {:u, 8} -> "eigh_cpu_custom_call_u8" + {:u, 16} -> "eigh_cpu_custom_call_u16" + {:u, 32} -> "eigh_cpu_custom_call_u32" + {:u, 64} -> "eigh_cpu_custom_call_u64" _ -> :skip end end @@ -163,7 +159,7 @@ defimpl EXLA.CustomCall, for: Any do %{platform: :host} ) when elem(q_type, 0) != :c and elem(in_type, 0) != :c do - EXLA.CustomCall.Builtins.qr_cpu_target(in_type, q_type) + EXLA.CustomCall.Builtins.qr_cpu_target(in_type) end def function_name( @@ -174,10 +170,7 @@ defimpl EXLA.CustomCall, for: Any do ) when elem(eval_type, 0) != :c and elem(evec_type, 0) != :c and elem(in_type, 0) != :c do - computation_type = - Nx.Type.merge(Nx.Type.to_floating(evec_type), {:f, 32}) - - EXLA.CustomCall.Builtins.eigh_cpu_target(in_type, computation_type) + EXLA.CustomCall.Builtins.eigh_cpu_target(in_type) end def function_name(_, _, _, _), do: :skip diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 3fb63b52cc..68ac35b539 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -729,10 +729,8 @@ defmodule EXLA.MLIR.Value do operands = [value] result_types = typespecs_to_mlir_types([eigenvals_typespec, eigenvecs_typespec]) - computation_type = Nx.Type.merge(Nx.Type.to_floating(evec_type), {:f, 32}) - call_target_name = - case EXLA.CustomCall.Builtins.eigh_cpu_target(op_type, computation_type) do + case EXLA.CustomCall.Builtins.eigh_cpu_target(op_type) do :skip -> raise "Eigh decomposition not supported on :host device for operand type #{inspect(op_type)}, eigenvalue type #{inspect(eval_type)}, eigenvector type #{inspect(evec_type)}" @@ -758,7 +756,7 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) call_target_name = - case EXLA.CustomCall.Builtins.qr_cpu_target(op_type, q_type) do + case EXLA.CustomCall.Builtins.qr_cpu_target(op_type) do :skip -> raise "QR decomposition not supported on :host device for operand type #{inspect(op_type)} and Q type #{inspect(q_type)}" diff --git a/exla/mix.exs b/exla/mix.exs index 404488aeef..c4631e47a6 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -34,7 +34,6 @@ defmodule EXLA.MixProject do "MIX_BUILD_EMBEDDED" => "#{Mix.Project.config()[:build_embedded]}", "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv, "EXLA_VERSION" => "#{@version}", - "BUILD_EXLA_TEST_PLUGIN" => if(Mix.env() == :test, do: "1", else: "0") } end, make_args: make_args From e424803c9a41d9d14dc1d08c0794a2485aca50ee Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:15:41 -0300 Subject: [PATCH 09/11] remove Value.qr/3 and Value.eigh/3 as they are not being used --- exla/lib/exla/mlir/value.ex | 56 ------------------------------------- 1 file changed, 56 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 68ac35b539..e5a6d6e1b5 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -719,62 +719,6 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.return", values, []) end - def eigh( - %Value{function: func} = value, - %{type: eval_type} = eigenvals_typespec, - %{type: evec_type} = eigenvecs_typespec - ) do - %{type: op_type} = get_typespec(value) - - operands = [value] - result_types = typespecs_to_mlir_types([eigenvals_typespec, eigenvecs_typespec]) - - call_target_name = - case EXLA.CustomCall.Builtins.eigh_cpu_target(op_type) do - :skip -> - raise "Eigh decomposition not supported on :host device for operand type #{inspect(op_type)}, eigenvalue type #{inspect(eval_type)}, eigenvector type #{inspect(evec_type)}" - - name when is_binary(name) -> - name - end - - attributes = [ - call_target_name: attr_string(call_target_name), - api_version: attr_i32(4) - ] - - [eigenvals, eigenvecs] = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - - {eigenvals, eigenvecs} - end - - def qr(%Value{function: func} = value, %{type: q_type} = q_typespec, r_typespec) do - %{type: op_type} = get_typespec(value) - - operands = [value] - result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) - - call_target_name = - case EXLA.CustomCall.Builtins.qr_cpu_target(op_type) do - :skip -> - raise "QR decomposition not supported on :host device for operand type #{inspect(op_type)} and Q type #{inspect(q_type)}" - - name when is_binary(name) -> - name - end - - attributes = [ - call_target_name: attr_string(call_target_name), - api_version: attr_i32(4) - ] - - [q, r] = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - - {q, r} - end - @doc false def custom_call( [%Value{function: func} | _] = operands, From 238743d8ab475d779d8c266bed8e66155906a510 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 4 May 2026 18:53:37 -0300 Subject: [PATCH 10/11] remove CustomCall.Builtins, add integer test on custom_call_alias_test --- exla/lib/exla/custom_call.ex | 72 ++++++++--------------- exla/test/exla/custom_call_alias_test.exs | 8 +++ 2 files changed, 33 insertions(+), 47 deletions(-) diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex index 70587f90a0..82005cfa59 100644 --- a/exla/lib/exla/custom_call.ex +++ b/exla/lib/exla/custom_call.ex @@ -94,16 +94,24 @@ defprotocol EXLA.CustomCall do def config(struct, out, args, client) end -defmodule EXLA.CustomCall.Builtins do +# Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live +# in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the +# protocol, applications and libraries can define their own +# `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that +# implementation instead of this fallback when the block tag matches (you can +# also target a built-in struct such as `Nx.Block...` from your app if needed). +# +defimpl EXLA.CustomCall, for: Any do @moduledoc false - @doc """ - Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.qr/2`, or `:skip`. - - `operand_type` is the input matrix element type. - """ - def qr_cpu_target(operand_type) do - case operand_type do + def function_name( + %Nx.Block.LinAlg.QR{}, + {%{type: q_type}, _r_expr}, + [%{type: in_type} | _], + %{platform: :host} + ) + when elem(q_type, 0) != :c and elem(in_type, 0) != :c do + case in_type do {:f, 32} -> "qr_cpu_custom_call_f32" {:f, 64} -> "qr_cpu_custom_call_f64" {:f, 16} -> "qr_cpu_custom_call_f16" @@ -120,13 +128,15 @@ defmodule EXLA.CustomCall.Builtins do end end - @doc """ - Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.eigh/2`, or `:skip`. - - `operand_type` is the input matrix element type. - """ - def eigh_cpu_target(operand_type) do - case operand_type do + def function_name( + %Nx.Block.LinAlg.Eigh{}, + {%{type: eval_type}, %{type: evec_type}}, + [%{type: in_type} | _], + %{platform: :host} + ) + when elem(eval_type, 0) != :c and elem(evec_type, 0) != :c and + elem(in_type, 0) != :c do + case in_type do {:f, 32} -> "eigh_cpu_custom_call_f32" {:f, 64} -> "eigh_cpu_custom_call_f64" {:s, 8} -> "eigh_cpu_custom_call_s8" @@ -140,38 +150,6 @@ defmodule EXLA.CustomCall.Builtins do _ -> :skip end end -end - -# Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live -# in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the -# protocol, applications and libraries can define their own -# `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that -# implementation instead of this fallback when the block tag matches (you can -# also target a built-in struct such as `Nx.Block...` from your app if needed). -# -defimpl EXLA.CustomCall, for: Any do - @moduledoc false - - def function_name( - %Nx.Block.LinAlg.QR{}, - {%{type: q_type}, _r_expr}, - [%{type: in_type} | _], - %{platform: :host} - ) - when elem(q_type, 0) != :c and elem(in_type, 0) != :c do - EXLA.CustomCall.Builtins.qr_cpu_target(in_type) - end - - def function_name( - %Nx.Block.LinAlg.Eigh{}, - {%{type: eval_type}, %{type: evec_type}}, - [%{type: in_type} | _], - %{platform: :host} - ) - when elem(eval_type, 0) != :c and elem(evec_type, 0) != :c and - elem(in_type, 0) != :c do - EXLA.CustomCall.Builtins.eigh_cpu_target(in_type) - end def function_name(_, _, _, _), do: :skip diff --git a/exla/test/exla/custom_call_alias_test.exs b/exla/test/exla/custom_call_alias_test.exs index fc4fad82bc..5eab675e76 100644 --- a/exla/test/exla/custom_call_alias_test.exs +++ b/exla/test/exla/custom_call_alias_test.exs @@ -75,6 +75,14 @@ defmodule EXLA.CustomCallAliasTest do refute mlir =~ "qr_cpu_custom_call_f32_exla_alias" end + test "builtin QR lowering includes qr_cpu_custom_call_s32 in MLIR" do + arg = Nx.iota({3, 4}, type: {:s, 32}) + assert %{mlir_module: mlir} = mlir_via_jit_apply!(&BuiltinFun.qr/1, [arg]) + + assert mlir =~ "@qr_cpu_custom_call_s32(" + refute mlir =~ "qr_cpu_custom_call_f32_exla_alias" + end + test "QR alias plugin: MLIR uses alias name and not the builtin target string" do load_plugin!() From 3f2cb9ed5a116ed149f59847572d3164650e1654 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 4 May 2026 18:58:48 -0300 Subject: [PATCH 11/11] . --- exla/mix.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exla/mix.exs b/exla/mix.exs index c4631e47a6..22c8ee9460 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -33,7 +33,7 @@ defmodule EXLA.MixProject do "FINE_INCLUDE_DIR" => Fine.include_dir(), "MIX_BUILD_EMBEDDED" => "#{Mix.Project.config()[:build_embedded]}", "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv, - "EXLA_VERSION" => "#{@version}", + "EXLA_VERSION" => "#{@version}" } end, make_args: make_args