Skip to content

Change bias initialization from 'embed' to 'heads'#371

Open
csgoogle wants to merge 1 commit intomainfrom
fixbiassharding
Open

Change bias initialization from 'embed' to 'heads'#371
csgoogle wants to merge 1 commit intomainfrom
fixbiassharding

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented Apr 6, 2026

  • Fix incorrect logical partitioning axes for attention and feed-forward parameters in Flax/WAN/LTX2 modules.
  • Refactor flash-attention block-size selection into a helper and add unit tests

doc: https://docs.google.com/document/d/1absFkpQAMM3YaYWxO_FYeqzDpypYeDbPsJRAV86nFQ0/edit?usp=sharing&resourcekey=0-FOzOmM0UdfU1LcDd_7epvw

Results

Metric main fixbiassharding Δ
Compile time 1913.9s 1906.4s -7.5s
Inference time 1656.4s 1642.1s -14.3s (-0.9%)

Notes

  • No difference observed with tp=1 configs — improvement only surfaces when tensor parallelism is active, as the axis fixes reduce parameter all-gather overhead in MLP layers
  • Primary motivation for this change is correctness: incorrect sharding axes can cause OOM or numerical issues at other parallelism configs
  • Larger gains expected at tp=4 or tp=8 where parameter communication is a larger fraction of step time

Video Quality Comparison

Branch Video
main main.mp4
fixbiassharding fixbiassharding.mp4

PSNR/SSIM (frame-by-frame, 81 frames):

Metric Mean Min Max
PSNR 19.37 dB 18.83 20.17
SSIM 0.7884 0.7654 0.8043

Low PSNR/SSIM reflects floating point non-determinism from different sharding layouts across 50 denoising steps (bfloat16 + different collective patterns) — videos are visually identical.

Video and Xprof after fix:

https://console.cloud.google.com/storage/browser/sagarchapara/shardingfixes

@csgoogle csgoogle requested a review from entrpn as a code owner April 6, 2026 10:09
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 6, 2026

@csgoogle csgoogle force-pushed the fixbiassharding branch 2 times, most recently from d822acb to 15af39f Compare April 13, 2026 10:41
entrpn
entrpn previously approved these changes Apr 14, 2026
@Perseus14
Copy link
Copy Markdown
Collaborator

Could you add more details and results on the new commits? @csgoogle

@csgoogle csgoogle force-pushed the fixbiassharding branch 5 times, most recently from 9780b17 to 7a6ab88 Compare April 15, 2026 14:59
raise ValueError(f"Flash attention expects rank-3 or rank-4 inputs, got rank {tensor.ndim}.")


def _select_flash_block_sizes(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just refactoring no logical changes and added unit tests, this would be helpful for other ulyesses attention pr

@csgoogle
Copy link
Copy Markdown
Collaborator Author

Could you add more details and results on the new commits? @csgoogle

done

@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @csgoogle, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants