EMLX wants to add fused MLX transformer kernels in EMLX.Fast (rms_norm, layer_norm, rope, scaled_dot_product_attention, einsum). These will be defn-callable via Nx.runtime_call/4 and call into mx::fast::* on the C++ side. Once that module lands, they will still not be reachable from Axon or Bumblebee model forward passes — those libraries build their computation graphs from composed Nx ops and have no mechanism to opt into backend-specific fused paths.
Nx.Block provides the extension point. EMLX.Backend already dispatches on Nx.Block.* structs (linalg, cumulative reductions, FFT, etc.) and falls through to apply(fun, [struct | args]) for unknown blocks. Once Axon and Bumblebee define their own Nx.block wrappers with their respective structs for fusable subgraphs, EMLX can intercept them and call the appropriate EMLX.Fast.* kernel instead of running the composed-defn fallback.
This issue tracks adding that interception layer as two opt-in extension libraries.
Motivation
Without this feature, a Bumblebee forward pass on EMLX will still compute SDPA as QKᵀ → softmax → V in three separate MLX kernel dispatches. With the fused kernel, it becomes a single mx::fast::scaled_dot_product_attention call — lower latency, lower memory bandwidth. The same applies to every RMSNorm and RoPE layer in a modern transformer. The gains are real but only accessible to users who manually call EMLX.Fast.* from hand-written defn functions or ad-hoc graph rewrites.
Proposed design
Upstream requirement (blocker)
This issue cannot land until Axon and Bumblebee each define Nx.block calls and structs for their fusable subgraphs. The structs needed are:
| Library |
Block struct |
Maps to |
| Axon |
Axon.Block.RMSNorm |
EMLX.Fast.rms_norm/3 |
| Axon |
Axon.Block.LayerNorm |
EMLX.Fast.layer_norm/4 |
| Axon |
Axon.Block.RoPE |
EMLX.Fast.rope/4 |
| Axon |
Axon.Block.SDPA |
EMLX.Fast.scaled_dot_product_attention/5 |
| Bumblebee |
Bumblebee.Block.SDPA |
EMLX.Fast.scaled_dot_product_attention/5 (+ attention sinks opt-in) |
The exact struct names and fields are to be ratified by the respective upstream maintainers. Tracking issues should be opened in:
elixir-nx/axon — "Expose Nx.Block structs for fusable subgraphs (RMSNorm, LayerNorm, RoPE, SDPA)"
elixir-nx/bumblebee — "Expose Nx.Block structs for SDPA (and attention sinks)"
EMLX side — EMLX.Block protocol
EMLX.Backend.block/4 already falls through for unknown structs:
# Current catch-all in lib/emlx/backend.ex
def block(struct, _output, args, fun) do
apply(fun, [struct | args])
end
The proposal is to replace this catch-all with a dispatch through a protocol defined in EMLX:
defprotocol EMLX.Block do
@doc """
Override a block with an EMLX-native fused kernel.
The default `Any` implementation falls through to the composed-defn `fun`,
preserving the existing behaviour for unrecognised block structs.
"""
def call(block, output, args, fun)
end
defimpl EMLX.Block, for: Any do
def call(%SomeStruct{}, ...), do: ...
def call(struct, _output, args, fun), do: apply(fun, [struct | args])
end
EMLX.Backend can provide the implementation for Any and def block just calls the EMLX.Block protocol for the structs.
Extension libraries implement the protocol for their respective block structs:
# In emlx_axon
defimpl EMLX.Block, for: Axon.Block.RMSNorm do
def dispatch(%{eps: eps}, _output, [x, weight], _fun) do
EMLX.Fast.rms_norm(x, weight, eps: eps)
end
end
Extension libraries
emlx_axon — depends on emlx and axon. Implements EMLX.Block for Axon.Block.* structs, calling EMLX.Fast.* with the correct arguments. No dependency on Bumblebee.
emlx_bumblebee — depends on emlx_axon and bumblebee. Implements EMLX.Block for Bumblebee.Block.* structs. Bumblebee users get both sets of implementations transitively.
This layering means emlx_axon users (who do not use Bumblebee) do not pay Bumblebee's dependency cost, and Bumblebee users get Axon-level kernels included automatically.
Kernel mapping (initial scope)
EMLX.Fast kernel |
Fuses |
Notes |
rms_norm/3 |
pow(2) → mean → rsqrt → multiply × 2 |
Most LLM norm layers |
layer_norm/4 |
mean → var → normalize → scale + bias |
BERT-family |
rope/4 |
split → rotate → concat → scale |
LLaMA, Mistral, Qwen, Falcon |
scaled_dot_product_attention/5 |
QKᵀ → (mask) → softmax → V |
All transformers |
einsum |
General tensor contraction |
Optional; lower priority |
SDPA attention sinks (sinks param on mx::fast::scaled_dot_product_attention) are a follow-up gated on the Bumblebee block landing.
EMLX wants to add fused MLX transformer kernels in
EMLX.Fast(rms_norm,layer_norm,rope,scaled_dot_product_attention,einsum). These will bedefn-callable viaNx.runtime_call/4and call intomx::fast::*on the C++ side. Once that module lands, they will still not be reachable from Axon or Bumblebee model forward passes — those libraries build their computation graphs from composed Nx ops and have no mechanism to opt into backend-specific fused paths.Nx.Blockprovides the extension point.EMLX.Backendalready dispatches onNx.Block.*structs (linalg, cumulative reductions, FFT, etc.) and falls through toapply(fun, [struct | args])for unknown blocks. Once Axon and Bumblebee define their own Nx.block wrappers with their respective structs for fusable subgraphs, EMLX can intercept them and call the appropriateEMLX.Fast.*kernel instead of running the composed-defn fallback.This issue tracks adding that interception layer as two opt-in extension libraries.
Motivation
Without this feature, a Bumblebee forward pass on EMLX will still compute SDPA as
QKᵀ → softmax → Vin three separate MLX kernel dispatches. With the fused kernel, it becomes a singlemx::fast::scaled_dot_product_attentioncall — lower latency, lower memory bandwidth. The same applies to every RMSNorm and RoPE layer in a modern transformer. The gains are real but only accessible to users who manually callEMLX.Fast.*from hand-writtendefnfunctions or ad-hoc graph rewrites.Proposed design
Upstream requirement (blocker)
This issue cannot land until Axon and Bumblebee each define Nx.block calls and structs for their fusable subgraphs. The structs needed are:
Axon.Block.RMSNormEMLX.Fast.rms_norm/3Axon.Block.LayerNormEMLX.Fast.layer_norm/4Axon.Block.RoPEEMLX.Fast.rope/4Axon.Block.SDPAEMLX.Fast.scaled_dot_product_attention/5Bumblebee.Block.SDPAEMLX.Fast.scaled_dot_product_attention/5(+ attention sinks opt-in)The exact struct names and fields are to be ratified by the respective upstream maintainers. Tracking issues should be opened in:
elixir-nx/axon— "ExposeNx.Blockstructs for fusable subgraphs (RMSNorm, LayerNorm, RoPE, SDPA)"elixir-nx/bumblebee— "ExposeNx.Blockstructs for SDPA (and attention sinks)"EMLX side —
EMLX.BlockprotocolEMLX.Backend.block/4already falls through for unknown structs:The proposal is to replace this catch-all with a dispatch through a protocol defined in EMLX:
EMLX.Backend can provide the implementation for Any and
def blockjust calls the EMLX.Block protocol for the structs.Extension libraries implement the protocol for their respective block structs:
Extension libraries
emlx_axon— depends onemlxandaxon. ImplementsEMLX.BlockforAxon.Block.*structs, callingEMLX.Fast.*with the correct arguments. No dependency on Bumblebee.emlx_bumblebee— depends onemlx_axonandbumblebee. ImplementsEMLX.BlockforBumblebee.Block.*structs. Bumblebee users get both sets of implementations transitively.This layering means
emlx_axonusers (who do not use Bumblebee) do not pay Bumblebee's dependency cost, and Bumblebee users get Axon-level kernels included automatically.Kernel mapping (initial scope)
EMLX.Fastkernelrms_norm/3pow(2) → mean → rsqrt → multiply × 2layer_norm/4mean → var → normalize → scale + biasrope/4split → rotate → concat → scalescaled_dot_product_attention/5QKᵀ → (mask) → softmax → VeinsumSDPA attention sinks (
sinksparam onmx::fast::scaled_dot_product_attention) are a follow-up gated on the Bumblebee block landing.