From 846dc2466676a6773b14456eec4212603e1a88d8 Mon Sep 17 00:00:00 2001 From: Fatih Akyon Date: Tue, 21 Apr 2026 02:14:49 -0500 Subject: [PATCH] feat: add --batch per-GPU auto-scaler to phase1 Ultralytics scales wd_eff with batch*accumulate/nbs but never scales lr0, so larger global batches silently drift from the recipe's intended dynamics. The new flag takes a per-GPU batch, computes global = per_gpu * world_size, and derives lr0, nbs, and warmup_epochs from scale = max(1, global / NBS_CANONICAL=512) so wd_eff stays at the recipe value while per-sample lr and optimizer-step warmup count are invariant. --- run_enc_distill_phase1.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/run_enc_distill_phase1.py b/run_enc_distill_phase1.py index 7c6a12383e..90b7d44bc9 100644 --- a/run_enc_distill_phase1.py +++ b/run_enc_distill_phase1.py @@ -20,6 +20,11 @@ RECIPES = { "radio": dict(lr0=1e-3, weight_decay=0.02, warmup_epochs=1, epochs=30, momentum=0.9, grad_clip=1.0, beta2=0.95), } +# Reference global step-batch the recipes' lr0 and warmup_epochs are tuned for. When +# per_gpu_batch * world_size exceeds this, lr0 and warmup_epochs scale linearly and nbs rises +# to the global batch so wd_eff stays at the recipe value. +NBS_CANONICAL = 512 + def _pop_flag(argv: list[str], flag: str, is_bool: bool = False) -> tuple[list[str], str]: """Pop a --flag [value] pair from argv, return (remaining_argv, value). @@ -51,6 +56,10 @@ def main(argv: list[str]) -> None: --cos_weight : cosine loss weight (default 0.9) --l1_weight : smooth L1 loss weight (default 0.1) --cls_l1: add smooth L1 to CLS token loss (default False) + --lr : override recipe lr0 (applied before batch scaling) + --batch : per-GPU (per-rank) batch. Global batch = per-GPU * world_size. When the + global batch exceeds NBS_CANONICAL (512), lr0 and warmup_epochs scale linearly and + nbs is raised to the global batch so wd_eff is invariant. """ args = argv[1:] args, resume = _pop_flag(args, "--resume") @@ -58,6 +67,7 @@ def main(argv: list[str]) -> None: args, l1_w = _pop_flag(args, "--l1_weight") args, cls_l1_str = _pop_flag(args, "--cls_l1", is_bool=True) args, lr_override = _pop_flag(args, "--lr") + args, batch_override = _pop_flag(args, "--batch") args, fork_from = _pop_flag(args, "--fork_from") # format: : cos_weight = float(cos_w) if cos_w else 0.9 @@ -77,7 +87,13 @@ def main(argv: list[str]) -> None: data = args[5] if len(args) > 5 else "/data/shared-datasets/datacomp-12m" epochs = int(args[6]) if len(args) > 6 else None r = RECIPES[recipe] - lr0 = float(lr_override) if lr_override else r["lr0"] + + world_size = len(gpu.split(",")) if "," in gpu else 1 + global_batch = int(batch_override or 64) * world_size # default per-GPU = 64 (anchor per-rank) + scale = max(1.0, global_batch / NBS_CANONICAL) + lr0 = float(lr_override or r["lr0"]) * scale + nbs = max(global_batch, NBS_CANONICAL) + warmup_epochs = r["warmup_epochs"] * scale model = YOLO(model_yaml) if r["grad_clip"]: @@ -112,16 +128,16 @@ def main(argv: list[str]) -> None: device=gpu, **paths.run_paths(name), epochs=epochs or r["epochs"], - batch=128, + batch=global_batch, imgsz=224, patience=5, - nbs=512, + nbs=nbs, cos_lr=True, lr0=lr0, lrf=0.01, momentum=r["momentum"], weight_decay=r["weight_decay"], - warmup_epochs=r["warmup_epochs"], + warmup_epochs=warmup_epochs, warmup_bias_lr=0, dropout=0, optimizer="AdamW",