Skip to content

Fast kernel overrides via Nx.Blockemlx_axon / emlx_bumblebee extension libraries #107

@polvalente

Description

@polvalente

EMLX wants to add fused MLX transformer kernels in EMLX.Fast (rms_norm, layer_norm, rope, scaled_dot_product_attention, einsum). These will be defn-callable via Nx.runtime_call/4 and call into mx::fast::* on the C++ side. Once that module lands, they will still not be reachable from Axon or Bumblebee model forward passes — those libraries build their computation graphs from composed Nx ops and have no mechanism to opt into backend-specific fused paths.

Nx.Block provides the extension point. EMLX.Backend already dispatches on Nx.Block.* structs (linalg, cumulative reductions, FFT, etc.) and falls through to apply(fun, [struct | args]) for unknown blocks. Once Axon and Bumblebee define their own Nx.block wrappers with their respective structs for fusable subgraphs, EMLX can intercept them and call the appropriate EMLX.Fast.* kernel instead of running the composed-defn fallback.

This issue tracks adding that interception layer as two opt-in extension libraries.

Motivation

Without this feature, a Bumblebee forward pass on EMLX will still compute SDPA as QKᵀ → softmax → V in three separate MLX kernel dispatches. With the fused kernel, it becomes a single mx::fast::scaled_dot_product_attention call — lower latency, lower memory bandwidth. The same applies to every RMSNorm and RoPE layer in a modern transformer. The gains are real but only accessible to users who manually call EMLX.Fast.* from hand-written defn functions or ad-hoc graph rewrites.

Proposed design

Upstream requirement (blocker)

This issue cannot land until Axon and Bumblebee each define Nx.block calls and structs for their fusable subgraphs. The structs needed are:

Library Block struct Maps to
Axon Axon.Block.RMSNorm EMLX.Fast.rms_norm/3
Axon Axon.Block.LayerNorm EMLX.Fast.layer_norm/4
Axon Axon.Block.RoPE EMLX.Fast.rope/4
Axon Axon.Block.SDPA EMLX.Fast.scaled_dot_product_attention/5
Bumblebee Bumblebee.Block.SDPA EMLX.Fast.scaled_dot_product_attention/5 (+ attention sinks opt-in)

The exact struct names and fields are to be ratified by the respective upstream maintainers. Tracking issues should be opened in:

  • elixir-nx/axon"Expose Nx.Block structs for fusable subgraphs (RMSNorm, LayerNorm, RoPE, SDPA)"
  • elixir-nx/bumblebee"Expose Nx.Block structs for SDPA (and attention sinks)"

EMLX side — EMLX.Block protocol

EMLX.Backend.block/4 already falls through for unknown structs:

# Current catch-all in lib/emlx/backend.ex
def block(struct, _output, args, fun) do
  apply(fun, [struct | args])
end

The proposal is to replace this catch-all with a dispatch through a protocol defined in EMLX:

defprotocol EMLX.Block do
  @doc """
  Override a block with an EMLX-native fused kernel.

  The default `Any` implementation falls through to the composed-defn `fun`,
  preserving the existing behaviour for unrecognised block structs.
  """
  def call(block, output, args, fun)
end

defimpl EMLX.Block, for: Any do
  def call(%SomeStruct{}, ...), do: ...
  def call(struct, _output, args, fun), do: apply(fun, [struct | args])
end

EMLX.Backend can provide the implementation for Any and def block just calls the EMLX.Block protocol for the structs.

Extension libraries implement the protocol for their respective block structs:

# In emlx_axon
defimpl EMLX.Block, for: Axon.Block.RMSNorm do
  def dispatch(%{eps: eps}, _output, [x, weight], _fun) do
    EMLX.Fast.rms_norm(x, weight, eps: eps)
  end
end

Extension libraries

emlx_axon — depends on emlx and axon. Implements EMLX.Block for Axon.Block.* structs, calling EMLX.Fast.* with the correct arguments. No dependency on Bumblebee.

emlx_bumblebee — depends on emlx_axon and bumblebee. Implements EMLX.Block for Bumblebee.Block.* structs. Bumblebee users get both sets of implementations transitively.

This layering means emlx_axon users (who do not use Bumblebee) do not pay Bumblebee's dependency cost, and Bumblebee users get Axon-level kernels included automatically.

Kernel mapping (initial scope)

EMLX.Fast kernel Fuses Notes
rms_norm/3 pow(2) → mean → rsqrt → multiply × 2 Most LLM norm layers
layer_norm/4 mean → var → normalize → scale + bias BERT-family
rope/4 split → rotate → concat → scale LLaMA, Mistral, Qwen, Falcon
scaled_dot_product_attention/5 QKᵀ → (mask) → softmax → V All transformers
einsum General tensor contraction Optional; lower priority

SDPA attention sinks (sinks param on mx::fast::scaled_dot_product_attention) are a follow-up gated on the Bumblebee block landing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions