diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index ec1e8ba..67b8a0b 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -649,7 +649,9 @@ class Chronos2Pipeline(BaseChronosPipeline): output_patch_size=self.model_output_patch_size, mode=DatasetMode.TEST, ) - test_loader = DataLoader(test_dataset, batch_size=None, pin_memory=True, shuffle=False, drop_last=False) + test_loader = DataLoader( + test_dataset, batch_size=None, pin_memory=self.model.device.type == "cuda", shuffle=False, drop_last=False + ) all_predictions: list[torch.Tensor] = [] # Track the current task index for custom group ID mapping @@ -1247,7 +1249,12 @@ class Chronos2Pipeline(BaseChronosPipeline): mode=DatasetMode.TEST, ) test_loader = DataLoader( - test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False + test_dataset, + batch_size=None, + num_workers=0, + pin_memory=self.model.device.type == "cuda", + shuffle=False, + drop_last=False, ) all_embeds: list[torch.Tensor] = [] all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []