feat: add --batch per-GPU auto-scaler to phase1

Ultralytics scales wd_eff with batch*accumulate/nbs but never scales lr0, so larger
global batches silently drift from the recipe's intended dynamics. The new flag takes
a per-GPU batch, computes global = per_gpu * world_size, and derives lr0, nbs, and
warmup_epochs from scale = max(1, global / NBS_CANONICAL=512) so wd_eff stays at the
recipe value while per-sample lr and optimizer-step warmup count are invariant.
This commit is contained in:
Fatih Akyon 2026-04-21 02:14:49 -05:00
parent 5105796c0f
commit 846dc24666
No known key found for this signature in database

View file

@ -20,6 +20,11 @@ RECIPES = {
"radio": dict(lr0=1e-3, weight_decay=0.02, warmup_epochs=1, epochs=30, momentum=0.9, grad_clip=1.0, beta2=0.95),
}
# Reference global step-batch the recipes' lr0 and warmup_epochs are tuned for. When
# per_gpu_batch * world_size exceeds this, lr0 and warmup_epochs scale linearly and nbs rises
# to the global batch so wd_eff stays at the recipe value.
NBS_CANONICAL = 512
def _pop_flag(argv: list[str], flag: str, is_bool: bool = False) -> tuple[list[str], str]:
"""Pop a --flag [value] pair from argv, return (remaining_argv, value).
@ -51,6 +56,10 @@ def main(argv: list[str]) -> None:
--cos_weight <float>: cosine loss weight (default 0.9)
--l1_weight <float>: smooth L1 loss weight (default 0.1)
--cls_l1: add smooth L1 to CLS token loss (default False)
--lr <float>: override recipe lr0 (applied before batch scaling)
--batch <int>: per-GPU (per-rank) batch. Global batch = per-GPU * world_size. When the
global batch exceeds NBS_CANONICAL (512), lr0 and warmup_epochs scale linearly and
nbs is raised to the global batch so wd_eff is invariant.
"""
args = argv[1:]
args, resume = _pop_flag(args, "--resume")
@ -58,6 +67,7 @@ def main(argv: list[str]) -> None:
args, l1_w = _pop_flag(args, "--l1_weight")
args, cls_l1_str = _pop_flag(args, "--cls_l1", is_bool=True)
args, lr_override = _pop_flag(args, "--lr")
args, batch_override = _pop_flag(args, "--batch")
args, fork_from = _pop_flag(args, "--fork_from") # format: <parent_run_id>:<fork_step>
cos_weight = float(cos_w) if cos_w else 0.9
@ -77,7 +87,13 @@ def main(argv: list[str]) -> None:
data = args[5] if len(args) > 5 else "/data/shared-datasets/datacomp-12m"
epochs = int(args[6]) if len(args) > 6 else None
r = RECIPES[recipe]
lr0 = float(lr_override) if lr_override else r["lr0"]
world_size = len(gpu.split(",")) if "," in gpu else 1
global_batch = int(batch_override or 64) * world_size # default per-GPU = 64 (anchor per-rank)
scale = max(1.0, global_batch / NBS_CANONICAL)
lr0 = float(lr_override or r["lr0"]) * scale
nbs = max(global_batch, NBS_CANONICAL)
warmup_epochs = r["warmup_epochs"] * scale
model = YOLO(model_yaml)
if r["grad_clip"]:
@ -112,16 +128,16 @@ def main(argv: list[str]) -> None:
device=gpu,
**paths.run_paths(name),
epochs=epochs or r["epochs"],
batch=128,
batch=global_batch,
imgsz=224,
patience=5,
nbs=512,
nbs=nbs,
cos_lr=True,
lr0=lr0,
lrf=0.01,
momentum=r["momentum"],
weight_decay=r["weight_decay"],
warmup_epochs=r["warmup_epochs"],
warmup_epochs=warmup_epochs,
warmup_bias_lr=0,
dropout=0,
optimizer="AdamW",