mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Fix padding for int contexts (#227)
*Issue #, if available:* On Linux, the final call to `.to` creates
trouble when input tensors are integer. For example:
```
>>> a = torch.tensor([1])
>>> b = torch.stack([torch.full((1,), torch.nan), a])
>>> b
tensor([[nan],
[1.]])
>>> b.to(a)
tensor([[-9223372036854775808],
[ 1]])
```
By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
parent
47cac082c1
commit
67f008432a
4 changed files with 32 additions and 3 deletions
|
|
@ -17,4 +17,4 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
|
|||
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
|
||||
)
|
||||
padded.append(torch.concat((padding, c), dim=-1))
|
||||
return torch.stack(padded).to(tensors[0])
|
||||
return torch.stack(padded)
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
|
|
|||
29
test/test_utils.py
Normal file
29
test/test_utils.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from chronos.utils import left_pad_and_stack_1D
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tensors",
|
||||
[
|
||||
[
|
||||
torch.tensor([2.0, 3.0], dtype=dtype),
|
||||
torch.tensor([4.0, 5.0, 6.0], dtype=dtype),
|
||||
torch.tensor([7.0, 8.0, 9.0, 10.0], dtype=dtype),
|
||||
]
|
||||
for dtype in [torch.int, torch.float16, torch.float32]
|
||||
],
|
||||
)
|
||||
def test_pad_and_stack(tensors: list):
|
||||
stacked_and_padded = left_pad_and_stack_1D(tensors)
|
||||
|
||||
assert stacked_and_padded.dtype == torch.float32
|
||||
assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors))
|
||||
|
||||
ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype)
|
||||
|
||||
assert torch.sum(torch.nan_to_num(stacked_and_padded, nan=0)) == torch.sum(ref)
|
||||
|
|
@ -10,4 +10,4 @@ def validate_tensor(
|
|||
assert a.shape == shape
|
||||
|
||||
if dtype is not None:
|
||||
assert a.dtype == dtype
|
||||
assert a.dtype == dtype
|
||||
|
|
|
|||
Loading…
Reference in a new issue