Skip to content

Introduce Nx.io_callback as a replacement for token-based hooks and EXLA's outfeed #1672

@josevalim

Description

@josevalim

Goal

Replace the current outfeed/token-based mechanism used by hooks (e.g. print_value, hook/3) with a new io_callback mechanism built on top of stablehlo.custom_call. This would eliminate the dependency on stablehlo.outfeed (which lives outside the computation graph) and instead use the same infrastructure that runtime_call already uses, but specialized for side-effect-only operations.

Background: How hooks work today

NX hooks (Nx.Defn.Kernel.hook/3, print_value/2, hook_token/4) allow users to execute side effects during compiled computation. For example:

defn my_fun(x) do
  x = Nx.cos(x)
  x = print_value(x, label: "after cos")  # side effect: prints tensor value
  Nx.sin(x)
end

Under the hood, hooks are pure side effects — they don't produce new values. To ensure they survive NX's tracing (which only retains expressions reachable from the output), the hook must be "attached" to an existing tensor via a token. The expression graph captures this as:

:attach_token(
  :token(%Nx.Defn.Token{hooks: [%{expr: <x>, name: :print_value, callback: fn ...}]}),
  <x>   # the original value passes through unchanged
)

The resulting tensor carries the same value as the input but now has a token dependency that keeps the side effect in the graph.

Current compilation path (EXLA)

When EXLA compiles hooks, it uses the outfeed mechanism:

  1. A stablehlo.create_token initializes a token at the start of the computation
  2. For each hook, the compiler emits a stablehlo.outfeed with:
    • A u16 flag identifying which hook to call
    • The tensor data to send to the host
  3. Tokens are threaded sequentially: each outfeed consumes the previous token and produces a new one, enforcing execution order
  4. A final outfeed(flag=0) signals completion

On the host side, EXLA.Defn.Outfeed spawns a dedicated process that:

  • Calls EXLA.Client.from_outfeed to receive the flag
  • Looks up the hook callback in compiled_hooks[flag]
  • Receives the tensor buffers and invokes the callback
  • Loops until it receives flag=0

Problems with outfeed

The outfeed mechanism has several drawbacks:

  • Lives outside the computation graph: Outfeed operates as a separate channel outside the compiled computation, making it harder to reason about and optimize
  • Limited flexibility: The flag-based dispatch (u16 flags, 1-65535) and fixed outfeed protocol constrain how hooks can be implemented
  • Separate infrastructure: Hooks use outfeed while runtime_call uses stablehlo.custom_call — two completely different mechanisms for conceptually similar host-callback patterns. Unifying them would simplify the codebase and reduce the maintenance surface

Proposed solution: io_callback

Replace hooks' outfeed-based implementation with a new io_callback mechanism. This would be similar to runtime_call but specialized for side effects:

Aspect runtime_call io_callback (proposed)
Returns a value? Yes (new tensor with defined output shape) No (output = input, passthrough)
Kept in graph if unused? Only if output is used Always executes (has_side_effect = true)
StableHLO lowering stablehlo.custom_call stablehlo.custom_call (with has_side_effect = true)
Use case Call host function, get result back Execute side effect (print, log, write to disk)

How runtime_call already works

The existing runtime_call implementation provides the blueprint:

  1. Nx.Defn.Expr creates a :runtime_call node with [tensor_expr, fun, out_template]
  2. EXLA compiler registers the callback function with EXLA.CallbackServer (a GenServer)
  3. MLIR generation emits stablehlo.custom_call targeting exla_runtime_callback, encoding the CallbackServer PID and callback ID as FFI attributes
  4. At runtime, the native bridge sends a message to the GenServer with the arguments, the GenServer calls the function, and sends the result back

For io_callback, the flow would be nearly identical except:

  • The callback does not need to return a value to the device
  • The output tensor is a new tensor with the same value as the input (passthrough semantics)
  • The has_side_effect attribute should be set to true on the custom call to prevent the compiler from eliminating it

How JAX handles this

JAX's jax.experimental.io_callback provides the reference implementation. Key design points:

  • API: io_callback(callback, result_shape_dtypes, *args, ordered=False) — the callback is a Python function, result_shape_dtypes defines expected output shapes
  • Lowering: Produces stablehlo.custom_call with has_side_effect = true
  • Ordering: When ordered=True, tokens are threaded between sequential io_callback calls to enforce execution order; when ordered=False, callbacks may be reordered
  • Guaranteed execution: Unlike pure_callback, io_callback is never eliminated even if its output is unused (because it has side effects)
  • Distinction from pure_callback: pure_callback can be reordered/eliminated by the compiler; io_callback cannot (similar to NX's distinction between runtime_call and hooks)

JAX also has debug.callback for debugging-only side effects that don't return values — this is closest to what NX's hooks do today.

Implementation notes

Nx layer changes

The proposed io_callback would live in Nx (not auto-imported via Nx.Defn.Kernel), similar to how Nx.runtime_call/4 is defined today. In Nx.Defn.Expr, the corresponding operation would:

  • Accept input tensors and a callback function
  • Have no meaningful output — the outputs in Nx.Defn.Expr would be the inputs themselves (passthrough)
  • Use the same dispatch mechanism as runtime_call at the compiler level

EXLA layer changes

The EXLA compiler would:

  • Reuse the EXLA.CallbackServer infrastructure already built for runtime_call
  • Generate stablehlo.custom_call with the exla_runtime_callback target (or a new dedicated target)
  • Set has_side_effect = true to prevent elimination
  • Not wait for a return value from the callback (return the input unchanged)
  • Remove the outfeed-based code path for hooks

Zero-copy passthrough via output_operand_aliases

Since io_callback returns the same tensors it receives, we can avoid copying input buffers to output buffers entirely. StableHLO's custom_call supports an output_operand_aliases attribute that tells the compiler to alias output buffers to input buffers — the same physical memory is used for both.

The generated MLIR would look like:

%result = "stablehlo.custom_call"(%input) {
    call_target_name = "exla_io_callback",
    has_side_effect = true,
    api_version = 4 : i32,
    output_operand_aliases = [
      #stablehlo.output_operand_alias<
        output_tuple_indices = [],
        operand_index = 0,
        operand_tuple_indices = []>
    ],
    backend_config = { ... }
} : (tensor<2x3xf32>) -> tensor<2x3xf32>

With this aliasing, XLA allocates a single buffer shared by both %input and %result (destination-passing still applies; the result destination is just aliased). The host callback should treat this as passthrough and return success without writing output data.

For multiple inputs, add one alias entry per passthrough pair:

%out0, %out1 = "stablehlo.custom_call"(%input0, %input1) {
    call_target_name = "exla_io_callback",
    has_side_effect = true,
    api_version = 4 : i32,
    output_operand_aliases = [
      #stablehlo.output_operand_alias<
        output_tuple_indices = [],
        operand_index = 0,
        operand_tuple_indices = []>,
      #stablehlo.output_operand_alias<
        output_tuple_indices = [],
        operand_index = 1,
        operand_tuple_indices = []>
    ],
    backend_config = { ... }
} : (tensor<2x3xf32>, tensor<4xf32>) -> (tensor<2x3xf32>, tensor<4xf32>)

The Elixir-side EXLA.CallbackServer would handle io_callback requests by invoking the callback and immediately replying with :ok (no return data), unblocking the native thread. (Result buffers are still part of the XLA call frame; they are just aliased by XLA and left untouched by the callback.) On the Elixir side, the output_operand_aliases attribute can be added in EXLA.MLIR.Value:

def io_callback(
      [%Value{function: func} | _] = operands,
      callback_server_pid,
      callback_id
    ) do
  # Output types match input types (passthrough)
  result_types = Enum.map(operands, &Value.type/1)

  {callback_server_pid_words, callback_server_pid_size} =
    term_to_int64_list(callback_server_pid)

  {callback_id_words, callback_id_size} =
    term_to_int64_list(callback_id)

  # Build one alias per operand: output[i] aliases operand[i]
  aliases =
    operands
    |> Enum.with_index()
    |> Enum.map(fn {_, i} ->
      attr_output_operand_alias(
        output_tuple_indices: [],
        operand_index: i,
        operand_tuple_indices: []
      )
    end)

  attributes = [
    call_target_name: attr_string("exla_io_callback"),
    has_side_effect: attr_bool(true),
    api_version: attr_i32(4),
    output_operand_aliases: attr_array(aliases),
    backend_config:
      attr_dict(
        callback_id: attr_array_i64_elements(callback_id_words),
        callback_id_size: attr_ui64(callback_id_size),
        callback_server_pid: attr_array_i64_elements(callback_server_pid_words),
        callback_server_pid_size: attr_ui64(callback_server_pid_size)
      )
  ]

  op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes)
end

Eliminating tokens

A key simplification over the current outfeed approach: explicit token threading is no longer needed.

Today, outfeed requires stablehlo.create_token and sequential token threading because outfeed operations are "fire and forget" on a separate channel — without tokens, the compiler has no way to know they must execute in order.

With io_callback, ordering comes naturally from data dependencies. Since each io_callback passes its input through as output, chaining them creates an implicit execution order:

x = Nx.cos(x)
x = Nx.io_callback(x, &log/1)   # custom_call takes x, returns x (after calling log)
x = Nx.io_callback(x, &save/1)  # depends on x from above — must wait for log to complete
Nx.sin(x)

The second io_callback cannot execute before the first because it needs the first's output as its input. Combined with has_side_effect = true (which prevents the compiler from optimizing away the call even though output equals input), this gives us both guaranteed execution and correct ordering without any token machinery.

This means we can remove stablehlo.create_token, the token threading logic in EXLA.Defn, and the Nx.Defn.Token struct — replacing all of it with straightforward data-dependency chaining through passthrough custom calls.

References

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions