Skip to content

Refactor EXLA block lowering through EXLA.CustomCall protocol#1739

Draft
Chapaman wants to merge 9 commits intoelixir-nx:mainfrom
Chapaman:exla_block_implementation
Draft

Refactor EXLA block lowering through EXLA.CustomCall protocol#1739
Chapaman wants to merge 9 commits intoelixir-nx:mainfrom
Chapaman:exla_block_implementation

Conversation

@Chapaman
Copy link
Copy Markdown
Contributor

Summary

  • Introduced EXLA.CustomCall (apply?/4, call/4) as the hook for native lowering of Nx.block/4 in EXLA, with @fallback_to_any true and implementations on Any for the blocks that previously had dedicated clauses in defn.ex.
  • EXLA.Defn: replaced the long chain of :block special cases with one cached_recur_operator(:block, …) path — recurse tensor args, then apply?call if true, else default_block_implementation/5 (the previous generic subfunction + Value.call path).
  • Moved native paths for Nx.Block.LinAlg.QR, Nx.Block.LinAlg.Eigh, Nx.Block.Take, Nx.Block.TopK, Nx.Block.FFT2, Nx.Block.IFFT2, Nx.Block.RFFT, and Nx.Block.IRFFT into the protocol impl; behavior should match the old code paths.
  • Helpers: a few EXLA.Defn functions are now @doc false public (to_type, op_type, op_shape, expr_to_typespec, axes_for_rank, fft, fft2) so the protocol can call them; fft / fft2 no longer take state (builder comes from %Value{}.function).
  • Note: QR eligibility uses the type kind from the output template (e.g. {:f, _} vs {:c, _}), not q.type != :c (types are tuples).

Comment thread exla/lib/exla/custom_call.ex Outdated
Comment on lines +1 to +45
defprotocol EXLA.CustomCall do
@moduledoc """
Protocol used by `EXLA.Defn` to lower specific `Nx.block/4` tags natively
instead of compiling the fallback callback.

Implementations receive the block tag struct, the output template (`out`),
the already-recursed MLIR `EXLA.MLIR.Value` arguments and the active
`EXLA.Client`.
"""

@fallback_to_any true

@doc """
Returns `true` when EXLA should lower the block natively via `call/4`.

When it returns `false`, `EXLA.Defn` falls back to compiling the block's
default callback implementation.
"""
def apply?(struct, out, args, client)

@fallback_to_any true

@doc """
Lowers the block natively.

Must return the list of `EXLA.MLIR.Value`s (or a single value) that
represents the block result, matching the shape of `out`.
"""
def call(struct, out, args, client)
end

defimpl EXLA.CustomCall, for: Any do
alias EXLA.MLIR.Value
alias EXLA.Defn

# --- apply?/4 ---

def apply?(
%Nx.Block.LinAlg.QR{},
{%{type: {q_type_kind, _}}, _r},
_args,
client
) do
q_type_kind != :c and client.platform == :host
end
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We split this into apply?/4 and call/4 because a single callback would mix “should EXLA take the native path for this compile?” with “emit MLIR.” Keeping two functions separates eligibility from lowering.

  • apply? answers whether the native path applies for this block tag in context: e.g. for QR we still have %Nx.Block.LinAlg.QR{}, but we only native-lower when the output type is real (not complex) and the client is :host. Another block could gate on arity, struct fields, mesh, etc.—without duplicating the full call body.
  • call only runs when apply? is true and contains the actual Value.* / Defn lowering.

EXLA.Defn mirrors that split: recurse operands → apply? → either call or the generic block fallback (default_block_implementation).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have a single callback and instead return :skip in call when it cannot be lowered.

def call(struct, out, args, client)
end

defimpl EXLA.CustomCall, for: Any do
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should those be different implementations rather than having it all in Any?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way the user can provide their own overrides!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see! Let's add a comment explaining why then! :D

@josevalim
Copy link
Copy Markdown
Contributor

@polvalente @Chapaman very nice! However, I think we should couple this more closely to actual custom calls (in C), because that's how it will work in practice. In other words, we should only move to the protocol blocks which are implemented as custom C calls. This will make it easier to see what they have in common (and we won't need to get distracted with things like take, which is quite different).

@Chapaman Chapaman marked this pull request as ready for review April 24, 2026 21:35
@Chapaman Chapaman marked this pull request as draft April 25, 2026 02:26
Comment thread exla/c_src/exla/exla.cc Outdated
Comment on lines +540 to +552
// Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run.
// Used from tests (e.g. alias custom-call registration plugin).
fine::Ok<> dlopen_test_plugin(ErlNifEnv *env, std::string path) {
void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle == nullptr) {
const char *err = dlerror();
throw std::invalid_argument(err ? err : "dlopen failed");
}
(void)handle;
return fine::Ok();
}

FINE_NIF(dlopen_test_plugin, 0);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's expose this as load_dylib so that it's not used just for tests

//
// Built when BUILD_EXLA_TEST_PLUGIN=1 (Mix test). Load with RTLD_GLOBAL via
// EXLA.NIF.dlopen_test_plugin/1 before compiling or running graphs that emit
// the alias call_target_name.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can wrap this file in ifndef EXLA_PROD to exclude its contents from in MIX_ENV=prod.
Let's also remove the plugin suffix. I suggest exla_test/custom_calls.cc as the file path in c_src

Comment thread exla/Makefile Outdated
Comment on lines +94 to +98
ifneq ($(BUILD_EXLA_TEST_PLUGIN),)
ifneq ($(BUILD_EXLA_TEST_PLUGIN),0)
$(EXLA_SO): $(TEST_PLUGIN_SO)
endif
endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ifneq ($(BUILD_EXLA_TEST_PLUGIN),)
ifneq ($(BUILD_EXLA_TEST_PLUGIN),0)
$(EXLA_SO): $(TEST_PLUGIN_SO)
endif
endif

Comment thread exla/mix.exs Outdated
Comment on lines +36 to +37
"EXLA_VERSION" => "#{@version}",
"BUILD_EXLA_TEST_PLUGIN" => if(Mix.env() == :test, do: "1", else: "0")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"EXLA_VERSION" => "#{@version}",
"BUILD_EXLA_TEST_PLUGIN" => if(Mix.env() == :test, do: "1", else: "0")
"EXLA_VERSION" => "#{@version}"

Comment thread exla/Makefile Outdated
Comment on lines 100 to 104
$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(dir $@)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)

$(EXLA_SO): $(EXLA_CACHE_SO)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(dir $@)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)
$(EXLA_SO): $(EXLA_CACHE_SO)
$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(PRIV_DIR)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)
EXLA_SO_DEPS = $(EXLA_CACHE_SO)
ifeq($(MIX_ENV),test)
EXLA_SO_DEPS += $(TEST_PLUGIN_SO)
end
$(EXLA_SO): $(EXLA_SO_DEPS)

Comment thread exla/lib/exla/mlir/value.ex Outdated
Comment on lines +788 to +795
@doc false
def qr_with_call_target(
%Value{function: func} = value,
q_typespec,
r_typespec,
call_target_name
)
when is_binary(call_target_name) do
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the protocol instead had 2 callbacks:

@spec function_name(struct, output_container, input_templates_list, client) :: String.t() | :skip
@spec config(struct, output_container, input_templates_list, client) :: map() | nil

function_name would be the string name registered by invoking the .so (what we have here as call_target_name)

input and output typespecs we can infer ourselves

and backend_config would be obtained from the config callback. We'd have to validate all map values are encodable as mlir::DictionaryAttr (but we can encode them ourselves).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a separate discussion on whether we expose other things such as has_side_effect and output_operand_aliases, but we can add these afterwards.

Comment thread exla/lib/exla/custom_call.ex Outdated
Comment on lines +133 to +145
case {operand_type, computation_type} do
{{:f, 32}, {:f, 32}} -> "eigh_cpu_custom_call_f32"
{{:f, 64}, {:f, 64}} -> "eigh_cpu_custom_call_f64"
{{:s, 8}, {:f, 32}} -> "eigh_cpu_custom_call_s8"
{{:s, 16}, {:f, 32}} -> "eigh_cpu_custom_call_s16"
{{:s, 32}, {:f, 32}} -> "eigh_cpu_custom_call_s32"
{{:s, 64}, {:f, 32}} -> "eigh_cpu_custom_call_s64"
{{:u, 8}, {:f, 32}} -> "eigh_cpu_custom_call_u8"
{{:u, 16}, {:f, 32}} -> "eigh_cpu_custom_call_u16"
{{:u, 32}, {:f, 32}} -> "eigh_cpu_custom_call_u32"
{{:u, 64}, {:f, 32}} -> "eigh_cpu_custom_call_u64"
_ -> :skip
end
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name is invariant to the output type. We can just use the input type

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, why are these not defps in the protocol impl?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, why are these not defps in the protocol impl?

I thought I needed to make it a def and not defp because I'm using it in both EXLA.CustomCall and EXLA.MLIR.Value should I change it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are Value.qr and Value.eigh actually used? I think if they are, they should just be aliases to Value.custom_call (which will end up calling the protocol).

If this is indeed possible, than these functions will just be used in the protocol implementation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅😅 it was not being used lol
removed it :~

@polvalente polvalente requested a review from josevalim April 30, 2026 05:09
Comment on lines +97 to +108
defmodule EXLA.CustomCall.Builtins do
@moduledoc false

@doc """
Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.qr/2`, or `:skip`.

`operand_type` is the input matrix element type.
"""
def qr_cpu_target(operand_type) do
case operand_type do
{:f, 32} -> "qr_cpu_custom_call_f32"
{:f, 64} -> "qr_cpu_custom_call_f64"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get rid of this module and only define the target resolution inside the call sites

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants