[chronos-2] add support for SDPA (#331)

This pull request introduces configurable attention backends to the
Chronos-2 model, allowing users to select between eager, SDPA, and
FlashAttention-2 implementations.

---------

Co-authored-by: Oleksandr Shchur <oleks.shchur@gmail.com>
Co-authored-by: Abdul Fatir <Abdulfatirs@gmail.com>
This commit is contained in:
Kashif Rasul 2025-10-22 14:02:09 +02:00 committed by GitHub
parent 0c51188db7
commit ca9c3275a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 212 additions and 15 deletions

View file

@ -663,7 +663,6 @@ def main(
lr_scheduler_type=lr_scheduler_type,
warmup_ratio=warmup_ratio,
optim=optim,
logging_dir=str(output_dir / "logs"),
logging_strategy="steps",
logging_steps=log_steps,
save_strategy="steps",

View file

@ -4,7 +4,7 @@
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
from dataclasses import dataclass
from typing import List
from typing import List, Literal
from transformers.configuration_utils import PretrainedConfig
@ -39,6 +39,8 @@ class Chronos2CoreConfig(PretrainedConfig):
Token ID for padding/missing value token, by default 0
rope_theta
The base theta for rotary position embedding (RoPE), by default 10000.0
attn_implementation
The attention implementation to use. Options: "eager" or "sdpa", by default None (uses "sdpa")
"""
model_type = "t5"
@ -63,6 +65,7 @@ class Chronos2CoreConfig(PretrainedConfig):
vocab_size: int = 2,
pad_token_id: int = 0,
rope_theta: float = 10000.0,
attn_implementation: Literal["eager", "sdpa"] | None = None,
**kwargs,
):
self.vocab_size = vocab_size
@ -83,11 +86,17 @@ class Chronos2CoreConfig(PretrainedConfig):
assert not self.is_gated_act, "gated activation is not supported"
# Attention implementation - default to "sdpa" if not specified
attn_implementation = attn_implementation or "sdpa"
assert attn_implementation in ["eager", "sdpa"], f"attn_implementation {attn_implementation} not supported"
# unused
kwargs.pop("is_encoder_decoder", None)
kwargs.pop("eos_token_id", None)
super().__init__(pad_token_id=pad_token_id, is_encoder_decoder=False, **kwargs)
super().__init__(
pad_token_id=pad_token_id, is_encoder_decoder=False, attn_implementation=attn_implementation, **kwargs
)
@dataclass

View file

@ -155,6 +155,7 @@ class MHA(nn.Module):
self.n_heads: int = config.num_heads
self.dropout: float = config.dropout_rate
self.inner_dim: int = self.n_heads * self.kv_proj_dim
self.config = config
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
@ -165,6 +166,64 @@ class MHA(nn.Module):
if use_rope:
self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta)
def _eager_attention(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Eager attention implementation using manual matmul.
Args:
query_states: [batch, n_heads, seq_len, kv_proj_dim]
key_states: [batch, n_heads, seq_len, kv_proj_dim]
value_states: [batch, n_heads, seq_len, kv_proj_dim]
mask: [batch, n_heads, q_len, kv_len]
Returns:
attn_output: [batch, n_heads, seq_len, kv_proj_dim]
attn_weights: [batch, n_heads, q_len, kv_len]
"""
# Compute attention weights (no scaling - this is the original Chronos-2 implementation)
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk"
scores += mask
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
return attn_output, attn_weights
def _sdpa_attention(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
mask: torch.Tensor,
) -> tuple[torch.Tensor, None]:
"""SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention.
Args:
query_states: [batch, n_heads, seq_len, kv_proj_dim]
key_states: [batch, n_heads, seq_len, kv_proj_dim]
value_states: [batch, n_heads, seq_len, kv_proj_dim]
mask: [batch, n_heads, q_len, kv_len] - additive mask (0 for valid, -inf for invalid)
Returns:
attn_output: [batch, n_heads, seq_len, kv_proj_dim]
attn_weights: None (SDPA doesn't return weights)
"""
attn_output = nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
scale=1.0, # Match eager implementation (no scaling)
)
return attn_output, None
def forward(
self,
hidden_states: torch.Tensor,
@ -190,6 +249,11 @@ class MHA(nn.Module):
if self.use_rope:
assert position_ids is not None, "position_ids must be provided when self.use_rope=True"
# Force eager attention if output_attentions is True (only eager returns weights)
attn_implementation = self.config._attn_implementation
if output_attentions:
attn_implementation = "eager"
seq_length = hidden_states.shape[1]
def shape(states: torch.Tensor) -> torch.Tensor:
@ -215,12 +279,10 @@ class MHA(nn.Module):
cos, sin = self.rope_embed(value_states, position_ids)
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Compute attention weights
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk"
scores += mask
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_implementation == "sdpa":
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)
else: # eager
attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask)
# Project attention output
attn_output = unshape(attn_output)

View file

@ -199,6 +199,7 @@ class Chronos2Model(PreTrainedModel):
config_class = Chronos2CoreConfig # type: ignore[assignment]
_supports_long_horizon: bool = True
_supports_future_covariates: bool = True
_supports_sdpa: bool = True
def __init__(self, config: Chronos2CoreConfig):
assert hasattr(config, "chronos_config"), "Not a valid Chronos config"

View file

@ -211,7 +211,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
lr_scheduler_type="linear",
warmup_ratio=0.0,
optim="adamw_torch_fused",
logging_dir=str(output_dir / "logs"),
logging_strategy="steps",
logging_steps=100,
disable_tqdm=False,

View file

@ -14,6 +14,9 @@ import torch
from chronos import BaseChronosPipeline, Chronos2Pipeline
from chronos.chronos2.dataset import convert_df_input_to_list_of_dicts_input
from chronos.chronos2.config import Chronos2CoreConfig
from chronos.chronos2.layers import MHA
from test.util import validate_tensor
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model"
@ -317,13 +320,11 @@ def test_when_input_is_invalid_then_predict_raises_value_error(pipeline, inputs,
_ = pipeline.predict(inputs, prediction_length=10)
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_predict_can_handle_different_model_and_input_dtypes(
torch_dtype: torch.dtype, input_dtype: torch.dtype
):
def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = BaseChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", torch_dtype=torch_dtype
Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", dtype=dtype
)
context = 10 * torch.rand(size=(4, 3, 16)) + 10
context = context.to(dtype=input_dtype)
@ -936,3 +937,129 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future
# Check predictions from the fine-tuned model are different from the original predictions
assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy())
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
def test_pipeline_works_with_different_attention_implementations(attn_implementation):
"""Test that the pipeline works with different attention implementations."""
# Load the dummy model
model_path = Path(__file__).parent / "dummy-chronos2-model"
# Load with specified attention implementation
pipeline = BaseChronosPipeline.from_pretrained(
model_path, device_map="cpu", attn_implementation=attn_implementation
)
# Verify the config has the correct attention implementation
assert pipeline.model.config._attn_implementation == attn_implementation
# Test prediction with simple input
inputs = torch.rand(2, 1, 16)
prediction_length = 7
outputs = pipeline.predict(inputs, prediction_length=prediction_length)
# Check outputs are valid
assert isinstance(outputs, list) and len(outputs) == 2
for out in outputs:
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 7), dtype=torch.float32)
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
@pytest.mark.parametrize("output_attentions", [False, True])
def test_attention_implementations_with_output_attentions(attn_implementation, output_attentions):
"""Test that attention implementations handle output_attentions correctly."""
# Create config with specified attention implementation
config = Chronos2CoreConfig(
d_model=128,
d_kv=32,
num_heads=4,
dropout_rate=0.1,
attn_implementation=attn_implementation,
)
# Create MHA layer
mha = MHA(config, use_rope=True)
mha.eval()
# Create dummy inputs
batch_size = 2
seq_len = 10
hidden_states = torch.randn(batch_size, seq_len, config.d_model)
position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len)
# Test forward pass
output = mha(
hidden_states=hidden_states,
mask=mask,
position_ids=position_ids,
output_attentions=output_attentions,
)
# Check output shape
assert output.hidden_states.shape == (batch_size, seq_len, config.d_model)
# Check attention weights - should only be returned when output_attentions=True
if output_attentions:
assert output.attn_weights is not None
assert output.attn_weights.shape == (batch_size, config.num_heads, seq_len, seq_len)
else:
# SDPA doesn't return weights
if attn_implementation == "sdpa":
assert output.attn_weights is None
def test_eager_and_sdpa_produce_identical_outputs(pipeline):
"""Test that eager and SDPA implementations produce identical outputs on full pipeline."""
# Reload pipeline with SDPA
model_path = Path(__file__).parent / "dummy-chronos2-model"
pipeline_sdpa = BaseChronosPipeline.from_pretrained(
model_path, device_map="cpu", attn_implementation="sdpa", dtype=torch.float32
)
# Note: the original pipeline fixture uses default attn_implementation which should be sdpa
# Force eager for comparison
pipeline_eager = BaseChronosPipeline.from_pretrained(
model_path, device_map="cpu", attn_implementation="eager", dtype=torch.float32
)
# Test 1: Simple univariate input
inputs_simple = torch.rand(2, 1, 16)
prediction_length = 7
with torch.no_grad():
outputs_eager = pipeline_eager.predict(inputs_simple, prediction_length=prediction_length)
outputs_sdpa = pipeline_sdpa.predict(inputs_simple, prediction_length=prediction_length)
# Verify outputs match exactly
assert len(outputs_eager) == len(outputs_sdpa)
for out_eager, out_sdpa in zip(outputs_eager, outputs_sdpa):
# Should match exactly or very close (numerical precision)
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
# Test 2: Multivariate inputs with covariates to test group attention
inputs_grouped = [
{
"target": np.random.randn(2, 36),
"past_covariates": {
"temperature": np.random.randn(36),
"weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=36),
},
"future_covariates": {
"temperature": np.random.randn(prediction_length),
"weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=prediction_length),
},
}
for _ in range(5)
]
with torch.no_grad():
outputs_eager_grouped = pipeline_eager.predict(inputs_grouped, prediction_length=prediction_length)
outputs_sdpa_grouped = pipeline_sdpa.predict(inputs_grouped, prediction_length=prediction_length)
# Verify outputs match for grouped inputs
assert len(outputs_eager_grouped) == len(outputs_sdpa_grouped)
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
# Should match exactly or very close (numerical precision)
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)