mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Relax transformers lower bound to >=4.41 (#364)
This commit is contained in:
parent
c23d34cd88
commit
93419cfe9f
2 changed files with 4 additions and 4 deletions
|
|
@ -15,7 +15,7 @@ license = { file = "LICENSE" }
|
|||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.0,<3",
|
||||
"transformers>=4.49,<5",
|
||||
"transformers>=4.41,<5",
|
||||
"accelerate>=0.34,<2",
|
||||
"numpy>=1.21,<3",
|
||||
"einops>=0.7.0,<1",
|
||||
|
|
|
|||
|
|
@ -327,7 +327,7 @@ def test_when_input_is_invalid_then_predict_raises_value_error(pipeline, inputs,
|
|||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = BaseChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", dtype=dtype
|
||||
Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", torch_dtype=dtype
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 3, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
|
@ -1018,13 +1018,13 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
|
|||
# Reload pipeline with SDPA
|
||||
model_path = Path(__file__).parent / "dummy-chronos2-model"
|
||||
pipeline_sdpa = BaseChronosPipeline.from_pretrained(
|
||||
model_path, device_map="cpu", attn_implementation="sdpa", dtype=torch.float32
|
||||
model_path, device_map="cpu", attn_implementation="sdpa", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Note: the original pipeline fixture uses default attn_implementation which should be sdpa
|
||||
# Force eager for comparison
|
||||
pipeline_eager = BaseChronosPipeline.from_pretrained(
|
||||
model_path, device_map="cpu", attn_implementation="eager", dtype=torch.float32
|
||||
model_path, device_map="cpu", attn_implementation="eager", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Test 1: Simple univariate input
|
||||
|
|
|
|||
Loading…
Reference in a new issue