From 28752931fd43e1e5963be0c8fcb55ad89cf8c5ce Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 25 Mar 2024 12:39:30 +0100 Subject: [PATCH] Speed up inference by avoiding unnecessary padding (#25) *Issue #, if available:* Unnecessary context padding slows down inference. We evaluated the models from HF with this change, and found no concerning issue with accuracy. Test code for a context of length 200: ```python import torch from chronos import ChronosPipeline import time pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-large", device_map="cuda", torch_dtype=torch.bfloat16, ) context = torch.ones((8, 200)) prediction_length = 24 num_runs = 10 t0 = time.time() for _ in range(num_runs): forecast = pipeline.predict( context, prediction_length, num_samples=20, ) t1 = time.time() print(f"total time: {t1 - t0}") ``` Before the change: ``` total time: 20.005481481552124 ``` After the change: ``` total time: 9.82350754737854 ``` *Description of changes:* Remove padding in case the provided batch is shorter than `context_length`. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- src/chronos/chronos.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 2beb957..c9f50d0 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -139,13 +139,6 @@ class MeanScaleUniformBins(ChronosTokenizer): if length > self.config.context_length: context = context[..., -self.config.context_length :] - elif length < self.config.context_length: - padding_size = ( - *context.shape[:-1], - self.config.context_length - length, - ) - padding = torch.full(size=padding_size, fill_value=torch.nan) - context = torch.concat((padding, context), dim=-1) attention_mask = ~torch.isnan(context) scale = torch.nansum(