Skip to content

support cp ,fix qwen3.5 gdn sp#138

Open
meichangsu1 wants to merge 3 commits intomodelscope:mainfrom
meichangsu1:fsdp_cp_ljl
Open

support cp ,fix qwen3.5 gdn sp#138
meichangsu1 wants to merge 3 commits intomodelscope:mainfrom
meichangsu1:fsdp_cp_ljl

Conversation

@meichangsu1
Copy link
Copy Markdown
Collaborator

@meichangsu1 meichangsu1 commented Apr 2, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR adds context parallel and Qwen3.5 Gated DeltaNet sequence parallel support to the transformers stack, and refactors sequence parallel into a package-based implementation.

Main changes:

  • Refactor sequence_parallel.py into sequence_parallel/ and add shared utilities.
  • Add derived ring / zigzag ring attention support for CP + SP.
  • Add Qwen3.5 linear attention SP support in linear_attention_sp.py;Ring attention is not supported for this path yet.
  • Update transformers model / processor paths to work with the new SP+CP flow.
  • Adjust loss metric aggregation for Ulysses replicated loss behavior.
  • Update cookbook examples for sp_fsdp_dense.
  • Add test coverage for:
    • Qwen3.5 linear attention SP alignment
    • sequence parallel + context parallel behavior
  • Remove outdated tests/moe/test_expert_parallel_qwen3_fsdp_sp.py.

Experiment results

@meichangsu1 meichangsu1 changed the title Fsdp cp ljl support cp ,fix qwen3.5 gdn sp Apr 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly enhances sequence parallelism support by implementing ZigZag Ring Attention for long-sequence training and Ulysses-style sequence parallelism for Qwen3.5 linear attention. It also introduces multimodal deepstack patching for Qwen3-VL and refactors the SequenceParallel strategy to better handle complex device meshes and packed/varlen inputs. Feedback focuses on improving code maintainability and robustness, specifically by grouping attributes in the SequenceParallel constructor, removing redundant logic and unused imports, replacing deprecated inspection methods, and centralizing duplicated loss-gathering logic.

Comment on lines +37 to 53
self.seq_world_size = None
self.sp_world_size = None
self.rp_world_size = None
self.dp_world_size = None
self.world_size = None
self.attn_implementation = None
self.model_dtype = None
self.tokenizer = None
self.device_mesh = None
self._sp_group = None
self._rp_group = None
self._data_rank_group = None
self._sp_rank = 0
self._rp_rank = 0
self.num_heads = None
self.causal_mask_func = None
self.extra_kwargs = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The init method is becoming quite large with many attributes. Consider grouping related attributes into a dataclass or a separate configuration object to improve maintainability.

Comment on lines +236 to +239
if query.shape[2] != total_tokens:
raise ValueError('Packed/varlen flash_attention_2 expects query sequence length to match '
f'cu_seqlens total tokens, got query_seq_len={query.shape[2]} '
f'and cu_seqlens_total={total_tokens}.')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check if world_size and world_size > 1 is redundant if world_size is guaranteed to be an integer. If it can be None, consider a more explicit check or default value handling.

Comment on lines +481 to +484
if self.rp_world_size > 1:
attn_impl = getattr(model.config, '_attn_implementation', None)
if attn_impl != 'flash_attention_2':
raise NotImplementedError('Derived ring attention only supports flash_attention_2 backend.')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for flash_attention_2 backend should ideally be done using a constant or a centralized configuration check to avoid hardcoded strings.

@@ -0,0 +1,283 @@
import os
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import import os is unused in this file. Please remove it.

Comment on lines +115 to +123
@cache
def _get_default_args(func):
spec = inspect.getfullargspec(func)
defaults = spec.defaults if spec.defaults is not None else ()
padded_defaults = (None, ) * (len(spec.args) - len(defaults)) + defaults
args = dict(zip(spec.args, padded_defaults))
if 'softcap' in args:
args['softcap'] = 0.0
return args
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _get_default_args function uses inspect.getfullargspec which is deprecated in newer Python versions. Consider using inspect.signature instead.

Comment on lines +490 to +491
if self.sp_strategy is not None:
loss_inputs, loss_outputs = self.sp_strategy.gather_loss_tensors(inputs, outputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for gathering loss tensors is duplicated or very similar to logic in other parts of the codebase. Consider centralizing this loss gathering logic to avoid drift.

@meichangsu1 meichangsu1 changed the title support cp ,fix qwen3.5 gdn sp support cp ,fix qwen3.5 gdn sp Apr 2, 2026
# local labels still count only the shard-local tokens. Normalize the loss
# contribution here so metric-side averaging matches the non-SP path.
if ulysses_size > 1:
loss = loss / float(ulysses_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么会放到这里呢,或者说,model进行backward的loss是否需要除以ulysses-size

from twinkle.model.transformers.strategy.sequence_parallel.utils import head_to_seq_shard, seq_to_head_shard
from twinkle.patch import Patch

if is_flash_linear_attention_available():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分能使用到swift中吗?swift也缺linear-attention的sp

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的

from twinkle.utils.grad_clip import normalize_and_clip_grad_norm


def _get_raw_dp_fsdp_world_size(device_mesh: Optional[DeviceMesh]) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和device_mesh的dp_world_size似乎是一样的?能否复用

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不一样,这里算的是 dp_world_size * fsdp_world_size,device_mesh的dp_world_size是: @Property
def dp_world_size(self) -> int:
return self._get_world_size_for_dim('dp')

result = loss_instance(inputs, outputs, **kwargs)
loss_inputs = inputs
loss_outputs = outputs
if self.sp_strategy is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分能否使用inputprocessor?既然切分是inputprocessor做,那gather是否应该也放在里面

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该不太合适吧,这里已经是到了loss 计算阶段了,inputprocessor的职责应该是做输入的处理的吧

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants