Return predictions in fp32 on CPU (#219)

*Issue #, if available:* N/A

*Description of changes:* This PR ensures that predictions are returned
in FP32 and on the CPU device. This choice is now better because we have
two types of models which have different types of forecasts (samples vs.
quantiles). Furthermore, `int64` input_type (our README example is one
such case) ran into issues with `predict_quantiles` before. This choice
also fixes that.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
Abdul Fatir 2024-11-29 16:54:21 +01:00 committed by GitHub
parent c887278706
commit 4c43cfbdac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 85 additions and 56 deletions

View file

@ -67,7 +67,8 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
**kwargs,
):
"""
Get forecasts for the given time series.
Get forecasts for the given time series. Predictions will be
returned in fp32 on the cpu.
Parameters
----------
@ -97,6 +98,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get quantile and mean forecasts for given time series.
Predictions will be returned in fp32 on the cpu.
Parameters
----------

View file

@ -500,9 +500,6 @@ class ChronosPipeline(BaseChronosPipeline):
raise ValueError(msg)
logger.warning(msg)
input_dtype = context_tensor.dtype
input_device = context_tensor.device
predictions = []
remaining = prediction_length
@ -533,7 +530,7 @@ class ChronosPipeline(BaseChronosPipeline):
[context_tensor, prediction.median(dim=1).values], dim=-1
)
return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device)
return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu")
def predict_quantiles(
self,

View file

@ -487,13 +487,14 @@ class ChronosBoltPipeline(BaseChronosPipeline):
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
# every 64 steps.
context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
while remaining > 0:
with torch.no_grad():
prediction = self.model(
context=context_tensor.to(
device=self.model.device,
dtype=torch.float32, # scaling should be done in 32-bit precision
),
context=context_tensor,
).quantile_preds.to(context_tensor)
predictions.append(prediction)
@ -507,7 +508,9 @@ class ChronosBoltPipeline(BaseChronosPipeline):
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
return torch.cat(predictions, dim=-1)[..., :prediction_length]
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
dtype=torch.float32, device="cpu"
)
def predict_quantiles(
self,

2
test/__init__.py Normal file
View file

@ -0,0 +1,2 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

View file

@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Optional, Tuple
import pytest
import torch
@ -13,6 +12,7 @@ from chronos import (
ChronosPipeline,
MeanScaleUniformBins,
)
from test.util import validate_tensor
def test_base_chronos_pipeline_loads_from_huggingface():
@ -166,30 +166,21 @@ def test_tokenizer_random_data(use_eos_token: bool):
assert samples.shape == (2, 10, 4)
def validate_tensor(
a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
) -> None:
assert isinstance(a, torch.Tensor)
assert a.shape == shape
if dtype is not None:
assert a.dtype == dtype
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=model_dtype,
)
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
# input: tensor of shape (batch_size, context_length)
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
with pytest.raises(ValueError):
samples = pipeline.predict(
@ -199,12 +190,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=False
)
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
# input: batch_size-long list of tensors of shape (context_length,)
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
with pytest.raises(ValueError):
samples = pipeline.predict(
@ -220,12 +211,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
prediction_length=65,
limit_prediction_length=False,
)
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
# input: tensor of shape (context_length,)
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype)
validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32)
with pytest.raises(ValueError):
samples = pipeline.predict(
@ -240,16 +231,18 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
num_samples=7,
prediction_length=65,
)
validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype)
validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32)
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
@pytest.mark.parametrize("prediction_length", [3, 65])
@pytest.mark.parametrize(
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
)
def test_pipeline_predict_quantiles(
model_dtype: torch.dtype,
input_dtype: torch.dtype,
prediction_length: int,
quantile_levels: list[int],
):
@ -259,6 +252,7 @@ def test_pipeline_predict_quantiles(
torch_dtype=model_dtype,
)
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
num_expected_quantiles = len(quantile_levels)
# input: tensor of shape (batch_size, context_length)
@ -269,8 +263,10 @@ def test_pipeline_predict_quantiles(
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
validate_tensor(mean, (4, prediction_length))
validate_tensor(
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
# input: batch_size-long list of tensors of shape (context_length,)
@ -280,8 +276,10 @@ def test_pipeline_predict_quantiles(
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
validate_tensor(mean, (4, prediction_length))
validate_tensor(
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
# input: tensor of shape (context_length,)
@ -291,12 +289,14 @@ def test_pipeline_predict_quantiles(
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
validate_tensor(mean, (1, prediction_length))
validate_tensor(
quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
@ -304,7 +304,8 @@ def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
torch_dtype=model_dtype,
)
d_model = pipeline.model.model.config.d_model
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
# input: tensor of shape (batch_size, context_length)

View file

@ -1,16 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Tuple
import pytest
import torch
from chronos import BaseChronosPipeline, ChronosBoltPipeline
from chronos.chronos_bolt import InstanceNorm, Patch
def validate_tensor(input: torch.Tensor, shape: Tuple[int, ...]) -> None:
assert isinstance(input, torch.Tensor)
assert input.shape == shape
from test.util import validate_tensor
def test_base_chronos_pipeline_loads_from_huggingface():
@ -18,19 +16,21 @@ def test_base_chronos_pipeline_loads_from_huggingface():
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_predict(torch_dtype: str):
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_predict(torch_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosBoltPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-bolt-model",
device_map="cpu",
torch_dtype=torch_dtype,
)
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
expected_num_quantiles = len(pipeline.quantiles)
# input: tensor of shape (batch_size, context_length)
quantiles = pipeline.predict(context, prediction_length=3)
validate_tensor(quantiles, (4, expected_num_quantiles, 3))
validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32)
with pytest.raises(ValueError):
quantiles = pipeline.predict(
@ -43,7 +43,7 @@ def test_pipeline_predict(torch_dtype: str):
# input: batch_size-long list of tensors of shape (context_length,)
quantiles = pipeline.predict(list(context), prediction_length=3)
validate_tensor(quantiles, (4, expected_num_quantiles, 3))
validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32)
with pytest.raises(ValueError):
quantiles = pipeline.predict(
@ -53,12 +53,12 @@ def test_pipeline_predict(torch_dtype: str):
)
quantiles = pipeline.predict(list(context), prediction_length=65)
validate_tensor(quantiles, (4, expected_num_quantiles, 65))
validate_tensor(quantiles, (4, expected_num_quantiles, 65), dtype=torch.float32)
# input: tensor of shape (context_length,)
quantiles = pipeline.predict(context[0, ...], prediction_length=3)
validate_tensor(quantiles, (1, expected_num_quantiles, 3))
validate_tensor(quantiles, (1, expected_num_quantiles, 3), dtype=torch.float32)
with pytest.raises(ValueError):
quantiles = pipeline.predict(
@ -71,16 +71,20 @@ def test_pipeline_predict(torch_dtype: str):
context[0, ...],
prediction_length=65,
)
validate_tensor(quantiles, (1, expected_num_quantiles, 65))
validate_tensor(quantiles, (1, expected_num_quantiles, 65), dtype=torch.float32)
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
@pytest.mark.parametrize("prediction_length", [3, 65])
@pytest.mark.parametrize(
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
)
def test_pipeline_predict_quantiles(
torch_dtype: str, prediction_length: int, quantile_levels: list[int]
torch_dtype: torch.dtype,
input_dtype: torch.dtype,
prediction_length: int,
quantile_levels: list[int],
):
pipeline = ChronosBoltPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-bolt-model",
@ -88,6 +92,7 @@ def test_pipeline_predict_quantiles(
torch_dtype=torch_dtype,
)
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
num_expected_quantiles = len(quantile_levels)
# input: tensor of shape (batch_size, context_length)
@ -97,8 +102,10 @@ def test_pipeline_predict_quantiles(
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
validate_tensor(mean, (4, prediction_length))
validate_tensor(
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
# input: batch_size-long list of tensors of shape (context_length,)
@ -107,8 +114,10 @@ def test_pipeline_predict_quantiles(
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
validate_tensor(mean, (4, prediction_length))
validate_tensor(
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
# input: tensor of shape (context_length,)
@ -117,8 +126,10 @@ def test_pipeline_predict_quantiles(
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
validate_tensor(mean, (1, prediction_length))
validate_tensor(
quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
# The following tests have been taken from

13
test/util.py Normal file
View file

@ -0,0 +1,13 @@
from typing import Optional, Tuple
import torch
def validate_tensor(
a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
) -> None:
assert isinstance(a, torch.Tensor)
assert a.shape == shape
if dtype is not None:
assert a.dtype == dtype