mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
Merge bbbdfac392 into 32111085d8
This commit is contained in:
commit
10acf2ee3e
2 changed files with 86 additions and 0 deletions
|
|
@ -74,6 +74,28 @@ class Chronos2Trainer(Trainer):
|
|||
|
||||
return DataLoader(train_dataset, **dataloader_params) # type: ignore
|
||||
|
||||
def _move_model_to_device(self, model, device):
|
||||
"""
|
||||
Keep the model on its existing CUDA device when fine-tuning a single-device model.
|
||||
|
||||
`Trainer` may otherwise move a model loaded on e.g. `cuda:5` to `args.device` (often `cuda:0`).
|
||||
"""
|
||||
model_device = getattr(model, "device", None)
|
||||
model_device_type = getattr(model_device, "type", None)
|
||||
target_device_type = getattr(device, "type", None)
|
||||
has_hf_device_map = getattr(model, "hf_device_map", None) is not None
|
||||
|
||||
if (
|
||||
not has_hf_device_map
|
||||
and model_device is not None
|
||||
and model_device_type == "cuda"
|
||||
and target_device_type == "cuda"
|
||||
and model_device != device
|
||||
):
|
||||
device = model_device
|
||||
|
||||
super()._move_model_to_device(model, device)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader:
|
||||
if self.eval_dataset is None:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
|
|
|
|||
64
test/test_chronos2_trainer.py
Normal file
64
test/test_chronos2_trainer.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
|
||||
from chronos.chronos2.trainer import Chronos2Trainer
|
||||
|
||||
|
||||
class _DummyModel:
|
||||
def __init__(self, device: torch.device, hf_device_map=None):
|
||||
self.device = device
|
||||
self.hf_device_map = hf_device_map
|
||||
|
||||
|
||||
def test_move_model_to_device_preserves_loaded_cuda_device(monkeypatch):
|
||||
"""When model is on a single CUDA device, keep that device instead of forcing cuda:0."""
|
||||
captured = {}
|
||||
|
||||
def fake_move(self, model, device):
|
||||
captured["device"] = device
|
||||
|
||||
monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)
|
||||
|
||||
trainer = object.__new__(Chronos2Trainer)
|
||||
model = _DummyModel(torch.device("cuda:5"))
|
||||
|
||||
Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cuda:0"))
|
||||
|
||||
assert captured["device"] == torch.device("cuda:5")
|
||||
|
||||
|
||||
def test_move_model_to_device_keeps_requested_cpu_device(monkeypatch):
|
||||
"""CPU fine-tuning should preserve existing Trainer behavior."""
|
||||
captured = {}
|
||||
|
||||
def fake_move(self, model, device):
|
||||
captured["device"] = device
|
||||
|
||||
monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)
|
||||
|
||||
trainer = object.__new__(Chronos2Trainer)
|
||||
model = _DummyModel(torch.device("cpu"))
|
||||
|
||||
Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cpu"))
|
||||
|
||||
assert captured["device"] == torch.device("cpu")
|
||||
|
||||
|
||||
def test_move_model_to_device_keeps_requested_device_for_hf_device_map(monkeypatch):
|
||||
"""Do not override device movement for models managed via hf_device_map."""
|
||||
captured = {}
|
||||
|
||||
def fake_move(self, model, device):
|
||||
captured["device"] = device
|
||||
|
||||
monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)
|
||||
|
||||
trainer = object.__new__(Chronos2Trainer)
|
||||
model = _DummyModel(torch.device("cuda:5"), hf_device_map={"": "cuda:5"})
|
||||
|
||||
Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cuda:0"))
|
||||
|
||||
assert captured["device"] == torch.device("cuda:0")
|
||||
Loading…
Reference in a new issue