mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
[chronos-2] add support for SDPA (#331)
This pull request introduces configurable attention backends to the Chronos-2 model, allowing users to select between eager, SDPA, and FlashAttention-2 implementations. --------- Co-authored-by: Oleksandr Shchur <oleks.shchur@gmail.com> Co-authored-by: Abdul Fatir <Abdulfatirs@gmail.com>
This commit is contained in:
parent
0c51188db7
commit
ca9c3275a2
6 changed files with 212 additions and 15 deletions
|
|
@ -663,7 +663,6 @@ def main(
|
|||
lr_scheduler_type=lr_scheduler_type,
|
||||
warmup_ratio=warmup_ratio,
|
||||
optim=optim,
|
||||
logging_dir=str(output_dir / "logs"),
|
||||
logging_strategy="steps",
|
||||
logging_steps=log_steps,
|
||||
save_strategy="steps",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import List, Literal
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
|
@ -39,6 +39,8 @@ class Chronos2CoreConfig(PretrainedConfig):
|
|||
Token ID for padding/missing value token, by default 0
|
||||
rope_theta
|
||||
The base theta for rotary position embedding (RoPE), by default 10000.0
|
||||
attn_implementation
|
||||
The attention implementation to use. Options: "eager" or "sdpa", by default None (uses "sdpa")
|
||||
"""
|
||||
|
||||
model_type = "t5"
|
||||
|
|
@ -63,6 +65,7 @@ class Chronos2CoreConfig(PretrainedConfig):
|
|||
vocab_size: int = 2,
|
||||
pad_token_id: int = 0,
|
||||
rope_theta: float = 10000.0,
|
||||
attn_implementation: Literal["eager", "sdpa"] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
|
@ -83,11 +86,17 @@ class Chronos2CoreConfig(PretrainedConfig):
|
|||
|
||||
assert not self.is_gated_act, "gated activation is not supported"
|
||||
|
||||
# Attention implementation - default to "sdpa" if not specified
|
||||
attn_implementation = attn_implementation or "sdpa"
|
||||
assert attn_implementation in ["eager", "sdpa"], f"attn_implementation {attn_implementation} not supported"
|
||||
|
||||
# unused
|
||||
kwargs.pop("is_encoder_decoder", None)
|
||||
kwargs.pop("eos_token_id", None)
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, is_encoder_decoder=False, **kwargs)
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id, is_encoder_decoder=False, attn_implementation=attn_implementation, **kwargs
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -155,6 +155,7 @@ class MHA(nn.Module):
|
|||
self.n_heads: int = config.num_heads
|
||||
self.dropout: float = config.dropout_rate
|
||||
self.inner_dim: int = self.n_heads * self.kv_proj_dim
|
||||
self.config = config
|
||||
|
||||
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
||||
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
||||
|
|
@ -165,6 +166,64 @@ class MHA(nn.Module):
|
|||
if use_rope:
|
||||
self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta)
|
||||
|
||||
def _eager_attention(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Eager attention implementation using manual matmul.
|
||||
|
||||
Args:
|
||||
query_states: [batch, n_heads, seq_len, kv_proj_dim]
|
||||
key_states: [batch, n_heads, seq_len, kv_proj_dim]
|
||||
value_states: [batch, n_heads, seq_len, kv_proj_dim]
|
||||
mask: [batch, n_heads, q_len, kv_len]
|
||||
|
||||
Returns:
|
||||
attn_output: [batch, n_heads, seq_len, kv_proj_dim]
|
||||
attn_weights: [batch, n_heads, q_len, kv_len]
|
||||
"""
|
||||
# Compute attention weights (no scaling - this is the original Chronos-2 implementation)
|
||||
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk"
|
||||
scores += mask
|
||||
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _sdpa_attention(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention.
|
||||
|
||||
Args:
|
||||
query_states: [batch, n_heads, seq_len, kv_proj_dim]
|
||||
key_states: [batch, n_heads, seq_len, kv_proj_dim]
|
||||
value_states: [batch, n_heads, seq_len, kv_proj_dim]
|
||||
mask: [batch, n_heads, q_len, kv_len] - additive mask (0 for valid, -inf for invalid)
|
||||
|
||||
Returns:
|
||||
attn_output: [batch, n_heads, seq_len, kv_proj_dim]
|
||||
attn_weights: None (SDPA doesn't return weights)
|
||||
"""
|
||||
attn_output = nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
scale=1.0, # Match eager implementation (no scaling)
|
||||
)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -190,6 +249,11 @@ class MHA(nn.Module):
|
|||
if self.use_rope:
|
||||
assert position_ids is not None, "position_ids must be provided when self.use_rope=True"
|
||||
|
||||
# Force eager attention if output_attentions is True (only eager returns weights)
|
||||
attn_implementation = self.config._attn_implementation
|
||||
if output_attentions:
|
||||
attn_implementation = "eager"
|
||||
|
||||
seq_length = hidden_states.shape[1]
|
||||
|
||||
def shape(states: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -215,12 +279,10 @@ class MHA(nn.Module):
|
|||
cos, sin = self.rope_embed(value_states, position_ids)
|
||||
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Compute attention weights
|
||||
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk"
|
||||
scores += mask
|
||||
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
if attn_implementation == "sdpa":
|
||||
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)
|
||||
else: # eager
|
||||
attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask)
|
||||
|
||||
# Project attention output
|
||||
attn_output = unshape(attn_output)
|
||||
|
|
|
|||
|
|
@ -199,6 +199,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
config_class = Chronos2CoreConfig # type: ignore[assignment]
|
||||
_supports_long_horizon: bool = True
|
||||
_supports_future_covariates: bool = True
|
||||
_supports_sdpa: bool = True
|
||||
|
||||
def __init__(self, config: Chronos2CoreConfig):
|
||||
assert hasattr(config, "chronos_config"), "Not a valid Chronos config"
|
||||
|
|
|
|||
|
|
@ -211,7 +211,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
lr_scheduler_type="linear",
|
||||
warmup_ratio=0.0,
|
||||
optim="adamw_torch_fused",
|
||||
logging_dir=str(output_dir / "logs"),
|
||||
logging_strategy="steps",
|
||||
logging_steps=100,
|
||||
disable_tqdm=False,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ import torch
|
|||
|
||||
from chronos import BaseChronosPipeline, Chronos2Pipeline
|
||||
from chronos.chronos2.dataset import convert_df_input_to_list_of_dicts_input
|
||||
from chronos.chronos2.config import Chronos2CoreConfig
|
||||
from chronos.chronos2.layers import MHA
|
||||
|
||||
from test.util import validate_tensor
|
||||
|
||||
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model"
|
||||
|
|
@ -317,13 +320,11 @@ def test_when_input_is_invalid_then_predict_raises_value_error(pipeline, inputs,
|
|||
_ = pipeline.predict(inputs, prediction_length=10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_predict_can_handle_different_model_and_input_dtypes(
|
||||
torch_dtype: torch.dtype, input_dtype: torch.dtype
|
||||
):
|
||||
def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = BaseChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", torch_dtype=torch_dtype
|
||||
Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", dtype=dtype
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 3, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
|
@ -936,3 +937,129 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future
|
|||
|
||||
# Check predictions from the fine-tuned model are different from the original predictions
|
||||
assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
|
||||
def test_pipeline_works_with_different_attention_implementations(attn_implementation):
|
||||
"""Test that the pipeline works with different attention implementations."""
|
||||
# Load the dummy model
|
||||
model_path = Path(__file__).parent / "dummy-chronos2-model"
|
||||
|
||||
# Load with specified attention implementation
|
||||
pipeline = BaseChronosPipeline.from_pretrained(
|
||||
model_path, device_map="cpu", attn_implementation=attn_implementation
|
||||
)
|
||||
|
||||
# Verify the config has the correct attention implementation
|
||||
assert pipeline.model.config._attn_implementation == attn_implementation
|
||||
|
||||
# Test prediction with simple input
|
||||
inputs = torch.rand(2, 1, 16)
|
||||
prediction_length = 7
|
||||
|
||||
outputs = pipeline.predict(inputs, prediction_length=prediction_length)
|
||||
|
||||
# Check outputs are valid
|
||||
assert isinstance(outputs, list) and len(outputs) == 2
|
||||
for out in outputs:
|
||||
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 7), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
|
||||
@pytest.mark.parametrize("output_attentions", [False, True])
|
||||
def test_attention_implementations_with_output_attentions(attn_implementation, output_attentions):
|
||||
"""Test that attention implementations handle output_attentions correctly."""
|
||||
# Create config with specified attention implementation
|
||||
config = Chronos2CoreConfig(
|
||||
d_model=128,
|
||||
d_kv=32,
|
||||
num_heads=4,
|
||||
dropout_rate=0.1,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
|
||||
# Create MHA layer
|
||||
mha = MHA(config, use_rope=True)
|
||||
mha.eval()
|
||||
|
||||
# Create dummy inputs
|
||||
batch_size = 2
|
||||
seq_len = 10
|
||||
hidden_states = torch.randn(batch_size, seq_len, config.d_model)
|
||||
position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
|
||||
mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len)
|
||||
|
||||
# Test forward pass
|
||||
output = mha(
|
||||
hidden_states=hidden_states,
|
||||
mask=mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# Check output shape
|
||||
assert output.hidden_states.shape == (batch_size, seq_len, config.d_model)
|
||||
|
||||
# Check attention weights - should only be returned when output_attentions=True
|
||||
if output_attentions:
|
||||
assert output.attn_weights is not None
|
||||
assert output.attn_weights.shape == (batch_size, config.num_heads, seq_len, seq_len)
|
||||
else:
|
||||
# SDPA doesn't return weights
|
||||
if attn_implementation == "sdpa":
|
||||
assert output.attn_weights is None
|
||||
|
||||
|
||||
def test_eager_and_sdpa_produce_identical_outputs(pipeline):
|
||||
"""Test that eager and SDPA implementations produce identical outputs on full pipeline."""
|
||||
# Reload pipeline with SDPA
|
||||
model_path = Path(__file__).parent / "dummy-chronos2-model"
|
||||
pipeline_sdpa = BaseChronosPipeline.from_pretrained(
|
||||
model_path, device_map="cpu", attn_implementation="sdpa", dtype=torch.float32
|
||||
)
|
||||
|
||||
# Note: the original pipeline fixture uses default attn_implementation which should be sdpa
|
||||
# Force eager for comparison
|
||||
pipeline_eager = BaseChronosPipeline.from_pretrained(
|
||||
model_path, device_map="cpu", attn_implementation="eager", dtype=torch.float32
|
||||
)
|
||||
|
||||
# Test 1: Simple univariate input
|
||||
inputs_simple = torch.rand(2, 1, 16)
|
||||
prediction_length = 7
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_eager = pipeline_eager.predict(inputs_simple, prediction_length=prediction_length)
|
||||
outputs_sdpa = pipeline_sdpa.predict(inputs_simple, prediction_length=prediction_length)
|
||||
|
||||
# Verify outputs match exactly
|
||||
assert len(outputs_eager) == len(outputs_sdpa)
|
||||
for out_eager, out_sdpa in zip(outputs_eager, outputs_sdpa):
|
||||
# Should match exactly or very close (numerical precision)
|
||||
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
|
||||
|
||||
# Test 2: Multivariate inputs with covariates to test group attention
|
||||
inputs_grouped = [
|
||||
{
|
||||
"target": np.random.randn(2, 36),
|
||||
"past_covariates": {
|
||||
"temperature": np.random.randn(36),
|
||||
"weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=36),
|
||||
},
|
||||
"future_covariates": {
|
||||
"temperature": np.random.randn(prediction_length),
|
||||
"weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=prediction_length),
|
||||
},
|
||||
}
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_eager_grouped = pipeline_eager.predict(inputs_grouped, prediction_length=prediction_length)
|
||||
outputs_sdpa_grouped = pipeline_sdpa.predict(inputs_grouped, prediction_length=prediction_length)
|
||||
|
||||
# Verify outputs match for grouped inputs
|
||||
assert len(outputs_eager_grouped) == len(outputs_sdpa_grouped)
|
||||
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
|
||||
# Should match exactly or very close (numerical precision)
|
||||
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
|
||||
|
|
|
|||
Loading…
Reference in a new issue