From 4c43cfbdac9fa71911f6fadd532234385db9c72e Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Fri, 29 Nov 2024 16:54:21 +0100 Subject: [PATCH] Return predictions in fp32 on CPU (#219) *Issue #, if available:* N/A *Description of changes:* This PR ensures that predictions are returned in FP32 and on the CPU device. This choice is now better because we have two types of models which have different types of forecasts (samples vs. quantiles). Furthermore, `int64` input_type (our README example is one such case) ran into issues with `predict_quantiles` before. This choice also fixes that. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari --- src/chronos/base.py | 4 ++- src/chronos/chronos.py | 5 +--- src/chronos/chronos_bolt.py | 13 +++++---- test/__init__.py | 2 ++ test/test_chronos.py | 55 +++++++++++++++++++------------------ test/test_chronos_bolt.py | 49 ++++++++++++++++++++------------- test/util.py | 13 +++++++++ 7 files changed, 85 insertions(+), 56 deletions(-) create mode 100644 test/__init__.py create mode 100644 test/util.py diff --git a/src/chronos/base.py b/src/chronos/base.py index 3dc1775..bf57b55 100644 --- a/src/chronos/base.py +++ b/src/chronos/base.py @@ -67,7 +67,8 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): **kwargs, ): """ - Get forecasts for the given time series. + Get forecasts for the given time series. Predictions will be + returned in fp32 on the cpu. Parameters ---------- @@ -97,6 +98,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get quantile and mean forecasts for given time series. + Predictions will be returned in fp32 on the cpu. Parameters ---------- diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index ef226f6..31df48a 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -500,9 +500,6 @@ class ChronosPipeline(BaseChronosPipeline): raise ValueError(msg) logger.warning(msg) - input_dtype = context_tensor.dtype - input_device = context_tensor.device - predictions = [] remaining = prediction_length @@ -533,7 +530,7 @@ class ChronosPipeline(BaseChronosPipeline): [context_tensor, prediction.median(dim=1).values], dim=-1 ) - return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device) + return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu") def predict_quantiles( self, diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index e3182f9..4825466 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -487,13 +487,14 @@ class ChronosBoltPipeline(BaseChronosPipeline): # TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast # horizon that the model was trained with (i.e., 64). This results in variance collapsing # every 64 steps. + context_tensor = context_tensor.to( + device=self.model.device, + dtype=torch.float32, + ) while remaining > 0: with torch.no_grad(): prediction = self.model( - context=context_tensor.to( - device=self.model.device, - dtype=torch.float32, # scaling should be done in 32-bit precision - ), + context=context_tensor, ).quantile_preds.to(context_tensor) predictions.append(prediction) @@ -507,7 +508,9 @@ class ChronosBoltPipeline(BaseChronosPipeline): context_tensor = torch.cat([context_tensor, central_prediction], dim=-1) - return torch.cat(predictions, dim=-1)[..., :prediction_length] + return torch.cat(predictions, dim=-1)[..., :prediction_length].to( + dtype=torch.float32, device="cpu" + ) def predict_quantiles( self, diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..03f633a --- /dev/null +++ b/test/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file diff --git a/test/test_chronos.py b/test/test_chronos.py index b0235c0..763fde2 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from typing import Optional, Tuple import pytest import torch @@ -13,6 +12,7 @@ from chronos import ( ChronosPipeline, MeanScaleUniformBins, ) +from test.util import validate_tensor def test_base_chronos_pipeline_loads_from_huggingface(): @@ -166,30 +166,21 @@ def test_tokenizer_random_data(use_eos_token: bool): assert samples.shape == (2, 10, 4) -def validate_tensor( - a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None -) -> None: - assert isinstance(a, torch.Tensor) - assert a.shape == shape - - if dtype is not None: - assert a.dtype == dtype - - @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", device_map="cpu", torch_dtype=model_dtype, ) - context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10 + context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) # input: tensor of shape (batch_size, context_length) samples = pipeline.predict(context, num_samples=12, prediction_length=3) - validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype) + validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32) with pytest.raises(ValueError): samples = pipeline.predict( @@ -199,12 +190,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): samples = pipeline.predict( context, num_samples=7, prediction_length=65, limit_prediction_length=False ) - validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype) + validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32) # input: batch_size-long list of tensors of shape (context_length,) samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) - validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype) + validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32) with pytest.raises(ValueError): samples = pipeline.predict( @@ -220,12 +211,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): prediction_length=65, limit_prediction_length=False, ) - validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype) + validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32) # input: tensor of shape (context_length,) samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) - validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype) + validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32) with pytest.raises(ValueError): samples = pipeline.predict( @@ -240,16 +231,18 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): num_samples=7, prediction_length=65, ) - validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype) + validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32) @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) @pytest.mark.parametrize("prediction_length", [3, 65]) @pytest.mark.parametrize( "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]] ) def test_pipeline_predict_quantiles( model_dtype: torch.dtype, + input_dtype: torch.dtype, prediction_length: int, quantile_levels: list[int], ): @@ -259,6 +252,7 @@ def test_pipeline_predict_quantiles( torch_dtype=model_dtype, ) context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) num_expected_quantiles = len(quantile_levels) # input: tensor of shape (batch_size, context_length) @@ -269,8 +263,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (4, prediction_length)) + validate_tensor( + quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (4, prediction_length), dtype=torch.float32) # input: batch_size-long list of tensors of shape (context_length,) @@ -280,8 +276,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (4, prediction_length)) + validate_tensor( + quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (4, prediction_length), dtype=torch.float32) # input: tensor of shape (context_length,) @@ -291,12 +289,14 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (1, prediction_length)) + validate_tensor( + quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (1, prediction_length), dtype=torch.float32) @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", @@ -304,7 +304,8 @@ def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): torch_dtype=model_dtype, ) d_model = pipeline.model.model.config.d_model - context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10 + context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0) # input: tensor of shape (batch_size, context_length) diff --git a/test/test_chronos_bolt.py b/test/test_chronos_bolt.py index c4c3db7..4b72568 100644 --- a/test/test_chronos_bolt.py +++ b/test/test_chronos_bolt.py @@ -1,16 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + from pathlib import Path -from typing import Tuple import pytest import torch from chronos import BaseChronosPipeline, ChronosBoltPipeline from chronos.chronos_bolt import InstanceNorm, Patch - - -def validate_tensor(input: torch.Tensor, shape: Tuple[int, ...]) -> None: - assert isinstance(input, torch.Tensor) - assert input.shape == shape +from test.util import validate_tensor def test_base_chronos_pipeline_loads_from_huggingface(): @@ -18,19 +16,21 @@ def test_base_chronos_pipeline_loads_from_huggingface(): @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline_predict(torch_dtype: str): +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) +def test_pipeline_predict(torch_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosBoltPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-bolt-model", device_map="cpu", torch_dtype=torch_dtype, ) context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) expected_num_quantiles = len(pipeline.quantiles) # input: tensor of shape (batch_size, context_length) quantiles = pipeline.predict(context, prediction_length=3) - validate_tensor(quantiles, (4, expected_num_quantiles, 3)) + validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32) with pytest.raises(ValueError): quantiles = pipeline.predict( @@ -43,7 +43,7 @@ def test_pipeline_predict(torch_dtype: str): # input: batch_size-long list of tensors of shape (context_length,) quantiles = pipeline.predict(list(context), prediction_length=3) - validate_tensor(quantiles, (4, expected_num_quantiles, 3)) + validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32) with pytest.raises(ValueError): quantiles = pipeline.predict( @@ -53,12 +53,12 @@ def test_pipeline_predict(torch_dtype: str): ) quantiles = pipeline.predict(list(context), prediction_length=65) - validate_tensor(quantiles, (4, expected_num_quantiles, 65)) + validate_tensor(quantiles, (4, expected_num_quantiles, 65), dtype=torch.float32) # input: tensor of shape (context_length,) quantiles = pipeline.predict(context[0, ...], prediction_length=3) - validate_tensor(quantiles, (1, expected_num_quantiles, 3)) + validate_tensor(quantiles, (1, expected_num_quantiles, 3), dtype=torch.float32) with pytest.raises(ValueError): quantiles = pipeline.predict( @@ -71,16 +71,20 @@ def test_pipeline_predict(torch_dtype: str): context[0, ...], prediction_length=65, ) - validate_tensor(quantiles, (1, expected_num_quantiles, 65)) + validate_tensor(quantiles, (1, expected_num_quantiles, 65), dtype=torch.float32) @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) @pytest.mark.parametrize("prediction_length", [3, 65]) @pytest.mark.parametrize( "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]] ) def test_pipeline_predict_quantiles( - torch_dtype: str, prediction_length: int, quantile_levels: list[int] + torch_dtype: torch.dtype, + input_dtype: torch.dtype, + prediction_length: int, + quantile_levels: list[int], ): pipeline = ChronosBoltPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-bolt-model", @@ -88,6 +92,7 @@ def test_pipeline_predict_quantiles( torch_dtype=torch_dtype, ) context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) num_expected_quantiles = len(quantile_levels) # input: tensor of shape (batch_size, context_length) @@ -97,8 +102,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (4, prediction_length)) + validate_tensor( + quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (4, prediction_length), dtype=torch.float32) # input: batch_size-long list of tensors of shape (context_length,) @@ -107,8 +114,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (4, prediction_length)) + validate_tensor( + quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (4, prediction_length), dtype=torch.float32) # input: tensor of shape (context_length,) @@ -117,8 +126,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (1, prediction_length)) + validate_tensor( + quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (1, prediction_length), dtype=torch.float32) # The following tests have been taken from diff --git a/test/util.py b/test/util.py new file mode 100644 index 0000000..37a2c3b --- /dev/null +++ b/test/util.py @@ -0,0 +1,13 @@ +from typing import Optional, Tuple + +import torch + + +def validate_tensor( + a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None +) -> None: + assert isinstance(a, torch.Tensor) + assert a.shape == shape + + if dtype is not None: + assert a.dtype == dtype \ No newline at end of file