mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Merge branch 'main' into custom_group_ids
This commit is contained in:
commit
0e621fac5c
1 changed files with 9 additions and 2 deletions
|
|
@ -649,7 +649,9 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
output_patch_size=self.model_output_patch_size,
|
||||
mode=DatasetMode.TEST,
|
||||
)
|
||||
test_loader = DataLoader(test_dataset, batch_size=None, pin_memory=True, shuffle=False, drop_last=False)
|
||||
test_loader = DataLoader(
|
||||
test_dataset, batch_size=None, pin_memory=self.model.device.type == "cuda", shuffle=False, drop_last=False
|
||||
)
|
||||
|
||||
all_predictions: list[torch.Tensor] = []
|
||||
# Track the current task index for custom group ID mapping
|
||||
|
|
@ -1247,7 +1249,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
mode=DatasetMode.TEST,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False
|
||||
test_dataset,
|
||||
batch_size=None,
|
||||
num_workers=0,
|
||||
pin_memory=self.model.device.type == "cuda",
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
)
|
||||
all_embeds: list[torch.Tensor] = []
|
||||
all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
|
|
|
|||
Loading…
Reference in a new issue