Goal
Replace the current outfeed/token-based mechanism used by hooks (e.g. print_value, hook/3) with a new io_callback mechanism built on top of stablehlo.custom_call. This would eliminate the dependency on stablehlo.outfeed (which lives outside the computation graph) and instead use the same infrastructure that runtime_call already uses, but specialized for side-effect-only operations.
Background: How hooks work today
NX hooks (Nx.Defn.Kernel.hook/3, print_value/2, hook_token/4) allow users to execute side effects during compiled computation. For example:
defn my_fun(x) do
x = Nx.cos(x)
x = print_value(x, label: "after cos") # side effect: prints tensor value
Nx.sin(x)
end
Under the hood, hooks are pure side effects — they don't produce new values. To ensure they survive NX's tracing (which only retains expressions reachable from the output), the hook must be "attached" to an existing tensor via a token. The expression graph captures this as:
:attach_token(
:token(%Nx.Defn.Token{hooks: [%{expr: <x>, name: :print_value, callback: fn ...}]}),
<x> # the original value passes through unchanged
)
The resulting tensor carries the same value as the input but now has a token dependency that keeps the side effect in the graph.
Current compilation path (EXLA)
When EXLA compiles hooks, it uses the outfeed mechanism:
- A
stablehlo.create_token initializes a token at the start of the computation
- For each hook, the compiler emits a
stablehlo.outfeed with:
- A u16 flag identifying which hook to call
- The tensor data to send to the host
- Tokens are threaded sequentially: each outfeed consumes the previous token and produces a new one, enforcing execution order
- A final
outfeed(flag=0) signals completion
On the host side, EXLA.Defn.Outfeed spawns a dedicated process that:
- Calls
EXLA.Client.from_outfeed to receive the flag
- Looks up the hook callback in
compiled_hooks[flag]
- Receives the tensor buffers and invokes the callback
- Loops until it receives flag=0
Problems with outfeed
The outfeed mechanism has several drawbacks:
- Lives outside the computation graph: Outfeed operates as a separate channel outside the compiled computation, making it harder to reason about and optimize
- Limited flexibility: The flag-based dispatch (u16 flags, 1-65535) and fixed outfeed protocol constrain how hooks can be implemented
- Separate infrastructure: Hooks use outfeed while
runtime_call uses stablehlo.custom_call — two completely different mechanisms for conceptually similar host-callback patterns. Unifying them would simplify the codebase and reduce the maintenance surface
Proposed solution: io_callback
Replace hooks' outfeed-based implementation with a new io_callback mechanism. This would be similar to runtime_call but specialized for side effects:
| Aspect |
runtime_call |
io_callback (proposed) |
| Returns a value? |
Yes (new tensor with defined output shape) |
No (output = input, passthrough) |
| Kept in graph if unused? |
Only if output is used |
Always executes (has_side_effect = true) |
| StableHLO lowering |
stablehlo.custom_call |
stablehlo.custom_call (with has_side_effect = true) |
| Use case |
Call host function, get result back |
Execute side effect (print, log, write to disk) |
How runtime_call already works
The existing runtime_call implementation provides the blueprint:
- Nx.Defn.Expr creates a
:runtime_call node with [tensor_expr, fun, out_template]
- EXLA compiler registers the callback function with
EXLA.CallbackServer (a GenServer)
- MLIR generation emits
stablehlo.custom_call targeting exla_runtime_callback, encoding the CallbackServer PID and callback ID as FFI attributes
- At runtime, the native bridge sends a message to the GenServer with the arguments, the GenServer calls the function, and sends the result back
For io_callback, the flow would be nearly identical except:
- The callback does not need to return a value to the device
- The output tensor is a new tensor with the same value as the input (passthrough semantics)
- The
has_side_effect attribute should be set to true on the custom call to prevent the compiler from eliminating it
How JAX handles this
JAX's jax.experimental.io_callback provides the reference implementation. Key design points:
- API:
io_callback(callback, result_shape_dtypes, *args, ordered=False) — the callback is a Python function, result_shape_dtypes defines expected output shapes
- Lowering: Produces
stablehlo.custom_call with has_side_effect = true
- Ordering: When
ordered=True, tokens are threaded between sequential io_callback calls to enforce execution order; when ordered=False, callbacks may be reordered
- Guaranteed execution: Unlike
pure_callback, io_callback is never eliminated even if its output is unused (because it has side effects)
- Distinction from
pure_callback: pure_callback can be reordered/eliminated by the compiler; io_callback cannot (similar to NX's distinction between runtime_call and hooks)
JAX also has debug.callback for debugging-only side effects that don't return values — this is closest to what NX's hooks do today.
Implementation notes
Nx layer changes
The proposed io_callback would live in Nx (not auto-imported via Nx.Defn.Kernel), similar to how Nx.runtime_call/4 is defined today. In Nx.Defn.Expr, the corresponding operation would:
- Accept input tensors and a callback function
- Have no meaningful output — the outputs in
Nx.Defn.Expr would be the inputs themselves (passthrough)
- Use the same dispatch mechanism as
runtime_call at the compiler level
EXLA layer changes
The EXLA compiler would:
- Reuse the
EXLA.CallbackServer infrastructure already built for runtime_call
- Generate
stablehlo.custom_call with the exla_runtime_callback target (or a new dedicated target)
- Set
has_side_effect = true to prevent elimination
- Not wait for a return value from the callback (return the input unchanged)
- Remove the outfeed-based code path for hooks
Zero-copy passthrough via output_operand_aliases
Since io_callback returns the same tensors it receives, we can avoid copying input buffers to output buffers entirely. StableHLO's custom_call supports an output_operand_aliases attribute that tells the compiler to alias output buffers to input buffers — the same physical memory is used for both.
The generated MLIR would look like:
%result = "stablehlo.custom_call"(%input) {
call_target_name = "exla_io_callback",
has_side_effect = true,
api_version = 4 : i32,
output_operand_aliases = [
#stablehlo.output_operand_alias<
output_tuple_indices = [],
operand_index = 0,
operand_tuple_indices = []>
],
backend_config = { ... }
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
With this aliasing, XLA allocates a single buffer shared by both %input and %result (destination-passing still applies; the result destination is just aliased). The host callback should treat this as passthrough and return success without writing output data.
For multiple inputs, add one alias entry per passthrough pair:
%out0, %out1 = "stablehlo.custom_call"(%input0, %input1) {
call_target_name = "exla_io_callback",
has_side_effect = true,
api_version = 4 : i32,
output_operand_aliases = [
#stablehlo.output_operand_alias<
output_tuple_indices = [],
operand_index = 0,
operand_tuple_indices = []>,
#stablehlo.output_operand_alias<
output_tuple_indices = [],
operand_index = 1,
operand_tuple_indices = []>
],
backend_config = { ... }
} : (tensor<2x3xf32>, tensor<4xf32>) -> (tensor<2x3xf32>, tensor<4xf32>)
The Elixir-side EXLA.CallbackServer would handle io_callback requests by invoking the callback and immediately replying with :ok (no return data), unblocking the native thread. (Result buffers are still part of the XLA call frame; they are just aliased by XLA and left untouched by the callback.) On the Elixir side, the output_operand_aliases attribute can be added in EXLA.MLIR.Value:
def io_callback(
[%Value{function: func} | _] = operands,
callback_server_pid,
callback_id
) do
# Output types match input types (passthrough)
result_types = Enum.map(operands, &Value.type/1)
{callback_server_pid_words, callback_server_pid_size} =
term_to_int64_list(callback_server_pid)
{callback_id_words, callback_id_size} =
term_to_int64_list(callback_id)
# Build one alias per operand: output[i] aliases operand[i]
aliases =
operands
|> Enum.with_index()
|> Enum.map(fn {_, i} ->
attr_output_operand_alias(
output_tuple_indices: [],
operand_index: i,
operand_tuple_indices: []
)
end)
attributes = [
call_target_name: attr_string("exla_io_callback"),
has_side_effect: attr_bool(true),
api_version: attr_i32(4),
output_operand_aliases: attr_array(aliases),
backend_config:
attr_dict(
callback_id: attr_array_i64_elements(callback_id_words),
callback_id_size: attr_ui64(callback_id_size),
callback_server_pid: attr_array_i64_elements(callback_server_pid_words),
callback_server_pid_size: attr_ui64(callback_server_pid_size)
)
]
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes)
end
Eliminating tokens
A key simplification over the current outfeed approach: explicit token threading is no longer needed.
Today, outfeed requires stablehlo.create_token and sequential token threading because outfeed operations are "fire and forget" on a separate channel — without tokens, the compiler has no way to know they must execute in order.
With io_callback, ordering comes naturally from data dependencies. Since each io_callback passes its input through as output, chaining them creates an implicit execution order:
x = Nx.cos(x)
x = Nx.io_callback(x, &log/1) # custom_call takes x, returns x (after calling log)
x = Nx.io_callback(x, &save/1) # depends on x from above — must wait for log to complete
Nx.sin(x)
The second io_callback cannot execute before the first because it needs the first's output as its input. Combined with has_side_effect = true (which prevents the compiler from optimizing away the call even though output equals input), this gives us both guaranteed execution and correct ordering without any token machinery.
This means we can remove stablehlo.create_token, the token threading logic in EXLA.Defn, and the Nx.Defn.Token struct — replacing all of it with straightforward data-dependency chaining through passthrough custom calls.
References
Goal
Replace the current outfeed/token-based mechanism used by hooks (e.g.
print_value,hook/3) with a newio_callbackmechanism built on top ofstablehlo.custom_call. This would eliminate the dependency onstablehlo.outfeed(which lives outside the computation graph) and instead use the same infrastructure thatruntime_callalready uses, but specialized for side-effect-only operations.Background: How hooks work today
NX hooks (
Nx.Defn.Kernel.hook/3,print_value/2,hook_token/4) allow users to execute side effects during compiled computation. For example:Under the hood, hooks are pure side effects — they don't produce new values. To ensure they survive NX's tracing (which only retains expressions reachable from the output), the hook must be "attached" to an existing tensor via a token. The expression graph captures this as:
The resulting tensor carries the same value as the input but now has a token dependency that keeps the side effect in the graph.
Current compilation path (EXLA)
When EXLA compiles hooks, it uses the outfeed mechanism:
stablehlo.create_tokeninitializes a token at the start of the computationstablehlo.outfeedwith:outfeed(flag=0)signals completionOn the host side,
EXLA.Defn.Outfeedspawns a dedicated process that:EXLA.Client.from_outfeedto receive the flagcompiled_hooks[flag]Problems with outfeed
The outfeed mechanism has several drawbacks:
runtime_callusesstablehlo.custom_call— two completely different mechanisms for conceptually similar host-callback patterns. Unifying them would simplify the codebase and reduce the maintenance surfaceProposed solution:
io_callbackReplace hooks' outfeed-based implementation with a new
io_callbackmechanism. This would be similar toruntime_callbut specialized for side effects:runtime_callio_callback(proposed)has_side_effect = true)stablehlo.custom_callstablehlo.custom_call(withhas_side_effect = true)How
runtime_callalready worksThe existing
runtime_callimplementation provides the blueprint::runtime_callnode with[tensor_expr, fun, out_template]EXLA.CallbackServer(a GenServer)stablehlo.custom_calltargetingexla_runtime_callback, encoding the CallbackServer PID and callback ID as FFI attributesFor
io_callback, the flow would be nearly identical except:has_side_effectattribute should be set totrueon the custom call to prevent the compiler from eliminating itHow JAX handles this
JAX's
jax.experimental.io_callbackprovides the reference implementation. Key design points:io_callback(callback, result_shape_dtypes, *args, ordered=False)— the callback is a Python function,result_shape_dtypesdefines expected output shapesstablehlo.custom_callwithhas_side_effect = trueordered=True, tokens are threaded between sequentialio_callbackcalls to enforce execution order; whenordered=False, callbacks may be reorderedpure_callback,io_callbackis never eliminated even if its output is unused (because it has side effects)pure_callback:pure_callbackcan be reordered/eliminated by the compiler;io_callbackcannot (similar to NX's distinction betweenruntime_calland hooks)JAX also has
debug.callbackfor debugging-only side effects that don't return values — this is closest to what NX's hooks do today.Implementation notes
Nx layer changes
The proposed
io_callbackwould live inNx(not auto-imported viaNx.Defn.Kernel), similar to howNx.runtime_call/4is defined today. InNx.Defn.Expr, the corresponding operation would:Nx.Defn.Exprwould be the inputs themselves (passthrough)runtime_callat the compiler levelEXLA layer changes
The EXLA compiler would:
EXLA.CallbackServerinfrastructure already built forruntime_callstablehlo.custom_callwith theexla_runtime_callbacktarget (or a new dedicated target)has_side_effect = trueto prevent eliminationZero-copy passthrough via
output_operand_aliasesSince
io_callbackreturns the same tensors it receives, we can avoid copying input buffers to output buffers entirely. StableHLO'scustom_callsupports anoutput_operand_aliasesattribute that tells the compiler to alias output buffers to input buffers — the same physical memory is used for both.The generated MLIR would look like:
With this aliasing, XLA allocates a single buffer shared by both
%inputand%result(destination-passing still applies; the result destination is just aliased). The host callback should treat this as passthrough and return success without writing output data.For multiple inputs, add one alias entry per passthrough pair:
The Elixir-side
EXLA.CallbackServerwould handleio_callbackrequests by invoking the callback and immediately replying with:ok(no return data), unblocking the native thread. (Result buffers are still part of the XLA call frame; they are just aliased by XLA and left untouched by the callback.) On the Elixir side, theoutput_operand_aliasesattribute can be added inEXLA.MLIR.Value:Eliminating tokens
A key simplification over the current outfeed approach: explicit token threading is no longer needed.
Today, outfeed requires
stablehlo.create_tokenand sequential token threading because outfeed operations are "fire and forget" on a separate channel — without tokens, the compiler has no way to know they must execute in order.With
io_callback, ordering comes naturally from data dependencies. Since eachio_callbackpasses its input through as output, chaining them creates an implicit execution order:The second
io_callbackcannot execute before the first because it needs the first's output as its input. Combined withhas_side_effect = true(which prevents the compiler from optimizing away the call even though output equals input), this gives us both guaranteed execution and correct ordering without any token machinery.This means we can remove
stablehlo.create_token, the token threading logic inEXLA.Defn, and theNx.Defn.Tokenstruct — replacing all of it with straightforward data-dependency chaining through passthrough custom calls.References
io_callbackdocumentationcustom_callspec (seeoutput_operand_aliasesattribute)nx/lib/nx/defn/expr.ex,exla/lib/exla/defn.ex,exla/lib/exla/defn/outfeed.ex