Skip to content

Fix: Overhaul WAN checkpointers for robust multi-host restoration#379

Merged
copybara-service[bot] merged 1 commit intomainfrom
ninatu/fix-sharding-mismatch
Apr 17, 2026
Merged

Fix: Overhaul WAN checkpointers for robust multi-host restoration#379
copybara-service[bot] merged 1 commit intomainfrom
ninatu/fix-sharding-mismatch

Conversation

@ninatu
Copy link
Copy Markdown
Collaborator

@ninatu ninatu commented Apr 16, 2026

This commit resolves several interrelated checkpointing issues by updating how Orbax handles metadata, sharding, and PyTree restoration.

Key changes:

  • Add explicit item_handlers: Defined specific handlers (JsonCheckpointHandler for configs, StandardCheckpointHandler for states) in CheckpointManager. This ensures metadata is restored correctly, resolving known Orbax limitations (reference: Restoring to CPU google/orbax#986).

  • Bypass mesh validation during restore: Replaced ocp.utils.to_shape_dtype_struct with manual jax.ShapeDtypeStruct construction in add_sharding_to_struct. This makes restoration topology-agnostic, preventing ValueError when the current device mesh has fewer devices than the saved checkpoint's topology (e.g., restoring 32-device metadata on 4 devices).

  • Migrate to Standard API: Upgraded all WAN checkpointers from the PyTreeSave/PyTreeRestore APIs to StandardSave/StandardRestore to align with item_handlers defined in CheckpointManager.

@ninatu ninatu requested a review from entrpn as a code owner April 16, 2026 14:47
@github-actions
Copy link
Copy Markdown

@ninatu ninatu force-pushed the ninatu/fix-sharding-mismatch branch 7 times, most recently from 09deaaa to 4850947 Compare April 16, 2026 18:29
entrpn
entrpn previously approved these changes Apr 16, 2026
@ninatu ninatu force-pushed the ninatu/fix-sharding-mismatch branch from 4850947 to 68d7f5d Compare April 17, 2026 06:52
This commit resolves several interrelated checkpointing issues by updating
how Orbax handles metadata, sharding, and PyTree restoration.

Key changes:
* Add explicit `item_handlers`: Defined specific handlers (`JsonCheckpointHandler`
  for configs, `StandardCheckpointHandler` for states) in `CheckpointManager`.
  This ensures metadata is restored correctly, resolving known Orbax limitations
  (reference: google/orbax#986).

* Bypass mesh validation during restore: Replaced `ocp.utils.to_shape_dtype_struct`
  with manual `jax.ShapeDtypeStruct` construction in `add_sharding_to_struct`.
  This makes restoration topology-agnostic, preventing `ValueError` when the
  current device mesh has fewer devices than the saved checkpoint's topology
  (e.g., restoring 32-device metadata on 4 devices).

* Migrate to Standard API: Upgraded all WAN checkpointers from
  the `PyTreeSave`/`PyTreeRestore` APIs to `StandardSave`/`StandardRestore`
  to align with `item_handlers` defined in CheckpointManager.

Co-authored-by: martinarroyo <martinarroyo@google.com>
@ninatu ninatu force-pushed the ninatu/fix-sharding-mismatch branch from 68d7f5d to 69f7701 Compare April 17, 2026 06:53
@copybara-service copybara-service bot merged commit 2965670 into main Apr 17, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants