Refactor EXLA block lowering through EXLA.CustomCall protocol#1739
Refactor EXLA block lowering through EXLA.CustomCall protocol#1739Chapaman wants to merge 9 commits intoelixir-nx:mainfrom
Conversation
| defprotocol EXLA.CustomCall do | ||
| @moduledoc """ | ||
| Protocol used by `EXLA.Defn` to lower specific `Nx.block/4` tags natively | ||
| instead of compiling the fallback callback. | ||
|
|
||
| Implementations receive the block tag struct, the output template (`out`), | ||
| the already-recursed MLIR `EXLA.MLIR.Value` arguments and the active | ||
| `EXLA.Client`. | ||
| """ | ||
|
|
||
| @fallback_to_any true | ||
|
|
||
| @doc """ | ||
| Returns `true` when EXLA should lower the block natively via `call/4`. | ||
|
|
||
| When it returns `false`, `EXLA.Defn` falls back to compiling the block's | ||
| default callback implementation. | ||
| """ | ||
| def apply?(struct, out, args, client) | ||
|
|
||
| @fallback_to_any true | ||
|
|
||
| @doc """ | ||
| Lowers the block natively. | ||
|
|
||
| Must return the list of `EXLA.MLIR.Value`s (or a single value) that | ||
| represents the block result, matching the shape of `out`. | ||
| """ | ||
| def call(struct, out, args, client) | ||
| end | ||
|
|
||
| defimpl EXLA.CustomCall, for: Any do | ||
| alias EXLA.MLIR.Value | ||
| alias EXLA.Defn | ||
|
|
||
| # --- apply?/4 --- | ||
|
|
||
| def apply?( | ||
| %Nx.Block.LinAlg.QR{}, | ||
| {%{type: {q_type_kind, _}}, _r}, | ||
| _args, | ||
| client | ||
| ) do | ||
| q_type_kind != :c and client.platform == :host | ||
| end |
There was a problem hiding this comment.
We split this into apply?/4 and call/4 because a single callback would mix “should EXLA take the native path for this compile?” with “emit MLIR.” Keeping two functions separates eligibility from lowering.
apply?answers whether the native path applies for this block tag in context: e.g. for QR we still have%Nx.Block.LinAlg.QR{}, but we only native-lower when the output type is real (not complex) and the client is:host. Another block could gate on arity, struct fields, mesh, etc.—without duplicating the fullcallbody.callonly runs whenapply?is true and contains the actualValue.*/ Defn lowering.
EXLA.Defn mirrors that split: recurse operands → apply? → either call or the generic block fallback (default_block_implementation).
There was a problem hiding this comment.
You could have a single callback and instead return :skip in call when it cannot be lowered.
| def call(struct, out, args, client) | ||
| end | ||
|
|
||
| defimpl EXLA.CustomCall, for: Any do |
There was a problem hiding this comment.
Should those be different implementations rather than having it all in Any?
There was a problem hiding this comment.
This way the user can provide their own overrides!
There was a problem hiding this comment.
Ah, I see! Let's add a comment explaining why then! :D
|
@polvalente @Chapaman very nice! However, I think we should couple this more closely to actual custom calls (in C), because that's how it will work in practice. In other words, we should only move to the protocol blocks which are implemented as custom C calls. This will make it easier to see what they have in common (and we won't need to get distracted with things like take, which is quite different). |
| // Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run. | ||
| // Used from tests (e.g. alias custom-call registration plugin). | ||
| fine::Ok<> dlopen_test_plugin(ErlNifEnv *env, std::string path) { | ||
| void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL); | ||
| if (handle == nullptr) { | ||
| const char *err = dlerror(); | ||
| throw std::invalid_argument(err ? err : "dlopen failed"); | ||
| } | ||
| (void)handle; | ||
| return fine::Ok(); | ||
| } | ||
|
|
||
| FINE_NIF(dlopen_test_plugin, 0); |
There was a problem hiding this comment.
Let's expose this as load_dylib so that it's not used just for tests
| // | ||
| // Built when BUILD_EXLA_TEST_PLUGIN=1 (Mix test). Load with RTLD_GLOBAL via | ||
| // EXLA.NIF.dlopen_test_plugin/1 before compiling or running graphs that emit | ||
| // the alias call_target_name. |
There was a problem hiding this comment.
We can wrap this file in ifndef EXLA_PROD to exclude its contents from in MIX_ENV=prod.
Let's also remove the plugin suffix. I suggest exla_test/custom_calls.cc as the file path in c_src
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),) | ||
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),0) | ||
| $(EXLA_SO): $(TEST_PLUGIN_SO) | ||
| endif | ||
| endif |
There was a problem hiding this comment.
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),) | |
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),0) | |
| $(EXLA_SO): $(TEST_PLUGIN_SO) | |
| endif | |
| endif |
| "EXLA_VERSION" => "#{@version}", | ||
| "BUILD_EXLA_TEST_PLUGIN" => if(Mix.env() == :test, do: "1", else: "0") |
There was a problem hiding this comment.
| "EXLA_VERSION" => "#{@version}", | |
| "BUILD_EXLA_TEST_PLUGIN" => if(Mix.env() == :test, do: "1", else: "0") | |
| "EXLA_VERSION" => "#{@version}" |
| $(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) | ||
| @ mkdir -p $(dir $@) | ||
| $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) | ||
|
|
||
| $(EXLA_SO): $(EXLA_CACHE_SO) |
There was a problem hiding this comment.
| $(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) | |
| @ mkdir -p $(dir $@) | |
| $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) | |
| $(EXLA_SO): $(EXLA_CACHE_SO) | |
| $(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) | |
| @ mkdir -p $(PRIV_DIR) | |
| $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) | |
| EXLA_SO_DEPS = $(EXLA_CACHE_SO) | |
| ifeq($(MIX_ENV),test) | |
| EXLA_SO_DEPS += $(TEST_PLUGIN_SO) | |
| end | |
| $(EXLA_SO): $(EXLA_SO_DEPS) |
| @doc false | ||
| def qr_with_call_target( | ||
| %Value{function: func} = value, | ||
| q_typespec, | ||
| r_typespec, | ||
| call_target_name | ||
| ) | ||
| when is_binary(call_target_name) do |
There was a problem hiding this comment.
What if the protocol instead had 2 callbacks:
@spec function_name(struct, output_container, input_templates_list, client) :: String.t() | :skip
@spec config(struct, output_container, input_templates_list, client) :: map() | nilfunction_name would be the string name registered by invoking the .so (what we have here as call_target_name)
input and output typespecs we can infer ourselves
and backend_config would be obtained from the config callback. We'd have to validate all map values are encodable as mlir::DictionaryAttr (but we can encode them ourselves).
There was a problem hiding this comment.
There's a separate discussion on whether we expose other things such as has_side_effect and output_operand_aliases, but we can add these afterwards.
| case {operand_type, computation_type} do | ||
| {{:f, 32}, {:f, 32}} -> "eigh_cpu_custom_call_f32" | ||
| {{:f, 64}, {:f, 64}} -> "eigh_cpu_custom_call_f64" | ||
| {{:s, 8}, {:f, 32}} -> "eigh_cpu_custom_call_s8" | ||
| {{:s, 16}, {:f, 32}} -> "eigh_cpu_custom_call_s16" | ||
| {{:s, 32}, {:f, 32}} -> "eigh_cpu_custom_call_s32" | ||
| {{:s, 64}, {:f, 32}} -> "eigh_cpu_custom_call_s64" | ||
| {{:u, 8}, {:f, 32}} -> "eigh_cpu_custom_call_u8" | ||
| {{:u, 16}, {:f, 32}} -> "eigh_cpu_custom_call_u16" | ||
| {{:u, 32}, {:f, 32}} -> "eigh_cpu_custom_call_u32" | ||
| {{:u, 64}, {:f, 32}} -> "eigh_cpu_custom_call_u64" | ||
| _ -> :skip | ||
| end |
There was a problem hiding this comment.
I think the name is invariant to the output type. We can just use the input type
There was a problem hiding this comment.
Also, why are these not defps in the protocol impl?
There was a problem hiding this comment.
Also, why are these not defps in the protocol impl?
I thought I needed to make it a def and not defp because I'm using it in both EXLA.CustomCall and EXLA.MLIR.Value should I change it?
There was a problem hiding this comment.
Are Value.qr and Value.eigh actually used? I think if they are, they should just be aliases to Value.custom_call (which will end up calling the protocol).
If this is indeed possible, than these functions will just be used in the protocol implementation.
There was a problem hiding this comment.
😅😅 it was not being used lol
removed it :~
| defmodule EXLA.CustomCall.Builtins do | ||
| @moduledoc false | ||
|
|
||
| @doc """ | ||
| Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.qr/2`, or `:skip`. | ||
|
|
||
| `operand_type` is the input matrix element type. | ||
| """ | ||
| def qr_cpu_target(operand_type) do | ||
| case operand_type do | ||
| {:f, 32} -> "qr_cpu_custom_call_f32" | ||
| {:f, 64} -> "qr_cpu_custom_call_f64" |
There was a problem hiding this comment.
Let's get rid of this module and only define the target resolution inside the call sites
Summary
EXLA.CustomCall(apply?/4,call/4) as the hook for native lowering ofNx.block/4in EXLA, with@fallback_to_any trueand implementations onAnyfor the blocks that previously had dedicated clauses indefn.ex.EXLA.Defn: replaced the long chain of:blockspecial cases with onecached_recur_operator(:block, …)path — recurse tensor args, thenapply?→callif true, elsedefault_block_implementation/5(the previous generic subfunction +Value.callpath).Nx.Block.LinAlg.QR,Nx.Block.LinAlg.Eigh,Nx.Block.Take,Nx.Block.TopK,Nx.Block.FFT2,Nx.Block.IFFT2,Nx.Block.RFFT, andNx.Block.IRFFTinto the protocol impl; behavior should match the old code paths.EXLA.Defnfunctions are now@doc falsepublic (to_type,op_type,op_shape,expr_to_typespec,axes_for_rank,fft,fft2) so the protocol can call them;fft/fft2no longer takestate(builder comes from%Value{}.function).{:f, _}vs{:c, _}), notq.type != :c(types are tuples).