Speed up inference by avoiding unnecessary padding (#25)

*Issue #, if available:* Unnecessary context padding slows down
inference. We evaluated the models from HF with this change, and found
no concerning issue with accuracy.

Test code for a context of length 200:

```python
import torch
from chronos import ChronosPipeline
import time

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-large",
    device_map="cuda",
    torch_dtype=torch.bfloat16,
)

context = torch.ones((8, 200))
prediction_length = 24
num_runs = 10

t0 = time.time()
for _ in range(num_runs):
    forecast = pipeline.predict(
        context,
        prediction_length,
        num_samples=20,
    )
t1 = time.time()

print(f"total time: {t1 - t0}")
```

Before the change:

```
total time: 20.005481481552124
```

After the change:

```
total time: 9.82350754737854
```

*Description of changes:* Remove padding in case the provided batch is
shorter than `context_length`.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Lorenzo Stella 2024-03-25 12:39:30 +01:00 committed by GitHub
parent 73be25042f
commit 28752931fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -139,13 +139,6 @@ class MeanScaleUniformBins(ChronosTokenizer):
if length > self.config.context_length:
context = context[..., -self.config.context_length :]
elif length < self.config.context_length:
padding_size = (
*context.shape[:-1],
self.config.context_length - length,
)
padding = torch.full(size=padding_size, fill_value=torch.nan)
context = torch.concat((padding, context), dim=-1)
attention_mask = ~torch.isnan(context)
scale = torch.nansum(