From ca9c3275a2c294000c764478fc951c91a1feea76 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Oct 2025 14:02:09 +0200 Subject: [PATCH] [chronos-2] add support for SDPA (#331) This pull request introduces configurable attention backends to the Chronos-2 model, allowing users to select between eager, SDPA, and FlashAttention-2 implementations. --------- Co-authored-by: Oleksandr Shchur Co-authored-by: Abdul Fatir --- scripts/training/train.py | 1 - src/chronos/chronos2/config.py | 13 ++- src/chronos/chronos2/layers.py | 74 +++++++++++++++-- src/chronos/chronos2/model.py | 1 + src/chronos/chronos2/pipeline.py | 1 - test/test_chronos2.py | 137 +++++++++++++++++++++++++++++-- 6 files changed, 212 insertions(+), 15 deletions(-) diff --git a/scripts/training/train.py b/scripts/training/train.py index c16092e..09d5d8e 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -663,7 +663,6 @@ def main( lr_scheduler_type=lr_scheduler_type, warmup_ratio=warmup_ratio, optim=optim, - logging_dir=str(output_dir / "logs"), logging_strategy="steps", logging_steps=log_steps, save_strategy="steps", diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index f73fda5..c6e011c 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -4,7 +4,7 @@ # Authors: Abdul Fatir Ansari from dataclasses import dataclass -from typing import List +from typing import List, Literal from transformers.configuration_utils import PretrainedConfig @@ -39,6 +39,8 @@ class Chronos2CoreConfig(PretrainedConfig): Token ID for padding/missing value token, by default 0 rope_theta The base theta for rotary position embedding (RoPE), by default 10000.0 + attn_implementation + The attention implementation to use. Options: "eager" or "sdpa", by default None (uses "sdpa") """ model_type = "t5" @@ -63,6 +65,7 @@ class Chronos2CoreConfig(PretrainedConfig): vocab_size: int = 2, pad_token_id: int = 0, rope_theta: float = 10000.0, + attn_implementation: Literal["eager", "sdpa"] | None = None, **kwargs, ): self.vocab_size = vocab_size @@ -83,11 +86,17 @@ class Chronos2CoreConfig(PretrainedConfig): assert not self.is_gated_act, "gated activation is not supported" + # Attention implementation - default to "sdpa" if not specified + attn_implementation = attn_implementation or "sdpa" + assert attn_implementation in ["eager", "sdpa"], f"attn_implementation {attn_implementation} not supported" + # unused kwargs.pop("is_encoder_decoder", None) kwargs.pop("eos_token_id", None) - super().__init__(pad_token_id=pad_token_id, is_encoder_decoder=False, **kwargs) + super().__init__( + pad_token_id=pad_token_id, is_encoder_decoder=False, attn_implementation=attn_implementation, **kwargs + ) @dataclass diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index 2c4e6b3..b00e8a8 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -155,6 +155,7 @@ class MHA(nn.Module): self.n_heads: int = config.num_heads self.dropout: float = config.dropout_rate self.inner_dim: int = self.n_heads * self.kv_proj_dim + self.config = config self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -165,6 +166,64 @@ class MHA(nn.Module): if use_rope: self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta) + def _eager_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Eager attention implementation using manual matmul. + + Args: + query_states: [batch, n_heads, seq_len, kv_proj_dim] + key_states: [batch, n_heads, seq_len, kv_proj_dim] + value_states: [batch, n_heads, seq_len, kv_proj_dim] + mask: [batch, n_heads, q_len, kv_len] + + Returns: + attn_output: [batch, n_heads, seq_len, kv_proj_dim] + attn_weights: [batch, n_heads, q_len, kv_len] + """ + # Compute attention weights (no scaling - this is the original Chronos-2 implementation) + scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" + scores += mask + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + return attn_output, attn_weights + + def _sdpa_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + mask: torch.Tensor, + ) -> tuple[torch.Tensor, None]: + """SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention. + + Args: + query_states: [batch, n_heads, seq_len, kv_proj_dim] + key_states: [batch, n_heads, seq_len, kv_proj_dim] + value_states: [batch, n_heads, seq_len, kv_proj_dim] + mask: [batch, n_heads, q_len, kv_len] - additive mask (0 for valid, -inf for invalid) + + Returns: + attn_output: [batch, n_heads, seq_len, kv_proj_dim] + attn_weights: None (SDPA doesn't return weights) + """ + attn_output = nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + scale=1.0, # Match eager implementation (no scaling) + ) + + return attn_output, None + def forward( self, hidden_states: torch.Tensor, @@ -190,6 +249,11 @@ class MHA(nn.Module): if self.use_rope: assert position_ids is not None, "position_ids must be provided when self.use_rope=True" + # Force eager attention if output_attentions is True (only eager returns weights) + attn_implementation = self.config._attn_implementation + if output_attentions: + attn_implementation = "eager" + seq_length = hidden_states.shape[1] def shape(states: torch.Tensor) -> torch.Tensor: @@ -215,12 +279,10 @@ class MHA(nn.Module): cos, sin = self.rope_embed(value_states, position_ids) query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) - # Compute attention weights - scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" - scores += mask - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + if attn_implementation == "sdpa": + attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) + else: # eager + attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask) # Project attention output attn_output = unshape(attn_output) diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 06eb708..9b72f61 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -199,6 +199,7 @@ class Chronos2Model(PreTrainedModel): config_class = Chronos2CoreConfig # type: ignore[assignment] _supports_long_horizon: bool = True _supports_future_covariates: bool = True + _supports_sdpa: bool = True def __init__(self, config: Chronos2CoreConfig): assert hasattr(config, "chronos_config"), "Not a valid Chronos config" diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index e00d336..2250fb2 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -211,7 +211,6 @@ class Chronos2Pipeline(BaseChronosPipeline): lr_scheduler_type="linear", warmup_ratio=0.0, optim="adamw_torch_fused", - logging_dir=str(output_dir / "logs"), logging_strategy="steps", logging_steps=100, disable_tqdm=False, diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 753a944..3c4281c 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -14,6 +14,9 @@ import torch from chronos import BaseChronosPipeline, Chronos2Pipeline from chronos.chronos2.dataset import convert_df_input_to_list_of_dicts_input +from chronos.chronos2.config import Chronos2CoreConfig +from chronos.chronos2.layers import MHA + from test.util import validate_tensor DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model" @@ -317,13 +320,11 @@ def test_when_input_is_invalid_then_predict_raises_value_error(pipeline, inputs, _ = pipeline.predict(inputs, prediction_length=10) -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) -def test_pipeline_predict_can_handle_different_model_and_input_dtypes( - torch_dtype: torch.dtype, input_dtype: torch.dtype -): +def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: torch.dtype, input_dtype: torch.dtype): pipeline = BaseChronosPipeline.from_pretrained( - Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", torch_dtype=torch_dtype + Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", dtype=dtype ) context = 10 * torch.rand(size=(4, 3, 16)) + 10 context = context.to(dtype=input_dtype) @@ -936,3 +937,129 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future # Check predictions from the fine-tuned model are different from the original predictions assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy()) + + +@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) +def test_pipeline_works_with_different_attention_implementations(attn_implementation): + """Test that the pipeline works with different attention implementations.""" + # Load the dummy model + model_path = Path(__file__).parent / "dummy-chronos2-model" + + # Load with specified attention implementation + pipeline = BaseChronosPipeline.from_pretrained( + model_path, device_map="cpu", attn_implementation=attn_implementation + ) + + # Verify the config has the correct attention implementation + assert pipeline.model.config._attn_implementation == attn_implementation + + # Test prediction with simple input + inputs = torch.rand(2, 1, 16) + prediction_length = 7 + + outputs = pipeline.predict(inputs, prediction_length=prediction_length) + + # Check outputs are valid + assert isinstance(outputs, list) and len(outputs) == 2 + for out in outputs: + validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 7), dtype=torch.float32) + + +@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) +@pytest.mark.parametrize("output_attentions", [False, True]) +def test_attention_implementations_with_output_attentions(attn_implementation, output_attentions): + """Test that attention implementations handle output_attentions correctly.""" + # Create config with specified attention implementation + config = Chronos2CoreConfig( + d_model=128, + d_kv=32, + num_heads=4, + dropout_rate=0.1, + attn_implementation=attn_implementation, + ) + + # Create MHA layer + mha = MHA(config, use_rope=True) + mha.eval() + + # Create dummy inputs + batch_size = 2 + seq_len = 10 + hidden_states = torch.randn(batch_size, seq_len, config.d_model) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len) + + # Test forward pass + output = mha( + hidden_states=hidden_states, + mask=mask, + position_ids=position_ids, + output_attentions=output_attentions, + ) + + # Check output shape + assert output.hidden_states.shape == (batch_size, seq_len, config.d_model) + + # Check attention weights - should only be returned when output_attentions=True + if output_attentions: + assert output.attn_weights is not None + assert output.attn_weights.shape == (batch_size, config.num_heads, seq_len, seq_len) + else: + # SDPA doesn't return weights + if attn_implementation == "sdpa": + assert output.attn_weights is None + + +def test_eager_and_sdpa_produce_identical_outputs(pipeline): + """Test that eager and SDPA implementations produce identical outputs on full pipeline.""" + # Reload pipeline with SDPA + model_path = Path(__file__).parent / "dummy-chronos2-model" + pipeline_sdpa = BaseChronosPipeline.from_pretrained( + model_path, device_map="cpu", attn_implementation="sdpa", dtype=torch.float32 + ) + + # Note: the original pipeline fixture uses default attn_implementation which should be sdpa + # Force eager for comparison + pipeline_eager = BaseChronosPipeline.from_pretrained( + model_path, device_map="cpu", attn_implementation="eager", dtype=torch.float32 + ) + + # Test 1: Simple univariate input + inputs_simple = torch.rand(2, 1, 16) + prediction_length = 7 + + with torch.no_grad(): + outputs_eager = pipeline_eager.predict(inputs_simple, prediction_length=prediction_length) + outputs_sdpa = pipeline_sdpa.predict(inputs_simple, prediction_length=prediction_length) + + # Verify outputs match exactly + assert len(outputs_eager) == len(outputs_sdpa) + for out_eager, out_sdpa in zip(outputs_eager, outputs_sdpa): + # Should match exactly or very close (numerical precision) + assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) + + # Test 2: Multivariate inputs with covariates to test group attention + inputs_grouped = [ + { + "target": np.random.randn(2, 36), + "past_covariates": { + "temperature": np.random.randn(36), + "weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=36), + }, + "future_covariates": { + "temperature": np.random.randn(prediction_length), + "weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=prediction_length), + }, + } + for _ in range(5) + ] + + with torch.no_grad(): + outputs_eager_grouped = pipeline_eager.predict(inputs_grouped, prediction_length=prediction_length) + outputs_sdpa_grouped = pipeline_sdpa.predict(inputs_grouped, prediction_length=prediction_length) + + # Verify outputs match for grouped inputs + assert len(outputs_eager_grouped) == len(outputs_sdpa_grouped) + for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped): + # Should match exactly or very close (numerical precision) + assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)