Conversation
There was a problem hiding this comment.
Code Review
This pull request significantly enhances sequence parallelism support by implementing ZigZag Ring Attention for long-sequence training and Ulysses-style sequence parallelism for Qwen3.5 linear attention. It also introduces multimodal deepstack patching for Qwen3-VL and refactors the SequenceParallel strategy to better handle complex device meshes and packed/varlen inputs. Feedback focuses on improving code maintainability and robustness, specifically by grouping attributes in the SequenceParallel constructor, removing redundant logic and unused imports, replacing deprecated inspection methods, and centralizing duplicated loss-gathering logic.
| self.seq_world_size = None | ||
| self.sp_world_size = None | ||
| self.rp_world_size = None | ||
| self.dp_world_size = None | ||
| self.world_size = None | ||
| self.attn_implementation = None | ||
| self.model_dtype = None | ||
| self.tokenizer = None | ||
| self.device_mesh = None | ||
| self._sp_group = None | ||
| self._rp_group = None | ||
| self._data_rank_group = None | ||
| self._sp_rank = 0 | ||
| self._rp_rank = 0 | ||
| self.num_heads = None | ||
| self.causal_mask_func = None | ||
| self.extra_kwargs = {} |
| if query.shape[2] != total_tokens: | ||
| raise ValueError('Packed/varlen flash_attention_2 expects query sequence length to match ' | ||
| f'cu_seqlens total tokens, got query_seq_len={query.shape[2]} ' | ||
| f'and cu_seqlens_total={total_tokens}.') |
| if self.rp_world_size > 1: | ||
| attn_impl = getattr(model.config, '_attn_implementation', None) | ||
| if attn_impl != 'flash_attention_2': | ||
| raise NotImplementedError('Derived ring attention only supports flash_attention_2 backend.') |
| @@ -0,0 +1,283 @@ | |||
| import os | |||
| @cache | ||
| def _get_default_args(func): | ||
| spec = inspect.getfullargspec(func) | ||
| defaults = spec.defaults if spec.defaults is not None else () | ||
| padded_defaults = (None, ) * (len(spec.args) - len(defaults)) + defaults | ||
| args = dict(zip(spec.args, padded_defaults)) | ||
| if 'softcap' in args: | ||
| args['softcap'] = 0.0 | ||
| return args |
| if self.sp_strategy is not None: | ||
| loss_inputs, loss_outputs = self.sp_strategy.gather_loss_tensors(inputs, outputs) |
| # local labels still count only the shard-local tokens. Normalize the loss | ||
| # contribution here so metric-side averaging matches the non-SP path. | ||
| if ulysses_size > 1: | ||
| loss = loss / float(ulysses_size) |
There was a problem hiding this comment.
为什么会放到这里呢,或者说,model进行backward的loss是否需要除以ulysses-size
| from twinkle.model.transformers.strategy.sequence_parallel.utils import head_to_seq_shard, seq_to_head_shard | ||
| from twinkle.patch import Patch | ||
|
|
||
| if is_flash_linear_attention_available(): |
There was a problem hiding this comment.
这部分能使用到swift中吗?swift也缺linear-attention的sp
| from twinkle.utils.grad_clip import normalize_and_clip_grad_norm | ||
|
|
||
|
|
||
| def _get_raw_dp_fsdp_world_size(device_mesh: Optional[DeviceMesh]) -> int: |
There was a problem hiding this comment.
这个和device_mesh的dp_world_size似乎是一样的?能否复用
There was a problem hiding this comment.
不一样,这里算的是 dp_world_size * fsdp_world_size,device_mesh的dp_world_size是: @Property
def dp_world_size(self) -> int:
return self._get_world_size_for_dim('dp')
| result = loss_instance(inputs, outputs, **kwargs) | ||
| loss_inputs = inputs | ||
| loss_outputs = outputs | ||
| if self.sp_strategy is not None: |
There was a problem hiding this comment.
这部分能否使用inputprocessor?既然切分是inputprocessor做,那gather是否应该也放在里面
There was a problem hiding this comment.
应该不太合适吧,这里已经是到了loss 计算阶段了,inputprocessor的职责应该是做输入的处理的吧
PR type
PR information
This PR adds context parallel and Qwen3.5 Gated DeltaNet sequence parallel support to the transformers stack, and refactors sequence parallel into a package-based implementation.
Main changes:
sequence_parallel.pyintosequence_parallel/and add shared utilities.linear_attention_sp.py;Ring attention is not supported for this path yet.sp_fsdp_dense.tests/moe/test_expert_parallel_qwen3_fsdp_sp.py.Experiment results