From 2ddf8abdf1192288bbdad925307b0b0a1a407051 Mon Sep 17 00:00:00 2001 From: Sagar Chapara Date: Wed, 15 Apr 2026 15:05:13 +0000 Subject: [PATCH] Fix transformer sharding and cross-attention flash block sizes --- src/maxdiffusion/models/attention_flax.py | 92 ++++++++++++------- .../models/ltx2/attention_ltx2.py | 8 +- .../wan/transformers/transformer_wan.py | 6 +- src/maxdiffusion/tests/attention_test.py | 56 ++++++++++- 4 files changed, 120 insertions(+), 42 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 8c96a299b..9783d6480 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -190,6 +190,49 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): return tensor, kv_size, seq_len +def _flash_sequence_length(tensor: Array) -> int: + if tensor.ndim == 3: + return tensor.shape[1] + if tensor.ndim == 4: + return tensor.shape[2] + raise ValueError(f"Flash attention expects rank-3 or rank-4 inputs, got rank {tensor.ndim}.") + + +def _select_flash_block_sizes( + query: Array, + key: Array, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype, + attention_kernel: str, +) -> BlockSizes: + query_seq_len = _flash_sequence_length(query) + key_seq_len = _flash_sequence_length(key) + + q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + if key_seq_len != query_seq_len: + kv_max_block_size = ((key_seq_len + 127) // 128) * 128 + else: + kv_max_block_size = q_max_block_size + + # Keep configured block sizes for self-attention, but let + # cross-attention derive safe KV-aware sizes when q_len != kv_len. + if flash_block_sizes and key_seq_len == query_seq_len: + return flash_block_sizes + + block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size + return splash_attention_kernel.BlockSizes( + block_q=block_size_q, + block_kv_compute=min(kv_max_block_size, key_seq_len), + block_kv=min(kv_max_block_size, key_seq_len), + block_q_dkv=block_size_q, + block_kv_dkv=min(kv_max_block_size, key_seq_len), + block_kv_dkv_compute=min(kv_max_block_size, query_seq_len), + block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q, + block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query_seq_len), + use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, + ) + + def convert_to_tokamax_splash_config( block_sizes: BlockSizes, q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, @@ -244,28 +287,7 @@ def _tpu_flash_attention( ) -> jax.Array: """TPU Flash Attention""" - q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - # This is the case for cross-attn. - if key.shape[1] != query.shape[1]: - kv_max_block_size = ((key.shape[1] + 127) // 128) * 128 - else: - kv_max_block_size = q_max_block_size - # ensure that for cross attention we override the block sizes. - if flash_block_sizes and key.shape[1] == query.shape[1]: - block_sizes = flash_block_sizes - else: - block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size - block_sizes = splash_attention_kernel.BlockSizes( - block_q=block_size_q, - block_kv_compute=min(kv_max_block_size, key.shape[2]), - block_kv=min(kv_max_block_size, key.shape[2]), - block_q_dkv=block_size_q, - block_kv_dkv=min(kv_max_block_size, key.shape[2]), - block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]), - block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q, - block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), - use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, - ) + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) num_context_shards = mesh.shape["context"] query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) key, _ = _reshape_data_for_flash(key, heads, num_context_shards) @@ -717,8 +739,8 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) self.act = get_activation(activation_fn) self.net_2 = nnx.Linear( @@ -729,8 +751,8 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) def __call__(self, hidden_states: Array) -> Array: @@ -979,7 +1001,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -993,7 +1015,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1007,7 +1029,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1021,7 +1043,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("heads",), + ("embed",), ), ) @@ -1333,11 +1355,13 @@ def setup(self): precision=self.precision, ) + proj_attn_kernel_axes = ("heads", "embed") + self.proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="i_proj", @@ -1346,9 +1370,9 @@ def setup(self): self.encoder_proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="e_proj", diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py index 7441a2038..398b0f473 100644 --- a/src/maxdiffusion/models/ltx2/attention_ltx2.py +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -359,13 +359,13 @@ def __init__( # 1. Define Partitioned Initializers (Logical Axes) # Q, K, V kernels: [in_features (embed), out_features (heads)] qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")) - # Q, K, V biases: [out_features (embed)] - qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) + # Q, K, V biases: [out_features (heads)] + qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) # Out kernel: [in_features (heads), out_features (embed)] out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")) - # Out bias: [out_features (heads)] - out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) + # Out bias: [out_features (embed)] + out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) # Norm scales norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index edb450454..62512693c 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -193,11 +193,11 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "mlp", "embed", + "mlp", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -249,8 +249,8 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "embed", "mlp", + "embed", ), ), ) diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index f345ab113..910b479a6 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -20,7 +20,8 @@ import jax from jax.sharding import Mesh import jax.numpy as jnp -from ..models.attention_flax import FlaxAttention +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel +from ..models.attention_flax import FlaxAttention, _select_flash_block_sizes from .. import max_utils from .. import pyconfig @@ -92,6 +93,59 @@ def test_splash_attention(self): assert diff_norm < 1.0 + def test_cross_attention_overrides_configured_flash_block_sizes(self): + query = jnp.zeros((1, 1024, 256), dtype=jnp.bfloat16) + key = jnp.zeros((1, 257, 256), dtype=jnp.bfloat16) + configured_block_sizes = splash_attention_kernel.BlockSizes( + block_q=384, + block_kv_compute=192, + block_kv=320, + block_q_dkv=256, + block_kv_dkv=288, + block_kv_dkv_compute=160, + block_q_dq=128, + block_kv_dq=96, + use_fused_bwd_kernel=False, + ) + + block_sizes = _select_flash_block_sizes( + query=query, + key=key, + flash_block_sizes=configured_block_sizes, + dtype=jnp.bfloat16, + attention_kernel="flash", + ) + + assert block_sizes.block_q == configured_block_sizes.block_q + assert block_sizes.block_q_dkv == configured_block_sizes.block_q + assert block_sizes.block_q_dq == configured_block_sizes.block_q + assert block_sizes.block_kv_compute == 257 + assert block_sizes.block_kv == 257 + assert block_sizes.block_kv_dkv == 257 + assert block_sizes.block_kv_dkv_compute == 384 + assert block_sizes.block_kv_dq == 384 + + def test_default_flash_block_sizes_use_sequence_axis_for_3d_inputs(self): + query = jnp.zeros((1, 128, 4096), dtype=jnp.bfloat16) + key = jnp.zeros((1, 257, 4096), dtype=jnp.bfloat16) + + block_sizes = _select_flash_block_sizes( + query=query, + key=key, + flash_block_sizes=None, + dtype=jnp.bfloat16, + attention_kernel="flash", + ) + + assert block_sizes.block_q == 1024 + assert block_sizes.block_kv_compute == 257 + assert block_sizes.block_kv == 257 + assert block_sizes.block_q_dkv == 1024 + assert block_sizes.block_kv_dkv == 257 + assert block_sizes.block_kv_dkv_compute == 128 + assert block_sizes.block_q_dq == 1024 + assert block_sizes.block_kv_dq == 128 + if __name__ == "__main__": absltest.main()