Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions fast_llm/functional/triton/gspo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def triton_gspo_loss_backward_kernel(
sum_exp_logits_ptr,
probability_ratio_ptr,
seg_advantage_ptr,
token_weight_ptr,
loss_weight_ptr,
grad_logits_ptr,
n_cols: tl_constexpr,
logits_stride_0: tl_constexpr,
Expand All @@ -31,10 +31,8 @@ def triton_gspo_loss_backward_kernel(
):
block_idx = tl.program_id(0).to(tl.int64)

# token_weight = mask_t / N_d, where N_d is the labeled-token count for the doc containing t.
# Zero for masked tokens (mask=0) and for tokens with N_d=0 after the kernel's clamp.
token_weight = tl.load(token_weight_ptr + block_idx).to(tl.float32)
if token_weight == 0.0:
loss_weight = tl.load(loss_weight_ptr + block_idx).to(tl.float32)
if loss_weight == 0.0:
if not accumulate:
for col_offset in tl.static_range(0, n_cols, block_size):
col_offsets = tl_arange(int(col_offset), int(col_offset + block_size))
Expand All @@ -61,7 +59,7 @@ def triton_gspo_loss_backward_kernel(
)
* probability_ratio
* grad_scale
* token_weight
* loss_weight
)

logits_ptr = logits_ptr + block_idx * logits_stride_0
Expand Down Expand Up @@ -107,7 +105,7 @@ def triton_gspo_loss_forward_backward(
num_warps: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Triton GSPO loss. Forward fuses softmax + predicted-logit lookup; backward fuses the
softmax chain rule with the per-token GSPO gradient factor (R_s * clip * token_weight).
softmax chain rule with the per-token GSPO gradient factor (R_s * clip * loss_weight).
Segment aggregation, loss, and the SDP/SP all-reduce live in PyTorch between the two passes.

See `fused_gspo_loss_forward_backward` in policy_gradient.py for the math derivation;
Expand Down Expand Up @@ -163,13 +161,13 @@ def triton_gspo_loss_forward_backward(
# contribution to the per-segment sum is already normalized, so SDP/SP all-reduce works
# without a separate token-count tensor.
flat_num_labels = num_labels_in_seq.reshape(-1).to(new_log_probs.dtype).clamp(min=1)
token_weight = loss_mask / flat_num_labels
mean_token_weight = loss_mask / flat_num_labels

mean_log_ratio_per_segment = log_ratio.new_zeros(num_segments).index_add_(
0, flat_document_index, log_ratio * token_weight
0, flat_document_index, log_ratio * mean_token_weight
)
mean_advantage_per_segment = log_ratio.new_zeros(num_segments).index_add_(
0, flat_document_index, flat_advantages * token_weight
0, flat_document_index, flat_advantages * mean_token_weight
)
for reduce_group in (sdp_group, sp_group):
if reduce_group is not None:
Expand All @@ -185,13 +183,13 @@ def triton_gspo_loss_forward_backward(

probability_ratio = segment_ratio[flat_document_index].contiguous()
seg_advantage = segment_advantage[flat_document_index].contiguous()
token_weight = token_weight.contiguous()
loss_weight = loss_mask.contiguous()

losses = -torch.min(
probability_ratio * seg_advantage,
torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * seg_advantage,
)
loss = (losses * token_weight).sum() / divisor
loss = (losses * loss_weight).sum() / divisor

new_logprobs_mean = (new_log_probs * loss_mask / flat_num_labels).sum()

Expand All @@ -208,7 +206,7 @@ def triton_gspo_loss_forward_backward(
sum_exp_logits_ptr=sum_exp_logits,
probability_ratio_ptr=probability_ratio,
seg_advantage_ptr=seg_advantage,
token_weight_ptr=token_weight,
loss_weight_ptr=loss_weight,
grad_logits_ptr=grad_logits,
n_cols=n_cols,
logits_stride_0=logits.stride(-2),
Expand Down
21 changes: 11 additions & 10 deletions fast_llm/layers/language_model/loss/policy_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,13 +460,14 @@ def fused_gspo_loss_forward_backward(
"""GSPO loss: sequence-level geometric-mean IS-ratio clipping.

Per-segment ratio R_s = exp(mean_t(log p_new_t / p_old_t)), clipped per segment.
Per-segment loss = -min(R_s * A_s, clip(R_s) * A_s), summed over segments and divided by `divisor`.
Per-segment loss = -min(R_s * A_s, clip(R_s) * A_s), multiplied by the segment
token count, summed over segments, and divided by `divisor`.
Computed as an equivalent per-token sum so the gradient chain mirrors GRPO.

`num_labels_in_seq[t]` is the labeled-token count for the document containing token `t`
(broadcast per token by the data preprocessor); it doubles as the geometric-mean denominator
and the per-token weight. Using it directly — rather than aggregating token counts inside the
kernel — is what makes the loss correct when a document spans SDP/SP ranks (numerator
(broadcast per token by the data preprocessor); it is the geometric-mean denominator.
Using it directly — rather than aggregating token counts inside the kernel — is what
makes the loss correct when a document spans SDP/SP ranks (numerator
`log_ratio_sum` is all-reduced; denominator is constant and available locally).

Constraint: each document must be visible to a single kernel call (modulo SDP/SP, where
Expand All @@ -493,17 +494,17 @@ def fused_gspo_loss_forward_backward(
# Per-token weight: mask / per-document label count, from the preprocessor.
# Each labeled token contributes `1 / N_d` so all of doc d's tokens sum to 1 (across
# SDP/SP ranks too), regardless of how the doc is sharded.
token_weight = flat_mask / num_labels_in_seq.reshape(-1).to(log_ratio.dtype).clamp(min=1)
mean_token_weight = flat_mask / num_labels_in_seq.reshape(-1).to(log_ratio.dtype).clamp(min=1)
# Pre-divide the per-token contributions by the per-doc label count, then sum per segment.
# All tokens in a segment share the same N_d, so this is mathematically equivalent to
# `log_ratio_sum / N_d` but avoids any per-segment denominator extraction.
mean_log_ratio_per_segment = log_ratio.new_zeros(num_segments).index_add_(
0, flat_document_index, log_ratio.reshape(-1) * token_weight
0, flat_document_index, log_ratio.reshape(-1) * mean_token_weight
)
# Accumulate in `log_ratio.dtype` (fp32). Casting the product back to `advantages.dtype`
# before summing would round each token's contribution to a possibly-low input dtype.
mean_advantage_per_segment = log_ratio.new_zeros(num_segments).index_add_(
0, flat_document_index, advantages.reshape(-1).to(log_ratio.dtype) * token_weight
0, flat_document_index, advantages.reshape(-1).to(log_ratio.dtype) * mean_token_weight
)
for reduce_group in (sdp_group, sp_group):
if reduce_group is not None:
Expand All @@ -519,13 +520,13 @@ def fused_gspo_loss_forward_backward(

probability_ratio = segment_ratio[flat_document_index].reshape(log_ratio.shape)
advantage_per_token = segment_advantage[flat_document_index].reshape(log_ratio.shape)
token_weight = token_weight.reshape(log_ratio.shape)
loss_weight = loss_mask.to(log_ratio.dtype)

losses = -torch.min(
probability_ratio * advantage_per_token,
torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantage_per_token,
)
loss = (losses * token_weight).sum() / divisor
loss = (losses * loss_weight).sum() / divisor

new_logprobs_mean = (new_log_probs * loss_mask / num_labels_in_seq.clamp(min=1)).sum()

Expand All @@ -536,7 +537,7 @@ def fused_gspo_loss_forward_backward(
torch.clamp_min(advantage_per_token, 0) * (probability_ratio <= 1 + epsilon_high)
+ torch.clamp_max(advantage_per_token, 0) * (probability_ratio >= 1 - epsilon_low)
)
* token_weight
* loss_weight
)
predicted_probabilities = exp_logits / sum_exp_logits.unsqueeze_(-1)
grad = (probability_ratio_grad * probability_ratio).unsqueeze(-1) * predicted_probabilities.scatter_add(
Expand Down
61 changes: 45 additions & 16 deletions tests/layers/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class LMHeadTestConfig:
prediction_heads: int = 1
tied_embedding_weight: bool = False
num_splits: int = 1
gspo_document_lengths: tuple[int, ...] | None = None

@property
def actual_label_loss(self):
Expand Down Expand Up @@ -97,6 +98,15 @@ def get_config(self) -> GPTModelConfig:
},
)

@property
def actual_gspo_document_lengths(self) -> tuple[int, ...]:
if self.gspo_document_lengths is not None:
Assert.eq(sum(self.gspo_document_lengths), NUM_TOKENS)
return self.gspo_document_lengths
document_length = NUM_TOKENS // GSPO_NUM_DOCUMENTS
Assert.eq(document_length * GSPO_NUM_DOCUMENTS, NUM_TOKENS)
return (document_length,) * GSPO_NUM_DOCUMENTS

def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]:
device = "cuda" if torch.cuda.is_available() else "cpu"
input_ = torch.randn(
Expand Down Expand Up @@ -159,23 +169,32 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]:
for labels_ in kwargs[LanguageModelKwargs.labels]
]
kwargs[LanguageModelKwargs.num_documents_in_batch] = (
GSPO_NUM_DOCUMENTS if self.gspo_loss is not False else 1
len(self.actual_gspo_document_lengths) if self.gspo_loss is not False else 1
)
if self.gspo_loss is not False:
document_length = NUM_TOKENS // GSPO_NUM_DOCUMENTS
document_lengths = self.actual_gspo_document_lengths
document_length_repeats = torch.tensor(document_lengths, dtype=torch.int64, device=device)
kwargs[BlockKwargs.global_document_index_q] = torch.repeat_interleave(
torch.arange(1, GSPO_NUM_DOCUMENTS + 1, dtype=torch.int32, device=device), document_length
torch.arange(1, len(document_lengths) + 1, dtype=torch.int32, device=device), document_length_repeats
)
kwargs[BlockKwargs.num_documents_in_sequence] = GSPO_NUM_DOCUMENTS
kwargs[BlockKwargs.lengths] = [document_length] * GSPO_NUM_DOCUMENTS
kwargs[BlockKwargs.num_documents_in_sequence] = len(document_lengths)
kwargs[BlockKwargs.lengths] = list(document_lengths)
# Override label_counts: per-token broadcast of the containing document's masked-label count
# (the kernel's per-document `new_logprobs` aggregation depends on this).
kwargs[LanguageModelLossKwargs.label_counts] = [
(labels_ >= 0)
.view(GSPO_NUM_DOCUMENTS, document_length)
.sum(-1)
.to(torch.float32)
.repeat_interleave(document_length)
torch.cat(
[
torch.full(
(document_length,),
float(document_mask.sum()),
dtype=torch.float32,
device=device,
)
for document_mask, document_length in zip(
torch.split(labels_ >= 0, document_lengths), document_lengths, strict=True
)
]
)
for labels_ in kwargs[LanguageModelKwargs.labels]
]
return input_, kwargs
Expand Down Expand Up @@ -258,18 +277,21 @@ def get_reference_outputs(
)
# Average over documents of per-document mean log-prob — matches the kernel's
# `sum_t logprob_t * mask_t / label_count_t` divided by `num_documents_in_batch`.
document_length = NUM_TOKENS // GSPO_NUM_DOCUMENTS
target_log_probabilities = (
torch.nn.functional.log_softmax(logits, -1)
.gather(-1, (labels * (labels >= 0)).unsqueeze(-1))
.squeeze(-1)
)
label_mask = (labels >= 0).to(target_log_probabilities.dtype)
logprob_sums_per_document = (
(target_log_probabilities * label_mask).view(GSPO_NUM_DOCUMENTS, document_length).sum(-1)
)
label_counts_per_document = label_mask.view(GSPO_NUM_DOCUMENTS, document_length).sum(-1).clamp(min=1)
new_logprobs = (logprob_sums_per_document / label_counts_per_document).mean()
document_means = [
(document_log_probabilities * document_label_mask).sum() / document_label_mask.sum().clamp(min=1)
for document_log_probabilities, document_label_mask in zip(
torch.split(target_log_probabilities, self.actual_gspo_document_lengths),
torch.split(label_mask, self.actual_gspo_document_lengths),
strict=True,
)
]
new_logprobs = torch.stack(document_means).mean()
names_losses_weights.append(("gspo_loss", gspo_loss, float(self.gspo_loss)))
names_losses_weights.append(("gspo_loss_new_logprobs", new_logprobs, 0.0))

Expand Down Expand Up @@ -327,6 +349,13 @@ def _add_configs(base_name: str, **kwargs):
loss_masking=loss_masking,
)
)
_lm_head_test_configs.append(
LMHeadTestConfig(
"gspo_loss_uneven_documents",
gspo_loss=True,
gspo_document_lengths=(17, 31, 58, 94),
)
)
_add_configs("label_and_distillation_loss", label_loss=True, distillation_loss=True)
_add_configs("label_and_z_loss_weighted", label_loss=True, z_loss=0.5)
_add_configs("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0)
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/test_lm_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def reference_gspo_loss(
ratio = (flat_log_ratio[in_segment].sum() / count_float).exp()
advantage = (flat_advantages[in_segment].sum() / count_float).detach()
clipped_ratio = ratio.clamp(1 - epsilon_low, 1 + epsilon_high)
total = total + -torch.minimum(ratio * advantage, clipped_ratio * advantage)
total = total + -torch.minimum(ratio * advantage, clipped_ratio * advantage) * count_float
# Matches the kernel's `sum_t logprob_t * mask_t / N_d` — sum of per-document mean logprobs.
new_logprobs = new_logprobs + target_log_probabilities.reshape(-1)[in_segment].sum() / count_float
total = total / num_segments
Expand Down
Loading