Merge branch 'main' into custom_group_ids

This commit is contained in:
Alexander März 2025-12-10 16:00:38 +01:00 committed by GitHub
commit 0e621fac5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -649,7 +649,9 @@ class Chronos2Pipeline(BaseChronosPipeline):
output_patch_size=self.model_output_patch_size,
mode=DatasetMode.TEST,
)
test_loader = DataLoader(test_dataset, batch_size=None, pin_memory=True, shuffle=False, drop_last=False)
test_loader = DataLoader(
test_dataset, batch_size=None, pin_memory=self.model.device.type == "cuda", shuffle=False, drop_last=False
)
all_predictions: list[torch.Tensor] = []
# Track the current task index for custom group ID mapping
@ -1247,7 +1249,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
mode=DatasetMode.TEST,
)
test_loader = DataLoader(
test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False
test_dataset,
batch_size=None,
num_workers=0,
pin_memory=self.model.device.type == "cuda",
shuffle=False,
drop_last=False,
)
all_embeds: list[torch.Tensor] = []
all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []