Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 58 additions & 34 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,49 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
return tensor, kv_size, seq_len


def _flash_sequence_length(tensor: Array) -> int:
Comment thread
csgoogle marked this conversation as resolved.
if tensor.ndim == 3:
return tensor.shape[1]
if tensor.ndim == 4:
return tensor.shape[2]
raise ValueError(f"Flash attention expects rank-3 or rank-4 inputs, got rank {tensor.ndim}.")


def _select_flash_block_sizes(
Comment thread
csgoogle marked this conversation as resolved.
query: Array,
key: Array,
flash_block_sizes: BlockSizes,
dtype: jnp.dtype,
attention_kernel: str,
) -> BlockSizes:
query_seq_len = _flash_sequence_length(query)
key_seq_len = _flash_sequence_length(key)

q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
if key_seq_len != query_seq_len:
kv_max_block_size = ((key_seq_len + 127) // 128) * 128
else:
kv_max_block_size = q_max_block_size

# Keep configured block sizes for self-attention, but let
# cross-attention derive safe KV-aware sizes when q_len != kv_len.
if flash_block_sizes and key_seq_len == query_seq_len:
return flash_block_sizes

block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
return splash_attention_kernel.BlockSizes(
block_q=block_size_q,
block_kv_compute=min(kv_max_block_size, key_seq_len),
block_kv=min(kv_max_block_size, key_seq_len),
block_q_dkv=block_size_q,
block_kv_dkv=min(kv_max_block_size, key_seq_len),
block_kv_dkv_compute=min(kv_max_block_size, query_seq_len),
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query_seq_len),
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
)


def convert_to_tokamax_splash_config(
block_sizes: BlockSizes,
q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
Expand Down Expand Up @@ -244,28 +287,7 @@ def _tpu_flash_attention(
) -> jax.Array:
"""TPU Flash Attention"""

q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
# This is the case for cross-attn.
if key.shape[1] != query.shape[1]:
kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
else:
kv_max_block_size = q_max_block_size
# ensure that for cross attention we override the block sizes.
if flash_block_sizes and key.shape[1] == query.shape[1]:
block_sizes = flash_block_sizes
else:
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
block_sizes = splash_attention_kernel.BlockSizes(
block_q=block_size_q,
block_kv_compute=min(kv_max_block_size, key.shape[2]),
block_kv=min(kv_max_block_size, key.shape[2]),
block_q_dkv=block_size_q,
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
)
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
num_context_shards = mesh.shape["context"]
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
Expand Down Expand Up @@ -717,8 +739,8 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
Comment thread
csgoogle marked this conversation as resolved.
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
)
self.act = get_activation(activation_fn)
self.net_2 = nnx.Linear(
Expand All @@ -729,8 +751,8 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
)

def __call__(self, hidden_states: Array) -> Array:
Expand Down Expand Up @@ -979,7 +1001,7 @@ def __init__(
precision=precision,
bias_init=nnx.with_partitioning(
nnx.initializers.zeros,
("embed",),
("heads",),
Comment thread
csgoogle marked this conversation as resolved.
),
)

Expand All @@ -993,7 +1015,7 @@ def __init__(
precision=precision,
bias_init=nnx.with_partitioning(
nnx.initializers.zeros,
("embed",),
("heads",),
),
)

Expand All @@ -1007,7 +1029,7 @@ def __init__(
precision=precision,
bias_init=nnx.with_partitioning(
nnx.initializers.zeros,
("embed",),
("heads",),
),
)

Expand All @@ -1021,7 +1043,7 @@ def __init__(
precision=precision,
bias_init=nnx.with_partitioning(
nnx.initializers.zeros,
("heads",),
("embed",),
Comment thread
csgoogle marked this conversation as resolved.
),
)

Expand Down Expand Up @@ -1333,11 +1355,13 @@ def setup(self):
precision=self.precision,
)

proj_attn_kernel_axes = ("heads", "embed")

self.proj_attn = nn.Dense(
self.query_dim,
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes),
use_bias=True,
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
dtype=self.dtype,
param_dtype=self.weights_dtype,
name="i_proj",
Expand All @@ -1346,9 +1370,9 @@ def setup(self):

self.encoder_proj_attn = nn.Dense(
self.query_dim,
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes),
use_bias=True,
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
dtype=self.dtype,
param_dtype=self.weights_dtype,
name="e_proj",
Expand Down
8 changes: 4 additions & 4 deletions src/maxdiffusion/models/ltx2/attention_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,13 @@ def __init__(
# 1. Define Partitioned Initializers (Logical Axes)
# Q, K, V kernels: [in_features (embed), out_features (heads)]
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
# Q, K, V biases: [out_features (embed)]
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
# Q, K, V biases: [out_features (heads)]
Comment thread
csgoogle marked this conversation as resolved.
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))

# Out kernel: [in_features (heads), out_features (embed)]
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
# Out bias: [out_features (heads)]
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
# Out bias: [out_features (embed)]
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))

# Norm scales
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))
Expand Down
6 changes: 3 additions & 3 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,11 @@ def __init__(
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(
"mlp",
"embed",
Comment thread
csgoogle marked this conversation as resolved.
"mlp",
),
),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
)

def __call__(self, x: jax.Array) -> jax.Array:
Expand Down Expand Up @@ -249,8 +249,8 @@ def __init__(
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(
"embed",
"mlp",
"embed",
),
),
)
Expand Down
56 changes: 55 additions & 1 deletion src/maxdiffusion/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import jax
from jax.sharding import Mesh
import jax.numpy as jnp
from ..models.attention_flax import FlaxAttention
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from ..models.attention_flax import FlaxAttention, _select_flash_block_sizes
from .. import max_utils
from .. import pyconfig

Expand Down Expand Up @@ -92,6 +93,59 @@ def test_splash_attention(self):

assert diff_norm < 1.0

def test_cross_attention_overrides_configured_flash_block_sizes(self):
query = jnp.zeros((1, 1024, 256), dtype=jnp.bfloat16)
key = jnp.zeros((1, 257, 256), dtype=jnp.bfloat16)
configured_block_sizes = splash_attention_kernel.BlockSizes(
block_q=384,
block_kv_compute=192,
block_kv=320,
block_q_dkv=256,
block_kv_dkv=288,
block_kv_dkv_compute=160,
block_q_dq=128,
block_kv_dq=96,
use_fused_bwd_kernel=False,
)

block_sizes = _select_flash_block_sizes(
query=query,
key=key,
flash_block_sizes=configured_block_sizes,
dtype=jnp.bfloat16,
attention_kernel="flash",
)

assert block_sizes.block_q == configured_block_sizes.block_q
assert block_sizes.block_q_dkv == configured_block_sizes.block_q
assert block_sizes.block_q_dq == configured_block_sizes.block_q
assert block_sizes.block_kv_compute == 257
assert block_sizes.block_kv == 257
assert block_sizes.block_kv_dkv == 257
assert block_sizes.block_kv_dkv_compute == 384
assert block_sizes.block_kv_dq == 384

def test_default_flash_block_sizes_use_sequence_axis_for_3d_inputs(self):
query = jnp.zeros((1, 128, 4096), dtype=jnp.bfloat16)
key = jnp.zeros((1, 257, 4096), dtype=jnp.bfloat16)

block_sizes = _select_flash_block_sizes(
query=query,
key=key,
flash_block_sizes=None,
dtype=jnp.bfloat16,
attention_kernel="flash",
)

assert block_sizes.block_q == 1024
assert block_sizes.block_kv_compute == 257
assert block_sizes.block_kv == 257
assert block_sizes.block_q_dkv == 1024
assert block_sizes.block_kv_dkv == 257
assert block_sizes.block_kv_dkv_compute == 128
assert block_sizes.block_q_dq == 1024
assert block_sizes.block_kv_dq == 128


if __name__ == "__main__":
absltest.main()
Loading