chronos-forecasting/test/util.py
Lorenzo Stella 67f008432a
Fix padding for int contexts (#227)
*Issue #, if available:* On Linux, the final call to `.to` creates
trouble when input tensors are integer. For example:

```
>>> a = torch.tensor([1])
>>> b = torch.stack([torch.full((1,), torch.nan), a])
>>> b
tensor([[nan],
        [1.]])
>>> b.to(a)
tensor([[-9223372036854775808],
        [                   1]])
```


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
2024-12-04 16:46:17 +01:00

13 lines
290 B
Python

from typing import Optional, Tuple
import torch
def validate_tensor(
a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
) -> None:
assert isinstance(a, torch.Tensor)
assert a.shape == shape
if dtype is not None:
assert a.dtype == dtype