Skip to content

YaRN#445

Merged
rrutmann merged 9 commits into
mainfrom
yarn_hf
Jun 2, 2026
Merged

YaRN#445
rrutmann merged 9 commits into
mainfrom
yarn_hf

Conversation

@rrutmann
Copy link
Copy Markdown
Collaborator

@rrutmann rrutmann commented May 11, 2026

What does this PR do?

This PR adds YaRN support to rotary position embeddings in the GPT-2 attention path.

General Changes

  • Implemented YaRN parameterization in rotary embeddings in gpt2_model.py
  • Added/updated YaRN configuration in config_lorem_ipsum_long_fsdp2_yarn.yaml
  • Refactored and strengthened rotary tests in test_rotary_qkv_transform.py

Breaking Changes

  • ..

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

rrutmann and others added 5 commits May 11, 2026 12:49
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
@rrutmann rrutmann requested a review from le1nux May 12, 2026 12:28
@rrutmann rrutmann self-assigned this May 12, 2026
@le1nux le1nux requested a review from BlueCrescent May 13, 2026 10:06
@rrutmann rrutmann requested review from therealdavidos and removed request for BlueCrescent May 20, 2026 08:32
Copy link
Copy Markdown
Collaborator

@therealdavidos therealdavidos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good from the math perspective!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related to the yarn PR?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it isn't related to Yarn. I noticed this test was failing and fixed it. I can open a separate PR for the test fix if needed, but I'm not sure it's worth the extra overhead.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah thats fine. thanks for the clarification


self.reset_parameters()

def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pleace place private methods below the public interface of the class.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this in e12db1a

seq_length_dim: Annotated[int, Field(strict=True)]
base_freq: Annotated[int, Field(strict=True, ge=10000)]
max_position_embeddings: Optional[Annotated[int, Field(strict=True, ge=1)]] = None
rope_scaling: Optional[dict[str, object]] = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this play nicely with our config setup? Would it be possible to have something like "rope_scaling: YarnSettings | DefaultSettingsIfExists | SomeFutureRopeScalingSettings | None = None" with the Settings being BaseModels themselves?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I added Configs based on BaseModel in b91762a

Comment on lines +179 to +182
beta_fast_raw = self.rope_scaling.get("beta_fast")
beta_slow_raw = self.rope_scaling.get("beta_slow")
beta_fast = float(beta_fast_raw) if isinstance(beta_fast_raw, (int, float)) else 32.0
beta_slow = float(beta_slow_raw) if isinstance(beta_slow_raw, (int, float)) else 1.0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried that in case these parameters are strings or torch types for some reason they will get dropped silently here.

Copy link
Copy Markdown
Collaborator Author

@rrutmann rrutmann Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this in 82019f1

return 0.1 * mscale * math.log(scale) + 1.0

if attention_factor is None:
if isinstance(mscale, (int, float)) and isinstance(mscale_all_dim, (int, float)):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried that in case these parameters are strings or torch types for some reason they will get dropped silently here.

Copy link
Copy Markdown
Collaborator Author

@rrutmann rrutmann Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this in 82019f1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah thats fine. thanks for the clarification

@rrutmann rrutmann merged commit 7337fe4 into main Jun 2, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants