-
Notifications
You must be signed in to change notification settings - Fork 218
Refactor EXLA block lowering through EXLA.CustomCall protocol #1739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
014eb64
5db679f
4011521
3152003
f2aa558
2d8e9c2
adf2369
c058794
e424803
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -86,6 +86,21 @@ else | |||||||||||||||||||||||||||||||
| LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib' | ||||||||||||||||||||||||||||||||
| endif | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Optional test plugin: registers qr_cpu_custom_call_f32_exla_alias -> same | ||||||||||||||||||||||||||||||||
| # handler as qr_cpu_custom_call_f32 (Mix sets BUILD_EXLA_TEST_PLUGIN=1 in :test). | ||||||||||||||||||||||||||||||||
| TEST_PLUGIN_CC = c_src/exla_test_plugin/qr_alias_registration.cc | ||||||||||||||||||||||||||||||||
| TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias_plugin.so | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),) | ||||||||||||||||||||||||||||||||
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),0) | ||||||||||||||||||||||||||||||||
| $(EXLA_SO): $(TEST_PLUGIN_SO) | ||||||||||||||||||||||||||||||||
| endif | ||||||||||||||||||||||||||||||||
| endif | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| $(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) | ||||||||||||||||||||||||||||||||
| @ mkdir -p $(dir $@) | ||||||||||||||||||||||||||||||||
| $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| $(EXLA_SO): $(EXLA_CACHE_SO) | ||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
| @ mkdir -p $(PRIV_DIR) | ||||||||||||||||||||||||||||||||
| @ mkdir -p $(PRIV_DIR)/xla_extension | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| #include <dlfcn.h> | ||
|
|
||
| #include <cstring> | ||
| #include <fine.hpp> | ||
| #include <stdexcept> | ||
|
|
@@ -535,6 +537,20 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type, | |
|
|
||
| FINE_NIF(load_pjrt_plugin, 0); | ||
|
|
||
| // Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run. | ||
| // Used from tests (e.g. alias custom-call registration plugin). | ||
| fine::Ok<> dlopen_test_plugin(ErlNifEnv *env, std::string path) { | ||
| void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL); | ||
| if (handle == nullptr) { | ||
| const char *err = dlerror(); | ||
| throw std::invalid_argument(err ? err : "dlopen failed"); | ||
| } | ||
| (void)handle; | ||
| return fine::Ok(); | ||
| } | ||
|
|
||
| FINE_NIF(dlopen_test_plugin, 0); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's expose this as load_dylib so that it's not used just for tests |
||
|
|
||
| int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client) { | ||
| return client->client()->device_count(); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| // Test-only shared library: registers an alias FFI name that reuses the | ||
| // existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so. | ||
| // | ||
| // Built when BUILD_EXLA_TEST_PLUGIN=1 (Mix test). Load with RTLD_GLOBAL via | ||
| // EXLA.NIF.dlopen_test_plugin/1 before compiling or running graphs that emit | ||
| // the alias call_target_name. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can wrap this file in |
||
|
|
||
| #include "xla/ffi/api/api.h" | ||
| #include "xla/ffi/ffi_api.h" | ||
|
|
||
| namespace ffi = xla::ffi; | ||
|
|
||
| extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame); | ||
|
|
||
| XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias", | ||
| "Host", qr_cpu_custom_call_f32); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| defprotocol EXLA.CustomCall do | ||
| @moduledoc """ | ||
| Extension point for lowering selected `Nx.block/4` tags to **XLA custom calls** | ||
| (`stablehlo.custom_call` in MLIR), the same style as helpers on | ||
| `EXLA.MLIR.Value` such as `qr/3` and `eigh/3`. | ||
|
|
||
| Other blocks (for example gather-based `take` or FFT) are lowered inline in | ||
| `EXLA.Defn` and do not use this protocol. | ||
|
|
||
| ## When `EXLA.Defn` calls it | ||
|
|
||
| During compilation with `compiler: EXLA`, when the builder is an MLIR | ||
| `EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is | ||
| passed here: `EXLA.Defn` invokes `call(tag, outputs_template, lowered_inputs, client)`. | ||
|
|
||
| If `call/4` returns `:skip`, EXLA compiles the block's **default callback** | ||
| (the anonymous function body) instead of emitting a custom call. | ||
|
|
||
| ## `call/4` arguments | ||
|
|
||
| * `struct` — the **tag** passed as the first argument to `Nx.block/4` | ||
| (your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`). | ||
|
|
||
| * `out` — the **output template** tuple passed to `Nx.block/4` (expression | ||
| metadata for shapes and types, not runtime tensors). | ||
|
|
||
| * `args` — list of already-lowered **operands** as `EXLA.MLIR.Value`s, in | ||
| the same order as `inputs` in `Nx.block/4`. | ||
|
|
||
| * `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate | ||
| host-only lowerings). | ||
|
|
||
| ## Return value | ||
|
|
||
| * **Success** — return a list of `EXLA.MLIR.Value` (or a single value) that | ||
| matches the block result shape implied by `out`. | ||
|
|
||
| * **`:skip`** — this implementation does not apply (unsupported type, | ||
| non-host platform, wrong arity, etc.). The default block implementation is | ||
| used instead. | ||
|
|
||
| ## Dispatch | ||
|
|
||
| The protocol uses `@fallback_to_any true`. Built-in lowerings for known tags | ||
| live in `defimpl EXLA.CustomCall, for: Any`. Your application or dependency can | ||
| add `defimpl EXLA.CustomCall, for: YourStruct`; that implementation is chosen | ||
| whenever the block tag is `%YourStruct{}`, instead of the `Any` fallback. | ||
|
|
||
| ## Native handlers | ||
|
|
||
| Emitting a custom call in MLIR is only half of the story: the **target name** | ||
| must be registered with XLA on the relevant platform (typically via a native | ||
| library loaded into the process). That registration is **not** configured | ||
| through `config :exla, ...`; you load or link the native code by the same | ||
| means you would for any other NIF-backed extension. | ||
|
|
||
| ## Example | ||
|
|
||
| defmodule MyApp.CustomQrTag do | ||
| defstruct [] | ||
| end | ||
|
|
||
| defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do | ||
| alias EXLA.Defn | ||
| alias EXLA.MLIR.Value | ||
|
|
||
| def call(_tag, {q_expr, r_expr}, [tensor], %{platform: :host}) do | ||
| tensor = | ||
| if Defn.op_type(tensor) != q_expr.type do | ||
| Defn.to_type(tensor, q_expr.type) | ||
| else | ||
| tensor | ||
| end | ||
|
|
||
| {q, r} = | ||
| Value.qr(tensor, Defn.expr_to_typespec(q_expr), Defn.expr_to_typespec(r_expr)) | ||
|
|
||
| [q, r] | ||
| end | ||
|
|
||
| def call(_, _, _, _), do: :skip | ||
| end | ||
|
|
||
| Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with | ||
| `compiler: EXLA`. | ||
| """ | ||
|
|
||
| @fallback_to_any true | ||
|
|
||
| @doc """ | ||
| Attempts to lower the block natively. | ||
|
|
||
| Returns a list of `EXLA.MLIR.Value`s (or a single value) that represents the | ||
| block result, matching the shape of `out`. | ||
|
|
||
| Returns `:skip` when this implementation does not apply (wrong types, | ||
| platform, arity, etc.). `EXLA.Defn` then compiles the block's default callback | ||
| instead. | ||
| """ | ||
| def call(struct, out, args, client) | ||
| end | ||
|
|
||
| # Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live | ||
| # in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the | ||
| # protocol, applications and libraries can define their own | ||
| # `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that | ||
| # implementation instead of this fallback when the block tag matches (you can | ||
| # also target a built-in struct such as `Nx.Block...` from your app if needed). | ||
| # | ||
| defimpl EXLA.CustomCall, for: Any do | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should those be different implementations rather than having it all in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This way the user can provide their own overrides!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see! Let's add a comment explaining why then! :D |
||
| @moduledoc false | ||
|
|
||
| alias EXLA.MLIR.Value | ||
| alias EXLA.Defn | ||
|
|
||
| def call(%Nx.Block.LinAlg.QR{}, {%{type: {q_type_kind, _}} = q_expr, r_expr}, [tensor], client) | ||
| when q_type_kind != :c and client.platform == :host do | ||
| tensor = | ||
| if Defn.op_type(tensor) != q_expr.type do | ||
| Defn.to_type(tensor, q_expr.type) | ||
| else | ||
| tensor | ||
| end | ||
|
|
||
| {q, r} = Value.qr(tensor, Defn.expr_to_typespec(q_expr), Defn.expr_to_typespec(r_expr)) | ||
| [q, r] | ||
| end | ||
|
|
||
| def call( | ||
| %Nx.Block.LinAlg.Eigh{}, | ||
| {%{type: {eval_type_kind, _}} = eigenvals_expr, | ||
| %{type: {evec_type_kind, _}} = eigenvecs_expr}, | ||
| [tensor], | ||
| client | ||
| ) | ||
| when eval_type_kind != :c and evec_type_kind != :c and client.platform == :host do | ||
| # Eigen only supports f32/f64, so promote to the smallest floating type | ||
| # wide enough to represent the requested output. | ||
| out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) | ||
|
|
||
| tensor = | ||
| if Defn.op_type(tensor) != out_type do | ||
| Defn.to_type(tensor, out_type) | ||
| else | ||
| tensor | ||
| end | ||
|
|
||
| {eigenvals, eigenvecs} = | ||
| Value.eigh( | ||
| tensor, | ||
| Defn.expr_to_typespec(%{eigenvals_expr | type: out_type}), | ||
| Defn.expr_to_typespec(%{eigenvecs_expr | type: out_type}) | ||
| ) | ||
|
|
||
| [ | ||
| Defn.to_type(eigenvals, eigenvals_expr.type), | ||
| Defn.to_type(eigenvecs, eigenvecs_expr.type) | ||
| ] | ||
| end | ||
|
|
||
| def call(%Nx.Block.LinAlg.QR{}, _, _, _), do: :skip | ||
| def call(%Nx.Block.LinAlg.Eigh{}, _, _, _), do: :skip | ||
| def call(_, _, _, _), do: :skip | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.