mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
* Add ROCm detection to install.sh and expand shell tests Add AMD ROCm GPU detection to get_torch_index_url() in install.sh. When nvidia-smi is not found, probe for ROCm via amd-smi, /opt/rocm version file, hipconfig, dpkg-query, and rpm. Includes validation guard for malformed _rocm_tag, Debian epoch prefix stripping, ROCm 7.2+ cap to rocm7.1 index, bitsandbytes AMD install, and status messaging. Shell tests expanded to 23 cases. Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Add ROCm torch reinstall support to install_python_stack.py Add _detect_rocm_version() and _ensure_rocm_torch() to detect when a Linux host has ROCm but the venv received CPU-only torch, and reinstall with the correct ROCm wheels. Covers ROCm 6.0 through 7.1 with a 30-second timeout on the torch GPU probe subprocess. Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Add ROCm support to llama.cpp prebuilt installer Add has_rocm field to HostInfo, extend detect_host() to probe for ROCm via hipcc/amd-smi/rocm-smi/ROCM_PATH, and route ROCm hosts to upstream prebuilts (Linux ROCm 7.2 prebuilt with source fallback, Windows HIP prebuilt with CPU fallback). Add linux-rocm and windows-hip install kinds to runtime_patterns_for_choice(). Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Add IS_ROCM hardware flag and fix AMD error message Add IS_ROCM flag to hardware.py detect_hardware() (set when torch.version.hip is present, DeviceType stays CUDA). Export IS_ROCM from __init__.py. Add "rocm" key to get_package_versions(). Replace "We do not support AMD" error in tokenizer_utils.py with a helpful message pointing to ROCm installation docs. Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Add comprehensive ROCm support test suite (68 tests) Add tests/studio/install/test_rocm_support.py covering all ROCm code paths across install_llama_prebuilt.py, install_python_stack.py, hardware.py, tokenizer_utils.py, and install.sh. All tests use mocks and run without AMD hardware. Covers: asset selection (11), runtime patterns (5), HostInfo (4), ROCm version detection (9), torch reinstall (9), index mapping (8), hardware flag (8), tokenizer message (2), install.sh structure (10), and live regression (1). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Harden ROCm support: probe error handling, version cap, validation Address review findings from 8 independent reviewers: - Wrap _ensure_rocm_torch() torch probe in try/except for TimeoutExpired and OSError so a hung or broken torch import does not crash the installer (8/8 reviewers flagged this) - Add torch>=2.4,<2.11.0 version cap to the ROCm reinstall path to prevent installing unsupported torch 2.11.0 from the rocm7.1 index - Use with-statement for file reads in _detect_rocm_version() to avoid resource leaks - Handle ROCM_PATH="" correctly (use `or "/opt/rocm"` instead of default parameter to avoid relative path resolution) - Strengthen shell validation guard from rocm[0-9] to rocm[1-9] to reject rocm0.x tags that would produce nonexistent PyTorch index URLs - Switch shell version cap from blocklist to allowlist (rocm6.*|rocm7.0* |rocm7.1* pass through, everything else caps to rocm7.1) so future ROCm 10+ does not fall through to a nonexistent index - Add sorted() to _ROCM_TORCH_INDEX lookup for defensive ordering - Fix test_probe_timeout_handled: replace zero-assertion test with proper assertions verifying reinstall proceeds after timeout * Clean up rocm_paths list construction in detect_host() Filter None from the ROCM_PATH env var lookup at list construction time instead of relying on the inline `if p` guard in the any() call. * Require actual AMD GPU presence before selecting ROCm paths All 8 reviewers across 2 cycles independently flagged that ROCm detection used toolkit/filesystem hints (hipcc, /opt/rocm, rocm-core) as a proxy for GPU presence, which would misroute CPU-only or NVIDIA hosts that happen to have ROCm tools installed. Now all 3 detection points (install.sh, install_python_stack.py, install_llama_prebuilt.py) probe for an actual AMD GPU before entering the ROCm path: - install.sh: check rocminfo for gfx* GPU names, or amd-smi list for device rows, before version detection - install_python_stack.py: new _has_rocm_gpu() function probes rocminfo and amd-smi list before _ensure_rocm_torch() proceeds - install_llama_prebuilt.py: detect_host() probes rocminfo/amd-smi list instead of just checking tool existence or directory paths Also: - Shell test mock amd-smi now handles "list" subcommand - Python tests updated to mock _has_rocm_gpu where needed - Added test_no_gpu_with_rocm_tools_skips to verify the new guard - Test index lookups now use sorted() to match production code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Harden hipconfig version parsing and torch probe compatibility - Add parts[1].isdigit() check in hipconfig version parsing to handle versions like "6.3-HIP" where the minor component has non-numeric suffix (strip "-" prefix before int() conversion) - Use getattr() in torch probe subprocess to safely handle old or custom torch builds that may lack torch.version.hip/cuda attributes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Strengthen AMD GPU detection and add NVIDIA precedence guard - Change amd-smi list detection from any-non-empty-output to requiring "gpu" marker in output, matching the shell-side NR>1 check. Prevents false positives from header-only amd-smi list output. - Add nvidia-smi check at the top of _ensure_rocm_torch() so mixed AMD+NVIDIA hosts preserve NVIDIA precedence (matching install.sh and install_llama_prebuilt.py behavior). - Apply the same amd-smi marker fix to install_llama_prebuilt.py detect_host() for consistency. * Add Windows-specific ROCm/HIP detection in detect_host() The previous detect_host() ROCm check used rocminfo and amd-smi list which are Linux-only tools. On Windows, has_rocm would always be False, making the Windows HIP prebuilt path at line 1794 unreachable. Now detect_host() uses platform-specific detection: - Linux: rocminfo (check for gfx GPU names) or amd-smi list - Windows: hipinfo.exe, amd-smi, or amdhip64.dll on PATH This allows Windows AMD users to get the HIP prebuilt binary instead of silently falling through to the CPU prebuilt. * Add AMD ROCm gaps: Mamba/SSM source builds, GPU monitoring, Windows messaging, RDNA expansion - worker.py: Add HIP detection to causal-conv1d/mamba-ssm probe, check for hipcc before ROCm source builds, improve status messages and error reporting, add timeout and uv support for the source build fallback - amd.py: New AMD GPU monitoring module via amd-smi metric --json, mirroring nvidia.py structure (utilization, temperature, power, VRAM) - hardware.py: Branch to amd.py when IS_ROCM is True for GPU utilization, visible GPU queries, and physical GPU count - install_python_stack.py: Detect AMD GPUs on Windows and warn that ROCm-enabled PyTorch must be installed manually - kernels/utils.py: Expand is_rdna() to cover RDNA2 (gfx1030-1032), RDNA3 (gfx1102-1103), RDNA3.5 (gfx1150-1152) alongside existing entries - tests: Add 32 new tests covering all changes (95/95 pass) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Harden ROCm detection, fix VRAM heuristic, and expand RDNA2 coverage - Windows ROCm detection: validate actual GPU presence via hipinfo/amd-smi output markers instead of just checking tool existence on PATH - _ensure_rocm_torch: validate nvidia-smi actually reports a GPU before giving NVIDIA precedence (fixes AMD-only hosts with stale NVIDIA tools) - amd.py _parse_numeric: handle dict-shaped metric objects from newer amd-smi versions ({"value": 10, "unit": "W"}) and strip MiB/GiB units - amd.py VRAM heuristic: raise threshold from 100k to 10M to correctly handle MI300X (192 GB = 196608 MB) and other high-VRAM GPUs - amd.py visible GPU: use AMD-reported GPU IDs instead of enumerate index so non-dense sets like CUDA_VISIBLE_DEVICES=1,3 report correctly - install.sh: add ROCm <6.0 minimum version guard (no PyTorch wheels exist for older versions); fix rocm7.1* glob to not match rocm7.10+ - is_rdna: add gfx1033-1036 for RDNA2 mobile GPUs (RX 6600M etc.) - worker.py: increase ROCm source build timeout from 600s to 1800s; fix success log message for ROCm source builds - Tests: update mocks for _has_usable_nvidia_gpu, add RDNA2 target asserts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add HIP_VISIBLE_DEVICES support, unit-aware VRAM parsing, Windows GPU validation - hardware.py: check HIP_VISIBLE_DEVICES and ROCR_VISIBLE_DEVICES on ROCm before falling back to CUDA_VISIBLE_DEVICES, so multi-GPU AMD setups with HIP-specific env vars report the correct visible device set - amd.py: add _parse_memory_mb() that reads "unit" from dict-shaped amd-smi JSON (e.g. {"value": 192, "unit": "GiB"}) and converts to MB correctly; fixes MI300X VRAM misreported as 0.19 GB instead of 192 GB - install_python_stack.py: Windows AMD warning now validates actual GPU presence via hipinfo/amd-smi output markers before printing - install_llama_prebuilt.py: restore amdhip64.dll fallback for Windows HIP detection after tool-based checks, so Windows HIP installs without CLI tools on PATH are still detected - hardware.py: fix IS_ROCM comment to accurately describe its role * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix HIP_VISIBLE_DEVICES empty-string handling in GPU visibility spec Use explicit None checks instead of Python `or` operator when reading HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES, so that an empty string ("") is correctly honored as "no visible GPUs" rather than silently falling through to CUDA_VISIBLE_DEVICES on mixed ROCm+CUDA systems. * Fix IS_ROCM test assertion for multi-line formatting * Cap torchvision/torchaudio versions, remove amdhip64.dll fallback, fix visible GPU count - Cap torchvision<0.26.0 and torchaudio<2.11.0 alongside torch<2.11.0 in both install.sh and install_python_stack.py to prevent resolver from selecting incompatible companion packages from ROCm wheel index - Remove amdhip64.dll fallback in Windows ROCm detection (DLL presence without hipinfo/amd-smi is not proof of GPU existence) - Fix get_visible_gpu_count() to use _get_parent_visible_gpu_spec() which respects HIP_VISIBLE_DEVICES/ROCR_VISIBLE_DEVICES on ROCm hosts * Attribute is_rdna() RDNA2/3/3.5/4 expansion to PR #4428 The is_rdna() expansion to cover RDNA2 (gfx1030-1036), RDNA3 (gfx1100-1103), RDNA3.5 (gfx1150-1152), and RDNA4 (gfx1200-1201) architectures is based on the original work from PR #4428. Co-authored-by: GoldenGrapeGentleman <yueyuan@amd.com> Co-authored-by: billishyahao <bill.he@amd.com> * Support AMD Radeon for studio (#4770) Co-authored-by: Iswarya Alex <iswarya.alex@amd.com> * Remove ROCm test files from main PR Move test_rocm_support.py and shell test additions to a separate PR to keep the main ROCm support PR focused on implementation changes. * Fix installer and hardware detection issues for PR #4720 - Fix empty _tri_arg passed to uv pip install in Radeon path (causes "Empty field is not allowed for PEP508" error) - Fix Radeon fallback: use ROCm index instead of CPU-only when repo.radeon.com is unreachable (TORCH_INDEX_URL already has ROCm) - Use $TORCH_CONSTRAINT in fallback paths instead of hardcoded strings - Fix _pick_radeon_wheel: relax suffix to match manylinux_2_28_x86_64 wheels (AMD Radeon repo does not use bare linux_x86_64 platform tag) - Fix IS_ROCM export: use __getattr__ so callers always see the live value after detect_hardware() runs - Fix apply_gpu_ids: set HIP_VISIBLE_DEVICES and ROCR_VISIBLE_DEVICES on ROCm so _get_parent_visible_gpu_spec picks up narrowed GPU set - Fix _parse_memory_mb: distinguish GB (1000 MB) from GiB (1024 MiB) - Add amd-smi version as a fallback in _detect_rocm_version - Fix trailing whitespace and missing newline at EOF in install.sh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix GPU detection false positives and add missing health groups - Fix _has_rocm_gpu() false positive: require "GPU: <number>" data rows from amd-smi list, not just header containing "gpu" - Apply same fix in detect_host() in install_llama_prebuilt.py - Add runtime_payload_health_groups for linux-rocm and windows-hip so partial/corrupt ROCm/HIP prebuilt installs are properly detected - Add bitsandbytes install to Radeon fallback paths (was only in the success path, skipped when repo.radeon.com was unreachable) - Keep DEVICE/CHAT_ONLY as direct imports in __init__.py (matching main) and only use __getattr__ for IS_ROCM * Fix _ensure_rocm_torch and Windows AMD warning false positives - _ensure_rocm_torch: only skip when HIP is already present, not for CUDA builds (which are unusable on AMD-only hosts). Fixes the case where a venv has a stale CUDA wheel and the repair step is skipped. - Windows AMD warning: use GPU data row check (same as Linux fix) to avoid false positives from amd-smi list header-only output. * Fix amd-smi GPU detection for GPU[N] output format Older amd-smi versions output "GPU[0] : Card series: ..." instead of "GPU: 0". The regex now matches both "GPU: <digit>" and "GPU[<digit>" formats to detect actual GPU data rows. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Harden AMD GPU detection against false positives - install.sh: replace weak amd-smi list check (awk 'NR>1 && NF') with strict pattern matching GPU data rows (/^GPU[[:space:]]*[:\[]/) - All files: reject rocminfo gfx000 (CPU HSA agent) by requiring gfx[1-9] instead of gfx[0-9] in the rocminfo GPU probe - Fixes false positives on hosts with ROCm tools but no AMD GPU * Remove duplicate comment from pre-commit merge * Refactor: deduplicate AMD detection, consolidate bitsandbytes, clean up imports - Extract _has_amd_rocm_gpu() shell function to avoid duplicating the rocminfo/amd-smi GPU detection logic in get_torch_index_url and the Radeon auto-detect block - Consolidate bitsandbytes install into a single case block after torch install (was duplicated 4 times across Radeon success/fallback paths) - Move math and re imports to top of amd.py (were inline in functions) - Add _smi_query() helper in hardware.py to centralize IS_ROCM backend selection for get_gpu_utilization and get_visible_gpu_utilization Addresses Gemini code review suggestions. * Fix VRAM parsing for string values and GB/GiB consistency - Extract unit from string-valued VRAM fields (e.g. "192 GiB") so _parse_memory_mb correctly applies the unit multiplier instead of treating the value as bare MB - Treat GB and GiB identically (both as binary x1024) since GPU tools including amd-smi use binary units even when labeling them "GB" - Fixes incorrect VRAM reporting on MI300-class cards (was showing ~0.19 GB instead of 192 GB for string-valued outputs) * Add --no-cache to uv for ROCm HIP source builds Avoid stale cache artifacts from partial HIP source builds when uv is used for causal-conv1d/mamba-ssm compilation on ROCm. The pip path already uses --no-cache-dir; this adds the uv equivalent (--no-cache) only when is_hip is True. * Fix critical: initialize _amd_gpu_radeon before case block _amd_gpu_radeon was only set inside the */rocm*) case arm, so on NVIDIA/CPU/macOS paths where TORCH_INDEX_URL does not contain "rocm", the variable was unbound. With set -u (nounset) enabled, this crashes the installer for every non-AMD user. Move initialization to before the case block so it is always defined. * Fix Windows AMD: route has_rocm hosts to HIP prebuilt path resolve_release_asset_choice was selecting windows-cpu for all Windows x86_64 hosts including those with has_rocm=True. Windows AMD users should fall through to resolve_upstream_asset_choice which tries the HIP prebuilt first. Add "not host.has_rocm" guard to the published windows-cpu selection. * Harden ROCm detection, Radeon wheel fallback, and HIP visibility Addresses review findings from parallel reviewers on PR #4720: - install.sh: add _has_usable_nvidia_gpu() helper requiring nvidia-smi -L to actually list a GPU before treating the host as NVIDIA. Fixes the stale-nvidia-smi-on-PATH regression where AMD-only hosts fell into the CUDA branch. - install.sh: fix hipconfig awk blocks to propagate a non-zero exit code when the output is not a recognisable version string, so the ||-chain continues to dpkg-query / rpm instead of terminating early. - install.sh: fail-closed on Radeon wheel fallback. When torch, torchvision or torchaudio is missing from the Radeon repo for the active Python tag, fall back to the standard ROCm index instead of silently mixing Radeon wheels with PyPI defaults. Quote all wheel arguments individually so wheel filenames cannot be word-split or glob-expanded. - install_llama_prebuilt.py: detect_host() now requires nvidia-smi -L to list a GPU before setting has_physical_nvidia. Routes AMD ROCm hosts with a broken leftover nvidia-smi to the ROCm path instead of misclassifying them as NVIDIA. - install_llama_prebuilt.py: scan upstream assets for any rocm-<version> prebuilt instead of hard-coding rocm-7.2, so ROCm 6.x / 7.0 / 7.1 / 7.3+ users pick up a matching upstream prebuilt when one exists. - install_llama_prebuilt.py: validate_server() adds --n-gpu-layers 1 for linux-rocm and windows-hip hosts, so new HIP prebuilts are preflighted on the GPU path instead of passing validation on CPU only. - install_llama_prebuilt.py: restore the published windows-cpu fallback for AMD Windows hosts without a HIP prebuilt so hash-approved bundles are still preferred over the raw upstream CPU asset. - install_python_stack.py: drop the /opt/rocm / hipcc gate in _ensure_rocm_torch() and rely on _has_rocm_gpu(). Runtime-only ROCm installs (package-managed minimal installs, Radeon software) that ship amd-smi / rocminfo without hipcc can now repair a CPU-only venv via "unsloth studio update". Adds an explicit IS_WINDOWS / IS_MACOS guard. - studio/backend/utils/hardware/amd.py: honour HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES / CUDA_VISIBLE_DEVICES in get_primary_gpu_utilization(). A process restricted to GPU 2 now reports metrics for GPU 2 instead of physical GPU 0. Tighten the plain bytes unit detection to an explicit allowlist. - studio/backend/utils/hardware/hardware.py: route get_backend_visible_gpu_info()'s backend_cuda_visible_devices field through a helper that reads HIP_VISIBLE_DEVICES on ROCm. Drop the unconditional "(rocm=False)" suffix in apply_gpu_ids() logs. * Fix round 2 regressions: ROCm validate_server and Windows HIP routing Follow-up to810b833baddressing review findings on the first round of hardening commits: - install_llama_prebuilt.py validate_server: gate --n-gpu-layers on the resolved install_kind instead of host.has_rocm. AMD Windows hosts without a HIP prebuilt fall back to windows-cpu and must not be validated with GPU layers; thread install_kind through from the caller. - install_llama_prebuilt.py resolve_release_asset_choice: reinstate the "not has_rocm" guard on the published windows-cpu bundle so AMD Windows hosts reach resolve_upstream_asset_choice() where the new HIP prebuilt path lives. Prefer a published windows-hip bundle first when one exists, fall through to upstream HIP + upstream CPU otherwise. - install_llama_prebuilt.py detect_host: also set has_physical_nvidia when the secondary --query-gpu block confirms a working NVIDIA GPU, so older nvidia-smi versions without -L support do not silently skip the Linux diagnostics that key off has_physical_nvidia. - install_llama_prebuilt.py: drop redundant "import re as _re" / "import re as _re_rocm" local aliases in favour of the existing top-level "import re". - install_python_stack.py _ensure_rocm_torch: run the AMD bitsandbytes install unconditionally after the HIP-torch probe so "unsloth studio update" on venvs that already have ROCm torch still gains the AMD bitsandbytes build. - install.sh: add a non-x86_64 early-exit to get_torch_index_url() so aarch64 / arm64 Linux hosts do not hit the ROCm wheel index (PyTorch only publishes ROCm wheels for linux_x86_64). - install.sh: add bitsandbytes install to the migrated-environment branch so upgrades pick it up for ROCm hosts instead of only the fresh-install path. - install.sh: in the Radeon wheel path, pass version constraints + --no-index --find-links to uv instead of explicit wheel URLs so a version-compatible torch / torchvision / torchaudio triple is resolved, rather than picking the highest-version wheel for each package independently. - studio/backend/utils/hardware/amd.py _first_visible_amd_gpu_id: fall through to lower-priority visibility env vars when the first entry is malformed (leading comma, all-whitespace first token) instead of silently returning GPU 0. * Fix round 3 findings: x86_64 guard, ROCm version clip, Radeon deps Address issues surfaced by the round 3 reviewers on top of8636fa63: - install_python_stack.py _ensure_rocm_torch: add the same `x86_64` guard that install.sh already has. Linux aarch64 / arm64 ROCm hosts must skip the repair path entirely; PyTorch only publishes ROCm wheels for linux_x86_64, and without this guard `unsloth studio update` aborts with a missing-wheel error on non x86_64 hosts. - install_llama_prebuilt.py resolve_upstream_asset_choice: add a best-effort _detect_host_rocm_version() helper (reading /opt/rocm/.info/version, amd-smi version, hipconfig --version) and filter rocm_candidates to entries whose major.minor is <= host version. Falls back to the newest candidate only when no compatible one exists, so a ROCm 6.4 host downloads rocm-6.4 instead of being handed the numerically newest rocm-7.2 bundle (which fails preflight and forces a source build). - install.sh: remove the round 2 --no-index switch from the Radeon wheel branch. --no-index forced uv to ignore PyPI entirely, which broke transitive dependency resolution (filelock, sympy, networkx, jinja2, fsspec, setuptools, typing-extensions, ...) on a fresh venv. Restore the round 1 explicit wheel URL invocation but add a torch / torchvision / torchaudio version-pair sanity check so a mismatched trio (e.g. torch 2.9.1 + torchvision 0.23.0 + torchaudio 2.9.0) falls back to the standard ROCm index instead of installing a broken combination. - install_python_stack.py _ensure_rocm_torch: restructure the "tag is None" path so it no longer short-circuits the bitsandbytes install. On a ROCm runtime older than anything in _ROCM_TORCH_INDEX, print the "no wheel" warning but still run the AMD bitsandbytes install. - studio/backend/core/training/worker.py: restore the pre-PR "no timeout" behaviour for non-HIP causal-conv1d / mamba-ssm source builds. The round 2 "timeout = 1800 if is_hip else 300" cap aborts slow non-HIP builds (Linux aarch64, unsupported torch/CUDA combos) after 5 minutes; omit timeout for the non-HIP branch so the cap only applies to ROCm source builds. * Fix round 4 findings: apply_gpu_ids env inheritance, Radeon X.Y, bitsandbytes gate Address remaining issues surfaced by the round 4 reviewers: - studio/backend/utils/hardware/hardware.py apply_gpu_ids: mirror the selection into HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES whenever the caller already had a ROCm visibility env var set, not only when IS_ROCM has already been set by detect_hardware(). Training and inference workers call apply_gpu_ids() before detect_hardware() runs, so the old guard would leave a forked ROCm worker with a stale HIP_VISIBLE_DEVICES mask that no longer matched the narrowed CUDA_VISIBLE_DEVICES selection. - install.sh get_radeon_wheel_url: accept X.Y ROCm versions in addition to X.Y.Z. The `/opt/rocm/.info/version` file and some hipconfig versions report only two components, and the Radeon repository publishes both rocm-rel-X.Y.Z/ and rocm-rel-X.Y/ directories, so treating X.Y as invalid caused Radeon hosts to fall back to the generic ROCm index even when a matching AMD wheel set existed. - install_python_stack.py _ensure_rocm_torch: only install the AMD bitsandbytes build when the venv actually has a ROCm-compatible torch (either already present or just installed by this function). Previously the bitsandbytes install ran unconditionally, which could leave an AMD bitsandbytes layered on top of a CPU/CUDA torch on hosts where the ROCm runtime is older than any entry in _ROCM_TORCH_INDEX. Also add --force-reinstall so an existing CPU/CUDA bitsandbytes is replaced by the AMD build during upgrades. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix gemini findings: amd-smi metric envelope validation and dict-wrapped GPU id Two medium-severity defensive fixes from the gemini-code-assist review on the AMD monitoring backend: 1. _extract_gpu_metrics may return a dict where every value is None when amd-smi succeeds (zero exit) but the JSON envelope contains no usable fields (error response, unsupported card). The new _has_real_metrics helper lets get_primary_gpu_utilization surface available:False and lets get_visible_gpu_utilization skip ghost device rows so the UI does not render placeholder cards with empty numbers. 2. Newer amd-smi versions wrap scalar fields as {"value": 0, "unit": "none"}, including the per-GPU id. The previous int(raw_id) call silently fell back to the enumeration index in that case, losing the real GPU id. Routing raw_id through the existing _parse_numeric helper handles bare ints, floats, strings, and the dict shape uniformly, with a debug log on parse failure. * Fix gemini round 2 findings: explicit length guard on ROCm version file parser Both _detect_rocm_version (install_python_stack.py) and _detect_host_rocm_version (install_llama_prebuilt.py) read /opt/rocm/.info/version or $ROCM_PATH/lib/rocm_version, split on "." and unconditionally accessed parts[1]. The surrounding broad `except Exception: pass` already swallowed the resulting IndexError, so a one-component file like "6\n" did fall through to the next detection source -- but the control flow relied on exception handling instead of an explicit check. Add `if len(parts) >= 2:` guards in both helpers so the loop falls through on its own without raising. Behaviour is unchanged for the common multi- component case; the previously-silent IndexError path becomes an explicit no-op. * Fix gemini round 3: include has_rocm in validate_server fallback path When validate_server is called without an explicit install_kind (older call sites that have not been updated), the fallback was only enabling --n-gpu-layers for NVIDIA and macOS arm64 hosts. AMD ROCm Linux hosts fell through to the CPU validation path even though the prebuilt being exercised was a HIP binary. Add host.has_rocm to the fallback expression so the GPU offload flag is applied consistently with the install_kind=='linux-rocm' / 'windows-hip' branches above. * Fix gemini round 4: remove risky bytes-vs-MB heuristic in _parse_memory_mb The previous heuristic divided any bare number above 10_000_000 by 1024*1024 on the assumption that large unit-less values were bytes. This misclassified small VRAM allocations: 5 MB of used VRAM reported as 5_242_880 bytes without a unit would be taken at face value and render as 5_242_880 MB (~5 TB) in the monitoring UI. Modern amd-smi always provides explicit units (MiB/GiB dict form), and legacy amd-smi returns bare numbers in MB -- the heuristic never had a real workload to handle. Drop it and default to MB for bare numeric input, keeping the existing unit-aware branches for dict / string inputs unchanged. The unrelated gemini suggestion to "default minor to 0" in the amd-smi version awk parser was intentionally NOT applied: rocm7.0 and rocm7.1 ship different wheel sets, so silently substituting 0 for a missing minor could install the wrong wheels. The existing reject-and-fall-through behaviour is safer. * Fix gemini round 5: POSIX compliance and leading-comma visibility parsing Three medium findings from gemini-code-assist addressed in this commit: 1. _pick_radeon_wheel used grep -o and sort -V, both GNU extensions that are not in POSIX and break on BSD/BusyBox coreutils. install.sh has a #!/bin/sh shebang so the whole pipeline was rewritten as a single awk script that extracts all href="..." hits on each line, filters to wheels matching the package prefix and python tag, and picks the newest version via zero-padded lexical comparison. No external sort or grep is needed. 2. _first_visible_amd_gpu_id in the AMD monitoring backend treated a leading comma (e.g. HIP_VISIBLE_DEVICES=",1") as "fall through to the next env var", which is surprising given the clear intent to narrow to device 1. Filter empty tokens after the split and return the first real one. An all-commas value ("," / ",,,") still falls through because no real tokens exist; the empty-string and "-1" explicit-zero cases are unchanged. The unrelated amd-smi version awk parser suggestion was not applied (see round 4 commit message for rationale: defaulting a missing minor to 0 could silently install the wrong ROCm wheel set). * Fix 20-reviewer.py findings: base drift, Radeon %2B, dpkg/rpm fallback, bnb, backend label Consolidated fix batch from a 20-parallel reviewer.py run on the current head. Each fix is drawn from a high-consensus finding and addresses a real bug or feature gap, not a stylistic preference. 1. install.sh: bump `unsloth>=2026.4.2` -> `unsloth>=2026.4.4` at five call sites so this branch no longer regresses main's version floor (main bumped to 2026.4.4 in #4876). Without this, merging 4720 would silently downgrade the minimum version pin for fresh installs. 2. install.sh: URL-decode Radeon wheel names before extracting the torch / torchvision / torchaudio version strings. Real wheel URLs from repo.radeon.com are percent-encoded ("torch-2.10.0%2Brocm7.2.0...") so the previous `[+-]` terminator in the sed regex never matched, `_torch_ver` stayed empty, `_radeon_versions_match` stayed false, and every Radeon consumer install silently fell back to the generic ROCm index. Now decode %2B -> + first, then extract, then validate. 3. install.sh: the two AMD bitsandbytes install lines were running `uv pip install "bitsandbytes>=0.49.1"` without `--force-reinstall`, so upgrades where the venv already has a CPU/CUDA bitsandbytes satisfying the constraint would keep the stale non-AMD wheel. Add `--force-reinstall --no-cache-dir` to both call sites, matching the pattern already used in install_python_stack.py::_ensure_rocm_torch. 4. install_python_stack.py and install_llama_prebuilt.py: add `dpkg-query -W rocm-core` and `rpm -q rocm-core` fallbacks to the Python-side ROCm version detectors so they match the chain in install.sh::get_torch_index_url. Package-managed ROCm installs (Debian/Ubuntu/RHEL/Fedora distro packages) can expose GPUs via rocminfo/amd-smi but still lack /opt/rocm/.info/version, hipconfig, or amd-smi `version` output -- without these fallbacks, `unsloth studio update` on such hosts returned None and skipped the ROCm torch repair. Also strip the dpkg epoch prefix ("1:6.3.0-1") before parsing so epoch-annotated packages parse correctly. 5. hardware.py: add a `_backend_label(device)` helper that returns "rocm" when IS_ROCM is set and the device is DeviceType.CUDA, and use it for every `"backend": ...` emission in JSON responses served to the Studio frontend. Internally we still represent ROCm hosts as DeviceType.CUDA (ROCm torch reuses the whole torch.cuda.* API surface), but the user-facing API now correctly reports "rocm" on AMD boxes instead of labeling them as "cuda". All 250 simulation scenarios pass (was 233 before this batch: added 17 new regression tests covering the version pin, %2B decoding, bnb force-reinstall flags, dpkg/rpm fallback presence, and the _backend_label helper's four-way truth table). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix gemini round 6 + URL audit: amd.py defensive checks, rocm6.5+ clip to 6.4 Two rounds of fixes in one commit, plus a full URL audit of every PyPI / download.pytorch.org / repo.radeon.com reference the PR introduces. amd.py (4 medium gemini findings on commitb3627bc2): 1. _extract_gpu_metrics used `and vram_total_mb` as part of the vram_util gate. The follow-up `vram_total_mb > 0` already handles the division guard, but the truthiness check was redundant and slightly surprising for a 0.0 valid value. Replace with explicit `is not None and > 0` for both vram_util and power_util. 2. get_physical_gpu_count called `data.get("gpu", ...)` without guarding for non-dict envelopes. A scalar / string JSON response from amd-smi would raise AttributeError. Add an isinstance(data, dict) check and return None for unexpected shapes. 3. get_visible_gpu_utilization had the same .get() exposure on the outer envelope. Rewrite the gpu_list extraction as an explicit list/dict/else cascade so a malformed scalar envelope produces gpu_list=[data] and continues without raising. 4. The same function's per-entry loop also called gpu_data.get() on whatever was inside gpu_list. If a scalar ever leaks into the list (directly or via the previous fix's fallback), _extract_gpu_metrics would raise on the first .get() inside the helper. Skip non-dict entries in the loop before extracting metrics. install.sh (URL audit finding, previously flagged by 20-reviewer as #13): 5. get_torch_index_url used `rocm6.*` in the rocm tag case statement, which matched rocm6.5 and rocm6.6 and emitted download.pytorch.org/whl/rocm6.5 -- which returns HTTP 403 because PyTorch only publishes rocm 5.7, 6.0-6.4, 7.0-7.2. Enumerate the supported 6.x minors explicitly and add a rocm6.* fallback branch that clips to rocm6.4 (the last supported 6.x wheel set). URL audit results (all URLs PR 4720 references): - 14/14 download.pytorch.org/whl/{cpu,cu118,cu124,cu126,cu128,cu130, rocm6.0..6.4,rocm7.0..7.2} return HTTP 200. - 9/9 repo.radeon.com/rocm/manylinux/rocm-rel-{5.7,6.0,6.1,6.2,6.3, 6.4,7.0,7.1,7.2}/ return HTTP 200. - X.Y.Z patch directories exist for 7.0.2, 7.1.1, 7.2.1 but NOT for 6.3.0, 6.4.0, 6.2.1 -- install.sh already handles this via the X.Y.Z -> X.Y fallback sed in the Radeon wheel install block. - Docs links (rocm.docs.amd.com, docs.unsloth.ai AMD guide) and the llama.cpp GitHub releases API endpoint all return 200. Test suite: 255 -> 258. New regression coverage: - U17: get_physical_gpu_count tolerates scalar amd-smi envelope - U18: get_visible_gpu_utilization tolerates scalar envelope - U19a-c: vram_util / power_util return None on zero total, but vram_total_gb still echoes 0.0 (not None) - A_rocm{6.5,6.6,6.9}_clips_to_rocm64: install.sh clips unsupported 6.x minors to rocm6.4 instead of producing a 403 index URL * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix reviewer.py round 2: tokenizer AMD multi-GPU, --no-torch bnb, main.py backend label Three high-confidence findings from a second 20-parallel reviewer.py run on commit7effb3ae. Triaged 15 total findings and applied the three that were confirmed as real bugs; the rest were either false positives (e.g. "migrated AMD venv not repaired" -- _ensure_rocm_torch runs downstream via setup.sh regardless), design decisions (e.g. visibility mask env vars not consulted in installer detection), or edge cases the existing fallback logic already handles. 1. unsloth/tokenizer_utils.py [6/20]: the multi-GPU guard's shell probe runs `nvidia-smi --query-gpu=memory.used`, catches the failure, then only raises if `torch.cuda.is_available()` is False. On ROCm torch, torch.cuda.is_available() returns True (ROCm reuses the torch.cuda.* API), so the guard becomes dead code on AMD hosts and multi-GPU AMD setups slip through even though unsloth does not support them yet. Add a torch.cuda.device_count() > 1 fallback inside the except so AMD multi-visible-device setups are flagged consistently with the original CUDA memory check. 2. install.sh [1/20]: the fresh-install bitsandbytes block for AMD ROCm ran unconditionally when TORCH_INDEX_URL matched `*/rocm*`, even when SKIP_TORCH=true (from --no-torch or Intel Mac auto-detect). A user running `install.sh --no-torch` on an AMD host would still pull in bitsandbytes despite explicitly asking for GGUF-only mode. Wrap the case block in an outer `[ "$SKIP_TORCH" = false ]` guard. 3. studio/backend/main.py [3/20]: the /api/system endpoint returned `"device_backend": get_device().value`, which is "cuda" on ROCm hosts (because ROCm torch piggybacks on torch.cuda). Other endpoints (hardware.py) already use the _backend_label helper which swaps "cuda" -> "rocm" when IS_ROCM. Route /api/system through the same helper so the Studio UI reports the backend consistently across all endpoints. 4. studio/backend/tests/test_utils.py: update test_backend_matches_device to call _backend_label(get_device()) instead of raw get_device().value so the test matches the new contract and still passes on CUDA hosts. Tests: 258 -> 261. New regression coverage: - X08 main.py /api/system uses _backend_label - X09 tokenizer multi-GPU guard has device_count() fallback - X10 fresh-install bnb case block gated on SKIP_TORCH=false * fix: prevent bitsandbytes from overwriting ROCm torch with CUDA wheels During install, bitsandbytes was installed without --no-deps, causing uv to resolve torch from PyPI (CUDA build) and silently overwrite the ROCm wheels that were just installed in the previous step. This happened in three places: - install.sh: bitsandbytes install in both migrated and fresh paths - install_python_stack.py: bitsandbytes install inside _ensure_rocm_torch() Additionally, multiple install steps in install_python_stack.py (extras, overrides, studio deps) can pull in CUDA torch via transitive dependencies. A final _ensure_rocm_torch() call at the end of the install sequence ensures ROCm torch is always in place at runtime. All changes are gated behind ROCm-specific conditions and do not affect NVIDIA, CPU-only, macOS, or Windows install paths. Tested on AMD Instinct MI300X VF with ROCm 7.2.0 -- confirms torch==2.10.0+rocm7.1 with HIP 7.1.25424 after install. * fix: ROCm inference fallback -- skip Unsloth patching and bnb 4-bit on HIP On AMD ROCm (HIP), two issues prevent the normal Unsloth inference path: 1. Unsloth's global monkey-patching of transformers model classes (LlamaRotaryEmbedding, attention modules) triggers _assert_async_cuda_kernel crashes on HIP during generation. Training uses different code paths and works fine. 2. bitsandbytes 4-bit matmul kernels also trigger HIP assertion failures on MI300X (CDNA3 / gfx942), even without Unsloth patching. This commit adds a ROCm-specific inference fallback that: - Skips importing Unsloth at module level (prevents global patching) - Loads models in 16-bit with plain transformers + PEFT instead - Resolves pre-quantized model names (e.g. "xxx-bnb-4bit" -> "xxx") since pre-quantized HF repos still trigger bnb codepaths - Guards get_chat_template calls (unavailable without Unsloth import) - Fixes max_seq_length=0 being passed to from_pretrained (GGUF semantics don't apply to transformers path) The NVIDIA path is completely unchanged -- Unsloth import and for_inference() optimization remain active. GGUF inference (via llama-server/HIP) is unaffected since it never imports Python model classes. AMD GPUs typically have large VRAM (e.g. 192GB on MI300X) so 16-bit loading is practical for inference. Tested on AMD Instinct MI300X VF (ROCm 7.2, HIP 7.1.25424): - Simple generation: PASS - Compare mode (base vs finetuned): PASS - GGUF inference + tool calling: PASS (unaffected by this change) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: guard audio/vision inference on ROCm, remove unused import - Add clear RuntimeError for audio/vision model inference on ROCm (these paths use Unsloth's FastModel/FastVisionModel which would crash on HIP; GGUF inference is the supported path on AMD) - Remove unused `import os as _os` from the ROCm changes * fix: amd-smi parsing for newer output format (gpu_data wrapper, mem_usage, temperature) amd-smi on recent ROCm versions (7.x) wraps metric output in a {"gpu_data": [...]} envelope instead of returning a raw list. This caused get_primary_gpu_utilization() and get_visible_gpu_utilization() to fail silently (returning available=False) because the GPU data dict was never unwrapped. Additionally: - VRAM data moved from "vram" to "mem_usage" with "total_vram" / "used_vram" keys. Added fallback key lookup. - Temperature "edge" sensor returns "N/A" on MI300X VF; the previous dict.get() chain returned the "N/A" string instead of falling through to "hotspot". Changed to a loop that checks each key until a parseable value is found. Tested on AMD Instinct MI300X VF (ROCm 7.2, amd-smi 24.x): - GPU utilization: 0% (idle), up to 100% during training - Temperature: 40-44C (from hotspot sensor) - VRAM: 0.28/191.69 GB (idle) - Power: 158-211W draw * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bug fix detecting radeon (#4940) * Bug fix detecting radeon * Expanding GPU target for gfx1100* * Generalize gfx family-prefix filter to cover gfx10/gfx12 as well rocminfo on ROCm 6.1+ emits LLVM generic-family ISA lines alongside the specific GPU (e.g. gfx11-generic next to gfx1100). The outer grep captures the bare family prefix from the generic line, and passing that to -DGPU_TARGETS breaks the HIP build because clang only accepts specific gfxNNN ids. The previous filter only special-cased gfx11. Generalize it so any bare 2-digit family prefix (gfx10, gfx11, gfx12, ...) is dropped whenever a specific sibling target is present in the same list. No real AMD GPU has a 2-digit gfx id, so the filter can only ever drop family prefixes and never a real target. Covers the existing gfx11 cases unchanged, and extends the same fix to gfx10-1-generic / gfx10-3-generic (RDNA1/2) and gfx12-generic (RDNA4), which would otherwise hit the same build failure on newer rocminfo. --------- Co-authored-by: Iswarya Alex <iswarya.alex@amd.com> Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com> --------- Co-authored-by: Eda Z <eda.zhou@amd.com> Co-authored-by: GoldenGrapeGentleman <yueyuan@amd.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: billishyahao <bill.he@amd.com> Co-authored-by: Iswarya Alex <47045679+iswaryaalex@users.noreply.github.com> Co-authored-by: Iswarya Alex <iswarya.alex@amd.com> Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
1062 lines
33 KiB
Python
1062 lines
33 KiB
Python
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import importlib
|
|
import triton
|
|
import ctypes
|
|
|
|
MAX_FUSED_SIZE: int = 65536
|
|
next_power_of_2 = triton.next_power_of_2
|
|
import functools
|
|
from typing import Optional
|
|
|
|
from ..device_type import (
|
|
is_hip,
|
|
get_device_type,
|
|
DEVICE_TYPE,
|
|
DEVICE_TYPE_TORCH,
|
|
DEVICE_COUNT,
|
|
ALLOW_PREQUANTIZED_MODELS,
|
|
)
|
|
from .fp8 import weight_dequant, fp8_linear
|
|
import functools
|
|
|
|
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
|
|
import torch
|
|
|
|
torch_Tensor = torch.Tensor
|
|
from unsloth_zoo.utils import Version
|
|
|
|
if DEVICE_TYPE == "xpu" and Version(torch.__version__) < Version("2.6.0"):
|
|
raise RuntimeError(
|
|
"Intel xpu currently supports unsloth with torch.version >= 2.6.0"
|
|
)
|
|
|
|
if Version(torch.__version__) < Version("2.4.0"):
|
|
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
|
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
|
else:
|
|
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
|
|
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
|
|
|
|
if DEVICE_TYPE == "xpu":
|
|
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
|
|
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
|
|
|
|
|
|
# tl.math.tanh now is libdevice.tanh
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
if Version(triton.__version__) >= Version("3.0.0"):
|
|
if DEVICE_TYPE == "xpu":
|
|
triton_tanh = tl.extra.intel.libdevice.tanh
|
|
else:
|
|
from triton.language.extra import libdevice
|
|
|
|
triton_tanh = libdevice.tanh
|
|
triton_cast = tl.cast
|
|
else:
|
|
triton_tanh = tl.math.tanh
|
|
|
|
# No casting in old Triton versions
|
|
@triton.jit
|
|
def triton_cast(x, dtype):
|
|
return x.to(dtype)
|
|
|
|
|
|
@functools.lru_cache(1)
|
|
def is_cdna():
|
|
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
|
|
"gfx940",
|
|
"gfx941",
|
|
"gfx942",
|
|
"gfx950", # CDNA4 (MI350/MI355X)
|
|
)
|
|
|
|
|
|
@functools.lru_cache(1)
|
|
def is_rdna():
|
|
"""Detect ROCm-supported RDNA consumer/workstation GPUs (RDNA2, RDNA3, RDNA3.5, RDNA4)."""
|
|
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
|
|
# RDNA2 (Navi 21-24)
|
|
"gfx1030",
|
|
"gfx1031",
|
|
"gfx1032",
|
|
"gfx1033",
|
|
"gfx1034",
|
|
"gfx1035",
|
|
"gfx1036",
|
|
# RDNA3 (Navi 31-33)
|
|
"gfx1100",
|
|
"gfx1101",
|
|
"gfx1102",
|
|
"gfx1103",
|
|
# RDNA3.5 (Strix Point / Strix Halo)
|
|
"gfx1150",
|
|
"gfx1151",
|
|
"gfx1152",
|
|
# RDNA4 (Navi 48-44)
|
|
"gfx1200",
|
|
"gfx1201",
|
|
)
|
|
|
|
|
|
def calculate_settings(
|
|
n: int,
|
|
) -> (
|
|
int,
|
|
int,
|
|
):
|
|
BLOCK_SIZE: int = next_power_of_2(n)
|
|
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
|
raise RuntimeError(
|
|
f"Cannot launch Triton kernel since n = {n} exceeds "
|
|
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}."
|
|
)
|
|
num_warps: int = 4
|
|
if BLOCK_SIZE >= 32768:
|
|
num_warps = 32
|
|
elif BLOCK_SIZE >= 8192:
|
|
num_warps = 16
|
|
elif BLOCK_SIZE >= 2048:
|
|
num_warps = 8
|
|
return BLOCK_SIZE, num_warps
|
|
|
|
|
|
HAS_CUDA_STREAM = False
|
|
import bitsandbytes as bnb
|
|
|
|
# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
|
|
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
|
|
get_ptr = bnb.functional.get_ptr
|
|
|
|
if DEVICE_TYPE == "xpu":
|
|
HAS_XPU_STREAM = True
|
|
|
|
if DEVICE_COUNT > 1:
|
|
if DEVICE_TYPE in ("cuda", "hip"):
|
|
torch_gpu_device = torch.cuda.device
|
|
elif DEVICE_TYPE == "xpu":
|
|
torch_gpu_device = torch.xpu.device
|
|
else:
|
|
from contextlib import nullcontext
|
|
|
|
def torch_gpu_device(device):
|
|
return nullcontext()
|
|
|
|
|
|
# INTEL GPU Specific Logic
|
|
if DEVICE_TYPE == "xpu":
|
|
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
|
|
# NVIDIA GPU Default Logic
|
|
else:
|
|
_gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
|
|
|
|
c_void_p = ctypes.c_void_p
|
|
|
|
|
|
def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
|
|
return c_void_p(_gpu_getCurrentRawStream(tensor.device.index))
|
|
|
|
|
|
# Get array of CUDA streams and other buffers
|
|
global CUDA_STREAMS
|
|
global XPU_STREAMS
|
|
global WEIGHT_BUFFERS
|
|
global ABSMAX_BUFFERS
|
|
|
|
# INTEL GPU Specific Logic
|
|
if DEVICE_TYPE == "xpu":
|
|
_XPU_STREAMS = {
|
|
(index := torch.xpu.device(i).idx): ctypes.c_void_p(
|
|
torch._C._xpu_getCurrentRawStream(index)
|
|
)
|
|
for i in range(DEVICE_COUNT)
|
|
}
|
|
XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1)
|
|
WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
|
|
ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
|
|
for k, v in _XPU_STREAMS.items():
|
|
XPU_STREAMS[k] = v
|
|
XPU_STREAMS = tuple(XPU_STREAMS)
|
|
del _XPU_STREAMS
|
|
else:
|
|
# NVIDIA GPU Default Logic
|
|
_CUDA_STREAMS = {
|
|
(index := torch.cuda.device(i).idx): ctypes.c_void_p(
|
|
torch._C._cuda_getCurrentRawStream(index)
|
|
)
|
|
for i in range(DEVICE_COUNT)
|
|
}
|
|
CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
|
WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
|
ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
|
for k, v in _CUDA_STREAMS.items():
|
|
CUDA_STREAMS[k] = v
|
|
CUDA_STREAMS = tuple(CUDA_STREAMS)
|
|
del _CUDA_STREAMS
|
|
|
|
# Bitsandbytes operations
|
|
ctypes_c_int = ctypes.c_int
|
|
ctypes_c_int32 = ctypes.c_int32
|
|
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
|
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
|
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
|
|
|
if DEVICE_TYPE == "xpu":
|
|
# https://github.com/bitsandbytes-foundation/bitsandbytes/blob/c3b8de268fdb55a88f92feada23fc811a1e6877a/bitsandbytes/backends/xpu/ops.py#L115
|
|
# for xpu, inference gemv using above link
|
|
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemv_4bit_inference_fp16
|
|
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemv_4bit_inference_bf16
|
|
else:
|
|
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
|
|
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
|
|
|
|
|
|
torch_device_stream = (
|
|
torch.xpu.current_stream if DEVICE_TYPE == "xpu" else torch.cuda.current_stream
|
|
)
|
|
|
|
torch_mm = torch.mm
|
|
torch_mv = torch.mv
|
|
torch_matmul = torch.matmul
|
|
torch_addmm = torch.addmm
|
|
torch_empty = torch.empty
|
|
torch_float32 = torch.float32
|
|
torch_float16 = torch.float16
|
|
torch_bfloat16 = torch.bfloat16
|
|
|
|
|
|
# Check whether torchao can be imported to get Float8Tensor
|
|
if importlib.util.find_spec("torchao") is not None:
|
|
try:
|
|
from torchao.quantization import Float8Tensor
|
|
except:
|
|
import torchao
|
|
|
|
if Version(torchao.__version__) >= Version("0.15.0"):
|
|
print(
|
|
f"Unsloth: `from torchao.quantization import Float8Tensor` failed on version={torchao.__version__}"
|
|
)
|
|
Float8Tensor = type(None)
|
|
else:
|
|
Float8Tensor = type(None)
|
|
|
|
|
|
def QUANT_STATE(W):
|
|
return getattr(W, "quant_state", None)
|
|
|
|
|
|
def get_lora_parameters(proj):
|
|
"""
|
|
Return a 5-tuple of (weight, weight quant_state, lora A, lora B, and lora scale).
|
|
If QAT is enabled, additionally fake quantize the base layer and lora weights.
|
|
"""
|
|
# For DPO or disabled adapters
|
|
base_layer = getattr(
|
|
proj, "base_layer", proj
|
|
) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
|
W = base_layer.weight
|
|
|
|
# Optionally apply fake quantization to base layer weights for QAT
|
|
if hasattr(base_layer, "weight_fake_quantizer"):
|
|
weight_fake_quantizer = getattr(base_layer, "weight_fake_quantizer", None)
|
|
if weight_fake_quantizer is not None:
|
|
W = weight_fake_quantizer(W)
|
|
|
|
# Get quant state for 4bit or FP8
|
|
W_quant = getattr(W, "quant_state", None)
|
|
if W_quant is None:
|
|
W_quant = getattr(base_layer, "weight_scale_inv", None)
|
|
if W_quant is None:
|
|
W_quant = getattr(base_layer, "weight_scale", None)
|
|
|
|
if getattr(base_layer, "quant_method", None) == "fp8":
|
|
# we need to somehow store and pass this information :)
|
|
W.block_size = getattr(base_layer, "block_size", [128, 128])
|
|
W_quant.block_size = W.block_size
|
|
|
|
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
|
if getattr(proj, "disable_adapters", True) or proj.merged:
|
|
return W, W_quant, None, None, None
|
|
|
|
adapter = getattr(proj, "active_adapters", None)
|
|
if adapter is None:
|
|
adapter = getattr(proj, "active_adapter", ("default"))
|
|
adapter = adapter[0]
|
|
|
|
# Optionally apply fake quantization to lora weights for QAT
|
|
lora_A_linear = proj.lora_A[adapter]
|
|
lora_B_linear = proj.lora_B[adapter]
|
|
A = lora_A_linear.weight
|
|
B = lora_B_linear.weight
|
|
if hasattr(lora_A_linear, "weight_fake_quantizer"):
|
|
lora_A_fake_quantizer = getattr(lora_A_linear, "weight_fake_quantizer", None)
|
|
if lora_A_fake_quantizer is not None:
|
|
A = lora_A_fake_quantizer(A)
|
|
if hasattr(lora_B_linear, "weight_fake_quantizer"):
|
|
lora_B_fake_quantizer = getattr(lora_B_linear, "weight_fake_quantizer", None)
|
|
if lora_B_fake_quantizer is not None:
|
|
B = lora_B_fake_quantizer(B)
|
|
|
|
return (
|
|
W,
|
|
W_quant,
|
|
A,
|
|
B,
|
|
proj.scaling[adapter],
|
|
)
|
|
|
|
|
|
def get_lora_parameters_bias(proj):
|
|
# For DPO or disabled adapters
|
|
base_layer = getattr(
|
|
proj, "base_layer", proj
|
|
) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
|
W = base_layer.weight
|
|
|
|
# Get quant state for 4bit or FP8
|
|
W_quant = getattr(W, "quant_state", None)
|
|
if W_quant is None:
|
|
W_quant = getattr(base_layer, "weight_scale_inv", None)
|
|
if W_quant is None:
|
|
W_quant = getattr(base_layer, "weight_scale", None)
|
|
|
|
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
|
if getattr(proj, "disable_adapters", True) or proj.merged:
|
|
return W, W_quant, None, None, None, base_layer.bias
|
|
|
|
if getattr(base_layer, "quant_method", None) == "fp8":
|
|
# we need to somehow store and pass this information :)
|
|
W.block_size = getattr(base_layer, "block_size", [128, 128])
|
|
W_quant.block_size = W.block_size
|
|
|
|
adapter = getattr(proj, "active_adapters", None)
|
|
if adapter is None:
|
|
adapter = getattr(proj, "active_adapter", ("default"))
|
|
adapter = adapter[0]
|
|
|
|
return (
|
|
W,
|
|
W_quant,
|
|
proj.lora_A[adapter].weight,
|
|
proj.lora_B[adapter].weight,
|
|
proj.scaling[adapter],
|
|
base_layer.bias,
|
|
)
|
|
|
|
|
|
def _maybe_fake_quantize_activations(
|
|
X: torch.Tensor, proj: torch.nn.Module
|
|
) -> torch.Tensor:
|
|
"""
|
|
If QAT is enabled, fake quantize the input activations.
|
|
Otherwise, just return the input activations as is.
|
|
Weights are fake quantized separately in `get_lora_parameters`.
|
|
"""
|
|
base_layer = getattr(proj, "base_layer", proj)
|
|
activation_fake_quantizer = getattr(base_layer, "activation_fake_quantizer", None)
|
|
if activation_fake_quantizer is not None:
|
|
X = activation_fake_quantizer(X)
|
|
return X
|
|
|
|
|
|
# INTEL GPU Specific Logic
|
|
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
|
|
|
@torch.inference_mode
|
|
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
|
# TODO: After adding XPU BNB support, check this function
|
|
if isinstance(W, Float8Tensor):
|
|
return W.dequantize()
|
|
if quant_state is None:
|
|
return W
|
|
if W.dtype == torch.float8_e4m3fn:
|
|
return weight_dequant(W, quant_state)
|
|
if type(quant_state) is not list:
|
|
# New quant_state as a class
|
|
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
|
absmax = quant_state.absmax
|
|
shape = quant_state.shape
|
|
dtype = quant_state.dtype
|
|
blocksize = quant_state.blocksize
|
|
offset = quant_state.offset
|
|
state2 = quant_state.state2
|
|
absmax2 = state2.absmax
|
|
code2 = state2.code
|
|
blocksize2 = state2.blocksize
|
|
else:
|
|
# Old quant_state as a list of lists
|
|
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
|
offset, state2 = compressed_stats
|
|
absmax2, code2, blocksize2, _, _, _, _ = state2
|
|
global XPU_STREAMS
|
|
device = W.device
|
|
device_index = device.index
|
|
XPU_STREAM = XPU_STREAMS[device_index]
|
|
|
|
n_elements_absmax = absmax.numel()
|
|
# Create weight matrix
|
|
if use_global_buffer:
|
|
# Use same buffers for faster inference
|
|
size = shape[0] * shape[1]
|
|
global WEIGHT_BUFFERS
|
|
global ABSMAX_BUFFERS
|
|
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
|
|
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
|
|
if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:
|
|
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
|
|
size, dtype = dtype, device = device, requires_grad = False
|
|
)
|
|
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
|
|
n_elements_absmax,
|
|
dtype = torch.float32,
|
|
device = device,
|
|
requires_grad = False,
|
|
)
|
|
|
|
if size > WEIGHT_BUFFER.numel():
|
|
WEIGHT_BUFFER.resize_(size)
|
|
if n_elements_absmax > ABSMAX_BUFFER.numel():
|
|
ABSMAX_BUFFER.resize_(n_elements_absmax)
|
|
|
|
out = WEIGHT_BUFFER[:size].view(shape)
|
|
out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
|
|
else:
|
|
if out is None:
|
|
out = torch_empty(
|
|
shape, dtype = dtype, device = device, requires_grad = False
|
|
)
|
|
else:
|
|
assert out.shape == shape
|
|
assert out.dtype == dtype
|
|
out_absmax = torch_empty(
|
|
n_elements_absmax,
|
|
dtype = torch_float32,
|
|
device = device,
|
|
requires_grad = False,
|
|
)
|
|
|
|
# NF4 dequantization of statistics
|
|
ptr_out_absmax = get_ptr(out_absmax)
|
|
with torch_gpu_device(device):
|
|
cdequantize_blockwise_fp32(
|
|
get_ptr(code2),
|
|
get_ptr(absmax),
|
|
get_ptr(absmax2),
|
|
ptr_out_absmax,
|
|
ctypes_c_int(blocksize2),
|
|
ctypes_c_int(n_elements_absmax),
|
|
XPU_STREAM,
|
|
)
|
|
out_absmax += offset
|
|
|
|
# Dequantize W
|
|
fx = (
|
|
cdequantize_blockwise_fp16_nf4
|
|
if dtype == torch_float16
|
|
else cdequantize_blockwise_bf16_nf4
|
|
)
|
|
fx(
|
|
get_ptr(None),
|
|
get_ptr(W),
|
|
ptr_out_absmax,
|
|
get_ptr(out),
|
|
ctypes_c_int(blocksize),
|
|
ctypes_c_int(out.numel()),
|
|
XPU_STREAM,
|
|
)
|
|
# Careful returning transposed data
|
|
is_transposed = True if W.shape[0] == 1 else False
|
|
return out.t() if is_transposed else out
|
|
|
|
# NVIDIA GPU Default Logic
|
|
elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
|
|
|
@torch.inference_mode
|
|
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
|
if isinstance(W, Float8Tensor):
|
|
return W.dequantize()
|
|
if quant_state is None:
|
|
return W
|
|
if W.dtype == torch.float8_e4m3fn:
|
|
return weight_dequant(W, quant_state)
|
|
if type(quant_state) is not list:
|
|
# New quant_state as a class
|
|
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
|
absmax = quant_state.absmax
|
|
shape = quant_state.shape
|
|
dtype = quant_state.dtype
|
|
blocksize = quant_state.blocksize
|
|
offset = quant_state.offset
|
|
state2 = quant_state.state2
|
|
absmax2 = state2.absmax
|
|
code2 = state2.code
|
|
blocksize2 = state2.blocksize
|
|
else:
|
|
# Old quant_state as a list of lists
|
|
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
|
offset, state2 = compressed_stats
|
|
absmax2, code2, blocksize2, _, _, _, _ = state2
|
|
pass
|
|
global CUDA_STREAMS
|
|
device = W.device
|
|
device_index = device.index
|
|
CUDA_STREAM = CUDA_STREAMS[device_index]
|
|
|
|
n_elements_absmax = absmax.numel()
|
|
|
|
# Create weight matrix
|
|
if use_global_buffer:
|
|
# Use same buffers for faster inference
|
|
size = shape[0] * shape[1]
|
|
global WEIGHT_BUFFERS
|
|
global ABSMAX_BUFFERS
|
|
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
|
|
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
|
|
if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:
|
|
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
|
|
size, dtype = dtype, device = device, requires_grad = False
|
|
)
|
|
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
|
|
n_elements_absmax,
|
|
dtype = torch_float32,
|
|
device = device,
|
|
requires_grad = False,
|
|
)
|
|
|
|
if size > WEIGHT_BUFFER.numel():
|
|
WEIGHT_BUFFER.resize_(size)
|
|
if n_elements_absmax > ABSMAX_BUFFER.numel():
|
|
ABSMAX_BUFFER.resize_(n_elements_absmax)
|
|
|
|
out = WEIGHT_BUFFER[:size].view(shape)
|
|
out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
|
|
else:
|
|
if out is None:
|
|
out = torch_empty(
|
|
shape, dtype = dtype, device = device, requires_grad = False
|
|
)
|
|
else:
|
|
assert out.shape == shape
|
|
assert out.dtype == dtype
|
|
out_absmax = torch_empty(
|
|
n_elements_absmax,
|
|
dtype = torch_float32,
|
|
device = device,
|
|
requires_grad = False,
|
|
)
|
|
pass
|
|
|
|
# NF4 dequantization of statistics
|
|
ptr_out_absmax = get_ptr(out_absmax)
|
|
with torch_gpu_device(device):
|
|
cdequantize_blockwise_fp32(
|
|
get_ptr(code2),
|
|
get_ptr(absmax),
|
|
get_ptr(absmax2),
|
|
ptr_out_absmax,
|
|
ctypes_c_int(blocksize2),
|
|
ctypes_c_int(n_elements_absmax),
|
|
CUDA_STREAM,
|
|
)
|
|
out_absmax += offset
|
|
|
|
# Dequantize W
|
|
fx = (
|
|
cdequantize_blockwise_fp16_nf4
|
|
if dtype == torch_float16
|
|
else cdequantize_blockwise_bf16_nf4
|
|
)
|
|
fx(
|
|
get_ptr(None),
|
|
get_ptr(W),
|
|
ptr_out_absmax,
|
|
get_ptr(out),
|
|
ctypes_c_int(blocksize),
|
|
ctypes_c_int(out.numel()),
|
|
CUDA_STREAM,
|
|
)
|
|
pass
|
|
# Careful returning transposed data
|
|
is_transposed = True if W.shape[0] == 1 else False
|
|
return out.t() if is_transposed else out
|
|
|
|
pass
|
|
else:
|
|
|
|
@torch.inference_mode
|
|
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
|
if isinstance(W, Float8Tensor):
|
|
return W.dequantize()
|
|
if quant_state is None:
|
|
return W
|
|
if W.dtype == torch.float8_e4m3fn:
|
|
return weight_dequant(W, quant_state)
|
|
if type(quant_state) is not list:
|
|
# New quant_state as a class
|
|
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
|
absmax = quant_state.absmax
|
|
shape = quant_state.shape
|
|
dtype = quant_state.dtype
|
|
blocksize = quant_state.blocksize
|
|
offset = quant_state.offset
|
|
state2 = quant_state.state2
|
|
absmax2 = state2.absmax
|
|
code2 = state2.code
|
|
blocksize2 = state2.blocksize
|
|
else:
|
|
# Old quant_state as a list of lists
|
|
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
|
offset, state2 = compressed_stats
|
|
absmax2, code2, blocksize2, _, _, _, _ = state2
|
|
pass
|
|
|
|
n_elements_absmax = absmax.numel()
|
|
device = W.device
|
|
|
|
# Create weight matrix
|
|
if out is None:
|
|
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
|
|
else:
|
|
assert out.shape == shape
|
|
assert out.dtype == dtype
|
|
out_absmax = torch_empty(
|
|
n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False
|
|
)
|
|
|
|
# Do dequantization
|
|
ptr_out_absmax = get_ptr(out_absmax)
|
|
cdequantize_blockwise_fp32(
|
|
get_ptr(code2),
|
|
get_ptr(absmax),
|
|
get_ptr(absmax2),
|
|
ptr_out_absmax,
|
|
ctypes_c_int(blocksize2),
|
|
ctypes_c_int(n_elements_absmax),
|
|
)
|
|
out_absmax += offset
|
|
|
|
fx = (
|
|
cdequantize_blockwise_fp16_nf4
|
|
if dtype == torch_float16
|
|
else cdequantize_blockwise_bf16_nf4
|
|
)
|
|
fx(
|
|
get_ptr(None),
|
|
get_ptr(W),
|
|
ptr_out_absmax,
|
|
get_ptr(out),
|
|
ctypes_c_int(blocksize),
|
|
ctypes_c_int(out.numel()),
|
|
)
|
|
|
|
# Careful returning transposed data
|
|
is_transposed = True if W.shape[0] == 1 else False
|
|
return out.t() if is_transposed else out
|
|
|
|
pass
|
|
|
|
|
|
# INTEL GPU Specific Logic
|
|
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
|
|
|
def fast_gemv(X, W, quant_state, out = None):
|
|
if quant_state is None:
|
|
return torch_matmul(X, W, out = out)
|
|
# For fast X @ W where seq_len == 1
|
|
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
|
_, q_len, hd = X.shape
|
|
# assert(q_len == 1)
|
|
|
|
if type(quant_state) is not list:
|
|
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
|
absmax = quant_state.absmax
|
|
shape = quant_state.shape
|
|
dtype = quant_state.dtype
|
|
blocksize = quant_state.blocksize
|
|
stats = quant_state.code
|
|
offset = quant_state.offset
|
|
state2 = quant_state.state2
|
|
absmax2 = state2.absmax
|
|
code2 = state2.code
|
|
blocksize2 = state2.blocksize
|
|
else:
|
|
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
|
|
quant_state
|
|
)
|
|
offset, state2 = compressed_stats
|
|
absmax2, code2, blocksize2, _, _, _, _ = state2
|
|
global XPU_STREAMS
|
|
device = W.device
|
|
device_index = device.index
|
|
XPU_STREAM = XPU_STREAMS[device_index]
|
|
|
|
# assert(dtype == X.dtype)
|
|
bout = shape[0]
|
|
|
|
if out is None:
|
|
out = torch_empty(
|
|
(
|
|
1,
|
|
1,
|
|
bout,
|
|
),
|
|
dtype = dtype,
|
|
device = device,
|
|
)
|
|
# else:
|
|
# assert(out.shape == (1, 1, bout,))
|
|
# pass
|
|
|
|
if DEVICE_TYPE == "xpu":
|
|
m = 1
|
|
n = shape[0]
|
|
else:
|
|
n = 1
|
|
m = shape[0]
|
|
k = shape[1]
|
|
lda = shape[0]
|
|
ldc = shape[0]
|
|
ldb = (hd + 1) // 2
|
|
m = ctypes_c_int32(m)
|
|
n = ctypes_c_int32(n)
|
|
k = ctypes_c_int32(k)
|
|
lda = ctypes_c_int32(lda)
|
|
ldb = ctypes_c_int32(ldb)
|
|
ldc = ctypes_c_int32(ldc)
|
|
|
|
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
|
|
with torch_gpu_device(device):
|
|
cdequantize_blockwise_fp32(
|
|
get_ptr(code2),
|
|
get_ptr(absmax),
|
|
get_ptr(absmax2),
|
|
get_ptr(df),
|
|
ctypes_c_int(blocksize2),
|
|
ctypes_c_int(df.numel()),
|
|
XPU_STREAM,
|
|
)
|
|
df += offset
|
|
absmax = df
|
|
|
|
fx = (
|
|
cgemm_4bit_inference_naive_fp16
|
|
if dtype == torch_float16
|
|
else cgemm_4bit_inference_naive_bf16
|
|
)
|
|
|
|
blocksize = ctypes_c_int32(blocksize)
|
|
fx(
|
|
m,
|
|
n,
|
|
k,
|
|
get_ptr(X),
|
|
get_ptr(W),
|
|
get_ptr(absmax),
|
|
get_ptr(stats),
|
|
get_ptr(out),
|
|
lda,
|
|
ldb,
|
|
ldc,
|
|
blocksize,
|
|
XPU_STREAM,
|
|
)
|
|
|
|
return out
|
|
|
|
elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
|
|
|
def fast_gemv(X, W, quant_state, out = None):
|
|
if quant_state is None:
|
|
return torch_matmul(X, W, out = out)
|
|
# For fast X @ W where seq_len == 1
|
|
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
|
_, q_len, hd = X.shape
|
|
# assert(q_len == 1)
|
|
|
|
if type(quant_state) is not list:
|
|
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
|
absmax = quant_state.absmax
|
|
shape = quant_state.shape
|
|
dtype = quant_state.dtype
|
|
blocksize = quant_state.blocksize
|
|
stats = quant_state.code
|
|
offset = quant_state.offset
|
|
state2 = quant_state.state2
|
|
absmax2 = state2.absmax
|
|
code2 = state2.code
|
|
blocksize2 = state2.blocksize
|
|
else:
|
|
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
|
|
quant_state
|
|
)
|
|
offset, state2 = compressed_stats
|
|
absmax2, code2, blocksize2, _, _, _, _ = state2
|
|
pass
|
|
global CUDA_STREAMS
|
|
device = W.device
|
|
device_index = device.index
|
|
CUDA_STREAM = CUDA_STREAMS[device_index]
|
|
|
|
# assert(dtype == X.dtype)
|
|
bout = shape[0]
|
|
|
|
if out is None:
|
|
out = torch_empty(
|
|
(
|
|
1,
|
|
1,
|
|
bout,
|
|
),
|
|
dtype = dtype,
|
|
device = device,
|
|
)
|
|
# else:
|
|
# assert(out.shape == (1, 1, bout,))
|
|
# pass
|
|
|
|
n = 1
|
|
m = shape[0]
|
|
k = shape[1]
|
|
lda = shape[0]
|
|
ldc = shape[0]
|
|
ldb = (hd + 1) // 2
|
|
m = ctypes_c_int32(m)
|
|
n = ctypes_c_int32(n)
|
|
k = ctypes_c_int32(k)
|
|
lda = ctypes_c_int32(lda)
|
|
ldb = ctypes_c_int32(ldb)
|
|
ldc = ctypes_c_int32(ldc)
|
|
|
|
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
|
|
with torch_gpu_device(device):
|
|
cdequantize_blockwise_fp32(
|
|
get_ptr(code2),
|
|
get_ptr(absmax),
|
|
get_ptr(absmax2),
|
|
get_ptr(df),
|
|
ctypes_c_int(blocksize2),
|
|
ctypes_c_int(df.numel()),
|
|
CUDA_STREAM,
|
|
)
|
|
df += offset
|
|
absmax = df
|
|
|
|
fx = (
|
|
cgemm_4bit_inference_naive_fp16
|
|
if dtype == torch_float16
|
|
else cgemm_4bit_inference_naive_bf16
|
|
)
|
|
|
|
blocksize = ctypes_c_int32(blocksize)
|
|
fx(
|
|
m,
|
|
n,
|
|
k,
|
|
get_ptr(X),
|
|
get_ptr(W),
|
|
get_ptr(absmax),
|
|
get_ptr(stats),
|
|
get_ptr(out),
|
|
lda,
|
|
ldb,
|
|
ldc,
|
|
blocksize,
|
|
CUDA_STREAM,
|
|
)
|
|
pass
|
|
|
|
return out
|
|
|
|
pass
|
|
else:
|
|
|
|
def fast_gemv(X, W, quant_state, out = None):
|
|
if quant_state is None:
|
|
return torch_matmul(X, W, out = out)
|
|
# For fast X @ W where seq_len == 1
|
|
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
|
_, q_len, hd = X.shape
|
|
# assert(q_len == 1)
|
|
|
|
if type(quant_state) is not list:
|
|
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
|
absmax = quant_state.absmax
|
|
shape = quant_state.shape
|
|
dtype = quant_state.dtype
|
|
blocksize = quant_state.blocksize
|
|
stats = quant_state.code
|
|
offset = quant_state.offset
|
|
state2 = quant_state.state2
|
|
absmax2 = state2.absmax
|
|
code2 = state2.code
|
|
blocksize2 = state2.blocksize
|
|
else:
|
|
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
|
|
quant_state
|
|
)
|
|
offset, state2 = compressed_stats
|
|
absmax2, code2, blocksize2, _, _, _, _ = state2
|
|
pass
|
|
# assert(dtype == X.dtype)
|
|
bout = shape[0]
|
|
device = W.device
|
|
|
|
if out is None:
|
|
out = torch_empty(
|
|
(
|
|
1,
|
|
1,
|
|
bout,
|
|
),
|
|
dtype = dtype,
|
|
device = device,
|
|
)
|
|
# else:
|
|
# assert(out.shape == (1, 1, bout,))
|
|
# pass
|
|
|
|
n = 1
|
|
m = shape[0]
|
|
k = shape[1]
|
|
lda = shape[0]
|
|
ldc = shape[0]
|
|
ldb = (hd + 1) // 2
|
|
m = ctypes_c_int32(m)
|
|
n = ctypes_c_int32(n)
|
|
k = ctypes_c_int32(k)
|
|
lda = ctypes_c_int32(lda)
|
|
ldb = ctypes_c_int32(ldb)
|
|
ldc = ctypes_c_int32(ldc)
|
|
|
|
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
|
|
cdequantize_blockwise_fp32(
|
|
get_ptr(code2),
|
|
get_ptr(absmax),
|
|
get_ptr(absmax2),
|
|
get_ptr(df),
|
|
ctypes_c_int(blocksize2),
|
|
ctypes_c_int(df.numel()),
|
|
)
|
|
df += offset
|
|
absmax = df
|
|
|
|
fx = (
|
|
cgemm_4bit_inference_naive_fp16
|
|
if dtype == torch_float16
|
|
else cgemm_4bit_inference_naive_bf16
|
|
)
|
|
|
|
blocksize = ctypes_c_int32(blocksize)
|
|
fx(
|
|
m,
|
|
n,
|
|
k,
|
|
get_ptr(X),
|
|
get_ptr(W),
|
|
get_ptr(absmax),
|
|
get_ptr(stats),
|
|
get_ptr(out),
|
|
lda,
|
|
ldb,
|
|
ldc,
|
|
blocksize,
|
|
)
|
|
|
|
return out
|
|
|
|
pass
|
|
|
|
|
|
def fast_linear_forward(proj, X, temp_lora = None, out = None):
|
|
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
|
|
bsz, q_len, in_dim = X.shape
|
|
if q_len != 1:
|
|
return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
|
|
|
|
if W_quant is None:
|
|
out = torch_matmul(X, W.t(), out = out)
|
|
elif W.dtype == torch.float8_e4m3fn:
|
|
out = fp8_linear(X, W, W_quant, bias)
|
|
elif bsz == 1 and q_len == 1:
|
|
out = fast_gemv(X, W, W_quant, out = out)
|
|
else:
|
|
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
|
|
out = torch_matmul(X, W, out = out)
|
|
|
|
# Add in LoRA weights
|
|
if lora_A is not None:
|
|
out_dim = out.shape[2]
|
|
dtype = X.dtype
|
|
|
|
if not hasattr(lora_A, "_fast_lora"):
|
|
lora_A._fast_lora = lora_A.to(dtype)
|
|
lora_B._fast_lora = lora_B.to(dtype)
|
|
|
|
if bsz == 1:
|
|
out = out.view(out_dim)
|
|
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
|
|
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
|
|
else:
|
|
out = out.view(bsz, out_dim)
|
|
temp_lora = torch_mm(
|
|
X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora
|
|
)
|
|
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
|
|
out = out.view(bsz, 1, out_dim)
|
|
|
|
if bias is not None:
|
|
out += bias
|
|
|
|
return out
|
|
|
|
|
|
def matmul_lora(X, W, W_quant, A, B, s, out = None):
|
|
dtype = X.dtype
|
|
|
|
if X.dim() == 3:
|
|
batch, seq_len, d = X.shape
|
|
X = X.view(-1, X.shape[-1])
|
|
reshape = True
|
|
else:
|
|
reshape = False
|
|
|
|
if isinstance(W, Float8Tensor):
|
|
assert W.ndim == 2
|
|
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
|
|
# In the backward pass, rowwise scaled becomes colwise scaled after we
|
|
# transpose the weight tensor. Use this case to detect backward.
|
|
# TODO: would be simpler if we simply don't call `matmul_lora` in backward
|
|
W = W.dequantize()
|
|
else:
|
|
W = W.contiguous()
|
|
out = torch_matmul(X, W.t(), out = out)
|
|
elif W.dtype == torch.float8_e4m3fn:
|
|
out = fp8_linear(X, W, W_quant)
|
|
else:
|
|
W = fast_dequantize(W, W_quant, use_global_buffer = True)
|
|
out = torch_matmul(X, W.t(), out = out)
|
|
if W_quant is not None:
|
|
del W
|
|
|
|
if A is not None:
|
|
# LoRA is enabled
|
|
A, B = A.t(), B.t()
|
|
XA = torch_matmul(X, A.to(dtype))
|
|
out.addmm_(XA, B.to(dtype), alpha = s)
|
|
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
|
|
|
return out.view(batch, seq_len, -1) if reshape else out
|