diff --git a/exla/Makefile b/exla/Makefile index b875ef59f7..0183a3ea0e 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -21,6 +21,8 @@ EXLA_LIB_DIR = $(PRIV_DIR)/xla_extension/lib XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LIB) EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO) +.DEFAULT_GOAL := $(EXLA_SO) + # Build flags # # Note that XLA requires c++17, Fine as well @@ -86,7 +88,21 @@ else LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib' endif -$(EXLA_SO): $(EXLA_CACHE_SO) +# Optional test dylib: registers qr_cpu_custom_call_f32_exla_alias -> same +# handler as qr_cpu_custom_call_f32. Built only when MIX_ENV=test. +TEST_PLUGIN_CC = c_src/exla_test/custom_calls.cc +TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias.so + +$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) + @ mkdir -p $(dir $@) + $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) + +EXLA_SO_DEPS = $(EXLA_CACHE_SO) +ifeq ($(MIX_ENV),test) +EXLA_SO_DEPS += $(TEST_PLUGIN_SO) +endif + +$(EXLA_SO): $(EXLA_SO_DEPS) @ mkdir -p $(PRIV_DIR) @ mkdir -p $(PRIV_DIR)/xla_extension @ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \ diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index c4e9085833..206cdb46e1 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -29,6 +31,12 @@ #include "xla/tsl/platform/statusor.h" #include "llvm/Support/ThreadPool.h" +#include + +#include "xla/extension/custom_calls/eigh.h" +#include "xla/extension/custom_calls/qr.h" +#include "xla/ffi/ffi_api.h" + namespace exla { using callback_bridge::Pending; @@ -535,6 +543,19 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type, FINE_NIF(load_pjrt_plugin, 0); +// Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run. +fine::Ok<> load_dylib(ErlNifEnv *env, std::string path) { + void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + const char *err = dlerror(); + throw std::invalid_argument(err ? err : "dlopen failed"); + } + (void)handle; + return fine::Ok(); +} + +FINE_NIF(load_dylib, 0); + int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr client) { return client->client()->device_count(); } @@ -715,4 +736,145 @@ FINE_NIF(write_to_pointer, 0); } // namespace exla +// Host QR custom calls: integer operands with f32 Q/R (see Nx.Type.to_floating/1 +// for integer matrices). Handlers live in libexla alongside the NIFs. +namespace { + +namespace ffi = xla::ffi; + +template +ffi::Error QrCpuCustomCallIntegerOperandF32ResultsImpl( + ffi::Buffer operand, ffi::ResultBuffer q, + ffi::ResultBuffer r) { + using IntT = ffi::NativeType; + auto operand_dims = operand.dimensions(); + auto q_dims = q->dimensions(); + auto r_dims = r->dimensions(); + + uint64_t m = q_dims[q_dims.size() - 2]; + uint64_t k = q_dims[q_dims.size() - 1]; + uint64_t n = r_dims[r_dims.size() - 1]; + uint64_t l = r_dims[r_dims.size() - 2]; + + bool complete = l == m; + + uint64_t batch_items = 1; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= static_cast(*it); + } + + uint64_t q_stride = m * k; + uint64_t r_stride = n * l; + uint64_t inner_stride = m * n; + + std::vector tmp(inner_stride); + const IntT *in_base = operand.typed_data(); + float *q_base = reinterpret_cast(q->untyped_data()); + float *r_base = reinterpret_cast(r->untyped_data()); + + for (uint64_t b = 0; b < batch_items; b++) { + const IntT *in = in_base + b * inner_stride; + for (uint64_t j = 0; j < inner_stride; j++) { + tmp[j] = static_cast(in[j]); + } + single_matrix_qr_cpu_custom_call( + q_base + b * q_stride, r_base + b * r_stride, tmp.data(), m, k, n, + complete); + } + + return ffi::Error::Success(); +} + +#define EXLA_REGISTER_QR_INT_F32(DTYPE, NAME) \ + static ffi::Error NAME##_impl(ffi::Buffer operand, \ + ffi::ResultBuffer q, \ + ffi::ResultBuffer r) { \ + return QrCpuCustomCallIntegerOperandF32ResultsImpl(operand, \ + q, r); \ + } \ + XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \ + ffi::Ffi::Bind() \ + .Arg>() \ + .Ret>() \ + .Ret>()); \ + XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME); + +EXLA_REGISTER_QR_INT_F32(S8, qr_cpu_custom_call_s8) +EXLA_REGISTER_QR_INT_F32(S16, qr_cpu_custom_call_s16) +EXLA_REGISTER_QR_INT_F32(S32, qr_cpu_custom_call_s32) +EXLA_REGISTER_QR_INT_F32(S64, qr_cpu_custom_call_s64) +EXLA_REGISTER_QR_INT_F32(U8, qr_cpu_custom_call_u8) +EXLA_REGISTER_QR_INT_F32(U16, qr_cpu_custom_call_u16) +EXLA_REGISTER_QR_INT_F32(U32, qr_cpu_custom_call_u32) +EXLA_REGISTER_QR_INT_F32(U64, qr_cpu_custom_call_u64) + +#undef EXLA_REGISTER_QR_INT_F32 + +template +ffi::Error EighCpuCustomCallIntegerOperandF32ResultsImpl( + ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + using IntT = ffi::NativeType; + auto operand_dims = operand.dimensions(); + auto eigenvalues_dims = eigenvalues->dimensions(); + auto eigenvectors_dims = eigenvectors->dimensions(); + + uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; + uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; + + uint64_t batch_items = 1; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= static_cast(*it); + } + + uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1]; + uint64_t eigenvectors_stride = m * n; + uint64_t inner_stride = m * n; + + std::vector tmp(inner_stride); + const IntT *in_base = operand.typed_data(); + float *eval_base = reinterpret_cast(eigenvalues->untyped_data()); + float *evec_base = reinterpret_cast(eigenvectors->untyped_data()); + + for (uint64_t b = 0; b < batch_items; b++) { + const IntT *in = in_base + b * inner_stride; + for (uint64_t j = 0; j < inner_stride; j++) { + tmp[j] = static_cast(in[j]); + } + single_matrix_eigh_cpu_custom_call( + eval_base + b * eigenvalues_stride, evec_base + b * eigenvectors_stride, + tmp.data(), m, n); + } + + return ffi::Error::Success(); +} + +#define EXLA_REGISTER_EIGH_INT_F32(DTYPE, NAME) \ + static ffi::Error NAME##_impl(ffi::Buffer operand, \ + ffi::ResultBuffer eigenvalues, \ + ffi::ResultBuffer eigenvectors) { \ + return EighCpuCustomCallIntegerOperandF32ResultsImpl( \ + operand, eigenvalues, eigenvectors); \ + } \ + XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \ + ffi::Ffi::Bind() \ + .Arg>() \ + .Ret>() \ + .Ret>()); \ + XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME); + +EXLA_REGISTER_EIGH_INT_F32(S8, eigh_cpu_custom_call_s8) +EXLA_REGISTER_EIGH_INT_F32(S16, eigh_cpu_custom_call_s16) +EXLA_REGISTER_EIGH_INT_F32(S32, eigh_cpu_custom_call_s32) +EXLA_REGISTER_EIGH_INT_F32(S64, eigh_cpu_custom_call_s64) +EXLA_REGISTER_EIGH_INT_F32(U8, eigh_cpu_custom_call_u8) +EXLA_REGISTER_EIGH_INT_F32(U16, eigh_cpu_custom_call_u16) +EXLA_REGISTER_EIGH_INT_F32(U32, eigh_cpu_custom_call_u32) +EXLA_REGISTER_EIGH_INT_F32(U64, eigh_cpu_custom_call_u64) + +#undef EXLA_REGISTER_EIGH_INT_F32 + +} // namespace + FINE_INIT("Elixir.EXLA.NIF"); diff --git a/exla/c_src/exla_test/custom_calls.cc b/exla/c_src/exla_test/custom_calls.cc new file mode 100644 index 0000000000..e54b095bf1 --- /dev/null +++ b/exla/c_src/exla_test/custom_calls.cc @@ -0,0 +1,15 @@ +// Test-only shared library: registers an alias FFI name that reuses the +// existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so. +#ifndef EXLA_PROD + +#include "xla/ffi/api/api.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; + +extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias", + "Host", qr_cpu_custom_call_f32); + +#endif diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index b558b0fd83..9be5dc13db 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -75,6 +75,13 @@ defmodule EXLA do * `:highest` - Slowest but most accurate. Performs computations in float32 or float64 as applicable + ## Native custom calls (`EXLA.CustomCall`) + + Some `Nx.block/4` tags can be lowered to XLA **custom calls** (StableHLO plus + a registered native handler). Implement the `EXLA.CustomCall` protocol for + your block tag struct; see `EXLA.CustomCall` for the `call/4` contract, + including returning `:skip` to fall back to the block's default Elixir callback. + ## Clients The `EXLA` library uses a client for compiling and executing code. diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex new file mode 100644 index 0000000000..70587f90a0 --- /dev/null +++ b/exla/lib/exla/custom_call.ex @@ -0,0 +1,179 @@ +defprotocol EXLA.CustomCall do + @moduledoc """ + Extension point for lowering selected `Nx.block/4` tags to **XLA custom calls** + (`stablehlo.custom_call` in MLIR), the same style as helpers on + `EXLA.MLIR.Value` such as `qr/3` and `eigh/3`. + + Other blocks (for example gather-based `take` or FFT) are lowered inline in + `EXLA.Defn` and do not use this protocol. + + ## When `EXLA.Defn` calls it + + During compilation with `compiler: EXLA`, when the builder is an MLIR + `EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is + passed here. `EXLA.Defn` invokes: + + * `function_name(tag, outputs_template, input_templates, client)` + * `config(tag, outputs_template, input_templates, client)` + + If `function_name/4` returns `:skip`, EXLA compiles the block's **default + callback** (the anonymous function body) instead of emitting a custom call. + + ## `function_name/4` and `config/4` arguments + + * `struct` — the **tag** passed as the first argument to `Nx.block/4` + (your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`). + + * `out` — the **output template** tuple passed to `Nx.block/4` (expression + metadata for shapes and types, not runtime tensors). + + * `args` — list of **input templates**, in the same order as `inputs` in + `Nx.block/4`. + + * `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate + host-only lowerings). + + ## Return values + + * `function_name/4`: + * **Success** — return the native custom-call target name. + * **`:skip`** — this implementation does not apply (unsupported type, + non-host platform, wrong arity, etc.). The default block implementation + is used instead. + + * `config/4`: + * Return a `map()` to be encoded as `backend_config`. + * Return `nil` to omit `backend_config`. + + ## Dispatch + + The protocol uses `@fallback_to_any true`. Built-in lowerings for known tags + live in `defimpl EXLA.CustomCall, for: Any`. Your application or dependency can + add `defimpl EXLA.CustomCall, for: YourStruct`; that implementation is chosen + whenever the block tag is `%YourStruct{}`, instead of the `Any` fallback. + + ## Native handlers + + Emitting a custom call in MLIR is only half of the story: the **target name** + must be registered with XLA on the relevant platform (typically via a native + library loaded into the process). That registration is **not** configured + through `config :exla, ...`; you load or link the native code by the same + means you would for any other NIF-backed extension. + + ## Example + + defmodule MyApp.CustomQrTag do + defstruct [] + end + + defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do + def function_name(_tag, {%{type: {kind, size}}, _r_expr}, [_input], %{platform: :host}) + when kind != :c and kind in [:f, :bf] and size in [16, 32, 64] do + "my_custom_qr_target" + end + + def function_name(_, _, _, _), do: :skip + + def config(_, _, _, _), do: nil + end + + Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with + `compiler: EXLA`. + """ + + @fallback_to_any true + + @doc """ + Returns the custom-call target name or `:skip`. + """ + def function_name(struct, out, args, client) + + @doc """ + Returns a map encoded into `backend_config`, or `nil`. + """ + def config(struct, out, args, client) +end + +defmodule EXLA.CustomCall.Builtins do + @moduledoc false + + @doc """ + Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.qr/2`, or `:skip`. + + `operand_type` is the input matrix element type. + """ + def qr_cpu_target(operand_type) do + case operand_type do + {:f, 32} -> "qr_cpu_custom_call_f32" + {:f, 64} -> "qr_cpu_custom_call_f64" + {:f, 16} -> "qr_cpu_custom_call_f16" + {:bf, 16} -> "qr_cpu_custom_call_bf16" + {:s, 8} -> "qr_cpu_custom_call_s8" + {:s, 16} -> "qr_cpu_custom_call_s16" + {:s, 32} -> "qr_cpu_custom_call_s32" + {:s, 64} -> "qr_cpu_custom_call_s64" + {:u, 8} -> "qr_cpu_custom_call_u8" + {:u, 16} -> "qr_cpu_custom_call_u16" + {:u, 32} -> "qr_cpu_custom_call_u32" + {:u, 64} -> "qr_cpu_custom_call_u64" + _ -> :skip + end + end + + @doc """ + Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.eigh/2`, or `:skip`. + + `operand_type` is the input matrix element type. + """ + def eigh_cpu_target(operand_type) do + case operand_type do + {:f, 32} -> "eigh_cpu_custom_call_f32" + {:f, 64} -> "eigh_cpu_custom_call_f64" + {:s, 8} -> "eigh_cpu_custom_call_s8" + {:s, 16} -> "eigh_cpu_custom_call_s16" + {:s, 32} -> "eigh_cpu_custom_call_s32" + {:s, 64} -> "eigh_cpu_custom_call_s64" + {:u, 8} -> "eigh_cpu_custom_call_u8" + {:u, 16} -> "eigh_cpu_custom_call_u16" + {:u, 32} -> "eigh_cpu_custom_call_u32" + {:u, 64} -> "eigh_cpu_custom_call_u64" + _ -> :skip + end + end +end + +# Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live +# in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the +# protocol, applications and libraries can define their own +# `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that +# implementation instead of this fallback when the block tag matches (you can +# also target a built-in struct such as `Nx.Block...` from your app if needed). +# +defimpl EXLA.CustomCall, for: Any do + @moduledoc false + + def function_name( + %Nx.Block.LinAlg.QR{}, + {%{type: q_type}, _r_expr}, + [%{type: in_type} | _], + %{platform: :host} + ) + when elem(q_type, 0) != :c and elem(in_type, 0) != :c do + EXLA.CustomCall.Builtins.qr_cpu_target(in_type) + end + + def function_name( + %Nx.Block.LinAlg.Eigh{}, + {%{type: eval_type}, %{type: evec_type}}, + [%{type: in_type} | _], + %{platform: :host} + ) + when elem(eval_type, 0) != :c and elem(evec_type, 0) != :c and + elem(in_type, 0) != :c do + EXLA.CustomCall.Builtins.eigh_cpu_target(in_type) + end + + def function_name(_, _, _, _), do: :skip + + def config(_, _, _, _), do: nil +end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 602b4972c6..5697dd2bb1 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -600,78 +600,8 @@ defmodule EXLA.Defn do {fun_computation(args, expr, type, state), cache} end - defp cached_recur_operator( - :block, - %T{ - data: %Expr{ - args: [ - %Nx.Block.LinAlg.QR{}, - [tensor], - {%{type: {type_kind, _}} = q_expr, r_expr}, - _callback - ] - } - }, - %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, - cache - ) - when type_kind != :c do - # We match only on platform: :host for MLIR, as we want to support - # QR-on-cpu as a custom call only in this case - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - tensor = - if op_type(tensor) != q_expr.type do - to_type(tensor, q_expr.type) - else - tensor - end - - {q, r} = Value.qr(tensor, expr_to_typespec(q_expr), expr_to_typespec(r_expr)) - {[q, r], cache} - end - - defp cached_recur_operator( - :block, - %T{ - data: %Expr{ - args: [ - %Nx.Block.LinAlg.Eigh{}, - [tensor], - {%{type: {evec_type_kind, _}} = eigenvals_expr, - %{type: {eval_type_kind, _}} = eigenvecs_expr}, - _callback - ] - } - }, - %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, - cache - ) - when evec_type_kind != :c and eval_type_kind != :c do - # We match only on platform: :host for MLIR, as we want to support - # eigh-on-cpu as a custom call only in this case - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - # convert to float and ensure that we're either using f32 or f64, because Eigen - # only supports f32 and f64 easily. - out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) - - tensor = - if op_type(tensor) != out_type do - to_type(tensor, out_type) - else - tensor - end - - {eigenvals, eigenvecs} = - Value.eigh( - tensor, - expr_to_typespec(%{eigenvals_expr | type: out_type}), - expr_to_typespec(%{eigenvecs_expr | type: out_type}) - ) - - {[to_type(eigenvals, eigenvals_expr.type), to_type(eigenvecs, eigenvecs_expr.type)], cache} - end + # StableHLO-style lowering (gather, top_k, fft): not the C custom_call path; + # see `EXLA.CustomCall` for blocks that delegate to native CPU kernels. defp cached_recur_operator( :block, @@ -736,15 +666,9 @@ defmodule EXLA.Defn do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() opts = [lengths: fft2_struct.lengths, axes: fft2_struct.axes] + opts = if eps = fft2_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts - opts = - if eps = fft2_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - {fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr, state), cache} + {fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr), cache} end defp cached_recur_operator( @@ -756,15 +680,9 @@ defmodule EXLA.Defn do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() opts = [lengths: ifft2_struct.lengths, axes: ifft2_struct.axes] + opts = if eps = ifft2_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts - opts = - if eps = ifft2_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache} + {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr), cache} end defp cached_recur_operator( @@ -776,19 +694,11 @@ defmodule EXLA.Defn do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() opts = [length: rfft_struct.length, axis: rfft_struct.axis] + opts = if eps = rfft_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts - opts = - if eps = rfft_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - # expr.type is complex; input tensor is real input_type = Nx.Type.to_real(expr.type) - {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state), - cache} + {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr), cache} end defp cached_recur_operator( @@ -800,16 +710,8 @@ defmodule EXLA.Defn do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() opts = [length: irfft_struct.length, axis: irfft_struct.axis] + opts = if eps = irfft_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts - opts = - if eps = irfft_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - # expr.type is real; input tensor is complex. - # pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length). n = irfft_struct.length input_type = Nx.Type.to_complex(expr.type) @@ -819,44 +721,49 @@ defmodule EXLA.Defn do expr.type, div(n, 2) + 1, [tensor, opts], - expr, - state + expr ), cache} end - defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do - [struct, in_args, expr, _callback] = args - %module{} = struct - + # C-backed custom_call blocks (QR, Eigh, …): `EXLA.CustomCall`; else compile default callback. + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [struct, in_args, out, _callback]}}, + %{client: client, builder: %Function{}} = state, + cache + ) do {call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2)) - key = computation_key(module, [struct | call_args]) - {call_body, cache} = - case cache do - %{^key => computation} -> - {computation, cache} + case EXLA.CustomCall.function_name(struct, out, in_args, client) do + :skip -> + default_block_implementation(struct, call_args, out, state, cache) - %{} -> - {computation, cache} = - block_computation( - block_subfunction_description(struct), - call_args, - expr, - state, - cache - ) + function_name -> + config = EXLA.CustomCall.config(struct, out, in_args, client) - {computation, Map.put(cache, key, computation)} - end + backend_config = + case config do + nil -> + nil - if token = get_token(cache) do - typespecs = [Typespec.token() | container_to_typespecs(expr)] - [token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs) - {wrap_tuple_result(result, expr), update_token(cache, token)} - else - typespecs = container_to_typespecs(expr) - result = Value.call(state.builder, call_args, call_body, typespecs) - {wrap_tuple_result(result, expr), cache} + %{} = map -> + map + + other -> + raise ArgumentError, + "EXLA.CustomCall.config/4 must return map() | nil, got: #{inspect(other)}" + end + + out_typespecs = + [out] + |> Composite.flatten_list() + |> Enum.map(&expr_to_typespec/1) + + lowered = + Value.custom_call(call_args, out_typespecs, function_name, backend_config) + |> wrap_tuple_result(out) + + {lowered, cache} end end @@ -998,6 +905,39 @@ defmodule EXLA.Defn do {to_operator(op, args, expr, state), cache} end + defp default_block_implementation(struct, call_args, expr, state, cache) do + %module{} = struct + key = computation_key(module, [struct | call_args]) + + {call_body, cache} = + case cache do + %{^key => computation} -> + {computation, cache} + + %{} -> + {computation, cache} = + block_computation( + block_subfunction_description(struct), + call_args, + expr, + state, + cache + ) + + {computation, Map.put(cache, key, computation)} + end + + if token = get_token(cache) do + typespecs = [Typespec.token() | container_to_typespecs(expr)] + [token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs) + {wrap_tuple_result(result, expr), update_token(cache, token)} + else + typespecs = container_to_typespecs(expr) + result = Value.call(state.builder, call_args, call_body, typespecs) + {wrap_tuple_result(result, expr), cache} + end + end + ## to_operator creation defp to_operator(:constant, [constant], ans, state) do @@ -1289,11 +1229,11 @@ defmodule EXLA.Defn do apply(Value, op, [to_type(arg, type), expr_to_typespec(ans)]) end - defp to_operator(:fft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out, state) + defp to_operator(:fft, [%Value{} | _] = args, out, _state), + do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out) - defp to_operator(:ifft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out, state) + defp to_operator(:ifft, [%Value{} | _] = args, out, _state), + do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out) defp to_operator(:is_nan, [%Value{} = arg], out, _state), do: Value.is_nan(arg, expr_to_typespec(out)) @@ -1618,7 +1558,8 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end - defp fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans, state) do + @doc false + def fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans) do fft_n = opts[:length] pad_n = pad_n || fft_n axis = opts[:axis] @@ -1627,7 +1568,7 @@ defmodule EXLA.Defn do shape = op_shape(tensor) m = elem(shape, axis) - tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type, state) + tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type) last_axis = tuple_size(shape) - 1 @@ -1662,7 +1603,8 @@ defmodule EXLA.Defn do end end - defp fft2(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do + @doc false + def fft2(exla_op, [%Value{} = tensor, opts], %{type: type} = ans) do [l1, l2] = lengths = opts[:lengths] [ax1, ax2] = axes = opts[:axes] output_type = Nx.Type.to_complex(type) @@ -1672,8 +1614,8 @@ defmodule EXLA.Defn do m1 = elem(shape, ax1) m2 = elem(shape, ax2) - tensor = fft_pad_or_slice(tensor, m1, l1, ax1, shape, output_type, state) - tensor = fft_pad_or_slice(tensor, m2, l2, ax2, op_shape(tensor), output_type, state) + tensor = fft_pad_or_slice(tensor, m1, l1, ax1, shape, output_type) + tensor = fft_pad_or_slice(tensor, m2, l2, ax2, op_shape(tensor), output_type) last_axis = tuple_size(shape) - 1 penultimate_axis = last_axis - 1 @@ -1701,7 +1643,7 @@ defmodule EXLA.Defn do end end - defp fft_pad_or_slice(tensor, m, n, axis, shape, output_type, state) do + defp fft_pad_or_slice(%Value{function: builder} = tensor, m, n, axis, shape, output_type) do cond do m == n -> tensor @@ -1726,7 +1668,7 @@ defmodule EXLA.Defn do zero_value = if Nx.Type.complex?(output_type), do: Complex.new(0), else: 0 zero = - Value.constant(state.builder, [zero_value], Typespec.tensor(output_type, {})) + Value.constant(builder, [zero_value], Typespec.tensor(output_type, {})) padding_config = {0, 0, 0} @@ -2251,19 +2193,23 @@ defmodule EXLA.Defn do defp count_up(0, _n), do: [] defp count_up(i, n), do: [n | count_up(i - 1, n + 1)] - defp axes_for_rank(0), do: [] + @doc false + def axes_for_rank(0), do: [] - defp axes_for_rank(rank) do + def axes_for_rank(rank) do Enum.to_list(0..(rank - 1)) end ## Op Helpers - defp op_type(%Value{} = op), do: Value.get_typespec(op).type + @doc false + def op_type(%Value{} = op), do: Value.get_typespec(op).type - defp op_shape(%Value{} = op), do: Value.get_typespec(op).shape + @doc false + def op_shape(%Value{} = op), do: Value.get_typespec(op).shape - defp to_type(%Value{} = op, type) do + @doc false + def to_type(%Value{} = op, type) do typespec = Value.get_typespec(op) if typespec.type == type do @@ -2357,7 +2303,8 @@ defmodule EXLA.Defn do |> Enum.reduce(shape, &Tuple.delete_at(&2, &1)) end - defp expr_to_typespec(expr) do + @doc false + def expr_to_typespec(expr) do Typespec.tensor(expr.type, expr.shape) end diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 9d028ff6dd..e5a6d6e1b5 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -719,70 +719,29 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.return", values, []) end - def eigh(%Value{function: func} = value, eigenvals_typespec, eigenvecs_typespec) do - %{type: op_type} = get_typespec(value) - - operands = [value] - result_types = typespecs_to_mlir_types([eigenvals_typespec, eigenvecs_typespec]) - - call_target_name = - case op_type do - {:f, 32} -> - "eigh_cpu_custom_call_f32" - - {:f, 64} -> - "eigh_cpu_custom_call_f64" - - type -> - # Due to matching on EXLA.Defn, we are sure that the device here is always :host - raise "Eigh decomposition not supported on :host device for type #{inspect(type)}" - end + @doc false + def custom_call( + [%Value{function: func} | _] = operands, + typespecs, + call_target_name, + backend_config \\ nil + ) + when is_binary(call_target_name) and is_list(typespecs) do + result_types = typespecs_to_mlir_types(typespecs) attributes = [ call_target_name: attr_string(call_target_name), api_version: attr_i32(4) ] - [eigenvals, eigenvecs] = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - - {eigenvals, eigenvecs} - end - - def qr(%Value{function: func} = value, q_typespec, r_typespec) do - %{type: op_type} = get_typespec(value) - - operands = [value] - result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) - - call_target_name = - case op_type do - {:f, 32} -> - "qr_cpu_custom_call_f32" - - {:f, 64} -> - "qr_cpu_custom_call_f64" - - {:f, 16} -> - "qr_cpu_custom_call_f16" - - {:bf, 16} -> - "qr_cpu_custom_call_bf16" - - type -> - # Due to matching on EXLA.Defn, we are sure that the device here is always :host - raise "QR decomposition not supported on :host device for type #{inspect(type)}" + attributes = + if is_map(backend_config) do + Keyword.put(attributes, :backend_config, backend_config_to_attr(backend_config)) + else + attributes end - attributes = [ - call_target_name: attr_string(call_target_name), - api_version: attr_i32(4) - ] - - [q, r] = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - - {q, r} + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) end def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do @@ -1088,6 +1047,44 @@ defmodule EXLA.MLIR.Value do "{#{content}}" end + defp backend_config_to_attr(map) when is_map(map) do + map + |> Enum.map(fn {k, v} -> {attr_dict_key(k), backend_config_value_to_attr(v)} end) + |> attr_dict() + end + + defp backend_config_value_to_attr(v) when is_boolean(v), do: attr_boolean(v) + defp backend_config_value_to_attr(v) when is_integer(v), do: attr_i64(v) + defp backend_config_value_to_attr(v) when is_float(v), do: "#{v} : f64" + defp backend_config_value_to_attr(v) when is_binary(v), do: attr_string(v) + + defp backend_config_value_to_attr(v) when is_list(v) do + "[" <> Enum.map_join(v, ", ", &backend_config_value_to_attr/1) <> "]" + end + + defp backend_config_value_to_attr(v) when is_map(v), do: backend_config_to_attr(v) + + defp backend_config_value_to_attr(v) do + raise ArgumentError, + "custom_call backend_config value is not encodable to MLIR DictionaryAttr: #{inspect(v)}" + end + + defp attr_dict_key(key) when is_atom(key), do: Atom.to_string(key) + + defp attr_dict_key(key) when is_binary(key) do + if Regex.match?(~r/^[A-Za-z_][A-Za-z0-9_]*$/, key) do + key + else + raise ArgumentError, + "custom_call backend_config key must match [A-Za-z_][A-Za-z0-9_]*, got: #{inspect(key)}" + end + end + + defp attr_dict_key(key) do + raise ArgumentError, + "custom_call backend_config key must be an atom or string, got: #{inspect(key)}" + end + defp join_list(list) do "[" <> Enum.join(list, ", ") <> "]" end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 2a0a99f1ef..9fab99366c 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -79,6 +79,7 @@ defmodule EXLA.NIF do def get_tpu_client(), do: err!() def get_c_api_client(_device_type), do: err!() def load_pjrt_plugin(_device_type, _library_path), do: err!() + def load_dylib(_path), do: err!() def get_device_count(_client), do: err!() def get_supported_platforms, do: err!() def run_cpu(_executable, _arguments, _device_id, _callback_server_pid), do: err!() diff --git a/exla/mix.exs b/exla/mix.exs index 22c8ee9460..c4631e47a6 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -33,7 +33,7 @@ defmodule EXLA.MixProject do "FINE_INCLUDE_DIR" => Fine.include_dir(), "MIX_BUILD_EMBEDDED" => "#{Mix.Project.config()[:build_embedded]}", "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv, - "EXLA_VERSION" => "#{@version}" + "EXLA_VERSION" => "#{@version}", } end, make_args: make_args diff --git a/exla/test/exla/custom_call_alias_test.exs b/exla/test/exla/custom_call_alias_test.exs new file mode 100644 index 0000000000..fc4fad82bc --- /dev/null +++ b/exla/test/exla/custom_call_alias_test.exs @@ -0,0 +1,98 @@ +defmodule EXLA.CustomCallAliasTest do + use EXLA.Case, async: false + + import Nx.Defn + + alias EXLA.Test.QRAliasBlock + + defmodule BuiltinFun do + import Nx.Defn + + defn qr(t), do: Nx.LinAlg.qr(t) + end + + defmodule Fun do + import Nx.Defn + + alias EXLA.Test.QRAliasBlock + + defn qr_alias_fn(t) do + q_out = Nx.template({3, 3}, {:f, 32}) + r_out = Nx.template({3, 4}, {:f, 32}) + + Nx.block(%QRAliasBlock{}, [t], {q_out, r_out}, fn _, t2 -> + Nx.LinAlg.qr(t2, mode: :reduced) + end) + end + end + + @plugin_relative ~c"test/exla_qr_alias.so" + + defp plugin_path do + :filename.join(:code.priv_dir(:exla), @plugin_relative) + end + + defp mlir_via_jit_apply!(fun, args) when is_function(fun) and is_list(args) do + try do + Nx.Defn.jit_apply(fun, args, + compiler: EXLA, + module_compilation: :to_mlir + ) + catch + :throw, {:mlir_module, ref, used_inputs, output_container} -> + %{ + mlir_module: EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref}), + used_inputs: used_inputs, + output_container: output_container + } + end + end + + defp load_plugin! do + path = List.to_string(plugin_path()) + + unless File.exists?(path) do + flunk(""" + Missing #{path}. Build EXLA with MIX_ENV=test so the alias dylib is compiled \ + (see Makefile target exla_qr_alias.so). + """) + end + + case EXLA.NIF.load_dylib(path) do + :ok -> + :ok + + other -> + flunk("load_dylib(#{path}) expected :ok, got: #{inspect(other)}") + end + end + + test "builtin QR lowering includes qr_cpu_custom_call_f32 in MLIR" do + arg = Nx.iota({3, 4}, type: {:f, 32}) + assert %{mlir_module: mlir} = mlir_via_jit_apply!(&BuiltinFun.qr/1, [arg]) + + assert mlir =~ "@qr_cpu_custom_call_f32(" + refute mlir =~ "qr_cpu_custom_call_f32_exla_alias" + end + + test "QR alias plugin: MLIR uses alias name and not the builtin target string" do + load_plugin!() + + arg = Nx.iota({3, 4}, type: {:f, 32}) + assert %{mlir_module: mlir} = mlir_via_jit_apply!(&Fun.qr_alias_fn/1, [arg]) + + assert mlir =~ "qr_cpu_custom_call_f32_exla_alias" + refute mlir =~ "@qr_cpu_custom_call_f32(" + end + + test "QR alias plugin: JIT result matches builtin QR" do + load_plugin!() + + t = Nx.iota({3, 4}, type: {:f, 32}) + exp = EXLA.jit(fn t -> Nx.LinAlg.qr(t) end).(t) + act = EXLA.jit(&Fun.qr_alias_fn/1).(t) + + assert Nx.all_close(elem(exp, 0), elem(act, 0), atol: 1.0e-4, rtol: 1.0e-4) + assert Nx.all_close(elem(exp, 1), elem(act, 1), atol: 1.0e-4, rtol: 1.0e-4) + end +end diff --git a/exla/test/support/exla_test_qr_alias_block.ex b/exla/test/support/exla_test_qr_alias_block.ex new file mode 100644 index 0000000000..909f72b58f --- /dev/null +++ b/exla/test/support/exla_test_qr_alias_block.ex @@ -0,0 +1,18 @@ +# Test-only block tag + `EXLA.CustomCall` impl used to emit a StableHLO custom_call +# with `call_target_name` `qr_cpu_custom_call_f32_exla_alias` (registered by +# `priv/test/exla_qr_alias.so` when built with `MIX_ENV=test`). +defmodule EXLA.Test.QRAliasBlock do + @moduledoc false + defstruct [] +end + +defimpl EXLA.CustomCall, for: EXLA.Test.QRAliasBlock do + def function_name(_, {%{type: {q_kind, q_size}}, _r_expr}, [_tensor], client) + when q_kind != :c and q_size == 32 and client.platform == :host do + "qr_cpu_custom_call_f32_exla_alias" + end + + def function_name(_, _, _, _), do: :skip + + def config(_, _, _, _), do: nil +end