mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
tests: add no-torch / Intel Mac test suite (#4646)
* tests: add no-torch / Intel Mac test suite Add comprehensive test coverage for the no-torch / --no-torch installer and Studio backend changes introduced in #4624. Shell tests (tests/sh/test_mac_intel_compat.sh): - version_ge edge cases (9 tests) - Architecture detection + Python version resolution (4 tests) - get_torch_index_url on Darwin (2 tests) - UNSLOTH_NO_TORCH propagation via SKIP_TORCH (5 tests) - E2E uv venv creation at Python 3.12 (3 tests) - E2E torch skip with mock uv shim (4 tests) - UNSLOTH_NO_TORCH env propagation (4 tests) - --python override flag parsing + resolution (11 tests) - --no-torch flag parsing (4 tests) - SKIP_TORCH unification (3 tests) - CPU hint printing (2 tests) Python tests (tests/python/test_no_torch_filtering.py): - _filter_requirements unit tests with synthetic + real requirements files - NO_TORCH / IS_MACOS constant parsing - Subprocess mock of install_python_stack() across platform configs - install.sh --no-torch flag structural + subprocess tests Python tests (tests/python/test_studio_import_no_torch.py): - AST checks for data_collators.py, chat_templates.py, format_conversion.py - Parametrized venv tests (Python 3.12 + 3.13) for no-torch exec - Dataclass instantiation without torch - format_conversion convert functions without torch - Negative controls (import torch fails, torchao fails) Python tests (tests/python/test_e2e_no_torch_sandbox.py): - Before/after import chain tests - Edge cases (broken torch, fake torch, lazy import) - Hardware detection without torch - install.sh logic tests (flag parsing, version resolution) - install_python_stack filtering tests - Live server startup tests (opt-in via @server marker) * fix: address review comments on test suite - Fix always-true assertion in test_studio_import_no_torch.py (or True) - Make IS_MACOS test platform-aware instead of hardcoding Linux - Restore torchvision + torchaudio in server test cleanup (not just torch) - Include server stderr in skip message for easier debugging * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e9ac785346
commit
2ffc8d2cea
6 changed files with 3166 additions and 0 deletions
7
tests/python/conftest.py
Normal file
7
tests/python/conftest.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Shared pytest configuration for tests/python/."""
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "server: heavyweight tests requiring studio venv"
|
||||
)
|
||||
1239
tests/python/test_e2e_no_torch_sandbox.py
Normal file
1239
tests/python/test_e2e_no_torch_sandbox.py
Normal file
File diff suppressed because it is too large
Load diff
753
tests/python/test_no_torch_filtering.py
Normal file
753
tests/python/test_no_torch_filtering.py
Normal file
|
|
@ -0,0 +1,753 @@
|
|||
"""Tests for install_python_stack NO_TORCH / IS_MACOS filtering logic.
|
||||
|
||||
Covers:
|
||||
- _filter_requirements unit tests (synthetic + REAL requirements files)
|
||||
- NO_TORCH / IS_MACOS / IS_WINDOWS env var parsing
|
||||
- Subprocess-mock of install_python_stack() to verify overrides/triton/filtering
|
||||
actually happen (or get skipped) under each platform/config combination
|
||||
- VCS URL and environment marker edge cases in filtering
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the studio directory so we can import install_python_stack
|
||||
STUDIO_DIR = Path(__file__).resolve().parents[2] / "studio"
|
||||
sys.path.insert(0, str(STUDIO_DIR))
|
||||
|
||||
import install_python_stack as ips
|
||||
|
||||
# Paths to the REAL requirements files
|
||||
REQ_ROOT = Path(__file__).resolve().parents[2] / "studio" / "backend" / "requirements"
|
||||
EXTRAS_TXT = REQ_ROOT / "extras.txt"
|
||||
EXTRAS_NO_DEPS_TXT = REQ_ROOT / "extras-no-deps.txt"
|
||||
OVERRIDES_TXT = REQ_ROOT / "overrides.txt"
|
||||
TRITON_KERNELS_TXT = REQ_ROOT / "triton-kernels.txt"
|
||||
|
||||
|
||||
# ── _filter_requirements unit tests (synthetic) ───────────────────────
|
||||
|
||||
|
||||
class TestFilterRequirements:
|
||||
"""Verify _filter_requirements correctly removes packages by prefix."""
|
||||
|
||||
def _write_req(self, tmp_path: Path, content: str) -> Path:
|
||||
req = tmp_path / "requirements.txt"
|
||||
req.write_text(textwrap.dedent(content), encoding = "utf-8")
|
||||
return req
|
||||
|
||||
def test_filters_no_torch_packages(self, tmp_path):
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
torch-stoi==0.1
|
||||
timm>=1.0
|
||||
numpy
|
||||
torchcodec>=0.1
|
||||
torch-c-dlpack-ext
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
# Only numpy should remain (non-blank lines)
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
assert non_blank == ["numpy"], f"Expected only numpy, got: {non_blank}"
|
||||
|
||||
def test_empty_file(self, tmp_path):
|
||||
req = self._write_req(tmp_path, "")
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
content = Path(result).read_text(encoding = "utf-8")
|
||||
assert content.strip() == ""
|
||||
|
||||
def test_comments_preserved(self, tmp_path):
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
# torch-stoi is needed for audio
|
||||
numpy
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
# Comment starts with "#", not "torch-stoi", so it's preserved
|
||||
assert len(non_blank) == 2
|
||||
assert non_blank[0].startswith("#")
|
||||
assert non_blank[1] == "numpy"
|
||||
|
||||
def test_version_specifiers_filtered(self, tmp_path):
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
torch-stoi>=0.1.0
|
||||
timm==1.2.3
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
assert non_blank == [], f"Expected empty, got: {non_blank}"
|
||||
|
||||
def test_prefix_match_catches_extensions(self, tmp_path):
|
||||
"""Prefix matching catches torch-stoi-extra (correct for pip names)."""
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
torch-stoi-extra
|
||||
numpy
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
assert non_blank == ["numpy"]
|
||||
|
||||
def test_mixed_case_filtered(self, tmp_path):
|
||||
"""Package names are lowercased before matching."""
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
Timm>=1.0
|
||||
TORCH-STOI
|
||||
numpy
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
assert non_blank == ["numpy"]
|
||||
|
||||
def test_whitespace_and_blank_lines_preserved(self, tmp_path):
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
numpy
|
||||
|
||||
pandas
|
||||
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
content = Path(result).read_text(encoding = "utf-8")
|
||||
# Blank lines should be preserved (not stripped)
|
||||
assert "\n\n" in content or content.count("\n") >= 3
|
||||
|
||||
def test_stacked_windows_and_no_torch_filters(self, tmp_path):
|
||||
"""Both WINDOWS_SKIP_PACKAGES and NO_TORCH_SKIP_PACKAGES applied."""
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
open_spiel
|
||||
triton_kernels
|
||||
torch-stoi
|
||||
timm
|
||||
numpy
|
||||
""",
|
||||
)
|
||||
# First filter Windows packages, then NO_TORCH packages
|
||||
intermediate = ips._filter_requirements(req, ips.WINDOWS_SKIP_PACKAGES)
|
||||
result = ips._filter_requirements(
|
||||
Path(intermediate), ips.NO_TORCH_SKIP_PACKAGES
|
||||
)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
assert non_blank == [
|
||||
"numpy"
|
||||
], f"Expected only numpy after stacked filters, got: {non_blank}"
|
||||
|
||||
def test_vcs_url_with_skip_package_name(self, tmp_path):
|
||||
"""VCS URLs like git+https://...torch-stoi should also be filtered (startswith matches)."""
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
numpy
|
||||
torch-stoi @ git+https://github.com/example/torch-stoi.git
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
assert non_blank == [
|
||||
"numpy"
|
||||
], f"VCS URL line should be filtered, got: {non_blank}"
|
||||
|
||||
def test_env_marker_line_filtered(self, tmp_path):
|
||||
"""Package lines with env markers are still filtered by prefix."""
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
timm>=1.0; python_version>="3.10"
|
||||
numpy
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
assert non_blank == [
|
||||
"numpy"
|
||||
], f"Env marker line should be filtered, got: {non_blank}"
|
||||
|
||||
def test_git_plus_url_not_over_matched(self, tmp_path):
|
||||
"""A git+ URL whose path contains a skip package name but does NOT start with it."""
|
||||
req = self._write_req(
|
||||
tmp_path,
|
||||
"""\
|
||||
git+https://github.com/meta-pytorch/OpenEnv.git
|
||||
numpy
|
||||
""",
|
||||
)
|
||||
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
lines = Path(result).read_text(encoding = "utf-8").splitlines()
|
||||
non_blank = [l.strip() for l in lines if l.strip()]
|
||||
# The git+ URL doesn't start with any skip package, so it is preserved
|
||||
assert len(non_blank) == 2, f"git+ URL should be preserved, got: {non_blank}"
|
||||
|
||||
|
||||
# ── Real requirements file filtering ──────────────────────────────────
|
||||
|
||||
|
||||
class TestRealRequirementsFiltering:
|
||||
"""Filter the ACTUAL extras.txt and extras-no-deps.txt with NO_TORCH_SKIP_PACKAGES."""
|
||||
|
||||
@pytest.fixture(autouse = True)
|
||||
def _check_req_files(self):
|
||||
if not EXTRAS_TXT.is_file():
|
||||
pytest.skip("extras.txt not found in repo")
|
||||
if not EXTRAS_NO_DEPS_TXT.is_file():
|
||||
pytest.skip("extras-no-deps.txt not found in repo")
|
||||
|
||||
def _non_blank_non_comment(self, path: Path) -> list[str]:
|
||||
"""Return non-blank, non-comment lines from a requirements file."""
|
||||
lines = path.read_text(encoding = "utf-8").splitlines()
|
||||
return [l.strip() for l in lines if l.strip() and not l.strip().startswith("#")]
|
||||
|
||||
def test_extras_txt_torch_packages_removed(self):
|
||||
"""extras.txt: all NO_TORCH_SKIP_PACKAGES must be removed, everything else preserved."""
|
||||
result = ips._filter_requirements(EXTRAS_TXT, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
filtered = self._non_blank_non_comment(Path(result))
|
||||
original = self._non_blank_non_comment(EXTRAS_TXT)
|
||||
|
||||
# These must be gone
|
||||
for pkg in ["torch-stoi", "timm", "openai-whisper", "transformers-cfg"]:
|
||||
assert not any(
|
||||
l.lower().startswith(pkg) for l in filtered
|
||||
), f"{pkg} should be removed from extras.txt"
|
||||
|
||||
# Everything else must remain
|
||||
expected = [
|
||||
l
|
||||
for l in original
|
||||
if not any(
|
||||
l.strip().lower().startswith(p) for p in ips.NO_TORCH_SKIP_PACKAGES
|
||||
)
|
||||
]
|
||||
assert filtered == expected, (
|
||||
f"Filtered extras.txt should match expected.\n"
|
||||
f"Missing: {set(expected) - set(filtered)}\n"
|
||||
f"Extra: {set(filtered) - set(expected)}"
|
||||
)
|
||||
|
||||
def test_extras_no_deps_txt_torchcodec_and_dlpack_removed(self):
|
||||
"""extras-no-deps.txt: torchcodec and torch-c-dlpack-ext must be removed."""
|
||||
result = ips._filter_requirements(
|
||||
EXTRAS_NO_DEPS_TXT, ips.NO_TORCH_SKIP_PACKAGES
|
||||
)
|
||||
filtered = self._non_blank_non_comment(Path(result))
|
||||
original = self._non_blank_non_comment(EXTRAS_NO_DEPS_TXT)
|
||||
|
||||
for pkg in ["torchcodec", "torch-c-dlpack-ext"]:
|
||||
assert not any(
|
||||
l.lower().startswith(pkg) for l in filtered
|
||||
), f"{pkg} should be removed from extras-no-deps.txt"
|
||||
|
||||
expected = [
|
||||
l
|
||||
for l in original
|
||||
if not any(
|
||||
l.strip().lower().startswith(p) for p in ips.NO_TORCH_SKIP_PACKAGES
|
||||
)
|
||||
]
|
||||
assert filtered == expected
|
||||
|
||||
def test_extras_txt_most_packages_preserved(self):
|
||||
"""Ensure a representative set of non-torch packages survive filtering."""
|
||||
result = ips._filter_requirements(EXTRAS_TXT, ips.NO_TORCH_SKIP_PACKAGES)
|
||||
filtered_text = Path(result).read_text(encoding = "utf-8").lower()
|
||||
|
||||
must_survive = ["scikit-learn", "loguru", "tiktoken", "einops", "tabulate"]
|
||||
for pkg in must_survive:
|
||||
if pkg in EXTRAS_TXT.read_text(encoding = "utf-8").lower():
|
||||
assert pkg in filtered_text, f"{pkg} should survive NO_TORCH filtering"
|
||||
|
||||
def test_extras_no_deps_txt_trl_preserved(self):
|
||||
"""trl should survive NO_TORCH filtering in extras-no-deps.txt."""
|
||||
result = ips._filter_requirements(
|
||||
EXTRAS_NO_DEPS_TXT, ips.NO_TORCH_SKIP_PACKAGES
|
||||
)
|
||||
filtered_text = Path(result).read_text(encoding = "utf-8").lower()
|
||||
assert "trl" in filtered_text, "trl should survive NO_TORCH filtering"
|
||||
|
||||
|
||||
# ── NO_TORCH constant tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNoTorchConstant:
|
||||
"""Verify NO_TORCH is derived correctly from UNSLOTH_NO_TORCH env var."""
|
||||
|
||||
def _reimport_no_torch(self) -> bool:
|
||||
return os.environ.get("UNSLOTH_NO_TORCH", "false").lower() in ("1", "true")
|
||||
|
||||
def test_true_lowercase(self):
|
||||
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "true"}):
|
||||
assert self._reimport_no_torch() is True
|
||||
|
||||
def test_true_one(self):
|
||||
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "1"}):
|
||||
assert self._reimport_no_torch() is True
|
||||
|
||||
def test_true_uppercase(self):
|
||||
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "TRUE"}):
|
||||
assert self._reimport_no_torch() is True
|
||||
|
||||
def test_false_string(self):
|
||||
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}):
|
||||
assert self._reimport_no_torch() is False
|
||||
|
||||
def test_false_zero(self):
|
||||
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "0"}):
|
||||
assert self._reimport_no_torch() is False
|
||||
|
||||
def test_not_set(self):
|
||||
env = os.environ.copy()
|
||||
env.pop("UNSLOTH_NO_TORCH", None)
|
||||
with mock.patch.dict(os.environ, env, clear = True):
|
||||
assert self._reimport_no_torch() is False
|
||||
|
||||
def test_infer_no_torch_on_intel_mac(self):
|
||||
"""_infer_no_torch falls back to platform detection when env var is unset."""
|
||||
env = os.environ.copy()
|
||||
env.pop("UNSLOTH_NO_TORCH", None)
|
||||
with (
|
||||
mock.patch.dict(os.environ, env, clear = True),
|
||||
mock.patch.object(ips, "IS_MAC_INTEL", True),
|
||||
):
|
||||
assert ips._infer_no_torch() is True
|
||||
|
||||
def test_infer_no_torch_respects_explicit_false_on_intel_mac(self):
|
||||
"""Explicit UNSLOTH_NO_TORCH=false overrides platform detection."""
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}),
|
||||
mock.patch.object(ips, "IS_MAC_INTEL", True),
|
||||
):
|
||||
assert ips._infer_no_torch() is False
|
||||
|
||||
def test_infer_no_torch_linux_unset(self):
|
||||
"""On Linux with env var unset, _infer_no_torch returns False."""
|
||||
env = os.environ.copy()
|
||||
env.pop("UNSLOTH_NO_TORCH", None)
|
||||
with (
|
||||
mock.patch.dict(os.environ, env, clear = True),
|
||||
mock.patch.object(ips, "IS_MAC_INTEL", False),
|
||||
):
|
||||
assert ips._infer_no_torch() is False
|
||||
|
||||
|
||||
# ── IS_MACOS constant tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsMacosConstant:
|
||||
"""Verify IS_MACOS detection logic."""
|
||||
|
||||
def test_is_macos_matches_platform(self):
|
||||
import sys
|
||||
|
||||
expected = sys.platform == "darwin"
|
||||
assert ips.IS_MACOS is expected
|
||||
|
||||
|
||||
# ── Subprocess mock of install_python_stack() ─────────────────────────
|
||||
|
||||
|
||||
class TestInstallPythonStackSubprocessMock:
|
||||
"""Monkeypatch subprocess.run to capture all pip/uv commands,
|
||||
then verify which requirements files are used/skipped under
|
||||
different NO_TORCH / IS_MACOS / IS_WINDOWS configurations."""
|
||||
|
||||
@pytest.fixture(autouse = True)
|
||||
def _check_req_files(self):
|
||||
"""Skip if requirements files are missing."""
|
||||
for f in [EXTRAS_TXT, EXTRAS_NO_DEPS_TXT, OVERRIDES_TXT]:
|
||||
if not f.is_file():
|
||||
pytest.skip(f"{f.name} not found in repo")
|
||||
|
||||
def _capture_install(
|
||||
self,
|
||||
no_torch: bool,
|
||||
is_macos: bool,
|
||||
is_windows: bool,
|
||||
*,
|
||||
skip_base: bool = True,
|
||||
):
|
||||
"""Run install_python_stack() with mocked subprocess, capturing all commands.
|
||||
|
||||
Returns a list of string-joined commands (each element is ' '.join(cmd)).
|
||||
"""
|
||||
captured_cmds: list[list[str]] = []
|
||||
|
||||
def mock_run(cmd, **kw):
|
||||
captured_cmds.append(
|
||||
list(cmd) if isinstance(cmd, (list, tuple)) else [str(cmd)]
|
||||
)
|
||||
return subprocess.CompletedProcess(cmd, 0, b"", b"")
|
||||
|
||||
env = {"SKIP_STUDIO_BASE": "1"} if skip_base else {}
|
||||
|
||||
with (
|
||||
mock.patch.object(ips, "NO_TORCH", no_torch),
|
||||
mock.patch.object(ips, "IS_MACOS", is_macos),
|
||||
mock.patch.object(ips, "IS_WINDOWS", is_windows),
|
||||
mock.patch.object(ips, "USE_UV", True),
|
||||
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
|
||||
mock.patch.object(ips, "VERBOSE", False),
|
||||
mock.patch("subprocess.run", side_effect = mock_run),
|
||||
mock.patch.object(ips, "_bootstrap_uv", return_value = True),
|
||||
mock.patch.object(
|
||||
ips, "LOCAL_DD_UNSTRUCTURED_PLUGIN", Path("/fake/plugin")
|
||||
),
|
||||
mock.patch("pathlib.Path.is_dir", return_value = True),
|
||||
mock.patch("pathlib.Path.is_file", return_value = True),
|
||||
):
|
||||
with mock.patch.dict(os.environ, env, clear = False):
|
||||
ips.install_python_stack()
|
||||
|
||||
return [" ".join(str(c) for c in cmd) for cmd in captured_cmds]
|
||||
|
||||
def _cmds_contain_file(self, cmds: list[str], filename: str) -> bool:
|
||||
"""Check if any captured command references the given filename."""
|
||||
return any(filename in cmd for cmd in cmds)
|
||||
|
||||
# -- NO_TORCH=True, IS_MACOS=True (Intel Mac scenario) --
|
||||
|
||||
def test_no_torch_macos_skips_overrides(self):
|
||||
"""With NO_TORCH=True, overrides.txt pip_install must NOT be called."""
|
||||
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
|
||||
assert not self._cmds_contain_file(
|
||||
cmds, "overrides.txt"
|
||||
), "overrides.txt should be skipped when NO_TORCH=True"
|
||||
|
||||
def test_no_torch_macos_skips_triton(self):
|
||||
"""With IS_MACOS=True, triton-kernels.txt must NOT be called."""
|
||||
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
|
||||
assert not self._cmds_contain_file(
|
||||
cmds, "triton-kernels.txt"
|
||||
), "triton-kernels.txt should be skipped on macOS"
|
||||
|
||||
def test_no_torch_macos_extras_called(self):
|
||||
"""With NO_TORCH=True, extras.txt is still called (but filtered)."""
|
||||
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
|
||||
has_extras = self._cmds_contain_file(cmds, "extras.txt") or any(
|
||||
"-r" in cmd and "tmp" in cmd.lower() for cmd in cmds
|
||||
)
|
||||
assert has_extras, "extras.txt (or its filtered temp) should be called"
|
||||
|
||||
def test_no_torch_macos_extras_no_deps_called(self):
|
||||
"""With NO_TORCH=True, extras-no-deps.txt is still called (but filtered)."""
|
||||
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
|
||||
has_extras_nd = self._cmds_contain_file(cmds, "extras-no-deps.txt") or any(
|
||||
"-r" in cmd and "tmp" in cmd.lower() for cmd in cmds
|
||||
)
|
||||
assert (
|
||||
has_extras_nd
|
||||
), "extras-no-deps.txt (or its filtered temp) should be called"
|
||||
|
||||
# -- IS_WINDOWS=True + NO_TORCH=True (stacked) --
|
||||
|
||||
def test_windows_no_torch_skips_overrides(self):
|
||||
"""Windows+NO_TORCH: overrides.txt must be skipped."""
|
||||
cmds = self._capture_install(no_torch = True, is_macos = False, is_windows = True)
|
||||
assert not self._cmds_contain_file(
|
||||
cmds, "overrides.txt"
|
||||
), "overrides.txt should be skipped with NO_TORCH=True on Windows"
|
||||
|
||||
def test_windows_no_torch_skips_triton(self):
|
||||
"""Windows: triton-kernels.txt must be skipped (IS_WINDOWS guard)."""
|
||||
cmds = self._capture_install(no_torch = True, is_macos = False, is_windows = True)
|
||||
assert not self._cmds_contain_file(
|
||||
cmds, "triton-kernels.txt"
|
||||
), "triton-kernels.txt should be skipped on Windows"
|
||||
|
||||
# -- Normal Linux path (NO_TORCH=False, IS_MACOS=False, IS_WINDOWS=False) --
|
||||
|
||||
def test_normal_linux_includes_overrides(self):
|
||||
"""Normal Linux: overrides.txt IS called."""
|
||||
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
|
||||
assert self._cmds_contain_file(
|
||||
cmds, "overrides.txt"
|
||||
), "overrides.txt should be called on normal Linux"
|
||||
|
||||
def test_normal_linux_includes_triton(self):
|
||||
"""Normal Linux: triton-kernels.txt IS called."""
|
||||
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
|
||||
assert self._cmds_contain_file(
|
||||
cmds, "triton-kernels.txt"
|
||||
), "triton-kernels.txt should be called on normal Linux"
|
||||
|
||||
def test_normal_linux_includes_extras(self):
|
||||
"""Normal Linux: extras.txt IS called (no filtering)."""
|
||||
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
|
||||
assert self._cmds_contain_file(
|
||||
cmds, "extras.txt"
|
||||
), "extras.txt should be called on normal Linux"
|
||||
|
||||
def test_normal_linux_includes_extras_no_deps(self):
|
||||
"""Normal Linux: extras-no-deps.txt IS called (no filtering)."""
|
||||
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
|
||||
assert self._cmds_contain_file(
|
||||
cmds, "extras-no-deps.txt"
|
||||
), "extras-no-deps.txt should be called on normal Linux"
|
||||
|
||||
# -- Windows-only (NO_TORCH=False) to verify triton is still skipped --
|
||||
|
||||
def test_windows_only_skips_triton(self):
|
||||
"""Windows (without NO_TORCH): triton still skipped."""
|
||||
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = True)
|
||||
assert not self._cmds_contain_file(
|
||||
cmds, "triton-kernels.txt"
|
||||
), "triton-kernels.txt should be skipped on Windows even without NO_TORCH"
|
||||
|
||||
def test_windows_only_includes_overrides(self):
|
||||
"""Windows (without NO_TORCH): overrides IS called (via filtered temp file).
|
||||
|
||||
On Windows, all req files go through _filter_requirements(WINDOWS_SKIP_PACKAGES),
|
||||
so the command uses a temp file, not overrides.txt directly. We check for
|
||||
--reinstall (uv translation of --force-reinstall) which is unique to overrides.
|
||||
"""
|
||||
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = True)
|
||||
assert any(
|
||||
"--reinstall" in cmd for cmd in cmds
|
||||
), "overrides step (--reinstall) should be called on Windows when NO_TORCH=False"
|
||||
|
||||
# -- Update path (skip_base=False) to verify no-torch mode is durable --
|
||||
|
||||
def test_update_path_intel_macos_still_skips_overrides(self):
|
||||
"""Update path (no SKIP_STUDIO_BASE): overrides still skipped on Intel Mac."""
|
||||
cmds = self._capture_install(
|
||||
no_torch = True, is_macos = True, is_windows = False, skip_base = False
|
||||
)
|
||||
assert not self._cmds_contain_file(
|
||||
cmds, "overrides.txt"
|
||||
), "overrides.txt should be skipped on Intel Mac even via studio update"
|
||||
|
||||
def test_update_path_intel_macos_still_skips_triton(self):
|
||||
"""Update path (no SKIP_STUDIO_BASE): triton still skipped on macOS."""
|
||||
cmds = self._capture_install(
|
||||
no_torch = True, is_macos = True, is_windows = False, skip_base = False
|
||||
)
|
||||
assert not self._cmds_contain_file(
|
||||
cmds, "triton-kernels.txt"
|
||||
), "triton-kernels.txt should be skipped on macOS even via studio update"
|
||||
|
||||
|
||||
# ── Overrides skip structural checks ─────────────────────────────────
|
||||
|
||||
|
||||
class TestOverridesSkip:
|
||||
"""Verify overrides.txt is skipped when NO_TORCH is True (source-level check)."""
|
||||
|
||||
def test_no_torch_guard_exists_in_source(self):
|
||||
"""The install_python_stack source must contain a NO_TORCH guard around overrides."""
|
||||
source = Path(ips.__file__).read_text(encoding = "utf-8")
|
||||
assert (
|
||||
"if NO_TORCH:" in source
|
||||
), "NO_TORCH guard not found in install_python_stack.py"
|
||||
|
||||
def test_overrides_skipped_when_no_torch(self):
|
||||
"""With NO_TORCH=True on the module, pip_install should NOT be called for overrides."""
|
||||
source = Path(ips.__file__).read_text(encoding = "utf-8")
|
||||
overrides_match = re.search(r"if NO_TORCH:.*?overrides", source, re.DOTALL)
|
||||
assert (
|
||||
overrides_match is not None
|
||||
), "Expected NO_TORCH conditional before overrides install"
|
||||
|
||||
|
||||
# ── install.sh --no-torch flag tests ──────────────────────────────────
|
||||
|
||||
|
||||
class TestInstallShNoTorchFlag:
|
||||
"""Verify install.sh has the --no-torch flag and SKIP_TORCH variable."""
|
||||
|
||||
@pytest.fixture(autouse = True)
|
||||
def _check_install_sh(self):
|
||||
install_sh = Path(__file__).resolve().parents[2] / "install.sh"
|
||||
if not install_sh.is_file():
|
||||
pytest.skip("install.sh not found")
|
||||
self.install_sh = install_sh
|
||||
self.source = install_sh.read_text(encoding = "utf-8")
|
||||
|
||||
def test_no_torch_flag_in_case_statement(self):
|
||||
"""--no-torch must appear in the flag parser case statement."""
|
||||
assert (
|
||||
"--no-torch)" in self.source
|
||||
), "--no-torch not found in install.sh flag parser"
|
||||
|
||||
def test_no_torch_flag_variable_initialized(self):
|
||||
"""_NO_TORCH_FLAG must be initialized to false."""
|
||||
assert (
|
||||
"_NO_TORCH_FLAG=false" in self.source
|
||||
), "_NO_TORCH_FLAG=false not found in install.sh"
|
||||
|
||||
def test_skip_torch_variable_exists(self):
|
||||
"""SKIP_TORCH variable must be defined."""
|
||||
assert (
|
||||
"SKIP_TORCH=false" in self.source
|
||||
), "SKIP_TORCH=false not found in install.sh"
|
||||
assert (
|
||||
"SKIP_TORCH=true" in self.source
|
||||
), "SKIP_TORCH=true not found in install.sh"
|
||||
|
||||
def test_skip_torch_driven_by_flag_and_mac_intel(self):
|
||||
"""SKIP_TORCH must check both _NO_TORCH_FLAG and MAC_INTEL."""
|
||||
assert (
|
||||
"_NO_TORCH_FLAG" in self.source
|
||||
), "_NO_TORCH_FLAG not referenced in SKIP_TORCH logic"
|
||||
assert (
|
||||
"MAC_INTEL" in self.source
|
||||
), "MAC_INTEL not referenced in SKIP_TORCH logic"
|
||||
|
||||
def test_unsloth_no_torch_uses_skip_torch(self):
|
||||
"""UNSLOTH_NO_TORCH must reference $SKIP_TORCH, not $MAC_INTEL."""
|
||||
import re
|
||||
|
||||
matches = re.findall(r'UNSLOTH_NO_TORCH="\$(\w+)"', self.source)
|
||||
for var in matches:
|
||||
assert (
|
||||
var == "SKIP_TORCH"
|
||||
), f"UNSLOTH_NO_TORCH references ${var} instead of $SKIP_TORCH"
|
||||
|
||||
def test_cpu_hint_message_exists(self):
|
||||
"""CPU hint message must exist in install.sh."""
|
||||
assert (
|
||||
"No NVIDIA GPU detected" in self.source
|
||||
), "CPU hint message not found in install.sh"
|
||||
assert (
|
||||
"--no-torch" in self.source
|
||||
), "--no-torch suggestion not found in CPU hint"
|
||||
|
||||
def test_no_torch_flag_parsing_subprocess(self):
|
||||
"""--no-torch flag sets _NO_TORCH_FLAG=true (subprocess test)."""
|
||||
script = textwrap.dedent("""\
|
||||
_NO_TORCH_FLAG=false
|
||||
_next_is_package=false
|
||||
STUDIO_LOCAL_INSTALL=false
|
||||
PACKAGE_NAME="unsloth"
|
||||
for arg in "$@"; do
|
||||
if [ "$_next_is_package" = true ]; then
|
||||
PACKAGE_NAME="$arg"
|
||||
_next_is_package=false
|
||||
continue
|
||||
fi
|
||||
case "$arg" in
|
||||
--local) STUDIO_LOCAL_INSTALL=true ;;
|
||||
--package) _next_is_package=true ;;
|
||||
--no-torch) _NO_TORCH_FLAG=true ;;
|
||||
esac
|
||||
done
|
||||
echo "$_NO_TORCH_FLAG"
|
||||
""")
|
||||
result = subprocess.run(
|
||||
["bash", "-c", script, "_", "--no-torch"],
|
||||
capture_output = True,
|
||||
text = True,
|
||||
)
|
||||
assert (
|
||||
result.stdout.strip() == "true"
|
||||
), f"Expected _NO_TORCH_FLAG=true, got: {result.stdout.strip()}"
|
||||
|
||||
def test_no_torch_with_local_flag(self):
|
||||
"""--no-torch and --local can be used together."""
|
||||
script = textwrap.dedent("""\
|
||||
_NO_TORCH_FLAG=false
|
||||
_next_is_package=false
|
||||
STUDIO_LOCAL_INSTALL=false
|
||||
PACKAGE_NAME="unsloth"
|
||||
for arg in "$@"; do
|
||||
if [ "$_next_is_package" = true ]; then
|
||||
PACKAGE_NAME="$arg"
|
||||
_next_is_package=false
|
||||
continue
|
||||
fi
|
||||
case "$arg" in
|
||||
--local) STUDIO_LOCAL_INSTALL=true ;;
|
||||
--package) _next_is_package=true ;;
|
||||
--no-torch) _NO_TORCH_FLAG=true ;;
|
||||
esac
|
||||
done
|
||||
echo "$_NO_TORCH_FLAG $STUDIO_LOCAL_INSTALL"
|
||||
""")
|
||||
result = subprocess.run(
|
||||
["bash", "-c", script, "_", "--local", "--no-torch"],
|
||||
capture_output = True,
|
||||
text = True,
|
||||
)
|
||||
assert (
|
||||
result.stdout.strip() == "true true"
|
||||
), f"Expected 'true true', got: {result.stdout.strip()}"
|
||||
|
||||
def test_cpu_hint_only_when_not_skip_torch(self):
|
||||
"""CPU hint should only print when SKIP_TORCH=false and OS!=macos."""
|
||||
script = textwrap.dedent("""\
|
||||
TORCH_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
SKIP_TORCH=false
|
||||
OS="linux"
|
||||
case "$TORCH_INDEX_URL" in
|
||||
*/cpu)
|
||||
if [ "$SKIP_TORCH" = false ] && [ "$OS" != "macos" ]; then
|
||||
echo "HINT_PRINTED"
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
""")
|
||||
result = subprocess.run(
|
||||
["bash", "-c", script],
|
||||
capture_output = True,
|
||||
text = True,
|
||||
)
|
||||
assert "HINT_PRINTED" in result.stdout, "CPU hint should print"
|
||||
|
||||
# With SKIP_TORCH=true, hint should NOT print
|
||||
script2 = script.replace("SKIP_TORCH=false", "SKIP_TORCH=true")
|
||||
result2 = subprocess.run(
|
||||
["bash", "-c", script2],
|
||||
capture_output = True,
|
||||
text = True,
|
||||
)
|
||||
assert (
|
||||
"HINT_PRINTED" not in result2.stdout
|
||||
), "CPU hint should NOT print when SKIP_TORCH=true"
|
||||
|
||||
|
||||
# ── Triton macOS skip structural checks ──────────────────────────────
|
||||
|
||||
|
||||
class TestTritonMacosSkip:
|
||||
"""Verify triton is skipped on macOS (source-level check)."""
|
||||
|
||||
def test_triton_guard_in_source(self):
|
||||
"""Source must skip triton on both Windows and macOS."""
|
||||
source = Path(ips.__file__).read_text(encoding = "utf-8")
|
||||
assert (
|
||||
"not IS_MACOS" in source
|
||||
), "IS_MACOS guard for triton not found in install_python_stack.py"
|
||||
assert (
|
||||
"not IS_WINDOWS and not IS_MACOS" in source
|
||||
), "Expected 'not IS_WINDOWS and not IS_MACOS' guard for triton"
|
||||
582
tests/python/test_studio_import_no_torch.py
Normal file
582
tests/python/test_studio_import_no_torch.py
Normal file
|
|
@ -0,0 +1,582 @@
|
|||
"""End-to-end sandbox tests: Studio modules in isolated no-torch venvs.
|
||||
|
||||
Covers:
|
||||
- Python 3.12 and 3.13 venv creation (Intel Mac uses 3.12, Apple Silicon/Linux 3.13)
|
||||
- data_collators.py loads and dataclasses instantiate without torch
|
||||
- chat_templates.py top-level exec works with stubs for relative imports
|
||||
- Negative control: prepending 'import torch' fails in no-torch venv
|
||||
- Negative control: installing torchao (from overrides.txt) fails in no-torch venv
|
||||
- AST structural checks for top-level torch imports
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
DATA_COLLATORS = (
|
||||
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "data_collators.py"
|
||||
)
|
||||
CHAT_TEMPLATES = (
|
||||
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "chat_templates.py"
|
||||
)
|
||||
FORMAT_CONVERSION = (
|
||||
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "format_conversion.py"
|
||||
)
|
||||
|
||||
|
||||
def _has_uv() -> bool:
|
||||
return shutil.which("uv") is not None
|
||||
|
||||
|
||||
def _create_venv(venv_dir: Path, python_version: str) -> Path | None:
|
||||
"""Create a uv venv at the given Python version. Returns python path or None."""
|
||||
result = subprocess.run(
|
||||
["uv", "venv", str(venv_dir), "--python", python_version],
|
||||
capture_output = True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
venv_python = venv_dir / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = venv_dir / "Scripts" / "python.exe"
|
||||
return venv_python if venv_python.exists() else None
|
||||
|
||||
|
||||
@pytest.fixture(params = ["3.12", "3.13"], scope = "module")
|
||||
def no_torch_venv(request, tmp_path_factory):
|
||||
"""Create a temporary venv at the requested Python version with no torch.
|
||||
|
||||
Parametrized for 3.12 (Intel Mac) and 3.13 (Apple Silicon / Linux).
|
||||
"""
|
||||
if not _has_uv():
|
||||
pytest.skip("uv not available")
|
||||
|
||||
py_version = request.param
|
||||
venv_dir = tmp_path_factory.mktemp(f"no_torch_venv_{py_version}")
|
||||
venv_python = _create_venv(venv_dir, py_version)
|
||||
if venv_python is None:
|
||||
pytest.skip(f"Could not create Python {py_version} venv")
|
||||
|
||||
# Verify torch is NOT importable
|
||||
check = subprocess.run(
|
||||
[str(venv_python), "-c", "import torch"],
|
||||
capture_output = True,
|
||||
)
|
||||
assert (
|
||||
check.returncode != 0
|
||||
), f"torch should NOT be importable in fresh {py_version} venv"
|
||||
|
||||
return str(venv_python)
|
||||
|
||||
|
||||
# ── AST structural checks ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDataCollatorsAST:
|
||||
"""Static analysis: data_collators.py has no top-level torch imports."""
|
||||
|
||||
def test_ast_parse(self):
|
||||
"""data_collators.py must be valid Python syntax."""
|
||||
source = DATA_COLLATORS.read_text(encoding = "utf-8")
|
||||
tree = ast.parse(source, filename = str(DATA_COLLATORS))
|
||||
assert tree is not None
|
||||
|
||||
def test_no_top_level_torch_import(self):
|
||||
"""No top-level 'import torch' or 'from torch' statements."""
|
||||
source = DATA_COLLATORS.read_text(encoding = "utf-8")
|
||||
tree = ast.parse(source)
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
assert not alias.name.startswith(
|
||||
"torch"
|
||||
), f"Top-level 'import {alias.name}' found at line {node.lineno}"
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
assert not node.module.startswith(
|
||||
"torch"
|
||||
), f"Top-level 'from {node.module}' found at line {node.lineno}"
|
||||
|
||||
|
||||
class TestChatTemplatesAST:
|
||||
"""Static analysis: chat_templates.py has no top-level torch imports."""
|
||||
|
||||
def test_ast_parse(self):
|
||||
"""chat_templates.py must be valid Python syntax."""
|
||||
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
|
||||
tree = ast.parse(source, filename = str(CHAT_TEMPLATES))
|
||||
assert tree is not None
|
||||
|
||||
def test_no_top_level_torch_import(self):
|
||||
"""No top-level 'import torch' or 'from torch' at module level."""
|
||||
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
|
||||
tree = ast.parse(source)
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
assert not alias.name.startswith(
|
||||
"torch"
|
||||
), f"Top-level 'import {alias.name}' found at line {node.lineno}"
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
assert not node.module.startswith(
|
||||
"torch"
|
||||
), f"Top-level 'from {node.module}' found at line {node.lineno}"
|
||||
|
||||
def test_torch_imports_only_inside_functions(self):
|
||||
"""All 'from torch' imports must be inside function/method bodies."""
|
||||
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
|
||||
tree = ast.parse(source)
|
||||
torch_imports = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
module = None
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
module = node.module
|
||||
elif isinstance(node, ast.Import):
|
||||
module = node.names[0].name if node.names else None
|
||||
if module and module.startswith("torch"):
|
||||
torch_imports.append(node)
|
||||
|
||||
top_level = set(id(n) for n in ast.iter_child_nodes(tree))
|
||||
for imp in torch_imports:
|
||||
assert id(imp) not in top_level, (
|
||||
f"torch import at line {imp.lineno} is at top level"
|
||||
" (should be inside a function)"
|
||||
)
|
||||
|
||||
|
||||
# ── data_collators.py: exec + dataclass instantiation in no-torch venv ──
|
||||
|
||||
|
||||
class TestDataCollatorsNoTorchVenv:
|
||||
"""Run data_collators.py in an isolated no-torch venv, verify classes load."""
|
||||
|
||||
def test_exec_in_no_torch_venv(self, no_torch_venv):
|
||||
"""data_collators.py executes in a venv without torch (with loggers stub)."""
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: None
|
||||
sys.modules['loggers'] = loggers
|
||||
exec(open({str(DATA_COLLATORS)!r}).read())
|
||||
print("OK: exec succeeded")
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"data_collators.py failed in no-torch venv:\n{result.stderr.decode()}"
|
||||
assert b"OK: exec succeeded" in result.stdout
|
||||
|
||||
def test_dataclass_speech_collator_instantiable(self, no_torch_venv):
|
||||
"""DataCollatorSpeechSeq2SeqWithPadding can be instantiated with processor=None."""
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: None
|
||||
sys.modules['loggers'] = loggers
|
||||
exec(open({str(DATA_COLLATORS)!r}).read())
|
||||
obj = DataCollatorSpeechSeq2SeqWithPadding(processor=None)
|
||||
assert obj.processor is None, "processor should be None"
|
||||
print("OK: DataCollatorSpeechSeq2SeqWithPadding instantiated")
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"DataCollatorSpeechSeq2SeqWithPadding failed:\n{result.stderr.decode()}"
|
||||
assert b"OK: DataCollatorSpeechSeq2SeqWithPadding instantiated" in result.stdout
|
||||
|
||||
def test_dataclass_deepseek_collator_instantiable(self, no_torch_venv):
|
||||
"""DeepSeekOCRDataCollator can be instantiated with processor=None."""
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: None
|
||||
sys.modules['loggers'] = loggers
|
||||
exec(open({str(DATA_COLLATORS)!r}).read())
|
||||
obj = DeepSeekOCRDataCollator(processor=None)
|
||||
assert obj.processor is None, "processor should be None"
|
||||
assert obj.max_length == 2048, "default max_length should be 2048"
|
||||
assert obj.ignore_index == -100, "default ignore_index should be -100"
|
||||
print("OK: DeepSeekOCRDataCollator instantiated")
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"DeepSeekOCRDataCollator failed:\n{result.stderr.decode()}"
|
||||
assert b"OK: DeepSeekOCRDataCollator instantiated" in result.stdout
|
||||
|
||||
def test_dataclass_vlm_collator_instantiable(self, no_torch_venv):
|
||||
"""VLMDataCollator can be instantiated with processor=None."""
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: None
|
||||
sys.modules['loggers'] = loggers
|
||||
exec(open({str(DATA_COLLATORS)!r}).read())
|
||||
obj = VLMDataCollator(processor=None)
|
||||
assert obj.processor is None
|
||||
assert obj.mask_input_tokens is True, "default mask_input_tokens should be True"
|
||||
print("OK: VLMDataCollator instantiated")
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"VLMDataCollator failed:\n{result.stderr.decode()}"
|
||||
assert b"OK: VLMDataCollator instantiated" in result.stdout
|
||||
|
||||
|
||||
# ── chat_templates.py: exec in no-torch venv ─────────────────────────
|
||||
|
||||
|
||||
class TestChatTemplatesNoTorchVenv:
|
||||
"""Run chat_templates.py in an isolated no-torch venv with stubs."""
|
||||
|
||||
def test_exec_with_stubs(self, no_torch_venv):
|
||||
"""chat_templates.py top-level exec works with stubs for relative imports."""
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
|
||||
# Stub loggers
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})()
|
||||
sys.modules['loggers'] = loggers
|
||||
|
||||
# Stub relative imports (.format_detection, .model_mappings)
|
||||
format_detection = types.ModuleType('format_detection')
|
||||
format_detection.detect_dataset_format = lambda *a, **k: None
|
||||
format_detection.detect_multimodal_dataset = lambda *a, **k: None
|
||||
format_detection.detect_custom_format_heuristic = lambda *a, **k: None
|
||||
sys.modules['format_detection'] = format_detection
|
||||
|
||||
model_mappings = types.ModuleType('model_mappings')
|
||||
model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}}
|
||||
sys.modules['model_mappings'] = model_mappings
|
||||
|
||||
# Read and transform the source: replace relative imports with absolute
|
||||
source = open({str(CHAT_TEMPLATES)!r}).read()
|
||||
source = source.replace('from .format_detection import', 'from format_detection import')
|
||||
source = source.replace('from .model_mappings import', 'from model_mappings import')
|
||||
|
||||
exec(source)
|
||||
|
||||
# Verify module-level constants are defined
|
||||
ns = dict(locals())
|
||||
assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined after exec"
|
||||
print("OK: chat_templates.py exec succeeded")
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"chat_templates.py failed in no-torch venv:\n{result.stderr.decode()}"
|
||||
assert b"OK: chat_templates.py exec succeeded" in result.stdout
|
||||
|
||||
def test_default_alpaca_template_defined(self, no_torch_venv):
|
||||
"""DEFAULT_ALPACA_TEMPLATE constant is accessible after exec."""
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})()
|
||||
sys.modules['loggers'] = loggers
|
||||
|
||||
format_detection = types.ModuleType('format_detection')
|
||||
format_detection.detect_dataset_format = lambda *a, **k: None
|
||||
format_detection.detect_multimodal_dataset = lambda *a, **k: None
|
||||
format_detection.detect_custom_format_heuristic = lambda *a, **k: None
|
||||
sys.modules['format_detection'] = format_detection
|
||||
|
||||
model_mappings = types.ModuleType('model_mappings')
|
||||
model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}}
|
||||
sys.modules['model_mappings'] = model_mappings
|
||||
|
||||
ns = {{}}
|
||||
source = open({str(CHAT_TEMPLATES)!r}).read()
|
||||
source = source.replace('from .format_detection import', 'from format_detection import')
|
||||
source = source.replace('from .model_mappings import', 'from model_mappings import')
|
||||
exec(source, ns)
|
||||
|
||||
assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined"
|
||||
assert 'Instruction' in ns['DEFAULT_ALPACA_TEMPLATE'], "Template content unexpected"
|
||||
print("OK: DEFAULT_ALPACA_TEMPLATE defined and valid")
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"DEFAULT_ALPACA_TEMPLATE check failed:\n{result.stderr.decode()}"
|
||||
assert b"OK: DEFAULT_ALPACA_TEMPLATE defined and valid" in result.stdout
|
||||
|
||||
|
||||
# ── format_conversion.py: AST + runtime tests ────────────────────────
|
||||
|
||||
|
||||
class TestFormatConversionAST:
|
||||
"""Static analysis: format_conversion.py torch imports are guarded."""
|
||||
|
||||
def test_ast_parse(self):
|
||||
"""format_conversion.py must be valid Python syntax."""
|
||||
source = FORMAT_CONVERSION.read_text(encoding = "utf-8")
|
||||
tree = ast.parse(source, filename = str(FORMAT_CONVERSION))
|
||||
assert tree is not None
|
||||
|
||||
def test_no_bare_torch_import_in_functions(self):
|
||||
"""All 'from torch' imports in function bodies must be inside try/except."""
|
||||
source = FORMAT_CONVERSION.read_text(encoding = "utf-8")
|
||||
tree = ast.parse(source)
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
for child in ast.walk(node):
|
||||
if (
|
||||
isinstance(child, ast.ImportFrom)
|
||||
and child.module
|
||||
and child.module.startswith("torch")
|
||||
):
|
||||
# This torch import must be inside a Try node
|
||||
found_in_try = False
|
||||
for try_node in ast.walk(node):
|
||||
if isinstance(try_node, ast.Try):
|
||||
for try_child in ast.walk(try_node):
|
||||
if try_child is child:
|
||||
found_in_try = True
|
||||
break
|
||||
if found_in_try:
|
||||
break
|
||||
assert found_in_try, (
|
||||
f"torch import at line {child.lineno} in {node.name}() "
|
||||
"is not inside a try/except block"
|
||||
)
|
||||
|
||||
|
||||
class TestFormatConversionNoTorchVenv:
|
||||
"""Run format_conversion.py functions in a no-torch venv."""
|
||||
|
||||
def test_convert_chatml_to_alpaca_no_torch(self, no_torch_venv):
|
||||
"""convert_chatml_to_alpaca works without torch (via try/except ImportError)."""
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
|
||||
# Stub loggers
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: type('L', (), {{
|
||||
'info': lambda s, m: None,
|
||||
'warning': lambda s, m: None,
|
||||
'debug': lambda s, m: None,
|
||||
}})()
|
||||
sys.modules['loggers'] = loggers
|
||||
|
||||
# Stub datasets.IterableDataset (HF datasets, not torch)
|
||||
datasets_mod = types.ModuleType('datasets')
|
||||
datasets_mod.IterableDataset = type('IterableDataset', (), {{}})
|
||||
sys.modules['datasets'] = datasets_mod
|
||||
|
||||
# Stub utils.hardware
|
||||
utils_mod = types.ModuleType('utils')
|
||||
hardware_mod = types.ModuleType('utils.hardware')
|
||||
hardware_mod.dataset_map_num_proc = lambda n=None: 1
|
||||
utils_mod.hardware = hardware_mod
|
||||
sys.modules['utils'] = utils_mod
|
||||
sys.modules['utils.hardware'] = hardware_mod
|
||||
|
||||
# Read and exec format_conversion.py
|
||||
source = open({str(FORMAT_CONVERSION)!r}).read()
|
||||
source = source.replace('from .format_detection import', 'from format_detection import')
|
||||
ns = {{'__name__': '__test__'}}
|
||||
exec(source, ns)
|
||||
|
||||
# Test convert_chatml_to_alpaca with a simple dataset
|
||||
class FakeDataset:
|
||||
def map(self, fn, **kw):
|
||||
result = fn({{
|
||||
'messages': [[
|
||||
{{'role': 'user', 'content': 'Hello'}},
|
||||
{{'role': 'assistant', 'content': 'Hi there'}},
|
||||
]]
|
||||
}})
|
||||
return result
|
||||
|
||||
result = ns['convert_chatml_to_alpaca'](FakeDataset())
|
||||
assert 'instruction' in result, f"Expected 'instruction' in result, got {{result.keys()}}"
|
||||
assert result['instruction'] == ['Hello']
|
||||
assert result['output'] == ['Hi there']
|
||||
print("OK: convert_chatml_to_alpaca works without torch")
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"convert_chatml_to_alpaca failed without torch:\n{result.stderr.decode()}"
|
||||
assert b"OK: convert_chatml_to_alpaca works without torch" in result.stdout
|
||||
|
||||
def test_convert_alpaca_to_chatml_no_torch(self, no_torch_venv):
|
||||
"""convert_alpaca_to_chatml works without torch (via try/except ImportError)."""
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: type('L', (), {{
|
||||
'info': lambda s, m: None,
|
||||
'warning': lambda s, m: None,
|
||||
'debug': lambda s, m: None,
|
||||
}})()
|
||||
sys.modules['loggers'] = loggers
|
||||
|
||||
datasets_mod = types.ModuleType('datasets')
|
||||
datasets_mod.IterableDataset = type('IterableDataset', (), {{}})
|
||||
sys.modules['datasets'] = datasets_mod
|
||||
|
||||
utils_mod = types.ModuleType('utils')
|
||||
hardware_mod = types.ModuleType('utils.hardware')
|
||||
hardware_mod.dataset_map_num_proc = lambda n=None: 1
|
||||
utils_mod.hardware = hardware_mod
|
||||
sys.modules['utils'] = utils_mod
|
||||
sys.modules['utils.hardware'] = hardware_mod
|
||||
|
||||
source = open({str(FORMAT_CONVERSION)!r}).read()
|
||||
source = source.replace('from .format_detection import', 'from format_detection import')
|
||||
ns = {{'__name__': '__test__'}}
|
||||
exec(source, ns)
|
||||
|
||||
class FakeDataset:
|
||||
def map(self, fn, **kw):
|
||||
result = fn({{
|
||||
'instruction': ['Write a poem'],
|
||||
'input': [''],
|
||||
'output': ['Roses are red'],
|
||||
}})
|
||||
return result
|
||||
|
||||
result = ns['convert_alpaca_to_chatml'](FakeDataset())
|
||||
assert 'conversations' in result
|
||||
convo = result['conversations'][0]
|
||||
assert convo[0]['role'] == 'user'
|
||||
assert convo[1]['role'] == 'assistant'
|
||||
print("OK: convert_alpaca_to_chatml works without torch")
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"convert_alpaca_to_chatml failed without torch:\n{result.stderr.decode()}"
|
||||
assert b"OK: convert_alpaca_to_chatml works without torch" in result.stdout
|
||||
|
||||
|
||||
# ── Negative controls ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNegativeControls:
|
||||
"""Prove the fix is necessary by showing what fails WITHOUT it."""
|
||||
|
||||
def test_import_torch_prepended_fails(self, no_torch_venv):
|
||||
"""Prepending 'import torch' to data_collators.py causes ModuleNotFoundError."""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode = "w", suffix = ".py", delete = False, encoding = "utf-8"
|
||||
) as f:
|
||||
f.write("import torch\n")
|
||||
f.write(DATA_COLLATORS.read_text(encoding = "utf-8"))
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
code = textwrap.dedent(f"""\
|
||||
import sys, types
|
||||
loggers = types.ModuleType('loggers')
|
||||
loggers.get_logger = lambda n: None
|
||||
sys.modules['loggers'] = loggers
|
||||
exec(open({temp_file!r}).read())
|
||||
""")
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", code],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert (
|
||||
result.returncode != 0
|
||||
), "Expected failure when 'import torch' is prepended"
|
||||
assert (
|
||||
b"ModuleNotFoundError" in result.stderr
|
||||
or b"ImportError" in result.stderr
|
||||
), f"Expected ImportError, got:\n{result.stderr.decode()}"
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_torchao_install_fails_no_torch_venv(self, no_torch_venv):
|
||||
"""Installing torchao (from overrides.txt) fails in a no-torch venv.
|
||||
|
||||
This proves the overrides.txt skip is necessary for Intel Mac.
|
||||
"""
|
||||
result = subprocess.run(
|
||||
[
|
||||
no_torch_venv,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"torchao==0.14.0",
|
||||
"--dry-run",
|
||||
],
|
||||
capture_output = True,
|
||||
timeout = 60,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
# torchao install/resolution failed as expected
|
||||
pass
|
||||
else:
|
||||
# pip dry-run may not catch dependency issues; verify torch is missing
|
||||
check = subprocess.run(
|
||||
[no_torch_venv, "-c", "import torch"],
|
||||
capture_output = True,
|
||||
)
|
||||
assert (
|
||||
check.returncode != 0
|
||||
), "torch should not be importable -- torchao would fail at runtime"
|
||||
|
||||
def test_direct_torch_import_fails(self, no_torch_venv):
|
||||
"""Direct 'import torch' fails in the no-torch venv."""
|
||||
result = subprocess.run(
|
||||
[no_torch_venv, "-c", "import torch; print('torch loaded')"],
|
||||
capture_output = True,
|
||||
timeout = 30,
|
||||
)
|
||||
assert result.returncode != 0, "import torch should fail in no-torch venv"
|
||||
assert (
|
||||
b"ModuleNotFoundError" in result.stderr or b"ImportError" in result.stderr
|
||||
)
|
||||
|
|
@ -6,11 +6,14 @@ TESTS_DIR="$(cd "$(dirname "$0")" && pwd)"
|
|||
|
||||
echo "=== Bash tests ==="
|
||||
sh "$TESTS_DIR/sh/test_get_torch_index_url.sh"
|
||||
sh "$TESTS_DIR/sh/test_mac_intel_compat.sh"
|
||||
|
||||
echo ""
|
||||
echo "=== Python tests ==="
|
||||
python -m pytest "$TESTS_DIR/python/test_install_python_stack.py" -v
|
||||
python -m pytest "$TESTS_DIR/python/test_cross_platform_parity.py" -v
|
||||
python -m pytest "$TESTS_DIR/python/test_no_torch_filtering.py" -v
|
||||
python -m pytest "$TESTS_DIR/python/test_studio_import_no_torch.py" -v
|
||||
|
||||
echo ""
|
||||
echo "All tests passed."
|
||||
|
|
|
|||
582
tests/sh/test_mac_intel_compat.sh
Normal file
582
tests/sh/test_mac_intel_compat.sh
Normal file
|
|
@ -0,0 +1,582 @@
|
|||
#!/bin/bash
|
||||
# End-to-end sandbox tests for Mac Intel compatibility and UNSLOTH_NO_TORCH propagation.
|
||||
# Tests version_ge, arch detection (existing), plus E2E venv creation, torch skip
|
||||
# via a mock uv shim, and UNSLOTH_NO_TORCH env propagation in install.sh.
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
INSTALL_SH="$SCRIPT_DIR/../../install.sh"
|
||||
PASS=0
|
||||
FAIL=0
|
||||
|
||||
assert_eq() {
|
||||
_label="$1"; _expected="$2"; _actual="$3"
|
||||
if [ "$_actual" = "$_expected" ]; then
|
||||
echo " PASS: $_label"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: $_label (expected '$_expected', got '$_actual')"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
assert_contains() {
|
||||
_label="$1"; _haystack="$2"; _needle="$3"
|
||||
if echo "$_haystack" | grep -qF "$_needle"; then
|
||||
echo " PASS: $_label"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: $_label (expected to find '$_needle')"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
assert_not_contains() {
|
||||
_label="$1"; _haystack="$2"; _needle="$3"
|
||||
if echo "$_haystack" | grep -qF "$_needle"; then
|
||||
echo " FAIL: $_label (found '$_needle' but should not)"
|
||||
FAIL=$((FAIL + 1))
|
||||
else
|
||||
echo " PASS: $_label"
|
||||
PASS=$((PASS + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
# ── Extract version_ge function from install.sh ──
|
||||
_VGE_FILE=$(mktemp)
|
||||
sed -n '/^version_ge()/,/^}/p' "$INSTALL_SH" > "$_VGE_FILE"
|
||||
|
||||
echo "=== version_ge ==="
|
||||
|
||||
# Basic comparisons
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13' '3.12' && echo pass || echo fail")
|
||||
assert_eq "3.13 >= 3.12" "pass" "$_result"
|
||||
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.12' '3.13' && echo pass || echo fail")
|
||||
assert_eq "3.12 >= 3.13" "fail" "$_result"
|
||||
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13' '3.13' && echo pass || echo fail")
|
||||
assert_eq "3.13 >= 3.13 (equal)" "pass" "$_result"
|
||||
|
||||
# Patch versions
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13.8' '3.13' && echo pass || echo fail")
|
||||
assert_eq "3.13.8 >= 3.13 (patch > implicit 0)" "pass" "$_result"
|
||||
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.12.0' '3.13.0' && echo pass || echo fail")
|
||||
assert_eq "3.12.0 >= 3.13.0 (minor less)" "fail" "$_result"
|
||||
|
||||
# UV_MIN_VERSION edge cases
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.7.14' '0.7.14' && echo pass || echo fail")
|
||||
assert_eq "0.7.14 >= 0.7.14 (exact UV_MIN_VERSION)" "pass" "$_result"
|
||||
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.7.13' '0.7.14' && echo pass || echo fail")
|
||||
assert_eq "0.7.13 >= 0.7.14 (below minimum)" "fail" "$_result"
|
||||
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.11.1' '0.7.14' && echo pass || echo fail")
|
||||
assert_eq "0.11.1 >= 0.7.14 (well above)" "pass" "$_result"
|
||||
|
||||
# Major jump
|
||||
_result=$(bash -c ". '$_VGE_FILE'; version_ge '1.0' '0.99.99' && echo pass || echo fail")
|
||||
assert_eq "1.0 >= 0.99.99 (major jump)" "pass" "$_result"
|
||||
|
||||
rm -f "$_VGE_FILE"
|
||||
|
||||
echo ""
|
||||
echo "=== Architecture detection + PYTHON_VERSION ==="
|
||||
|
||||
# Self-contained arch detection snippet matching install.sh logic
|
||||
_ARCH_SNIPPET=$(mktemp)
|
||||
cat > "$_ARCH_SNIPPET" << 'SNIPPET'
|
||||
OS="linux"
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
OS="macos"
|
||||
fi
|
||||
_ARCH=$(uname -m)
|
||||
MAC_INTEL=false
|
||||
if [ "$OS" = "macos" ] && [ "$_ARCH" = "x86_64" ]; then
|
||||
MAC_INTEL=true
|
||||
fi
|
||||
_USER_PYTHON=""
|
||||
if [ -n "$_USER_PYTHON" ]; then
|
||||
PYTHON_VERSION="$_USER_PYTHON"
|
||||
elif [ "$MAC_INTEL" = true ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
else
|
||||
PYTHON_VERSION="3.13"
|
||||
fi
|
||||
echo "$OS $MAC_INTEL $PYTHON_VERSION"
|
||||
SNIPPET
|
||||
|
||||
# Test: Darwin x86_64 -> macos true 3.12
|
||||
_result=$(bash -c '
|
||||
uname() {
|
||||
case "$1" in
|
||||
-m) echo "x86_64" ;;
|
||||
*) echo "Darwin" ;;
|
||||
esac
|
||||
}
|
||||
export -f uname
|
||||
'"source '$_ARCH_SNIPPET'")
|
||||
assert_eq "Darwin x86_64 -> macos true 3.12" "macos true 3.12" "$_result"
|
||||
|
||||
# Test: Darwin arm64 -> macos false 3.13
|
||||
_result=$(bash -c '
|
||||
uname() {
|
||||
case "$1" in
|
||||
-m) echo "arm64" ;;
|
||||
*) echo "Darwin" ;;
|
||||
esac
|
||||
}
|
||||
export -f uname
|
||||
'"source '$_ARCH_SNIPPET'")
|
||||
assert_eq "Darwin arm64 -> macos false 3.13" "macos false 3.13" "$_result"
|
||||
|
||||
# Test: Linux x86_64 -> linux false 3.13
|
||||
_result=$(bash -c '
|
||||
uname() {
|
||||
case "$1" in
|
||||
-m) echo "x86_64" ;;
|
||||
*) echo "Linux" ;;
|
||||
esac
|
||||
}
|
||||
export -f uname
|
||||
'"source '$_ARCH_SNIPPET'")
|
||||
assert_eq "Linux x86_64 -> linux false 3.13" "linux false 3.13" "$_result"
|
||||
|
||||
# Test: Linux aarch64 -> linux false 3.13
|
||||
_result=$(bash -c '
|
||||
uname() {
|
||||
case "$1" in
|
||||
-m) echo "aarch64" ;;
|
||||
*) echo "Linux" ;;
|
||||
esac
|
||||
}
|
||||
export -f uname
|
||||
'"source '$_ARCH_SNIPPET'")
|
||||
assert_eq "Linux aarch64 -> linux false 3.13" "linux false 3.13" "$_result"
|
||||
|
||||
rm -f "$_ARCH_SNIPPET"
|
||||
|
||||
echo ""
|
||||
echo "=== get_torch_index_url on Darwin ==="
|
||||
|
||||
# Extract get_torch_index_url and replace hardcoded nvidia-smi path
|
||||
_FUNC_FILE=$(mktemp)
|
||||
_FAKE_SMI_DIR=$(mktemp -d)
|
||||
sed -n '/^get_torch_index_url()/,/^}/p' "$INSTALL_SH" \
|
||||
| sed "s|/usr/bin/nvidia-smi|$_FAKE_SMI_DIR/nvidia-smi-absent|g" \
|
||||
> "$_FUNC_FILE"
|
||||
|
||||
# Build a minimal tools directory
|
||||
_TOOLS_DIR=$(mktemp -d)
|
||||
for _cmd in grep sed head sh bash cat; do
|
||||
_real=$(command -v "$_cmd" 2>/dev/null || true)
|
||||
[ -n "$_real" ] && ln -sf "$_real" "$_TOOLS_DIR/$_cmd"
|
||||
done
|
||||
|
||||
# Create a mock uname that returns Darwin
|
||||
_MOCK_UNAME_DIR=$(mktemp -d)
|
||||
cat > "$_MOCK_UNAME_DIR/uname" << 'MOCK_UNAME'
|
||||
#!/bin/sh
|
||||
case "$1" in
|
||||
-s) echo "Darwin" ;;
|
||||
-m) echo "arm64" ;;
|
||||
*) echo "Darwin" ;;
|
||||
esac
|
||||
MOCK_UNAME
|
||||
chmod +x "$_MOCK_UNAME_DIR/uname"
|
||||
|
||||
# Mock nvidia-smi that returns CUDA version (to prove macOS ignores it)
|
||||
_GPU_DIR=$(mktemp -d)
|
||||
cat > "$_GPU_DIR/nvidia-smi" << 'MOCK_SMI'
|
||||
#!/bin/sh
|
||||
cat <<'SMI_OUT'
|
||||
+-----------------------------------------------------------------------------------------+
|
||||
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.6 |
|
||||
+-----------------------------------------------------------------------------------------+
|
||||
SMI_OUT
|
||||
MOCK_SMI
|
||||
chmod +x "$_GPU_DIR/nvidia-smi"
|
||||
|
||||
# Test: Darwin always returns cpu (even with nvidia-smi present)
|
||||
_result=$(PATH="$_GPU_DIR:$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null)
|
||||
assert_eq "Darwin -> cpu (even with nvidia-smi)" "https://download.pytorch.org/whl/cpu" "$_result"
|
||||
|
||||
# Test: Darwin without nvidia-smi also returns cpu
|
||||
_result=$(PATH="$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null)
|
||||
assert_eq "Darwin -> cpu (no nvidia-smi)" "https://download.pytorch.org/whl/cpu" "$_result"
|
||||
|
||||
rm -f "$_FUNC_FILE"
|
||||
rm -rf "$_FAKE_SMI_DIR" "$_TOOLS_DIR" "$_MOCK_UNAME_DIR" "$_GPU_DIR"
|
||||
|
||||
echo ""
|
||||
echo "=== UNSLOTH_NO_TORCH propagation ==="
|
||||
|
||||
# Verify UNSLOTH_NO_TORCH is passed to setup.sh in BOTH the --local and non-local branches.
|
||||
_local_count=$(grep -c 'UNSLOTH_NO_TORCH=' "$INSTALL_SH" | head -1)
|
||||
if [ "$_local_count" -ge 2 ]; then
|
||||
echo " PASS: UNSLOTH_NO_TORCH appears in >= 2 setup.sh invocations ($_local_count found)"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: UNSLOTH_NO_TORCH should appear in >= 2 setup.sh invocations (found $_local_count)"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# Verify the value passed is "$SKIP_TORCH" (the unified variable, not MAC_INTEL)
|
||||
_skip_torch_count=$(grep 'UNSLOTH_NO_TORCH="\$SKIP_TORCH"' "$INSTALL_SH" | wc -l)
|
||||
if [ "$_skip_torch_count" -ge 2 ]; then
|
||||
echo " PASS: UNSLOTH_NO_TORCH=\"\$SKIP_TORCH\" in both branches ($_skip_torch_count found)"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: UNSLOTH_NO_TORCH=\"\$SKIP_TORCH\" should appear in >= 2 branches (found $_skip_torch_count)"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# Verify MAC_INTEL is set to true when Intel Mac is detected
|
||||
_mac_intel_set=$(grep -c 'MAC_INTEL=true' "$INSTALL_SH")
|
||||
if [ "$_mac_intel_set" -ge 1 ]; then
|
||||
echo " PASS: MAC_INTEL=true is set in install.sh"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: MAC_INTEL=true not found in install.sh"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# Verify the PyTorch skip message exists (now covers both --no-torch and Intel Mac)
|
||||
if grep -q 'Skipping PyTorch' "$INSTALL_SH"; then
|
||||
echo " PASS: PyTorch skip message found"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: PyTorch skip message not found"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# Verify SKIP_TORCH unified variable exists
|
||||
if grep -q 'SKIP_TORCH=true' "$INSTALL_SH"; then
|
||||
echo " PASS: SKIP_TORCH=true assignment found"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: SKIP_TORCH=true not found in install.sh"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== E2E: venv creation at Python 3.12 (simulated Intel Mac) ==="
|
||||
|
||||
# Actually create a uv venv at Python 3.12 to verify the path works
|
||||
if command -v uv >/dev/null 2>&1; then
|
||||
_VENV_DIR=$(mktemp -d)
|
||||
_uv_result=$(uv venv "$_VENV_DIR/test_venv" --python 3.12 2>&1) && _uv_rc=0 || _uv_rc=$?
|
||||
if [ "$_uv_rc" -eq 0 ]; then
|
||||
echo " PASS: uv venv created at Python 3.12"
|
||||
PASS=$((PASS + 1))
|
||||
|
||||
# Verify Python version inside the venv
|
||||
_py_ver=$("$_VENV_DIR/test_venv/bin/python" -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
|
||||
assert_eq "venv Python is 3.12" "3.12" "$_py_ver"
|
||||
|
||||
# Verify torch is NOT available (fresh venv has no torch)
|
||||
if "$_VENV_DIR/test_venv/bin/python" -c "import torch" 2>/dev/null; then
|
||||
echo " FAIL: torch should NOT be importable in fresh 3.12 venv"
|
||||
FAIL=$((FAIL + 1))
|
||||
else
|
||||
echo " PASS: torch not importable in fresh 3.12 venv (expected for Intel Mac)"
|
||||
PASS=$((PASS + 1))
|
||||
fi
|
||||
else
|
||||
echo " SKIP: Could not create Python 3.12 venv (python 3.12 not available)"
|
||||
fi
|
||||
rm -rf "$_VENV_DIR"
|
||||
else
|
||||
echo " SKIP: uv not available, cannot test venv creation"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== E2E: torch install skipped when SKIP_TORCH=true (mock uv shim) ==="
|
||||
|
||||
# Create a mock uv that logs all calls instead of running them
|
||||
_MOCK_UV_DIR=$(mktemp -d)
|
||||
_UV_LOG="$_MOCK_UV_DIR/uv_calls.log"
|
||||
touch "$_UV_LOG"
|
||||
cat > "$_MOCK_UV_DIR/uv" << MOCK_UV_EOF
|
||||
#!/bin/sh
|
||||
echo "UV_CALL: \$*" >> "$_UV_LOG"
|
||||
MOCK_UV_EOF
|
||||
chmod +x "$_MOCK_UV_DIR/uv"
|
||||
|
||||
# Simulates the torch install decision from install.sh using SKIP_TORCH
|
||||
_TORCH_BLOCK=$(mktemp)
|
||||
cat > "$_TORCH_BLOCK" << 'TORCH_EOF'
|
||||
# Simulates the torch install decision from install.sh
|
||||
TORCH_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
_VENV_PY="/fake/python"
|
||||
if [ "$SKIP_TORCH" = true ]; then
|
||||
echo "==> Skipping PyTorch (--no-torch or Intel Mac x86_64)."
|
||||
else
|
||||
echo "==> Installing PyTorch ($TORCH_INDEX_URL)..."
|
||||
uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" torchvision torchaudio \
|
||||
--index-url "$TORCH_INDEX_URL"
|
||||
fi
|
||||
TORCH_EOF
|
||||
|
||||
# Test: SKIP_TORCH=true -> torch install should be SKIPPED (no uv calls)
|
||||
> "$_UV_LOG" # clear log
|
||||
_torch_output=$(SKIP_TORCH=true PATH="$_MOCK_UV_DIR:$PATH" bash "$_TORCH_BLOCK" 2>&1)
|
||||
assert_contains "SKIP_TORCH=true prints skip message" "$_torch_output" "Skipping PyTorch"
|
||||
if [ -s "$_UV_LOG" ]; then
|
||||
echo " FAIL: uv was called when SKIP_TORCH=true (should be skipped)"
|
||||
echo " Log: $(cat "$_UV_LOG")"
|
||||
FAIL=$((FAIL + 1))
|
||||
else
|
||||
echo " PASS: no uv pip install torch when SKIP_TORCH=true"
|
||||
PASS=$((PASS + 1))
|
||||
fi
|
||||
|
||||
# Test: SKIP_TORCH=false -> torch install should EXECUTE (uv called with torch)
|
||||
> "$_UV_LOG" # clear log
|
||||
_torch_output=$(SKIP_TORCH=false PATH="$_MOCK_UV_DIR:$PATH" bash "$_TORCH_BLOCK" 2>&1)
|
||||
assert_contains "SKIP_TORCH=false prints install message" "$_torch_output" "Installing PyTorch"
|
||||
if grep -q "torch" "$_UV_LOG"; then
|
||||
echo " PASS: uv pip install torch called when SKIP_TORCH=false"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: uv pip install torch NOT called when SKIP_TORCH=false"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
rm -f "$_TORCH_BLOCK"
|
||||
rm -rf "$_MOCK_UV_DIR"
|
||||
|
||||
echo ""
|
||||
echo "=== E2E: UNSLOTH_NO_TORCH env propagation (dynamic test) ==="
|
||||
|
||||
# Simulates the setup.sh invocation using SKIP_TORCH
|
||||
_ENV_BLOCK=$(mktemp)
|
||||
cat > "$_ENV_BLOCK" << 'ENV_EOF'
|
||||
# Simulates the setup.sh invocation block from install.sh
|
||||
PACKAGE_NAME="unsloth"
|
||||
_REPO_ROOT="/fake/repo"
|
||||
SETUP_SH="/fake/setup.sh"
|
||||
|
||||
if [ "$STUDIO_LOCAL_INSTALL" = true ]; then
|
||||
SKIP_STUDIO_BASE=1 \
|
||||
STUDIO_PACKAGE_NAME="$PACKAGE_NAME" \
|
||||
STUDIO_LOCAL_INSTALL=1 \
|
||||
STUDIO_LOCAL_REPO="$_REPO_ROOT" \
|
||||
UNSLOTH_NO_TORCH="$SKIP_TORCH" \
|
||||
env | grep "^UNSLOTH_NO_TORCH="
|
||||
else
|
||||
SKIP_STUDIO_BASE=1 \
|
||||
STUDIO_PACKAGE_NAME="$PACKAGE_NAME" \
|
||||
UNSLOTH_NO_TORCH="$SKIP_TORCH" \
|
||||
env | grep "^UNSLOTH_NO_TORCH="
|
||||
fi
|
||||
ENV_EOF
|
||||
|
||||
# Test: SKIP_TORCH=true -> UNSLOTH_NO_TORCH=true in env
|
||||
_env_result=$(SKIP_TORCH=true STUDIO_LOCAL_INSTALL=false bash "$_ENV_BLOCK" 2>&1)
|
||||
assert_eq "non-local: UNSLOTH_NO_TORCH=true when SKIP_TORCH=true" "UNSLOTH_NO_TORCH=true" "$_env_result"
|
||||
|
||||
# Test: SKIP_TORCH=false -> UNSLOTH_NO_TORCH=false in env
|
||||
_env_result=$(SKIP_TORCH=false STUDIO_LOCAL_INSTALL=false bash "$_ENV_BLOCK" 2>&1)
|
||||
assert_eq "non-local: UNSLOTH_NO_TORCH=false when SKIP_TORCH=false" "UNSLOTH_NO_TORCH=false" "$_env_result"
|
||||
|
||||
# Test: local install path also propagates
|
||||
_env_result=$(SKIP_TORCH=true STUDIO_LOCAL_INSTALL=true bash "$_ENV_BLOCK" 2>&1)
|
||||
assert_eq "local: UNSLOTH_NO_TORCH=true when SKIP_TORCH=true" "UNSLOTH_NO_TORCH=true" "$_env_result"
|
||||
|
||||
_env_result=$(SKIP_TORCH=false STUDIO_LOCAL_INSTALL=true bash "$_ENV_BLOCK" 2>&1)
|
||||
assert_eq "local: UNSLOTH_NO_TORCH=false when SKIP_TORCH=false" "UNSLOTH_NO_TORCH=false" "$_env_result"
|
||||
|
||||
rm -f "$_ENV_BLOCK"
|
||||
|
||||
echo ""
|
||||
echo "=== --python override flag ==="
|
||||
|
||||
# Test: flag parsing extracts version correctly
|
||||
_PARSE_BLOCK=$(mktemp)
|
||||
cat > "$_PARSE_BLOCK" << 'PARSE_EOF'
|
||||
_USER_PYTHON=""
|
||||
_next_is_python=false
|
||||
_next_is_package=false
|
||||
STUDIO_LOCAL_INSTALL=false
|
||||
PACKAGE_NAME="unsloth"
|
||||
for arg in "$@"; do
|
||||
if [ "$_next_is_package" = true ]; then PACKAGE_NAME="$arg"; _next_is_package=false; continue; fi
|
||||
if [ "$_next_is_python" = true ]; then _USER_PYTHON="$arg"; _next_is_python=false; continue; fi
|
||||
case "$arg" in
|
||||
--local) STUDIO_LOCAL_INSTALL=true ;;
|
||||
--package) _next_is_package=true ;;
|
||||
--python) _next_is_python=true ;;
|
||||
esac
|
||||
done
|
||||
if [ "$_next_is_python" = true ]; then echo "ERROR"; exit 1; fi
|
||||
echo "$_USER_PYTHON"
|
||||
PARSE_EOF
|
||||
|
||||
_result=$(bash "$_PARSE_BLOCK" --python 3.12)
|
||||
assert_eq "--python 3.12 parsed" "3.12" "$_result"
|
||||
|
||||
_result=$(bash "$_PARSE_BLOCK" --local --python 3.11)
|
||||
assert_eq "--local --python 3.11 parsed" "3.11" "$_result"
|
||||
|
||||
_result=$(bash "$_PARSE_BLOCK" --python 3.12 --local --package foo)
|
||||
assert_eq "--python with --local --package" "3.12" "$_result"
|
||||
|
||||
_result=$(bash "$_PARSE_BLOCK" 2>&1) # no --python
|
||||
assert_eq "no --python -> empty" "" "$_result"
|
||||
|
||||
_rc=0
|
||||
bash "$_PARSE_BLOCK" --python >/dev/null 2>&1 || _rc=$?
|
||||
assert_eq "--python without arg -> error" "1" "$_rc"
|
||||
|
||||
rm -f "$_PARSE_BLOCK"
|
||||
|
||||
# Test: --python overrides auto-detected version in PYTHON_VERSION resolution
|
||||
_RESOLVE_BLOCK=$(mktemp)
|
||||
cat > "$_RESOLVE_BLOCK" << 'RESOLVE_EOF'
|
||||
_USER_PYTHON="$1"
|
||||
MAC_INTEL="$2"
|
||||
if [ -n "$_USER_PYTHON" ]; then
|
||||
PYTHON_VERSION="$_USER_PYTHON"
|
||||
elif [ "$MAC_INTEL" = true ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
else
|
||||
PYTHON_VERSION="3.13"
|
||||
fi
|
||||
echo "$PYTHON_VERSION"
|
||||
RESOLVE_EOF
|
||||
|
||||
_result=$(bash "$_RESOLVE_BLOCK" "3.11" "true")
|
||||
assert_eq "--python 3.11 overrides Intel Mac 3.12" "3.11" "$_result"
|
||||
|
||||
_result=$(bash "$_RESOLVE_BLOCK" "3.12" "false")
|
||||
assert_eq "--python 3.12 overrides default 3.13" "3.12" "$_result"
|
||||
|
||||
_result=$(bash "$_RESOLVE_BLOCK" "" "true")
|
||||
assert_eq "no override -> Intel Mac gets 3.12" "3.12" "$_result"
|
||||
|
||||
_result=$(bash "$_RESOLVE_BLOCK" "" "false")
|
||||
assert_eq "no override -> non-Intel gets 3.13" "3.13" "$_result"
|
||||
|
||||
rm -f "$_RESOLVE_BLOCK"
|
||||
|
||||
# Test: --python flag exists in install.sh
|
||||
if grep -q '\-\-python)' "$INSTALL_SH"; then
|
||||
echo " PASS: --python case exists in install.sh"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: --python case not found in install.sh"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# Test: _USER_PYTHON guards exist for stale-venv and 3.13.8 checks
|
||||
_user_py_guards=$(grep -c '_USER_PYTHON' "$INSTALL_SH")
|
||||
if [ "$_user_py_guards" -ge 4 ]; then
|
||||
echo " PASS: _USER_PYTHON referenced >= 4 times in install.sh (flag + resolution + guards)"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: _USER_PYTHON should appear >= 4 times (found $_user_py_guards)"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== --no-torch flag parsing ==="
|
||||
|
||||
# Test: --no-torch sets _NO_TORCH_FLAG=true
|
||||
_FLAG_SNIPPET=$(mktemp)
|
||||
cat > "$_FLAG_SNIPPET" << 'SNIPPET'
|
||||
_NO_TORCH_FLAG=false
|
||||
_next_is_package=false
|
||||
STUDIO_LOCAL_INSTALL=false
|
||||
PACKAGE_NAME="unsloth"
|
||||
for arg in "$@"; do
|
||||
if [ "$_next_is_package" = true ]; then
|
||||
PACKAGE_NAME="$arg"
|
||||
_next_is_package=false
|
||||
continue
|
||||
fi
|
||||
case "$arg" in
|
||||
--local) STUDIO_LOCAL_INSTALL=true ;;
|
||||
--package) _next_is_package=true ;;
|
||||
--no-torch) _NO_TORCH_FLAG=true ;;
|
||||
esac
|
||||
done
|
||||
echo "$_NO_TORCH_FLAG"
|
||||
SNIPPET
|
||||
|
||||
_result=$(bash "$_FLAG_SNIPPET" --no-torch)
|
||||
assert_eq "--no-torch sets flag to true" "true" "$_result"
|
||||
|
||||
_result=$(bash "$_FLAG_SNIPPET")
|
||||
assert_eq "no flags -> flag is false" "false" "$_result"
|
||||
|
||||
_result=$(bash "$_FLAG_SNIPPET" --local --no-torch)
|
||||
assert_eq "--local --no-torch both work" "true" "$_result"
|
||||
|
||||
_result=$(bash "$_FLAG_SNIPPET" --no-torch --package custom-pkg)
|
||||
assert_eq "--no-torch with --package works" "true" "$_result"
|
||||
|
||||
rm -f "$_FLAG_SNIPPET"
|
||||
|
||||
echo ""
|
||||
echo "=== SKIP_TORCH unification ==="
|
||||
|
||||
# Test: SKIP_TORCH is set to true when --no-torch flag is set (even without MAC_INTEL)
|
||||
_SKIP_SNIPPET=$(mktemp)
|
||||
cat > "$_SKIP_SNIPPET" << 'SNIPPET'
|
||||
MAC_INTEL=false
|
||||
_NO_TORCH_FLAG=$1
|
||||
SKIP_TORCH=false
|
||||
if [ "$_NO_TORCH_FLAG" = true ] || [ "$MAC_INTEL" = true ]; then
|
||||
SKIP_TORCH=true
|
||||
fi
|
||||
echo "$SKIP_TORCH"
|
||||
SNIPPET
|
||||
|
||||
_result=$(bash "$_SKIP_SNIPPET" true)
|
||||
assert_eq "--no-torch flag alone sets SKIP_TORCH=true" "true" "$_result"
|
||||
|
||||
_result=$(bash "$_SKIP_SNIPPET" false)
|
||||
assert_eq "no flag, no MAC_INTEL -> SKIP_TORCH=false" "false" "$_result"
|
||||
|
||||
# Test: MAC_INTEL=true alone also sets SKIP_TORCH=true
|
||||
_SKIP_SNIPPET2=$(mktemp)
|
||||
cat > "$_SKIP_SNIPPET2" << 'SNIPPET'
|
||||
MAC_INTEL=true
|
||||
_NO_TORCH_FLAG=false
|
||||
SKIP_TORCH=false
|
||||
if [ "$_NO_TORCH_FLAG" = true ] || [ "$MAC_INTEL" = true ]; then
|
||||
SKIP_TORCH=true
|
||||
fi
|
||||
echo "$SKIP_TORCH"
|
||||
SNIPPET
|
||||
|
||||
_result=$(bash "$_SKIP_SNIPPET2")
|
||||
assert_eq "MAC_INTEL=true alone sets SKIP_TORCH=true" "true" "$_result"
|
||||
|
||||
rm -f "$_SKIP_SNIPPET" "$_SKIP_SNIPPET2"
|
||||
|
||||
echo ""
|
||||
echo "=== CPU hint printing ==="
|
||||
|
||||
# Verify the CPU hint is present in install.sh source
|
||||
if grep -q 'No NVIDIA GPU detected' "$INSTALL_SH"; then
|
||||
echo " PASS: CPU hint message found in install.sh"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: CPU hint message not found in install.sh"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
if grep -q '\-\-no-torch' "$INSTALL_SH"; then
|
||||
echo " PASS: --no-torch appears in install.sh"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: --no-torch not found in install.sh"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Results: $PASS passed, $FAIL failed"
|
||||
[ "$FAIL" -eq 0 ] || exit 1
|
||||
Loading…
Reference in a new issue