Skip to content

feat: add init cache type opts#455

Closed
polvalente wants to merge 2 commits intomainfrom
pv-fix/gqa-cache-and-inference-opts
Closed

feat: add init cache type opts#455
polvalente wants to merge 2 commits intomainfrom
pv-fix/gqa-cache-and-inference-opts

Conversation

@polvalente
Copy link
Copy Markdown
Contributor

This PR adds init_cache types and num_heads for better flexibility of text generation models

@polvalente polvalente self-assigned this May 5, 2026
Comment on lines +378 to +379
output_hidden_states: false,
output_attentions: false,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these new options? We prune these by default, and in order to actually return it in the model output, the user needs to opt-in by configuring global layer options.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop these. I must've missed them in my self-review. My focus was on the cache typing

Comment on lines +444 to +445
if is_nil(cross_hidden_state) do
{Layers.Decoder.get_self_attention_cache(block_cache), %Axon.None{}}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a separate function? If cross attention is not enabled then get_attention_caches already returns none in the second element (or rather a model that compiles to none).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@doc """
Retrieves self-attention and cross-attention caches from a block
cache.
"""
def get_attention_caches(block_cache) do
{Axon.nx(block_cache, & &1.self_attention), Axon.nx(block_cache, & &1.cross_attention)}
end

It always returns cross-attention. Do you mean that cross attention is always Axon.None?

Also, being eager means fewer Axon.nx calls, which reduces the overall graph in Nx.Defn.Evaluator

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that cross attention is always Axon.None?

Yes!

Also, being eager means fewer Axon.nx calls, which reduces the overall graph in Nx.Defn.Evaluator

I'd expect that to be marginal, but I am fine either way. If we go with this, I would rather have put_self_attention_cache, rather than matching on none in put_attention_caches, so it's more aligned on both ends.

# produced by projection layers running in compute precision, so this
# matches what the model will actually return for the cache.
cache_type = output_policy.compute || {:f, 32}
cache = init_cache(spec, batch_size, max_length, inputs, cache_type: cache_type)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we get anything from passing it downstream instead of casting as above?

Copy link
Copy Markdown
Contributor Author

@polvalente polvalente May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means we can fit in a smaller memory footprint, if we allocate bf16 instead of f32 and then downcast to bf16

# Use the compute precision as the cache type. The key/value tensors are
# produced by projection layers running in compute precision, so this
# matches what the model will actually return for the cache.
cache_type = output_policy.compute || {:f, 32}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC the cache value returned from attention layers is cast using :output precision (since it's the layer output). That's why we cast as output here.

I'm not really sure how to model this with mixed precision policy. It may be that we don't want to cast cache at any point, but then we don't have granularity to specify that, since it's a specific input/output.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with using the policy, but originally what I wanted was to introduce a new explicit parameter for the cache type. When I just used the output, at least in my use-case, things ended up using f32 for the cache instead of bf16 like I wanted

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I just used the output, at least in my use-case, things ended up using f32 for the cache instead of bf16 like I wanted

Right, my point is that the problem is bigger. We can initialize the cache as bf16, but when predict returns an updated cache it will be cast to f32, as per the output policy, no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a sidenote, we are using policy from the output node, but technically the attention layers could use a different policy.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just add a cache_type option that we can thread through?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm creating an Axon rewrites library for EMLX and being able to define the cache in a type that won't upcast to f32 is the most important part. I think the rest of this PR can be ignored, but I'm not 100% sure. I'd have to check after removing stuff.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but any layer that returns the cache will upcast it to :f32 if that's the policy :output type, unless I am missing something.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you think we should be using the policy output type instead?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what we do right now, so that cache type going into the model is consistent with the cache type going out of the model. I understand that it is an issue when using compute bf16, and output f32. An actual solution would be to be able to avoid any casting of cache at any point, it's just that I'm not sure how this would look like (this probably means rethinking how the policy and casting works in general, which is also came up in #448 (comment)).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm fine with passing the output type to cache init. That should work as a memory optimisation, but my understanding is that the issue with compute bf16 and output f32 is still the same :)

@polvalente
Copy link
Copy Markdown
Contributor Author

Closing for now as I was able to bypass the need for cache size control in my client library

@polvalente polvalente closed this May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants