mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
* [Studio] Install flash attn at setup time for linux * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup changes Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Test cases * wheel_utils: narrow url_exists exceptions and log at debug level --------- Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com> Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
282 lines
11 KiB
Python
282 lines
11 KiB
Python
"""Tests for the optional FlashAttention installer."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
STUDIO_DIR = Path(__file__).resolve().parents[2] / "studio"
|
|
sys.path.insert(0, str(STUDIO_DIR))
|
|
|
|
import install_python_stack as ips
|
|
|
|
|
|
class TestFlashAttnWheelSelection:
|
|
def test_torch_210_maps_to_v281(self):
|
|
assert ips._select_flash_attn_version("2.10") == "2.8.1"
|
|
|
|
def test_torch_29_maps_to_v283(self):
|
|
assert ips._select_flash_attn_version("2.9") == "2.8.3"
|
|
|
|
def test_unsupported_torch_has_no_wheel_mapping(self):
|
|
assert ips._select_flash_attn_version("2.11") is None
|
|
|
|
def test_exact_wheel_url_uses_full_env_tuple(self):
|
|
url = ips._build_flash_attn_wheel_url(
|
|
{
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "12",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
}
|
|
)
|
|
assert url is not None
|
|
assert "v2.8.1" in url
|
|
assert (
|
|
"flash_attn-2.8.1+cu12torch2.10cxx11abiTRUE-cp313-cp313-linux_x86_64.whl"
|
|
in url
|
|
)
|
|
|
|
def test_missing_cuda_major_disables_wheel_lookup(self):
|
|
assert (
|
|
ips._build_flash_attn_wheel_url(
|
|
{
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
}
|
|
)
|
|
is None
|
|
)
|
|
|
|
|
|
class TestEnsureFlashAttn:
|
|
def _import_check(self, code: int = 1):
|
|
return subprocess.CompletedProcess(["python", "-c", "import flash_attn"], code)
|
|
|
|
def test_prefers_exact_match_wheel(self):
|
|
install_calls = []
|
|
|
|
def fake_install_wheel(*args, **kwargs):
|
|
install_calls.append((args, kwargs))
|
|
return [("uv", subprocess.CompletedProcess(["uv"], 0, ""))]
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "USE_UV", True),
|
|
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
|
|
mock.patch.object(
|
|
ips,
|
|
"probe_torch_wheel_env",
|
|
return_value = {
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "12",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
},
|
|
),
|
|
mock.patch.object(ips, "url_exists", return_value = True),
|
|
mock.patch.object(ips, "install_wheel", side_effect = fake_install_wheel),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
assert len(install_calls) == 1
|
|
args, kwargs = install_calls[0]
|
|
assert args == (
|
|
"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.10cxx11abiTRUE-cp313-cp313-linux_x86_64.whl",
|
|
)
|
|
assert kwargs["python_executable"] == sys.executable
|
|
assert kwargs["use_uv"] is True
|
|
assert kwargs["uv_needs_system"] is False
|
|
|
|
def test_uv_install_respects_system_flag(self):
|
|
install_calls = []
|
|
|
|
def fake_install_wheel(*args, **kwargs):
|
|
install_calls.append((args, kwargs))
|
|
return [("uv", subprocess.CompletedProcess(["uv"], 0, ""))]
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "USE_UV", True),
|
|
mock.patch.object(ips, "UV_NEEDS_SYSTEM", True),
|
|
mock.patch.object(
|
|
ips,
|
|
"probe_torch_wheel_env",
|
|
return_value = {
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "12",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
},
|
|
),
|
|
mock.patch.object(ips, "url_exists", return_value = True),
|
|
mock.patch.object(ips, "install_wheel", side_effect = fake_install_wheel),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
assert len(install_calls) == 1
|
|
_, kwargs = install_calls[0]
|
|
assert kwargs["uv_needs_system"] is True
|
|
|
|
def test_wheel_failure_warns_and_continues(self):
|
|
step_messages: list[tuple[str, str]] = []
|
|
printed_failures: list[str] = []
|
|
|
|
def fake_step(label: str, value: str, color_fn = None):
|
|
step_messages.append((label, value))
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "USE_UV", True),
|
|
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
|
|
mock.patch.object(
|
|
ips,
|
|
"probe_torch_wheel_env",
|
|
return_value = {
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "12",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
},
|
|
),
|
|
mock.patch.object(ips, "url_exists", return_value = True),
|
|
mock.patch.object(
|
|
ips,
|
|
"install_wheel",
|
|
return_value = [
|
|
("uv", subprocess.CompletedProcess(["uv"], 1, "uv wheel failed")),
|
|
(
|
|
"pip",
|
|
subprocess.CompletedProcess(["pip"], 1, "pip wheel failed"),
|
|
),
|
|
],
|
|
),
|
|
mock.patch.object(
|
|
ips,
|
|
"_print_optional_install_failure",
|
|
side_effect = lambda label, result: printed_failures.append(label),
|
|
),
|
|
mock.patch.object(ips, "_step", side_effect = fake_step),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
assert printed_failures == [
|
|
"Installing flash-attn prebuilt wheel with uv",
|
|
"Installing flash-attn prebuilt wheel with pip",
|
|
]
|
|
assert ("warning", "Continuing without flash-attn") in step_messages
|
|
|
|
def test_wheel_missing_skips_install_at_setup_time(self):
|
|
step_messages: list[tuple[str, str]] = []
|
|
|
|
def fake_step(label: str, value: str, color_fn = None):
|
|
step_messages.append((label, value))
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(
|
|
ips,
|
|
"probe_torch_wheel_env",
|
|
return_value = {
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "13",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
},
|
|
),
|
|
mock.patch.object(ips, "url_exists", return_value = False),
|
|
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
|
|
mock.patch.object(ips, "_step", side_effect = fake_step),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
mock_install_wheel.assert_not_called()
|
|
assert (
|
|
"warning",
|
|
"No published flash-attn prebuilt wheel found",
|
|
) in step_messages
|
|
|
|
def test_skip_env_disables_setup_install(self):
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.dict(os.environ, {"UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL": "1"}),
|
|
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
|
|
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
mock_probe.assert_not_called()
|
|
mock_install_wheel.assert_not_called()
|
|
|
|
|
|
class TestInstallPythonStackFlashAttnIntegration:
|
|
def _run_install(self, *, no_torch: bool, is_macos: bool, is_windows: bool) -> int:
|
|
flash_attn_calls = 0
|
|
|
|
def fake_run(cmd, **kw):
|
|
return subprocess.CompletedProcess(cmd, 0, b"", b"")
|
|
|
|
def count_flash_attn():
|
|
nonlocal flash_attn_calls
|
|
flash_attn_calls += 1
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", no_torch),
|
|
mock.patch.object(ips, "IS_MACOS", is_macos),
|
|
mock.patch.object(ips, "IS_WINDOWS", is_windows),
|
|
mock.patch.object(ips, "USE_UV", True),
|
|
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
|
|
mock.patch.object(ips, "VERBOSE", False),
|
|
mock.patch.object(ips, "_bootstrap_uv", return_value = True),
|
|
mock.patch.object(ips, "_ensure_flash_attn", side_effect = count_flash_attn),
|
|
mock.patch("subprocess.run", side_effect = fake_run),
|
|
mock.patch.object(ips, "_has_usable_nvidia_gpu", return_value = False),
|
|
mock.patch.object(ips, "_has_rocm_gpu", return_value = False),
|
|
mock.patch.object(
|
|
ips, "LOCAL_DD_UNSTRUCTURED_PLUGIN", Path("/fake/plugin")
|
|
),
|
|
mock.patch("pathlib.Path.is_dir", return_value = True),
|
|
mock.patch("pathlib.Path.is_file", return_value = True),
|
|
mock.patch.dict(os.environ, {"SKIP_STUDIO_BASE": "1"}, clear = False),
|
|
):
|
|
ips.install_python_stack()
|
|
|
|
return flash_attn_calls
|
|
|
|
def test_linux_torch_install_calls_flash_attn_step(self):
|
|
assert self._run_install(no_torch = False, is_macos = False, is_windows = False) == 1
|
|
|
|
def test_no_torch_install_skips_flash_attn_step(self):
|
|
assert self._run_install(no_torch = True, is_macos = False, is_windows = False) == 0
|
|
|
|
def test_macos_install_skips_flash_attn_step(self):
|
|
assert self._run_install(no_torch = False, is_macos = True, is_windows = False) == 0
|
|
|
|
def test_windows_install_skips_flash_attn_step(self):
|
|
assert self._run_install(no_torch = False, is_macos = False, is_windows = True) == 0
|