chronos-forecasting/test/test_utils.py
Lorenzo Stella 67f008432a
Fix padding for int contexts (#227)
*Issue #, if available:* On Linux, the final call to `.to` creates
trouble when input tensors are integer. For example:

```
>>> a = torch.tensor([1])
>>> b = torch.stack([torch.full((1,), torch.nan), a])
>>> b
tensor([[nan],
        [1.]])
>>> b.to(a)
tensor([[-9223372036854775808],
        [                   1]])
```


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
2024-12-04 16:46:17 +01:00

29 lines
880 B
Python

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from chronos.utils import left_pad_and_stack_1D
@pytest.mark.parametrize(
"tensors",
[
[
torch.tensor([2.0, 3.0], dtype=dtype),
torch.tensor([4.0, 5.0, 6.0], dtype=dtype),
torch.tensor([7.0, 8.0, 9.0, 10.0], dtype=dtype),
]
for dtype in [torch.int, torch.float16, torch.float32]
],
)
def test_pad_and_stack(tensors: list):
stacked_and_padded = left_pad_and_stack_1D(tensors)
assert stacked_and_padded.dtype == torch.float32
assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors))
ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype)
assert torch.sum(torch.nan_to_num(stacked_and_padded, nan=0)) == torch.sum(ref)