Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a comprehensive "Strict Resume" feature for Transformers models, enabling the restoration of full training state including optimizer, scheduler, scaler, RNG states, and data progress. Key changes involve implementing load_training_state and read_training_progress across the model, server, and client layers, alongside dataloader enhancements to support sample-level skipping for map-style datasets. Feedback highlights several critical improvements: ensuring deterministic RNG in distributed settings by avoiding unseeded random states, addressing the deprecated use of StopIteration in generators, improving security by using weights_only=True during checkpoint loading, and removing an accidental BOM character in the client generator. Additionally, a more robust approach for re-initializing the dataloader is suggested to avoid modifying private PyTorch attributes.
| self.dataloader.__initialized = False | ||
| self._rebuild_sampler_stack() | ||
| self.dataloader.__initialized = True |
There was a problem hiding this comment.
Accessing and modifying the private attribute __initialized of torch.utils.data.DataLoader is brittle and relies on internal implementation details of PyTorch that could change. A safer approach to update the sampler stack after the dataloader has been created is to simply re-instantiate the underlying self.dataloader using the stored self.dataloader_params.
| self.dataloader.__initialized = False | |
| self._rebuild_sampler_stack() | |
| self.dataloader.__initialized = True | |
| if self.dataloader is not None: | |
| self.dataloader = None | |
| self._lazy_init_dataloader() |
|
/gemini summary |
Summary of ChangesThis pull request introduces robust checkpoint resumption capabilities to the training framework. By enabling the restoration of full training states—including optimizer, scheduler, and RNG configurations—and implementing precise data skipping in the dataloader, the changes ensure that training can be reliably resumed after interruptions. Additionally, the PR optimizes checkpoint handling for FSDP2 strategies and adds necessary API endpoints to support these features in distributed and remote training environments. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Activity
|
| @@ -99,21 +99,29 @@ def train(): | |||
| # model.set_lr_scheduler('LinearLR') | |||
There was a problem hiding this comment.
确实是 typo,但 self_congnition.py 在 main 上就已存在,是否在单独 PR 中修正更合适?
| ) | ||
| response.raise_for_status() | ||
|
|
||
| def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: |
There was a problem hiding this comment.
load_training_state和read_training_progress什么区别,能否合并为一个呢
| twinkle_path = model.save( | ||
| name=f'twinkle-epoch-{epoch}', | ||
| save_optimizer=True, | ||
| consumed_train_samples=consumed_train_samples, |
There was a problem hiding this comment.
dataloader.get_consumed_samples()?
There was a problem hiding this comment.
或者,dataloader.get_state(),更通用一些
There was a problem hiding this comment.
另外,这里额外测试下torchrun/ray的兼容性,还有megatron和transformers双模型的兼容性
|
|
||
| - `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. | ||
| - `model.load(name, output_dir=..., adapter_name=...)` restores LoRA / adapter model weights. | ||
| - `model.read_training_progress(checkpoint_dir, ...)` reads checkpoint metadata such as `cur_step`, `gradient_accumulation_steps`, and `consumed_train_samples`. |
There was a problem hiding this comment.
这两个比较相似,合成一个是否合适?比如
training_progress = model.resume_from_checkpoint(xxx)
dataloader.resume_from_checkpoint(training_progress.get('dataloader'))
这样?
|
|
||
| if optimizer is not None: | ||
| optimizer_path = os.path.join(output_dir, 'optimizer.pt') | ||
| if hasattr(self.strategy, 'save_optimizer_checkpoint'): |
There was a problem hiding this comment.
这里职责有点不清晰,具体为什么有的strategy有save_optimizer_checkpoint,有的又没有?
读代码的人就会感觉不理解,到底什么情况需要strategy存储
| adapter_name = kwargs.pop('adapter_name', _default_adapter_name) | ||
| optimizer_config = self.optimizer_group[adapter_name] | ||
|
|
||
| if not Platform.is_master(): |
There was a problem hiding this comment.
这里ray和torchrun都需要确保正确,megatron部分也需要对应考虑
|
|
||
| def save_optimizer_checkpoint(self, model, optimizer, output_path: str): | ||
| fsdp_plugin = self._get_fsdp_plugin() | ||
| if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: |
PR type
PR information
在TrasnfomersModel和MultiLoraModel实现完整训练状态的恢复——包括优化器、调度器、RNG配置以及数据集跳过