Skip to content
15 changes: 15 additions & 0 deletions exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ else
LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib'
endif

# Optional test plugin: registers qr_cpu_custom_call_f32_exla_alias -> same
# handler as qr_cpu_custom_call_f32 (Mix sets BUILD_EXLA_TEST_PLUGIN=1 in :test).
TEST_PLUGIN_CC = c_src/exla_test_plugin/qr_alias_registration.cc
TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias_plugin.so

ifneq ($(BUILD_EXLA_TEST_PLUGIN),)
ifneq ($(BUILD_EXLA_TEST_PLUGIN),0)
$(EXLA_SO): $(TEST_PLUGIN_SO)
endif
endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ifneq ($(BUILD_EXLA_TEST_PLUGIN),)
ifneq ($(BUILD_EXLA_TEST_PLUGIN),0)
$(EXLA_SO): $(TEST_PLUGIN_SO)
endif
endif


$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(dir $@)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)

$(EXLA_SO): $(EXLA_CACHE_SO)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(dir $@)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)
$(EXLA_SO): $(EXLA_CACHE_SO)
$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(PRIV_DIR)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)
EXLA_SO_DEPS = $(EXLA_CACHE_SO)
ifeq($(MIX_ENV),test)
EXLA_SO_DEPS += $(TEST_PLUGIN_SO)
end
$(EXLA_SO): $(EXLA_SO_DEPS)

@ mkdir -p $(PRIV_DIR)
@ mkdir -p $(PRIV_DIR)/xla_extension
Expand Down
16 changes: 16 additions & 0 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <dlfcn.h>

#include <cstring>
#include <fine.hpp>
#include <stdexcept>
Expand Down Expand Up @@ -535,6 +537,20 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type,

FINE_NIF(load_pjrt_plugin, 0);

// Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run.
// Used from tests (e.g. alias custom-call registration plugin).
fine::Ok<> dlopen_test_plugin(ErlNifEnv *env, std::string path) {
void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle == nullptr) {
const char *err = dlerror();
throw std::invalid_argument(err ? err : "dlopen failed");
}
(void)handle;
return fine::Ok();
}

FINE_NIF(dlopen_test_plugin, 0);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's expose this as load_dylib so that it's not used just for tests


int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client) {
return client->client()->device_count();
}
Expand Down
16 changes: 16 additions & 0 deletions exla/c_src/exla_test_plugin/qr_alias_registration.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Test-only shared library: registers an alias FFI name that reuses the
// existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so.
//
// Built when BUILD_EXLA_TEST_PLUGIN=1 (Mix test). Load with RTLD_GLOBAL via
// EXLA.NIF.dlopen_test_plugin/1 before compiling or running graphs that emit
// the alias call_target_name.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can wrap this file in ifndef EXLA_PROD to exclude its contents from in MIX_ENV=prod.
Let's also remove the plugin suffix. I suggest exla_test/custom_calls.cc as the file path in c_src


#include "xla/ffi/api/api.h"
#include "xla/ffi/ffi_api.h"

namespace ffi = xla::ffi;

extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame);

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias",
"Host", qr_cpu_custom_call_f32);
7 changes: 7 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ defmodule EXLA do
* `:highest` - Slowest but most accurate. Performs computations in float32
or float64 as applicable

## Native custom calls (`EXLA.CustomCall`)

Some `Nx.block/4` tags can be lowered to XLA **custom calls** (StableHLO plus
a registered native handler). Implement the `EXLA.CustomCall` protocol for
your block tag struct; see `EXLA.CustomCall` for the `call/4` contract,
including returning `:skip` to fall back to the block's default Elixir callback.

## Clients

The `EXLA` library uses a client for compiling and executing code.
Expand Down
164 changes: 164 additions & 0 deletions exla/lib/exla/custom_call.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
defprotocol EXLA.CustomCall do
@moduledoc """
Extension point for lowering selected `Nx.block/4` tags to **XLA custom calls**
(`stablehlo.custom_call` in MLIR), the same style as helpers on
`EXLA.MLIR.Value` such as `qr/3` and `eigh/3`.

Other blocks (for example gather-based `take` or FFT) are lowered inline in
`EXLA.Defn` and do not use this protocol.

## When `EXLA.Defn` calls it

During compilation with `compiler: EXLA`, when the builder is an MLIR
`EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is
passed here: `EXLA.Defn` invokes `call(tag, outputs_template, lowered_inputs, client)`.

If `call/4` returns `:skip`, EXLA compiles the block's **default callback**
(the anonymous function body) instead of emitting a custom call.

## `call/4` arguments

* `struct` — the **tag** passed as the first argument to `Nx.block/4`
(your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`).

* `out` — the **output template** tuple passed to `Nx.block/4` (expression
metadata for shapes and types, not runtime tensors).

* `args` — list of already-lowered **operands** as `EXLA.MLIR.Value`s, in
the same order as `inputs` in `Nx.block/4`.

* `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate
host-only lowerings).

## Return value

* **Success** — return a list of `EXLA.MLIR.Value` (or a single value) that
matches the block result shape implied by `out`.

* **`:skip`** — this implementation does not apply (unsupported type,
non-host platform, wrong arity, etc.). The default block implementation is
used instead.

## Dispatch

The protocol uses `@fallback_to_any true`. Built-in lowerings for known tags
live in `defimpl EXLA.CustomCall, for: Any`. Your application or dependency can
add `defimpl EXLA.CustomCall, for: YourStruct`; that implementation is chosen
whenever the block tag is `%YourStruct{}`, instead of the `Any` fallback.

## Native handlers

Emitting a custom call in MLIR is only half of the story: the **target name**
must be registered with XLA on the relevant platform (typically via a native
library loaded into the process). That registration is **not** configured
through `config :exla, ...`; you load or link the native code by the same
means you would for any other NIF-backed extension.

## Example

defmodule MyApp.CustomQrTag do
defstruct []
end

defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do
alias EXLA.Defn
alias EXLA.MLIR.Value

def call(_tag, {q_expr, r_expr}, [tensor], %{platform: :host}) do
tensor =
if Defn.op_type(tensor) != q_expr.type do
Defn.to_type(tensor, q_expr.type)
else
tensor
end

{q, r} =
Value.qr(tensor, Defn.expr_to_typespec(q_expr), Defn.expr_to_typespec(r_expr))

[q, r]
end

def call(_, _, _, _), do: :skip
end

Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with
`compiler: EXLA`.
"""

@fallback_to_any true

@doc """
Attempts to lower the block natively.

Returns a list of `EXLA.MLIR.Value`s (or a single value) that represents the
block result, matching the shape of `out`.

Returns `:skip` when this implementation does not apply (wrong types,
platform, arity, etc.). `EXLA.Defn` then compiles the block's default callback
instead.
"""
def call(struct, out, args, client)
end

# Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live
# in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the
# protocol, applications and libraries can define their own
# `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that
# implementation instead of this fallback when the block tag matches (you can
# also target a built-in struct such as `Nx.Block...` from your app if needed).
#
defimpl EXLA.CustomCall, for: Any do
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should those be different implementations rather than having it all in Any?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way the user can provide their own overrides!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see! Let's add a comment explaining why then! :D

@moduledoc false

alias EXLA.MLIR.Value
alias EXLA.Defn

def call(%Nx.Block.LinAlg.QR{}, {%{type: {q_type_kind, _}} = q_expr, r_expr}, [tensor], client)
when q_type_kind != :c and client.platform == :host do
tensor =
if Defn.op_type(tensor) != q_expr.type do
Defn.to_type(tensor, q_expr.type)
else
tensor
end

{q, r} = Value.qr(tensor, Defn.expr_to_typespec(q_expr), Defn.expr_to_typespec(r_expr))
[q, r]
end

def call(
%Nx.Block.LinAlg.Eigh{},
{%{type: {eval_type_kind, _}} = eigenvals_expr,
%{type: {evec_type_kind, _}} = eigenvecs_expr},
[tensor],
client
)
when eval_type_kind != :c and evec_type_kind != :c and client.platform == :host do
# Eigen only supports f32/f64, so promote to the smallest floating type
# wide enough to represent the requested output.
out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32})

tensor =
if Defn.op_type(tensor) != out_type do
Defn.to_type(tensor, out_type)
else
tensor
end

{eigenvals, eigenvecs} =
Value.eigh(
tensor,
Defn.expr_to_typespec(%{eigenvals_expr | type: out_type}),
Defn.expr_to_typespec(%{eigenvecs_expr | type: out_type})
)

[
Defn.to_type(eigenvals, eigenvals_expr.type),
Defn.to_type(eigenvecs, eigenvecs_expr.type)
]
end

def call(%Nx.Block.LinAlg.QR{}, _, _, _), do: :skip
def call(%Nx.Block.LinAlg.Eigh{}, _, _, _), do: :skip
def call(_, _, _, _), do: :skip
end
Loading
Loading