Skip to content
18 changes: 17 additions & 1 deletion exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ EXLA_LIB_DIR = $(PRIV_DIR)/xla_extension/lib
XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LIB)
EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)

.DEFAULT_GOAL := $(EXLA_SO)

# Build flags
#
# Note that XLA requires c++17, Fine as well
Expand Down Expand Up @@ -86,7 +88,21 @@ else
LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib'
endif

$(EXLA_SO): $(EXLA_CACHE_SO)
# Optional test dylib: registers qr_cpu_custom_call_f32_exla_alias -> same
# handler as qr_cpu_custom_call_f32. Built only when MIX_ENV=test.
TEST_PLUGIN_CC = c_src/exla_test/custom_calls.cc
TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias.so

$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(dir $@)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)

EXLA_SO_DEPS = $(EXLA_CACHE_SO)
ifeq ($(MIX_ENV),test)
EXLA_SO_DEPS += $(TEST_PLUGIN_SO)
endif

$(EXLA_SO): $(EXLA_SO_DEPS)
@ mkdir -p $(PRIV_DIR)
@ mkdir -p $(PRIV_DIR)/xla_extension
@ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \
Expand Down
162 changes: 162 additions & 0 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <dlfcn.h>

#include <cstring>
#include <fine.hpp>
#include <stdexcept>
Expand Down Expand Up @@ -29,6 +31,12 @@
#include "xla/tsl/platform/statusor.h"
#include "llvm/Support/ThreadPool.h"

#include <vector>

#include "xla/extension/custom_calls/eigh.h"
#include "xla/extension/custom_calls/qr.h"
#include "xla/ffi/ffi_api.h"

namespace exla {

using callback_bridge::Pending;
Expand Down Expand Up @@ -535,6 +543,19 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type,

FINE_NIF(load_pjrt_plugin, 0);

// Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run.
fine::Ok<> load_dylib(ErlNifEnv *env, std::string path) {
void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle == nullptr) {
const char *err = dlerror();
throw std::invalid_argument(err ? err : "dlopen failed");
}
(void)handle;
return fine::Ok();
}

FINE_NIF(load_dylib, 0);

int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client) {
return client->client()->device_count();
}
Expand Down Expand Up @@ -715,4 +736,145 @@ FINE_NIF(write_to_pointer, 0);

} // namespace exla

// Host QR custom calls: integer operands with f32 Q/R (see Nx.Type.to_floating/1
// for integer matrices). Handlers live in libexla alongside the NIFs.
namespace {

namespace ffi = xla::ffi;

template <ffi::DataType kIntDtype>
ffi::Error QrCpuCustomCallIntegerOperandF32ResultsImpl(
ffi::Buffer<kIntDtype> operand, ffi::ResultBuffer<ffi::F32> q,
ffi::ResultBuffer<ffi::F32> r) {
using IntT = ffi::NativeType<kIntDtype>;
auto operand_dims = operand.dimensions();
auto q_dims = q->dimensions();
auto r_dims = r->dimensions();

uint64_t m = q_dims[q_dims.size() - 2];
uint64_t k = q_dims[q_dims.size() - 1];
uint64_t n = r_dims[r_dims.size() - 1];
uint64_t l = r_dims[r_dims.size() - 2];

bool complete = l == m;

uint64_t batch_items = 1;
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
batch_items *= static_cast<uint64_t>(*it);
}

uint64_t q_stride = m * k;
uint64_t r_stride = n * l;
uint64_t inner_stride = m * n;

std::vector<float> tmp(inner_stride);
const IntT *in_base = operand.typed_data();
float *q_base = reinterpret_cast<float *>(q->untyped_data());
float *r_base = reinterpret_cast<float *>(r->untyped_data());

for (uint64_t b = 0; b < batch_items; b++) {
const IntT *in = in_base + b * inner_stride;
for (uint64_t j = 0; j < inner_stride; j++) {
tmp[j] = static_cast<float>(in[j]);
}
single_matrix_qr_cpu_custom_call<float>(
q_base + b * q_stride, r_base + b * r_stride, tmp.data(), m, k, n,
complete);
}

return ffi::Error::Success();
}

#define EXLA_REGISTER_QR_INT_F32(DTYPE, NAME) \
static ffi::Error NAME##_impl(ffi::Buffer<ffi::DTYPE> operand, \
ffi::ResultBuffer<ffi::F32> q, \
ffi::ResultBuffer<ffi::F32> r) { \
return QrCpuCustomCallIntegerOperandF32ResultsImpl<ffi::DTYPE>(operand, \
q, r); \
} \
XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \
ffi::Ffi::Bind() \
.Arg<ffi::Buffer<ffi::DTYPE>>() \
.Ret<ffi::Buffer<ffi::F32>>() \
.Ret<ffi::Buffer<ffi::F32>>()); \
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME);

EXLA_REGISTER_QR_INT_F32(S8, qr_cpu_custom_call_s8)
EXLA_REGISTER_QR_INT_F32(S16, qr_cpu_custom_call_s16)
EXLA_REGISTER_QR_INT_F32(S32, qr_cpu_custom_call_s32)
EXLA_REGISTER_QR_INT_F32(S64, qr_cpu_custom_call_s64)
EXLA_REGISTER_QR_INT_F32(U8, qr_cpu_custom_call_u8)
EXLA_REGISTER_QR_INT_F32(U16, qr_cpu_custom_call_u16)
EXLA_REGISTER_QR_INT_F32(U32, qr_cpu_custom_call_u32)
EXLA_REGISTER_QR_INT_F32(U64, qr_cpu_custom_call_u64)

#undef EXLA_REGISTER_QR_INT_F32

template <ffi::DataType kIntDtype>
ffi::Error EighCpuCustomCallIntegerOperandF32ResultsImpl(
ffi::Buffer<kIntDtype> operand,
ffi::ResultBuffer<ffi::F32> eigenvalues,
ffi::ResultBuffer<ffi::F32> eigenvectors) {
using IntT = ffi::NativeType<kIntDtype>;
auto operand_dims = operand.dimensions();
auto eigenvalues_dims = eigenvalues->dimensions();
auto eigenvectors_dims = eigenvectors->dimensions();

uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];

uint64_t batch_items = 1;
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
batch_items *= static_cast<uint64_t>(*it);
}

uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
uint64_t eigenvectors_stride = m * n;
uint64_t inner_stride = m * n;

std::vector<float> tmp(inner_stride);
const IntT *in_base = operand.typed_data();
float *eval_base = reinterpret_cast<float *>(eigenvalues->untyped_data());
float *evec_base = reinterpret_cast<float *>(eigenvectors->untyped_data());

for (uint64_t b = 0; b < batch_items; b++) {
const IntT *in = in_base + b * inner_stride;
for (uint64_t j = 0; j < inner_stride; j++) {
tmp[j] = static_cast<float>(in[j]);
}
single_matrix_eigh_cpu_custom_call<float>(
eval_base + b * eigenvalues_stride, evec_base + b * eigenvectors_stride,
tmp.data(), m, n);
}

return ffi::Error::Success();
}

#define EXLA_REGISTER_EIGH_INT_F32(DTYPE, NAME) \
static ffi::Error NAME##_impl(ffi::Buffer<ffi::DTYPE> operand, \
ffi::ResultBuffer<ffi::F32> eigenvalues, \
ffi::ResultBuffer<ffi::F32> eigenvectors) { \
return EighCpuCustomCallIntegerOperandF32ResultsImpl<ffi::DTYPE>( \
operand, eigenvalues, eigenvectors); \
} \
XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \
ffi::Ffi::Bind() \
.Arg<ffi::Buffer<ffi::DTYPE>>() \
.Ret<ffi::Buffer<ffi::F32>>() \
.Ret<ffi::Buffer<ffi::F32>>()); \
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME);

EXLA_REGISTER_EIGH_INT_F32(S8, eigh_cpu_custom_call_s8)
EXLA_REGISTER_EIGH_INT_F32(S16, eigh_cpu_custom_call_s16)
EXLA_REGISTER_EIGH_INT_F32(S32, eigh_cpu_custom_call_s32)
EXLA_REGISTER_EIGH_INT_F32(S64, eigh_cpu_custom_call_s64)
EXLA_REGISTER_EIGH_INT_F32(U8, eigh_cpu_custom_call_u8)
EXLA_REGISTER_EIGH_INT_F32(U16, eigh_cpu_custom_call_u16)
EXLA_REGISTER_EIGH_INT_F32(U32, eigh_cpu_custom_call_u32)
EXLA_REGISTER_EIGH_INT_F32(U64, eigh_cpu_custom_call_u64)

#undef EXLA_REGISTER_EIGH_INT_F32

} // namespace

FINE_INIT("Elixir.EXLA.NIF");
15 changes: 15 additions & 0 deletions exla/c_src/exla_test/custom_calls.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Test-only shared library: registers an alias FFI name that reuses the
// existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so.
#ifndef EXLA_PROD

#include "xla/ffi/api/api.h"
#include "xla/ffi/ffi_api.h"

namespace ffi = xla::ffi;

extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame);

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias",
"Host", qr_cpu_custom_call_f32);

#endif
7 changes: 7 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ defmodule EXLA do
* `:highest` - Slowest but most accurate. Performs computations in float32
or float64 as applicable

## Native custom calls (`EXLA.CustomCall`)

Some `Nx.block/4` tags can be lowered to XLA **custom calls** (StableHLO plus
a registered native handler). Implement the `EXLA.CustomCall` protocol for
your block tag struct; see `EXLA.CustomCall` for the `call/4` contract,
including returning `:skip` to fall back to the block's default Elixir callback.

## Clients

The `EXLA` library uses a client for compiling and executing code.
Expand Down
Loading
Loading