This commit is contained in:
Dario Fumarola 2026-04-27 14:52:27 +02:00 committed by GitHub
commit 10acf2ee3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 86 additions and 0 deletions

View file

@ -74,6 +74,28 @@ class Chronos2Trainer(Trainer):
return DataLoader(train_dataset, **dataloader_params) # type: ignore
def _move_model_to_device(self, model, device):
"""
Keep the model on its existing CUDA device when fine-tuning a single-device model.
`Trainer` may otherwise move a model loaded on e.g. `cuda:5` to `args.device` (often `cuda:0`).
"""
model_device = getattr(model, "device", None)
model_device_type = getattr(model_device, "type", None)
target_device_type = getattr(device, "type", None)
has_hf_device_map = getattr(model, "hf_device_map", None) is not None
if (
not has_hf_device_map
and model_device is not None
and model_device_type == "cuda"
and target_device_type == "cuda"
and model_device != device
):
device = model_device
super()._move_model_to_device(model, device)
def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader:
if self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")

View file

@ -0,0 +1,64 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from transformers import Trainer
from chronos.chronos2.trainer import Chronos2Trainer
class _DummyModel:
def __init__(self, device: torch.device, hf_device_map=None):
self.device = device
self.hf_device_map = hf_device_map
def test_move_model_to_device_preserves_loaded_cuda_device(monkeypatch):
"""When model is on a single CUDA device, keep that device instead of forcing cuda:0."""
captured = {}
def fake_move(self, model, device):
captured["device"] = device
monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)
trainer = object.__new__(Chronos2Trainer)
model = _DummyModel(torch.device("cuda:5"))
Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cuda:0"))
assert captured["device"] == torch.device("cuda:5")
def test_move_model_to_device_keeps_requested_cpu_device(monkeypatch):
"""CPU fine-tuning should preserve existing Trainer behavior."""
captured = {}
def fake_move(self, model, device):
captured["device"] = device
monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)
trainer = object.__new__(Chronos2Trainer)
model = _DummyModel(torch.device("cpu"))
Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cpu"))
assert captured["device"] == torch.device("cpu")
def test_move_model_to_device_keeps_requested_device_for_hf_device_map(monkeypatch):
"""Do not override device movement for models managed via hf_device_map."""
captured = {}
def fake_move(self, model, device):
captured["device"] = device
monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)
trainer = object.__new__(Chronos2Trainer)
model = _DummyModel(torch.device("cuda:5"), hf_device_map={"": "cuda:5"})
Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cuda:0"))
assert captured["device"] == torch.device("cuda:0")