diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c79cd51ff0..d387a5c08c 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -474,6 +474,9 @@ void sdpa_vector_2pass( blocks = 32; } } + if (int blocks_env = env::get_var("MLX_SDPA_BLOCKS", 0); blocks_env > 0) { + blocks = blocks_env; + } size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);