From 68aadadf7ddf4faf5e6ad3be8a43c594de227cef Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 17 Jun 2026 15:11:38 -0400 Subject: [PATCH 1/2] Make GSPO loss length-proportional --- fast_llm/functional/triton/gspo_loss.py | 24 +++++++++---------- .../language_model/loss/policy_gradient.py | 21 ++++++++-------- tests/layers/test_lm_losses.py | 2 +- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/fast_llm/functional/triton/gspo_loss.py b/fast_llm/functional/triton/gspo_loss.py index d1bcfef17..9244935b2 100644 --- a/fast_llm/functional/triton/gspo_loss.py +++ b/fast_llm/functional/triton/gspo_loss.py @@ -16,7 +16,7 @@ def triton_gspo_loss_backward_kernel( sum_exp_logits_ptr, probability_ratio_ptr, seg_advantage_ptr, - token_weight_ptr, + loss_weight_ptr, grad_logits_ptr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, @@ -31,10 +31,8 @@ def triton_gspo_loss_backward_kernel( ): block_idx = tl.program_id(0).to(tl.int64) - # token_weight = mask_t / N_d, where N_d is the labeled-token count for the doc containing t. - # Zero for masked tokens (mask=0) and for tokens with N_d=0 after the kernel's clamp. - token_weight = tl.load(token_weight_ptr + block_idx).to(tl.float32) - if token_weight == 0.0: + loss_weight = tl.load(loss_weight_ptr + block_idx).to(tl.float32) + if loss_weight == 0.0: if not accumulate: for col_offset in tl.static_range(0, n_cols, block_size): col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) @@ -61,7 +59,7 @@ def triton_gspo_loss_backward_kernel( ) * probability_ratio * grad_scale - * token_weight + * loss_weight ) logits_ptr = logits_ptr + block_idx * logits_stride_0 @@ -107,7 +105,7 @@ def triton_gspo_loss_forward_backward( num_warps: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """Triton GSPO loss. Forward fuses softmax + predicted-logit lookup; backward fuses the - softmax chain rule with the per-token GSPO gradient factor (R_s * clip * token_weight). + softmax chain rule with the per-token GSPO gradient factor (R_s * clip * loss_weight). Segment aggregation, loss, and the SDP/SP all-reduce live in PyTorch between the two passes. See `fused_gspo_loss_forward_backward` in policy_gradient.py for the math derivation; @@ -163,13 +161,13 @@ def triton_gspo_loss_forward_backward( # contribution to the per-segment sum is already normalized, so SDP/SP all-reduce works # without a separate token-count tensor. flat_num_labels = num_labels_in_seq.reshape(-1).to(new_log_probs.dtype).clamp(min=1) - token_weight = loss_mask / flat_num_labels + mean_token_weight = loss_mask / flat_num_labels mean_log_ratio_per_segment = log_ratio.new_zeros(num_segments).index_add_( - 0, flat_document_index, log_ratio * token_weight + 0, flat_document_index, log_ratio * mean_token_weight ) mean_advantage_per_segment = log_ratio.new_zeros(num_segments).index_add_( - 0, flat_document_index, flat_advantages * token_weight + 0, flat_document_index, flat_advantages * mean_token_weight ) for reduce_group in (sdp_group, sp_group): if reduce_group is not None: @@ -185,13 +183,13 @@ def triton_gspo_loss_forward_backward( probability_ratio = segment_ratio[flat_document_index].contiguous() seg_advantage = segment_advantage[flat_document_index].contiguous() - token_weight = token_weight.contiguous() + loss_weight = loss_mask.contiguous() losses = -torch.min( probability_ratio * seg_advantage, torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * seg_advantage, ) - loss = (losses * token_weight).sum() / divisor + loss = (losses * loss_weight).sum() / divisor new_logprobs_mean = (new_log_probs * loss_mask / flat_num_labels).sum() @@ -208,7 +206,7 @@ def triton_gspo_loss_forward_backward( sum_exp_logits_ptr=sum_exp_logits, probability_ratio_ptr=probability_ratio, seg_advantage_ptr=seg_advantage, - token_weight_ptr=token_weight, + loss_weight_ptr=loss_weight, grad_logits_ptr=grad_logits, n_cols=n_cols, logits_stride_0=logits.stride(-2), diff --git a/fast_llm/layers/language_model/loss/policy_gradient.py b/fast_llm/layers/language_model/loss/policy_gradient.py index 62db342fe..a024d4232 100644 --- a/fast_llm/layers/language_model/loss/policy_gradient.py +++ b/fast_llm/layers/language_model/loss/policy_gradient.py @@ -460,13 +460,14 @@ def fused_gspo_loss_forward_backward( """GSPO loss: sequence-level geometric-mean IS-ratio clipping. Per-segment ratio R_s = exp(mean_t(log p_new_t / p_old_t)), clipped per segment. - Per-segment loss = -min(R_s * A_s, clip(R_s) * A_s), summed over segments and divided by `divisor`. + Per-segment loss = -min(R_s * A_s, clip(R_s) * A_s), multiplied by the segment + token count, summed over segments, and divided by `divisor`. Computed as an equivalent per-token sum so the gradient chain mirrors GRPO. `num_labels_in_seq[t]` is the labeled-token count for the document containing token `t` - (broadcast per token by the data preprocessor); it doubles as the geometric-mean denominator - and the per-token weight. Using it directly — rather than aggregating token counts inside the - kernel — is what makes the loss correct when a document spans SDP/SP ranks (numerator + (broadcast per token by the data preprocessor); it is the geometric-mean denominator. + Using it directly — rather than aggregating token counts inside the kernel — is what + makes the loss correct when a document spans SDP/SP ranks (numerator `log_ratio_sum` is all-reduced; denominator is constant and available locally). Constraint: each document must be visible to a single kernel call (modulo SDP/SP, where @@ -493,17 +494,17 @@ def fused_gspo_loss_forward_backward( # Per-token weight: mask / per-document label count, from the preprocessor. # Each labeled token contributes `1 / N_d` so all of doc d's tokens sum to 1 (across # SDP/SP ranks too), regardless of how the doc is sharded. - token_weight = flat_mask / num_labels_in_seq.reshape(-1).to(log_ratio.dtype).clamp(min=1) + mean_token_weight = flat_mask / num_labels_in_seq.reshape(-1).to(log_ratio.dtype).clamp(min=1) # Pre-divide the per-token contributions by the per-doc label count, then sum per segment. # All tokens in a segment share the same N_d, so this is mathematically equivalent to # `log_ratio_sum / N_d` but avoids any per-segment denominator extraction. mean_log_ratio_per_segment = log_ratio.new_zeros(num_segments).index_add_( - 0, flat_document_index, log_ratio.reshape(-1) * token_weight + 0, flat_document_index, log_ratio.reshape(-1) * mean_token_weight ) # Accumulate in `log_ratio.dtype` (fp32). Casting the product back to `advantages.dtype` # before summing would round each token's contribution to a possibly-low input dtype. mean_advantage_per_segment = log_ratio.new_zeros(num_segments).index_add_( - 0, flat_document_index, advantages.reshape(-1).to(log_ratio.dtype) * token_weight + 0, flat_document_index, advantages.reshape(-1).to(log_ratio.dtype) * mean_token_weight ) for reduce_group in (sdp_group, sp_group): if reduce_group is not None: @@ -519,13 +520,13 @@ def fused_gspo_loss_forward_backward( probability_ratio = segment_ratio[flat_document_index].reshape(log_ratio.shape) advantage_per_token = segment_advantage[flat_document_index].reshape(log_ratio.shape) - token_weight = token_weight.reshape(log_ratio.shape) + loss_weight = loss_mask.to(log_ratio.dtype) losses = -torch.min( probability_ratio * advantage_per_token, torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantage_per_token, ) - loss = (losses * token_weight).sum() / divisor + loss = (losses * loss_weight).sum() / divisor new_logprobs_mean = (new_log_probs * loss_mask / num_labels_in_seq.clamp(min=1)).sum() @@ -536,7 +537,7 @@ def fused_gspo_loss_forward_backward( torch.clamp_min(advantage_per_token, 0) * (probability_ratio <= 1 + epsilon_high) + torch.clamp_max(advantage_per_token, 0) * (probability_ratio >= 1 - epsilon_low) ) - * token_weight + * loss_weight ) predicted_probabilities = exp_logits / sum_exp_logits.unsqueeze_(-1) grad = (probability_ratio_grad * probability_ratio).unsqueeze(-1) * predicted_probabilities.scatter_add( diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 8026d9739..d7d14ad3e 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -206,7 +206,7 @@ def reference_gspo_loss( ratio = (flat_log_ratio[in_segment].sum() / count_float).exp() advantage = (flat_advantages[in_segment].sum() / count_float).detach() clipped_ratio = ratio.clamp(1 - epsilon_low, 1 + epsilon_high) - total = total + -torch.minimum(ratio * advantage, clipped_ratio * advantage) + total = total + -torch.minimum(ratio * advantage, clipped_ratio * advantage) * count_float # Matches the kernel's `sum_t logprob_t * mask_t / N_d` — sum of per-document mean logprobs. new_logprobs = new_logprobs + target_log_probabilities.reshape(-1)[in_segment].sum() / count_float total = total / num_segments From 17d3ccb95c0692749988d7b50b4418d8e56dea48 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 17 Jun 2026 17:56:05 -0400 Subject: [PATCH 2/2] Add GSPO uneven-document regression test --- tests/layers/test_lm_head.py | 61 ++++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9cf5fc4b1..6f09cb108 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -38,6 +38,7 @@ class LMHeadTestConfig: prediction_heads: int = 1 tied_embedding_weight: bool = False num_splits: int = 1 + gspo_document_lengths: tuple[int, ...] | None = None @property def actual_label_loss(self): @@ -97,6 +98,15 @@ def get_config(self) -> GPTModelConfig: }, ) + @property + def actual_gspo_document_lengths(self) -> tuple[int, ...]: + if self.gspo_document_lengths is not None: + Assert.eq(sum(self.gspo_document_lengths), NUM_TOKENS) + return self.gspo_document_lengths + document_length = NUM_TOKENS // GSPO_NUM_DOCUMENTS + Assert.eq(document_length * GSPO_NUM_DOCUMENTS, NUM_TOKENS) + return (document_length,) * GSPO_NUM_DOCUMENTS + def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: device = "cuda" if torch.cuda.is_available() else "cpu" input_ = torch.randn( @@ -159,23 +169,32 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: for labels_ in kwargs[LanguageModelKwargs.labels] ] kwargs[LanguageModelKwargs.num_documents_in_batch] = ( - GSPO_NUM_DOCUMENTS if self.gspo_loss is not False else 1 + len(self.actual_gspo_document_lengths) if self.gspo_loss is not False else 1 ) if self.gspo_loss is not False: - document_length = NUM_TOKENS // GSPO_NUM_DOCUMENTS + document_lengths = self.actual_gspo_document_lengths + document_length_repeats = torch.tensor(document_lengths, dtype=torch.int64, device=device) kwargs[BlockKwargs.global_document_index_q] = torch.repeat_interleave( - torch.arange(1, GSPO_NUM_DOCUMENTS + 1, dtype=torch.int32, device=device), document_length + torch.arange(1, len(document_lengths) + 1, dtype=torch.int32, device=device), document_length_repeats ) - kwargs[BlockKwargs.num_documents_in_sequence] = GSPO_NUM_DOCUMENTS - kwargs[BlockKwargs.lengths] = [document_length] * GSPO_NUM_DOCUMENTS + kwargs[BlockKwargs.num_documents_in_sequence] = len(document_lengths) + kwargs[BlockKwargs.lengths] = list(document_lengths) # Override label_counts: per-token broadcast of the containing document's masked-label count # (the kernel's per-document `new_logprobs` aggregation depends on this). kwargs[LanguageModelLossKwargs.label_counts] = [ - (labels_ >= 0) - .view(GSPO_NUM_DOCUMENTS, document_length) - .sum(-1) - .to(torch.float32) - .repeat_interleave(document_length) + torch.cat( + [ + torch.full( + (document_length,), + float(document_mask.sum()), + dtype=torch.float32, + device=device, + ) + for document_mask, document_length in zip( + torch.split(labels_ >= 0, document_lengths), document_lengths, strict=True + ) + ] + ) for labels_ in kwargs[LanguageModelKwargs.labels] ] return input_, kwargs @@ -258,18 +277,21 @@ def get_reference_outputs( ) # Average over documents of per-document mean log-prob — matches the kernel's # `sum_t logprob_t * mask_t / label_count_t` divided by `num_documents_in_batch`. - document_length = NUM_TOKENS // GSPO_NUM_DOCUMENTS target_log_probabilities = ( torch.nn.functional.log_softmax(logits, -1) .gather(-1, (labels * (labels >= 0)).unsqueeze(-1)) .squeeze(-1) ) label_mask = (labels >= 0).to(target_log_probabilities.dtype) - logprob_sums_per_document = ( - (target_log_probabilities * label_mask).view(GSPO_NUM_DOCUMENTS, document_length).sum(-1) - ) - label_counts_per_document = label_mask.view(GSPO_NUM_DOCUMENTS, document_length).sum(-1).clamp(min=1) - new_logprobs = (logprob_sums_per_document / label_counts_per_document).mean() + document_means = [ + (document_log_probabilities * document_label_mask).sum() / document_label_mask.sum().clamp(min=1) + for document_log_probabilities, document_label_mask in zip( + torch.split(target_log_probabilities, self.actual_gspo_document_lengths), + torch.split(label_mask, self.actual_gspo_document_lengths), + strict=True, + ) + ] + new_logprobs = torch.stack(document_means).mean() names_losses_weights.append(("gspo_loss", gspo_loss, float(self.gspo_loss))) names_losses_weights.append(("gspo_loss_new_logprobs", new_logprobs, 0.0)) @@ -327,6 +349,13 @@ def _add_configs(base_name: str, **kwargs): loss_masking=loss_masking, ) ) +_lm_head_test_configs.append( + LMHeadTestConfig( + "gspo_loss_uneven_documents", + gspo_loss=True, + gspo_document_lengths=(17, 31, 58, 94), + ) +) _add_configs("label_and_distillation_loss", label_loss=True, distillation_loss=True) _add_configs("label_and_z_loss_weighted", label_loss=True, z_loss=0.5) _add_configs("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0)