diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 3d22a8e..3eddcd7 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -608,7 +608,9 @@ class Chronos2Pipeline(BaseChronosPipeline): output_patch_size=self.model_output_patch_size, mode=DatasetMode.TEST, ) - test_loader = DataLoader(test_dataset, batch_size=None, pin_memory=True, shuffle=False, drop_last=False) + test_loader = DataLoader( + test_dataset, batch_size=None, pin_memory=self.model.device.type == "cuda", shuffle=False, drop_last=False + ) all_predictions: list[torch.Tensor] = [] for batch in test_loader: @@ -1122,7 +1124,12 @@ class Chronos2Pipeline(BaseChronosPipeline): mode=DatasetMode.TEST, ) test_loader = DataLoader( - test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False + test_dataset, + batch_size=None, + num_workers=0, + pin_memory=self.model.device.type == "cuda", + shuffle=False, + drop_last=False, ) all_embeds: list[torch.Tensor] = [] all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []