diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index ca2579d92..832efb343 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -206,6 +206,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 65e7d19e0..84807830f 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -211,6 +211,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 16948296a..95435d041 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -221,6 +221,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 7a508095f..996ae177f 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -245,6 +245,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index 1aba7431f..af3a0bde2 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -232,6 +232,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 9ae399713..8454d4809 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -240,6 +240,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index ecc430339..10a814b8c 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -302,6 +302,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0.0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index b54cbbdd6..fde9efe8d 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -258,6 +258,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0.0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 464689a14..0a013285b 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -268,6 +268,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0.0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index ee61c5e31..686f66280 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -263,6 +263,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0.0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 44ecf5166..3eac96ccd 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -264,6 +264,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0.0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 3dbb1578e..34c74fc6b 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -205,6 +205,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index e487559a7..b140a6c1a 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -166,6 +166,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 04b3869fe..e66be865a 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -483,13 +483,19 @@ def create_learning_rate_schedule(learning_rate, learning_rate_schedule_steps, w def create_optimizer(config, learning_rate_scheduler): - return optax.adamw( + opt = optax.adamw( learning_rate=learning_rate_scheduler, b1=config.adam_b1, b2=config.adam_b2, eps=config.adam_eps, weight_decay=config.adam_weight_decay, ) + if config.opt_enable_grad_global_norm_clipping: + opt = optax.chain(optax.clip_by_global_norm(config.max_grad_norm), opt) + + if config.opt_enable_grad_clipping: + opt = optax.chain(optax.clip(config.max_grad_value), opt) + return opt def get_precision(config):