Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions src/maxdiffusion/models/embeddings_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,28 @@ def get_sinusoidal_embeddings(
"""Returns the positional encoding (same as Tensor2Tensor).

Args:
timesteps: a 1-D Tensor of N indices, one per batch element.
timesteps: a 1-D or 2-D Tensor of indices.
These may be fractional.
embedding_dim: The number of output channels.
min_timescale: The smallest time unit (should probably be 0.0).
max_timescale: The largest time unit.
Returns:
a Tensor of timing signals [N, num_channels]
a Tensor of timing signals [B, num_channels] or [B, N, num_channels]
"""
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
assert timesteps.ndim <= 2, "Timesteps should be a 1d or 2d-array"
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
num_timescales = float(embedding_dim // 2)
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
emb = jnp.expand_dims(timesteps, -1) * inv_timescales

# scale embeddings
scaled_time = scale * emb

if flip_sin_to_cos:
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=-1)
else:
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1)
return signal


Expand All @@ -84,7 +83,7 @@ def __init__(
sample_proj_bias=True,
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.Precision = None,
precision: jax.lax.Precision | None = None,
):
self.linear_1 = nnx.Linear(
rngs=rngs,
Expand Down Expand Up @@ -221,7 +220,7 @@ def __call__(self, timesteps):

def get_1d_rotary_pos_embed(
dim: int,
pos: Union[jnp.array, int],
pos: Union[jnp.ndarray, int],
theta: float = 10000.0,
linear_factor=1.0,
ntk_factor=1.0,
Expand Down Expand Up @@ -332,11 +331,11 @@ def __init__(
rngs: nnx.Rngs,
in_features: int,
hidden_size: int,
out_features: int = None,
out_features: int | None = None,
act_fn: str = "gelu_tanh",
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.Precision = None,
precision: jax.lax.Precision | None = None,
):
if out_features is None:
out_features = hidden_size
Expand Down Expand Up @@ -392,11 +391,11 @@ class PixArtAlphaTextProjection(nn.Module):
"""

hidden_size: int
out_features: int = None
out_features: int | None = None
act_fn: str = "gelu_tanh"
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
precision: jax.lax.Precision = None
precision: jax.lax.Precision | None = None

@nn.compact
def __call__(self, caption):
Expand Down Expand Up @@ -455,7 +454,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
pooled_projection_dim: int
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
precision: jax.lax.Precision = None
precision: jax.lax.Precision | None = None

@nn.compact
def __call__(self, timestep, pooled_projection):
Expand All @@ -479,7 +478,7 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
pooled_projection_dim: int
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
precision: jax.lax.Precision = None
precision: jax.lax.Precision | None = None

@nn.compact
def __call__(self, timestep, guidance, pooled_projection):
Expand Down
21 changes: 3 additions & 18 deletions src/maxdiffusion/models/wan/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
feat_cache = _update_cache(feat_cache, idx, cache_x)
feat_idx += 1
x = x.reshape(b, t, h, w, 2, c)
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
# x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
x = x.transpose(0, 1, 4, 2, 3, 5)
x = x.reshape(b, t * 2, h, w, c)
t = x.shape[1]
x = x.reshape(b * t, h, w, c)
Expand Down Expand Up @@ -1160,23 +1161,7 @@ def _decode(
out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
else:
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)

# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
# Most likely due to an incorrect reshaping in the decoder.
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
# When batch_size is 0, expand batch dim for concatenation
# else, expand frame dim for concatenation so that batch dim stays intact.
axis = 0
if fm1.shape[0] > 1:
axis = 1

if len(fm1.shape) == 4:
fm1 = jnp.expand_dims(fm1, axis=axis)
fm2 = jnp.expand_dims(fm2, axis=axis)
fm3 = jnp.expand_dims(fm3, axis=axis)
fm4 = jnp.expand_dims(fm4, axis=axis)
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
out = jnp.concatenate([out, out_], axis=1)

feat_cache._feat_map = dec_feat_map

Expand Down
Loading
Loading