mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Return predictions in fp32 on CPU (#219)
*Issue #, if available:* N/A *Description of changes:* This PR ensures that predictions are returned in FP32 and on the CPU device. This choice is now better because we have two types of models which have different types of forecasts (samples vs. quantiles). Furthermore, `int64` input_type (our README example is one such case) ran into issues with `predict_quantiles` before. This choice also fixes that. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
parent
c887278706
commit
4c43cfbdac
7 changed files with 85 additions and 56 deletions
|
|
@ -67,7 +67,8 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Get forecasts for the given time series.
|
||||
Get forecasts for the given time series. Predictions will be
|
||||
returned in fp32 on the cpu.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -97,6 +98,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get quantile and mean forecasts for given time series.
|
||||
Predictions will be returned in fp32 on the cpu.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
|||
|
|
@ -500,9 +500,6 @@ class ChronosPipeline(BaseChronosPipeline):
|
|||
raise ValueError(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
input_dtype = context_tensor.dtype
|
||||
input_device = context_tensor.device
|
||||
|
||||
predictions = []
|
||||
remaining = prediction_length
|
||||
|
||||
|
|
@ -533,7 +530,7 @@ class ChronosPipeline(BaseChronosPipeline):
|
|||
[context_tensor, prediction.median(dim=1).values], dim=-1
|
||||
)
|
||||
|
||||
return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device)
|
||||
return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu")
|
||||
|
||||
def predict_quantiles(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -487,13 +487,14 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
|
||||
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
|
||||
# every 64 steps.
|
||||
context_tensor = context_tensor.to(
|
||||
device=self.model.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
while remaining > 0:
|
||||
with torch.no_grad():
|
||||
prediction = self.model(
|
||||
context=context_tensor.to(
|
||||
device=self.model.device,
|
||||
dtype=torch.float32, # scaling should be done in 32-bit precision
|
||||
),
|
||||
context=context_tensor,
|
||||
).quantile_preds.to(context_tensor)
|
||||
|
||||
predictions.append(prediction)
|
||||
|
|
@ -507,7 +508,9 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
|
||||
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
|
||||
|
||||
return torch.cat(predictions, dim=-1)[..., :prediction_length]
|
||||
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
|
||||
dtype=torch.float32, device="cpu"
|
||||
)
|
||||
|
||||
def predict_quantiles(
|
||||
self,
|
||||
|
|
|
|||
2
test/__init__.py
Normal file
2
test/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
|
@ -13,6 +12,7 @@ from chronos import (
|
|||
ChronosPipeline,
|
||||
MeanScaleUniformBins,
|
||||
)
|
||||
from test.util import validate_tensor
|
||||
|
||||
|
||||
def test_base_chronos_pipeline_loads_from_huggingface():
|
||||
|
|
@ -166,30 +166,21 @@ def test_tokenizer_random_data(use_eos_token: bool):
|
|||
assert samples.shape == (2, 10, 4)
|
||||
|
||||
|
||||
def validate_tensor(
|
||||
a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
|
||||
) -> None:
|
||||
assert isinstance(a, torch.Tensor)
|
||||
assert a.shape == shape
|
||||
|
||||
if dtype is not None:
|
||||
assert a.dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
|
||||
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(
|
||||
|
|
@ -199,12 +190,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
samples = pipeline.predict(
|
||||
context, num_samples=7, prediction_length=65, limit_prediction_length=False
|
||||
)
|
||||
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
|
||||
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
|
||||
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(
|
||||
|
|
@ -220,12 +211,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
prediction_length=65,
|
||||
limit_prediction_length=False,
|
||||
)
|
||||
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
|
||||
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype)
|
||||
validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(
|
||||
|
|
@ -240,16 +231,18 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
num_samples=7,
|
||||
prediction_length=65,
|
||||
)
|
||||
validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype)
|
||||
validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
@pytest.mark.parametrize("prediction_length", [3, 65])
|
||||
@pytest.mark.parametrize(
|
||||
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
|
||||
)
|
||||
def test_pipeline_predict_quantiles(
|
||||
model_dtype: torch.dtype,
|
||||
input_dtype: torch.dtype,
|
||||
prediction_length: int,
|
||||
quantile_levels: list[int],
|
||||
):
|
||||
|
|
@ -259,6 +252,7 @@ def test_pipeline_predict_quantiles(
|
|||
torch_dtype=model_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
||||
num_expected_quantiles = len(quantile_levels)
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
|
@ -269,8 +263,10 @@ def test_pipeline_predict_quantiles(
|
|||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (4, prediction_length))
|
||||
validate_tensor(
|
||||
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
|
||||
)
|
||||
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
|
|
@ -280,8 +276,10 @@ def test_pipeline_predict_quantiles(
|
|||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (4, prediction_length))
|
||||
validate_tensor(
|
||||
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
|
||||
)
|
||||
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
|
|
@ -291,12 +289,14 @@ def test_pipeline_predict_quantiles(
|
|||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (1, prediction_length))
|
||||
validate_tensor(
|
||||
quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
|
||||
)
|
||||
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
|
|
@ -304,7 +304,8 @@ def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
torch_dtype=model_dtype,
|
||||
)
|
||||
d_model = pipeline.model.model.config.d_model
|
||||
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from chronos import BaseChronosPipeline, ChronosBoltPipeline
|
||||
from chronos.chronos_bolt import InstanceNorm, Patch
|
||||
|
||||
|
||||
def validate_tensor(input: torch.Tensor, shape: Tuple[int, ...]) -> None:
|
||||
assert isinstance(input, torch.Tensor)
|
||||
assert input.shape == shape
|
||||
from test.util import validate_tensor
|
||||
|
||||
|
||||
def test_base_chronos_pipeline_loads_from_huggingface():
|
||||
|
|
@ -18,19 +16,21 @@ def test_base_chronos_pipeline_loads_from_huggingface():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_predict(torch_dtype: str):
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_predict(torch_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-bolt-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
expected_num_quantiles = len(pipeline.quantiles)
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
quantiles = pipeline.predict(context, prediction_length=3)
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 3))
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
quantiles = pipeline.predict(
|
||||
|
|
@ -43,7 +43,7 @@ def test_pipeline_predict(torch_dtype: str):
|
|||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
quantiles = pipeline.predict(list(context), prediction_length=3)
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 3))
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
quantiles = pipeline.predict(
|
||||
|
|
@ -53,12 +53,12 @@ def test_pipeline_predict(torch_dtype: str):
|
|||
)
|
||||
|
||||
quantiles = pipeline.predict(list(context), prediction_length=65)
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 65))
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 65), dtype=torch.float32)
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
quantiles = pipeline.predict(context[0, ...], prediction_length=3)
|
||||
validate_tensor(quantiles, (1, expected_num_quantiles, 3))
|
||||
validate_tensor(quantiles, (1, expected_num_quantiles, 3), dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
quantiles = pipeline.predict(
|
||||
|
|
@ -71,16 +71,20 @@ def test_pipeline_predict(torch_dtype: str):
|
|||
context[0, ...],
|
||||
prediction_length=65,
|
||||
)
|
||||
validate_tensor(quantiles, (1, expected_num_quantiles, 65))
|
||||
validate_tensor(quantiles, (1, expected_num_quantiles, 65), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
@pytest.mark.parametrize("prediction_length", [3, 65])
|
||||
@pytest.mark.parametrize(
|
||||
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
|
||||
)
|
||||
def test_pipeline_predict_quantiles(
|
||||
torch_dtype: str, prediction_length: int, quantile_levels: list[int]
|
||||
torch_dtype: torch.dtype,
|
||||
input_dtype: torch.dtype,
|
||||
prediction_length: int,
|
||||
quantile_levels: list[int],
|
||||
):
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-bolt-model",
|
||||
|
|
@ -88,6 +92,7 @@ def test_pipeline_predict_quantiles(
|
|||
torch_dtype=torch_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
||||
num_expected_quantiles = len(quantile_levels)
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
|
@ -97,8 +102,10 @@ def test_pipeline_predict_quantiles(
|
|||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (4, prediction_length))
|
||||
validate_tensor(
|
||||
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
|
||||
)
|
||||
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
|
|
@ -107,8 +114,10 @@ def test_pipeline_predict_quantiles(
|
|||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (4, prediction_length))
|
||||
validate_tensor(
|
||||
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
|
||||
)
|
||||
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
|
|
@ -117,8 +126,10 @@ def test_pipeline_predict_quantiles(
|
|||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (1, prediction_length))
|
||||
validate_tensor(
|
||||
quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
|
||||
)
|
||||
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
|
||||
|
||||
|
||||
# The following tests have been taken from
|
||||
|
|
|
|||
13
test/util.py
Normal file
13
test/util.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def validate_tensor(
|
||||
a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
|
||||
) -> None:
|
||||
assert isinstance(a, torch.Tensor)
|
||||
assert a.shape == shape
|
||||
|
||||
if dtype is not None:
|
||||
assert a.dtype == dtype
|
||||
Loading…
Reference in a new issue