tests: add no-torch / Intel Mac test suite (#4646)

* tests: add no-torch / Intel Mac test suite

Add comprehensive test coverage for the no-torch / --no-torch installer
and Studio backend changes introduced in #4624.

Shell tests (tests/sh/test_mac_intel_compat.sh):
- version_ge edge cases (9 tests)
- Architecture detection + Python version resolution (4 tests)
- get_torch_index_url on Darwin (2 tests)
- UNSLOTH_NO_TORCH propagation via SKIP_TORCH (5 tests)
- E2E uv venv creation at Python 3.12 (3 tests)
- E2E torch skip with mock uv shim (4 tests)
- UNSLOTH_NO_TORCH env propagation (4 tests)
- --python override flag parsing + resolution (11 tests)
- --no-torch flag parsing (4 tests)
- SKIP_TORCH unification (3 tests)
- CPU hint printing (2 tests)

Python tests (tests/python/test_no_torch_filtering.py):
- _filter_requirements unit tests with synthetic + real requirements files
- NO_TORCH / IS_MACOS constant parsing
- Subprocess mock of install_python_stack() across platform configs
- install.sh --no-torch flag structural + subprocess tests

Python tests (tests/python/test_studio_import_no_torch.py):
- AST checks for data_collators.py, chat_templates.py, format_conversion.py
- Parametrized venv tests (Python 3.12 + 3.13) for no-torch exec
- Dataclass instantiation without torch
- format_conversion convert functions without torch
- Negative controls (import torch fails, torchao fails)

Python tests (tests/python/test_e2e_no_torch_sandbox.py):
- Before/after import chain tests
- Edge cases (broken torch, fake torch, lazy import)
- Hardware detection without torch
- install.sh logic tests (flag parsing, version resolution)
- install_python_stack filtering tests
- Live server startup tests (opt-in via @server marker)

* fix: address review comments on test suite

- Fix always-true assertion in test_studio_import_no_torch.py (or True)
- Make IS_MACOS test platform-aware instead of hardcoding Linux
- Restore torchvision + torchaudio in server test cleanup (not just torch)
- Include server stderr in skip message for easier debugging

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-03-27 02:33:45 -07:00 committed by GitHub
parent e9ac785346
commit 2ffc8d2cea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 3166 additions and 0 deletions

7
tests/python/conftest.py Normal file
View file

@ -0,0 +1,7 @@
"""Shared pytest configuration for tests/python/."""
def pytest_configure(config):
config.addinivalue_line(
"markers", "server: heavyweight tests requiring studio venv"
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,753 @@
"""Tests for install_python_stack NO_TORCH / IS_MACOS filtering logic.
Covers:
- _filter_requirements unit tests (synthetic + REAL requirements files)
- NO_TORCH / IS_MACOS / IS_WINDOWS env var parsing
- Subprocess-mock of install_python_stack() to verify overrides/triton/filtering
actually happen (or get skipped) under each platform/config combination
- VCS URL and environment marker edge cases in filtering
"""
from __future__ import annotations
import importlib
import os
import re
import subprocess
import sys
import textwrap
from pathlib import Path
from unittest import mock
import pytest
# Add the studio directory so we can import install_python_stack
STUDIO_DIR = Path(__file__).resolve().parents[2] / "studio"
sys.path.insert(0, str(STUDIO_DIR))
import install_python_stack as ips
# Paths to the REAL requirements files
REQ_ROOT = Path(__file__).resolve().parents[2] / "studio" / "backend" / "requirements"
EXTRAS_TXT = REQ_ROOT / "extras.txt"
EXTRAS_NO_DEPS_TXT = REQ_ROOT / "extras-no-deps.txt"
OVERRIDES_TXT = REQ_ROOT / "overrides.txt"
TRITON_KERNELS_TXT = REQ_ROOT / "triton-kernels.txt"
# ── _filter_requirements unit tests (synthetic) ───────────────────────
class TestFilterRequirements:
"""Verify _filter_requirements correctly removes packages by prefix."""
def _write_req(self, tmp_path: Path, content: str) -> Path:
req = tmp_path / "requirements.txt"
req.write_text(textwrap.dedent(content), encoding = "utf-8")
return req
def test_filters_no_torch_packages(self, tmp_path):
req = self._write_req(
tmp_path,
"""\
torch-stoi==0.1
timm>=1.0
numpy
torchcodec>=0.1
torch-c-dlpack-ext
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
# Only numpy should remain (non-blank lines)
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == ["numpy"], f"Expected only numpy, got: {non_blank}"
def test_empty_file(self, tmp_path):
req = self._write_req(tmp_path, "")
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
content = Path(result).read_text(encoding = "utf-8")
assert content.strip() == ""
def test_comments_preserved(self, tmp_path):
req = self._write_req(
tmp_path,
"""\
# torch-stoi is needed for audio
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
# Comment starts with "#", not "torch-stoi", so it's preserved
assert len(non_blank) == 2
assert non_blank[0].startswith("#")
assert non_blank[1] == "numpy"
def test_version_specifiers_filtered(self, tmp_path):
req = self._write_req(
tmp_path,
"""\
torch-stoi>=0.1.0
timm==1.2.3
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == [], f"Expected empty, got: {non_blank}"
def test_prefix_match_catches_extensions(self, tmp_path):
"""Prefix matching catches torch-stoi-extra (correct for pip names)."""
req = self._write_req(
tmp_path,
"""\
torch-stoi-extra
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == ["numpy"]
def test_mixed_case_filtered(self, tmp_path):
"""Package names are lowercased before matching."""
req = self._write_req(
tmp_path,
"""\
Timm>=1.0
TORCH-STOI
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == ["numpy"]
def test_whitespace_and_blank_lines_preserved(self, tmp_path):
req = self._write_req(
tmp_path,
"""\
numpy
pandas
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
content = Path(result).read_text(encoding = "utf-8")
# Blank lines should be preserved (not stripped)
assert "\n\n" in content or content.count("\n") >= 3
def test_stacked_windows_and_no_torch_filters(self, tmp_path):
"""Both WINDOWS_SKIP_PACKAGES and NO_TORCH_SKIP_PACKAGES applied."""
req = self._write_req(
tmp_path,
"""\
open_spiel
triton_kernels
torch-stoi
timm
numpy
""",
)
# First filter Windows packages, then NO_TORCH packages
intermediate = ips._filter_requirements(req, ips.WINDOWS_SKIP_PACKAGES)
result = ips._filter_requirements(
Path(intermediate), ips.NO_TORCH_SKIP_PACKAGES
)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == [
"numpy"
], f"Expected only numpy after stacked filters, got: {non_blank}"
def test_vcs_url_with_skip_package_name(self, tmp_path):
"""VCS URLs like git+https://...torch-stoi should also be filtered (startswith matches)."""
req = self._write_req(
tmp_path,
"""\
numpy
torch-stoi @ git+https://github.com/example/torch-stoi.git
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == [
"numpy"
], f"VCS URL line should be filtered, got: {non_blank}"
def test_env_marker_line_filtered(self, tmp_path):
"""Package lines with env markers are still filtered by prefix."""
req = self._write_req(
tmp_path,
"""\
timm>=1.0; python_version>="3.10"
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == [
"numpy"
], f"Env marker line should be filtered, got: {non_blank}"
def test_git_plus_url_not_over_matched(self, tmp_path):
"""A git+ URL whose path contains a skip package name but does NOT start with it."""
req = self._write_req(
tmp_path,
"""\
git+https://github.com/meta-pytorch/OpenEnv.git
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
# The git+ URL doesn't start with any skip package, so it is preserved
assert len(non_blank) == 2, f"git+ URL should be preserved, got: {non_blank}"
# ── Real requirements file filtering ──────────────────────────────────
class TestRealRequirementsFiltering:
"""Filter the ACTUAL extras.txt and extras-no-deps.txt with NO_TORCH_SKIP_PACKAGES."""
@pytest.fixture(autouse = True)
def _check_req_files(self):
if not EXTRAS_TXT.is_file():
pytest.skip("extras.txt not found in repo")
if not EXTRAS_NO_DEPS_TXT.is_file():
pytest.skip("extras-no-deps.txt not found in repo")
def _non_blank_non_comment(self, path: Path) -> list[str]:
"""Return non-blank, non-comment lines from a requirements file."""
lines = path.read_text(encoding = "utf-8").splitlines()
return [l.strip() for l in lines if l.strip() and not l.strip().startswith("#")]
def test_extras_txt_torch_packages_removed(self):
"""extras.txt: all NO_TORCH_SKIP_PACKAGES must be removed, everything else preserved."""
result = ips._filter_requirements(EXTRAS_TXT, ips.NO_TORCH_SKIP_PACKAGES)
filtered = self._non_blank_non_comment(Path(result))
original = self._non_blank_non_comment(EXTRAS_TXT)
# These must be gone
for pkg in ["torch-stoi", "timm", "openai-whisper", "transformers-cfg"]:
assert not any(
l.lower().startswith(pkg) for l in filtered
), f"{pkg} should be removed from extras.txt"
# Everything else must remain
expected = [
l
for l in original
if not any(
l.strip().lower().startswith(p) for p in ips.NO_TORCH_SKIP_PACKAGES
)
]
assert filtered == expected, (
f"Filtered extras.txt should match expected.\n"
f"Missing: {set(expected) - set(filtered)}\n"
f"Extra: {set(filtered) - set(expected)}"
)
def test_extras_no_deps_txt_torchcodec_and_dlpack_removed(self):
"""extras-no-deps.txt: torchcodec and torch-c-dlpack-ext must be removed."""
result = ips._filter_requirements(
EXTRAS_NO_DEPS_TXT, ips.NO_TORCH_SKIP_PACKAGES
)
filtered = self._non_blank_non_comment(Path(result))
original = self._non_blank_non_comment(EXTRAS_NO_DEPS_TXT)
for pkg in ["torchcodec", "torch-c-dlpack-ext"]:
assert not any(
l.lower().startswith(pkg) for l in filtered
), f"{pkg} should be removed from extras-no-deps.txt"
expected = [
l
for l in original
if not any(
l.strip().lower().startswith(p) for p in ips.NO_TORCH_SKIP_PACKAGES
)
]
assert filtered == expected
def test_extras_txt_most_packages_preserved(self):
"""Ensure a representative set of non-torch packages survive filtering."""
result = ips._filter_requirements(EXTRAS_TXT, ips.NO_TORCH_SKIP_PACKAGES)
filtered_text = Path(result).read_text(encoding = "utf-8").lower()
must_survive = ["scikit-learn", "loguru", "tiktoken", "einops", "tabulate"]
for pkg in must_survive:
if pkg in EXTRAS_TXT.read_text(encoding = "utf-8").lower():
assert pkg in filtered_text, f"{pkg} should survive NO_TORCH filtering"
def test_extras_no_deps_txt_trl_preserved(self):
"""trl should survive NO_TORCH filtering in extras-no-deps.txt."""
result = ips._filter_requirements(
EXTRAS_NO_DEPS_TXT, ips.NO_TORCH_SKIP_PACKAGES
)
filtered_text = Path(result).read_text(encoding = "utf-8").lower()
assert "trl" in filtered_text, "trl should survive NO_TORCH filtering"
# ── NO_TORCH constant tests ──────────────────────────────────────────
class TestNoTorchConstant:
"""Verify NO_TORCH is derived correctly from UNSLOTH_NO_TORCH env var."""
def _reimport_no_torch(self) -> bool:
return os.environ.get("UNSLOTH_NO_TORCH", "false").lower() in ("1", "true")
def test_true_lowercase(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "true"}):
assert self._reimport_no_torch() is True
def test_true_one(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "1"}):
assert self._reimport_no_torch() is True
def test_true_uppercase(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "TRUE"}):
assert self._reimport_no_torch() is True
def test_false_string(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}):
assert self._reimport_no_torch() is False
def test_false_zero(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "0"}):
assert self._reimport_no_torch() is False
def test_not_set(self):
env = os.environ.copy()
env.pop("UNSLOTH_NO_TORCH", None)
with mock.patch.dict(os.environ, env, clear = True):
assert self._reimport_no_torch() is False
def test_infer_no_torch_on_intel_mac(self):
"""_infer_no_torch falls back to platform detection when env var is unset."""
env = os.environ.copy()
env.pop("UNSLOTH_NO_TORCH", None)
with (
mock.patch.dict(os.environ, env, clear = True),
mock.patch.object(ips, "IS_MAC_INTEL", True),
):
assert ips._infer_no_torch() is True
def test_infer_no_torch_respects_explicit_false_on_intel_mac(self):
"""Explicit UNSLOTH_NO_TORCH=false overrides platform detection."""
with (
mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}),
mock.patch.object(ips, "IS_MAC_INTEL", True),
):
assert ips._infer_no_torch() is False
def test_infer_no_torch_linux_unset(self):
"""On Linux with env var unset, _infer_no_torch returns False."""
env = os.environ.copy()
env.pop("UNSLOTH_NO_TORCH", None)
with (
mock.patch.dict(os.environ, env, clear = True),
mock.patch.object(ips, "IS_MAC_INTEL", False),
):
assert ips._infer_no_torch() is False
# ── IS_MACOS constant tests ──────────────────────────────────────────
class TestIsMacosConstant:
"""Verify IS_MACOS detection logic."""
def test_is_macos_matches_platform(self):
import sys
expected = sys.platform == "darwin"
assert ips.IS_MACOS is expected
# ── Subprocess mock of install_python_stack() ─────────────────────────
class TestInstallPythonStackSubprocessMock:
"""Monkeypatch subprocess.run to capture all pip/uv commands,
then verify which requirements files are used/skipped under
different NO_TORCH / IS_MACOS / IS_WINDOWS configurations."""
@pytest.fixture(autouse = True)
def _check_req_files(self):
"""Skip if requirements files are missing."""
for f in [EXTRAS_TXT, EXTRAS_NO_DEPS_TXT, OVERRIDES_TXT]:
if not f.is_file():
pytest.skip(f"{f.name} not found in repo")
def _capture_install(
self,
no_torch: bool,
is_macos: bool,
is_windows: bool,
*,
skip_base: bool = True,
):
"""Run install_python_stack() with mocked subprocess, capturing all commands.
Returns a list of string-joined commands (each element is ' '.join(cmd)).
"""
captured_cmds: list[list[str]] = []
def mock_run(cmd, **kw):
captured_cmds.append(
list(cmd) if isinstance(cmd, (list, tuple)) else [str(cmd)]
)
return subprocess.CompletedProcess(cmd, 0, b"", b"")
env = {"SKIP_STUDIO_BASE": "1"} if skip_base else {}
with (
mock.patch.object(ips, "NO_TORCH", no_torch),
mock.patch.object(ips, "IS_MACOS", is_macos),
mock.patch.object(ips, "IS_WINDOWS", is_windows),
mock.patch.object(ips, "USE_UV", True),
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
mock.patch.object(ips, "VERBOSE", False),
mock.patch("subprocess.run", side_effect = mock_run),
mock.patch.object(ips, "_bootstrap_uv", return_value = True),
mock.patch.object(
ips, "LOCAL_DD_UNSTRUCTURED_PLUGIN", Path("/fake/plugin")
),
mock.patch("pathlib.Path.is_dir", return_value = True),
mock.patch("pathlib.Path.is_file", return_value = True),
):
with mock.patch.dict(os.environ, env, clear = False):
ips.install_python_stack()
return [" ".join(str(c) for c in cmd) for cmd in captured_cmds]
def _cmds_contain_file(self, cmds: list[str], filename: str) -> bool:
"""Check if any captured command references the given filename."""
return any(filename in cmd for cmd in cmds)
# -- NO_TORCH=True, IS_MACOS=True (Intel Mac scenario) --
def test_no_torch_macos_skips_overrides(self):
"""With NO_TORCH=True, overrides.txt pip_install must NOT be called."""
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
assert not self._cmds_contain_file(
cmds, "overrides.txt"
), "overrides.txt should be skipped when NO_TORCH=True"
def test_no_torch_macos_skips_triton(self):
"""With IS_MACOS=True, triton-kernels.txt must NOT be called."""
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
assert not self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be skipped on macOS"
def test_no_torch_macos_extras_called(self):
"""With NO_TORCH=True, extras.txt is still called (but filtered)."""
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
has_extras = self._cmds_contain_file(cmds, "extras.txt") or any(
"-r" in cmd and "tmp" in cmd.lower() for cmd in cmds
)
assert has_extras, "extras.txt (or its filtered temp) should be called"
def test_no_torch_macos_extras_no_deps_called(self):
"""With NO_TORCH=True, extras-no-deps.txt is still called (but filtered)."""
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
has_extras_nd = self._cmds_contain_file(cmds, "extras-no-deps.txt") or any(
"-r" in cmd and "tmp" in cmd.lower() for cmd in cmds
)
assert (
has_extras_nd
), "extras-no-deps.txt (or its filtered temp) should be called"
# -- IS_WINDOWS=True + NO_TORCH=True (stacked) --
def test_windows_no_torch_skips_overrides(self):
"""Windows+NO_TORCH: overrides.txt must be skipped."""
cmds = self._capture_install(no_torch = True, is_macos = False, is_windows = True)
assert not self._cmds_contain_file(
cmds, "overrides.txt"
), "overrides.txt should be skipped with NO_TORCH=True on Windows"
def test_windows_no_torch_skips_triton(self):
"""Windows: triton-kernels.txt must be skipped (IS_WINDOWS guard)."""
cmds = self._capture_install(no_torch = True, is_macos = False, is_windows = True)
assert not self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be skipped on Windows"
# -- Normal Linux path (NO_TORCH=False, IS_MACOS=False, IS_WINDOWS=False) --
def test_normal_linux_includes_overrides(self):
"""Normal Linux: overrides.txt IS called."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
assert self._cmds_contain_file(
cmds, "overrides.txt"
), "overrides.txt should be called on normal Linux"
def test_normal_linux_includes_triton(self):
"""Normal Linux: triton-kernels.txt IS called."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
assert self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be called on normal Linux"
def test_normal_linux_includes_extras(self):
"""Normal Linux: extras.txt IS called (no filtering)."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
assert self._cmds_contain_file(
cmds, "extras.txt"
), "extras.txt should be called on normal Linux"
def test_normal_linux_includes_extras_no_deps(self):
"""Normal Linux: extras-no-deps.txt IS called (no filtering)."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
assert self._cmds_contain_file(
cmds, "extras-no-deps.txt"
), "extras-no-deps.txt should be called on normal Linux"
# -- Windows-only (NO_TORCH=False) to verify triton is still skipped --
def test_windows_only_skips_triton(self):
"""Windows (without NO_TORCH): triton still skipped."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = True)
assert not self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be skipped on Windows even without NO_TORCH"
def test_windows_only_includes_overrides(self):
"""Windows (without NO_TORCH): overrides IS called (via filtered temp file).
On Windows, all req files go through _filter_requirements(WINDOWS_SKIP_PACKAGES),
so the command uses a temp file, not overrides.txt directly. We check for
--reinstall (uv translation of --force-reinstall) which is unique to overrides.
"""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = True)
assert any(
"--reinstall" in cmd for cmd in cmds
), "overrides step (--reinstall) should be called on Windows when NO_TORCH=False"
# -- Update path (skip_base=False) to verify no-torch mode is durable --
def test_update_path_intel_macos_still_skips_overrides(self):
"""Update path (no SKIP_STUDIO_BASE): overrides still skipped on Intel Mac."""
cmds = self._capture_install(
no_torch = True, is_macos = True, is_windows = False, skip_base = False
)
assert not self._cmds_contain_file(
cmds, "overrides.txt"
), "overrides.txt should be skipped on Intel Mac even via studio update"
def test_update_path_intel_macos_still_skips_triton(self):
"""Update path (no SKIP_STUDIO_BASE): triton still skipped on macOS."""
cmds = self._capture_install(
no_torch = True, is_macos = True, is_windows = False, skip_base = False
)
assert not self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be skipped on macOS even via studio update"
# ── Overrides skip structural checks ─────────────────────────────────
class TestOverridesSkip:
"""Verify overrides.txt is skipped when NO_TORCH is True (source-level check)."""
def test_no_torch_guard_exists_in_source(self):
"""The install_python_stack source must contain a NO_TORCH guard around overrides."""
source = Path(ips.__file__).read_text(encoding = "utf-8")
assert (
"if NO_TORCH:" in source
), "NO_TORCH guard not found in install_python_stack.py"
def test_overrides_skipped_when_no_torch(self):
"""With NO_TORCH=True on the module, pip_install should NOT be called for overrides."""
source = Path(ips.__file__).read_text(encoding = "utf-8")
overrides_match = re.search(r"if NO_TORCH:.*?overrides", source, re.DOTALL)
assert (
overrides_match is not None
), "Expected NO_TORCH conditional before overrides install"
# ── install.sh --no-torch flag tests ──────────────────────────────────
class TestInstallShNoTorchFlag:
"""Verify install.sh has the --no-torch flag and SKIP_TORCH variable."""
@pytest.fixture(autouse = True)
def _check_install_sh(self):
install_sh = Path(__file__).resolve().parents[2] / "install.sh"
if not install_sh.is_file():
pytest.skip("install.sh not found")
self.install_sh = install_sh
self.source = install_sh.read_text(encoding = "utf-8")
def test_no_torch_flag_in_case_statement(self):
"""--no-torch must appear in the flag parser case statement."""
assert (
"--no-torch)" in self.source
), "--no-torch not found in install.sh flag parser"
def test_no_torch_flag_variable_initialized(self):
"""_NO_TORCH_FLAG must be initialized to false."""
assert (
"_NO_TORCH_FLAG=false" in self.source
), "_NO_TORCH_FLAG=false not found in install.sh"
def test_skip_torch_variable_exists(self):
"""SKIP_TORCH variable must be defined."""
assert (
"SKIP_TORCH=false" in self.source
), "SKIP_TORCH=false not found in install.sh"
assert (
"SKIP_TORCH=true" in self.source
), "SKIP_TORCH=true not found in install.sh"
def test_skip_torch_driven_by_flag_and_mac_intel(self):
"""SKIP_TORCH must check both _NO_TORCH_FLAG and MAC_INTEL."""
assert (
"_NO_TORCH_FLAG" in self.source
), "_NO_TORCH_FLAG not referenced in SKIP_TORCH logic"
assert (
"MAC_INTEL" in self.source
), "MAC_INTEL not referenced in SKIP_TORCH logic"
def test_unsloth_no_torch_uses_skip_torch(self):
"""UNSLOTH_NO_TORCH must reference $SKIP_TORCH, not $MAC_INTEL."""
import re
matches = re.findall(r'UNSLOTH_NO_TORCH="\$(\w+)"', self.source)
for var in matches:
assert (
var == "SKIP_TORCH"
), f"UNSLOTH_NO_TORCH references ${var} instead of $SKIP_TORCH"
def test_cpu_hint_message_exists(self):
"""CPU hint message must exist in install.sh."""
assert (
"No NVIDIA GPU detected" in self.source
), "CPU hint message not found in install.sh"
assert (
"--no-torch" in self.source
), "--no-torch suggestion not found in CPU hint"
def test_no_torch_flag_parsing_subprocess(self):
"""--no-torch flag sets _NO_TORCH_FLAG=true (subprocess test)."""
script = textwrap.dedent("""\
_NO_TORCH_FLAG=false
_next_is_package=false
STUDIO_LOCAL_INSTALL=false
PACKAGE_NAME="unsloth"
for arg in "$@"; do
if [ "$_next_is_package" = true ]; then
PACKAGE_NAME="$arg"
_next_is_package=false
continue
fi
case "$arg" in
--local) STUDIO_LOCAL_INSTALL=true ;;
--package) _next_is_package=true ;;
--no-torch) _NO_TORCH_FLAG=true ;;
esac
done
echo "$_NO_TORCH_FLAG"
""")
result = subprocess.run(
["bash", "-c", script, "_", "--no-torch"],
capture_output = True,
text = True,
)
assert (
result.stdout.strip() == "true"
), f"Expected _NO_TORCH_FLAG=true, got: {result.stdout.strip()}"
def test_no_torch_with_local_flag(self):
"""--no-torch and --local can be used together."""
script = textwrap.dedent("""\
_NO_TORCH_FLAG=false
_next_is_package=false
STUDIO_LOCAL_INSTALL=false
PACKAGE_NAME="unsloth"
for arg in "$@"; do
if [ "$_next_is_package" = true ]; then
PACKAGE_NAME="$arg"
_next_is_package=false
continue
fi
case "$arg" in
--local) STUDIO_LOCAL_INSTALL=true ;;
--package) _next_is_package=true ;;
--no-torch) _NO_TORCH_FLAG=true ;;
esac
done
echo "$_NO_TORCH_FLAG $STUDIO_LOCAL_INSTALL"
""")
result = subprocess.run(
["bash", "-c", script, "_", "--local", "--no-torch"],
capture_output = True,
text = True,
)
assert (
result.stdout.strip() == "true true"
), f"Expected 'true true', got: {result.stdout.strip()}"
def test_cpu_hint_only_when_not_skip_torch(self):
"""CPU hint should only print when SKIP_TORCH=false and OS!=macos."""
script = textwrap.dedent("""\
TORCH_INDEX_URL="https://download.pytorch.org/whl/cpu"
SKIP_TORCH=false
OS="linux"
case "$TORCH_INDEX_URL" in
*/cpu)
if [ "$SKIP_TORCH" = false ] && [ "$OS" != "macos" ]; then
echo "HINT_PRINTED"
fi
;;
esac
""")
result = subprocess.run(
["bash", "-c", script],
capture_output = True,
text = True,
)
assert "HINT_PRINTED" in result.stdout, "CPU hint should print"
# With SKIP_TORCH=true, hint should NOT print
script2 = script.replace("SKIP_TORCH=false", "SKIP_TORCH=true")
result2 = subprocess.run(
["bash", "-c", script2],
capture_output = True,
text = True,
)
assert (
"HINT_PRINTED" not in result2.stdout
), "CPU hint should NOT print when SKIP_TORCH=true"
# ── Triton macOS skip structural checks ──────────────────────────────
class TestTritonMacosSkip:
"""Verify triton is skipped on macOS (source-level check)."""
def test_triton_guard_in_source(self):
"""Source must skip triton on both Windows and macOS."""
source = Path(ips.__file__).read_text(encoding = "utf-8")
assert (
"not IS_MACOS" in source
), "IS_MACOS guard for triton not found in install_python_stack.py"
assert (
"not IS_WINDOWS and not IS_MACOS" in source
), "Expected 'not IS_WINDOWS and not IS_MACOS' guard for triton"

View file

@ -0,0 +1,582 @@
"""End-to-end sandbox tests: Studio modules in isolated no-torch venvs.
Covers:
- Python 3.12 and 3.13 venv creation (Intel Mac uses 3.12, Apple Silicon/Linux 3.13)
- data_collators.py loads and dataclasses instantiate without torch
- chat_templates.py top-level exec works with stubs for relative imports
- Negative control: prepending 'import torch' fails in no-torch venv
- Negative control: installing torchao (from overrides.txt) fails in no-torch venv
- AST structural checks for top-level torch imports
"""
from __future__ import annotations
import ast
import os
import shutil
import subprocess
import sys
import tempfile
import textwrap
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
DATA_COLLATORS = (
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "data_collators.py"
)
CHAT_TEMPLATES = (
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "chat_templates.py"
)
FORMAT_CONVERSION = (
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "format_conversion.py"
)
def _has_uv() -> bool:
return shutil.which("uv") is not None
def _create_venv(venv_dir: Path, python_version: str) -> Path | None:
"""Create a uv venv at the given Python version. Returns python path or None."""
result = subprocess.run(
["uv", "venv", str(venv_dir), "--python", python_version],
capture_output = True,
)
if result.returncode != 0:
return None
venv_python = venv_dir / "bin" / "python"
if not venv_python.exists():
venv_python = venv_dir / "Scripts" / "python.exe"
return venv_python if venv_python.exists() else None
@pytest.fixture(params = ["3.12", "3.13"], scope = "module")
def no_torch_venv(request, tmp_path_factory):
"""Create a temporary venv at the requested Python version with no torch.
Parametrized for 3.12 (Intel Mac) and 3.13 (Apple Silicon / Linux).
"""
if not _has_uv():
pytest.skip("uv not available")
py_version = request.param
venv_dir = tmp_path_factory.mktemp(f"no_torch_venv_{py_version}")
venv_python = _create_venv(venv_dir, py_version)
if venv_python is None:
pytest.skip(f"Could not create Python {py_version} venv")
# Verify torch is NOT importable
check = subprocess.run(
[str(venv_python), "-c", "import torch"],
capture_output = True,
)
assert (
check.returncode != 0
), f"torch should NOT be importable in fresh {py_version} venv"
return str(venv_python)
# ── AST structural checks ─────────────────────────────────────────────
class TestDataCollatorsAST:
"""Static analysis: data_collators.py has no top-level torch imports."""
def test_ast_parse(self):
"""data_collators.py must be valid Python syntax."""
source = DATA_COLLATORS.read_text(encoding = "utf-8")
tree = ast.parse(source, filename = str(DATA_COLLATORS))
assert tree is not None
def test_no_top_level_torch_import(self):
"""No top-level 'import torch' or 'from torch' statements."""
source = DATA_COLLATORS.read_text(encoding = "utf-8")
tree = ast.parse(source)
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.Import):
for alias in node.names:
assert not alias.name.startswith(
"torch"
), f"Top-level 'import {alias.name}' found at line {node.lineno}"
elif isinstance(node, ast.ImportFrom):
if node.module:
assert not node.module.startswith(
"torch"
), f"Top-level 'from {node.module}' found at line {node.lineno}"
class TestChatTemplatesAST:
"""Static analysis: chat_templates.py has no top-level torch imports."""
def test_ast_parse(self):
"""chat_templates.py must be valid Python syntax."""
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
tree = ast.parse(source, filename = str(CHAT_TEMPLATES))
assert tree is not None
def test_no_top_level_torch_import(self):
"""No top-level 'import torch' or 'from torch' at module level."""
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
tree = ast.parse(source)
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.Import):
for alias in node.names:
assert not alias.name.startswith(
"torch"
), f"Top-level 'import {alias.name}' found at line {node.lineno}"
elif isinstance(node, ast.ImportFrom):
if node.module:
assert not node.module.startswith(
"torch"
), f"Top-level 'from {node.module}' found at line {node.lineno}"
def test_torch_imports_only_inside_functions(self):
"""All 'from torch' imports must be inside function/method bodies."""
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
tree = ast.parse(source)
torch_imports = []
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
module = None
if isinstance(node, ast.ImportFrom):
module = node.module
elif isinstance(node, ast.Import):
module = node.names[0].name if node.names else None
if module and module.startswith("torch"):
torch_imports.append(node)
top_level = set(id(n) for n in ast.iter_child_nodes(tree))
for imp in torch_imports:
assert id(imp) not in top_level, (
f"torch import at line {imp.lineno} is at top level"
" (should be inside a function)"
)
# ── data_collators.py: exec + dataclass instantiation in no-torch venv ──
class TestDataCollatorsNoTorchVenv:
"""Run data_collators.py in an isolated no-torch venv, verify classes load."""
def test_exec_in_no_torch_venv(self, no_torch_venv):
"""data_collators.py executes in a venv without torch (with loggers stub)."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({str(DATA_COLLATORS)!r}).read())
print("OK: exec succeeded")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"data_collators.py failed in no-torch venv:\n{result.stderr.decode()}"
assert b"OK: exec succeeded" in result.stdout
def test_dataclass_speech_collator_instantiable(self, no_torch_venv):
"""DataCollatorSpeechSeq2SeqWithPadding can be instantiated with processor=None."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({str(DATA_COLLATORS)!r}).read())
obj = DataCollatorSpeechSeq2SeqWithPadding(processor=None)
assert obj.processor is None, "processor should be None"
print("OK: DataCollatorSpeechSeq2SeqWithPadding instantiated")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"DataCollatorSpeechSeq2SeqWithPadding failed:\n{result.stderr.decode()}"
assert b"OK: DataCollatorSpeechSeq2SeqWithPadding instantiated" in result.stdout
def test_dataclass_deepseek_collator_instantiable(self, no_torch_venv):
"""DeepSeekOCRDataCollator can be instantiated with processor=None."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({str(DATA_COLLATORS)!r}).read())
obj = DeepSeekOCRDataCollator(processor=None)
assert obj.processor is None, "processor should be None"
assert obj.max_length == 2048, "default max_length should be 2048"
assert obj.ignore_index == -100, "default ignore_index should be -100"
print("OK: DeepSeekOCRDataCollator instantiated")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"DeepSeekOCRDataCollator failed:\n{result.stderr.decode()}"
assert b"OK: DeepSeekOCRDataCollator instantiated" in result.stdout
def test_dataclass_vlm_collator_instantiable(self, no_torch_venv):
"""VLMDataCollator can be instantiated with processor=None."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({str(DATA_COLLATORS)!r}).read())
obj = VLMDataCollator(processor=None)
assert obj.processor is None
assert obj.mask_input_tokens is True, "default mask_input_tokens should be True"
print("OK: VLMDataCollator instantiated")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"VLMDataCollator failed:\n{result.stderr.decode()}"
assert b"OK: VLMDataCollator instantiated" in result.stdout
# ── chat_templates.py: exec in no-torch venv ─────────────────────────
class TestChatTemplatesNoTorchVenv:
"""Run chat_templates.py in an isolated no-torch venv with stubs."""
def test_exec_with_stubs(self, no_torch_venv):
"""chat_templates.py top-level exec works with stubs for relative imports."""
code = textwrap.dedent(f"""\
import sys, types
# Stub loggers
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})()
sys.modules['loggers'] = loggers
# Stub relative imports (.format_detection, .model_mappings)
format_detection = types.ModuleType('format_detection')
format_detection.detect_dataset_format = lambda *a, **k: None
format_detection.detect_multimodal_dataset = lambda *a, **k: None
format_detection.detect_custom_format_heuristic = lambda *a, **k: None
sys.modules['format_detection'] = format_detection
model_mappings = types.ModuleType('model_mappings')
model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}}
sys.modules['model_mappings'] = model_mappings
# Read and transform the source: replace relative imports with absolute
source = open({str(CHAT_TEMPLATES)!r}).read()
source = source.replace('from .format_detection import', 'from format_detection import')
source = source.replace('from .model_mappings import', 'from model_mappings import')
exec(source)
# Verify module-level constants are defined
ns = dict(locals())
assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined after exec"
print("OK: chat_templates.py exec succeeded")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"chat_templates.py failed in no-torch venv:\n{result.stderr.decode()}"
assert b"OK: chat_templates.py exec succeeded" in result.stdout
def test_default_alpaca_template_defined(self, no_torch_venv):
"""DEFAULT_ALPACA_TEMPLATE constant is accessible after exec."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})()
sys.modules['loggers'] = loggers
format_detection = types.ModuleType('format_detection')
format_detection.detect_dataset_format = lambda *a, **k: None
format_detection.detect_multimodal_dataset = lambda *a, **k: None
format_detection.detect_custom_format_heuristic = lambda *a, **k: None
sys.modules['format_detection'] = format_detection
model_mappings = types.ModuleType('model_mappings')
model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}}
sys.modules['model_mappings'] = model_mappings
ns = {{}}
source = open({str(CHAT_TEMPLATES)!r}).read()
source = source.replace('from .format_detection import', 'from format_detection import')
source = source.replace('from .model_mappings import', 'from model_mappings import')
exec(source, ns)
assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined"
assert 'Instruction' in ns['DEFAULT_ALPACA_TEMPLATE'], "Template content unexpected"
print("OK: DEFAULT_ALPACA_TEMPLATE defined and valid")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"DEFAULT_ALPACA_TEMPLATE check failed:\n{result.stderr.decode()}"
assert b"OK: DEFAULT_ALPACA_TEMPLATE defined and valid" in result.stdout
# ── format_conversion.py: AST + runtime tests ────────────────────────
class TestFormatConversionAST:
"""Static analysis: format_conversion.py torch imports are guarded."""
def test_ast_parse(self):
"""format_conversion.py must be valid Python syntax."""
source = FORMAT_CONVERSION.read_text(encoding = "utf-8")
tree = ast.parse(source, filename = str(FORMAT_CONVERSION))
assert tree is not None
def test_no_bare_torch_import_in_functions(self):
"""All 'from torch' imports in function bodies must be inside try/except."""
source = FORMAT_CONVERSION.read_text(encoding = "utf-8")
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
for child in ast.walk(node):
if (
isinstance(child, ast.ImportFrom)
and child.module
and child.module.startswith("torch")
):
# This torch import must be inside a Try node
found_in_try = False
for try_node in ast.walk(node):
if isinstance(try_node, ast.Try):
for try_child in ast.walk(try_node):
if try_child is child:
found_in_try = True
break
if found_in_try:
break
assert found_in_try, (
f"torch import at line {child.lineno} in {node.name}() "
"is not inside a try/except block"
)
class TestFormatConversionNoTorchVenv:
"""Run format_conversion.py functions in a no-torch venv."""
def test_convert_chatml_to_alpaca_no_torch(self, no_torch_venv):
"""convert_chatml_to_alpaca works without torch (via try/except ImportError)."""
code = textwrap.dedent(f"""\
import sys, types
# Stub loggers
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: type('L', (), {{
'info': lambda s, m: None,
'warning': lambda s, m: None,
'debug': lambda s, m: None,
}})()
sys.modules['loggers'] = loggers
# Stub datasets.IterableDataset (HF datasets, not torch)
datasets_mod = types.ModuleType('datasets')
datasets_mod.IterableDataset = type('IterableDataset', (), {{}})
sys.modules['datasets'] = datasets_mod
# Stub utils.hardware
utils_mod = types.ModuleType('utils')
hardware_mod = types.ModuleType('utils.hardware')
hardware_mod.dataset_map_num_proc = lambda n=None: 1
utils_mod.hardware = hardware_mod
sys.modules['utils'] = utils_mod
sys.modules['utils.hardware'] = hardware_mod
# Read and exec format_conversion.py
source = open({str(FORMAT_CONVERSION)!r}).read()
source = source.replace('from .format_detection import', 'from format_detection import')
ns = {{'__name__': '__test__'}}
exec(source, ns)
# Test convert_chatml_to_alpaca with a simple dataset
class FakeDataset:
def map(self, fn, **kw):
result = fn({{
'messages': [[
{{'role': 'user', 'content': 'Hello'}},
{{'role': 'assistant', 'content': 'Hi there'}},
]]
}})
return result
result = ns['convert_chatml_to_alpaca'](FakeDataset())
assert 'instruction' in result, f"Expected 'instruction' in result, got {{result.keys()}}"
assert result['instruction'] == ['Hello']
assert result['output'] == ['Hi there']
print("OK: convert_chatml_to_alpaca works without torch")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"convert_chatml_to_alpaca failed without torch:\n{result.stderr.decode()}"
assert b"OK: convert_chatml_to_alpaca works without torch" in result.stdout
def test_convert_alpaca_to_chatml_no_torch(self, no_torch_venv):
"""convert_alpaca_to_chatml works without torch (via try/except ImportError)."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: type('L', (), {{
'info': lambda s, m: None,
'warning': lambda s, m: None,
'debug': lambda s, m: None,
}})()
sys.modules['loggers'] = loggers
datasets_mod = types.ModuleType('datasets')
datasets_mod.IterableDataset = type('IterableDataset', (), {{}})
sys.modules['datasets'] = datasets_mod
utils_mod = types.ModuleType('utils')
hardware_mod = types.ModuleType('utils.hardware')
hardware_mod.dataset_map_num_proc = lambda n=None: 1
utils_mod.hardware = hardware_mod
sys.modules['utils'] = utils_mod
sys.modules['utils.hardware'] = hardware_mod
source = open({str(FORMAT_CONVERSION)!r}).read()
source = source.replace('from .format_detection import', 'from format_detection import')
ns = {{'__name__': '__test__'}}
exec(source, ns)
class FakeDataset:
def map(self, fn, **kw):
result = fn({{
'instruction': ['Write a poem'],
'input': [''],
'output': ['Roses are red'],
}})
return result
result = ns['convert_alpaca_to_chatml'](FakeDataset())
assert 'conversations' in result
convo = result['conversations'][0]
assert convo[0]['role'] == 'user'
assert convo[1]['role'] == 'assistant'
print("OK: convert_alpaca_to_chatml works without torch")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"convert_alpaca_to_chatml failed without torch:\n{result.stderr.decode()}"
assert b"OK: convert_alpaca_to_chatml works without torch" in result.stdout
# ── Negative controls ─────────────────────────────────────────────────
class TestNegativeControls:
"""Prove the fix is necessary by showing what fails WITHOUT it."""
def test_import_torch_prepended_fails(self, no_torch_venv):
"""Prepending 'import torch' to data_collators.py causes ModuleNotFoundError."""
with tempfile.NamedTemporaryFile(
mode = "w", suffix = ".py", delete = False, encoding = "utf-8"
) as f:
f.write("import torch\n")
f.write(DATA_COLLATORS.read_text(encoding = "utf-8"))
temp_file = f.name
try:
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({temp_file!r}).read())
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode != 0
), "Expected failure when 'import torch' is prepended"
assert (
b"ModuleNotFoundError" in result.stderr
or b"ImportError" in result.stderr
), f"Expected ImportError, got:\n{result.stderr.decode()}"
finally:
os.unlink(temp_file)
def test_torchao_install_fails_no_torch_venv(self, no_torch_venv):
"""Installing torchao (from overrides.txt) fails in a no-torch venv.
This proves the overrides.txt skip is necessary for Intel Mac.
"""
result = subprocess.run(
[
no_torch_venv,
"-m",
"pip",
"install",
"torchao==0.14.0",
"--dry-run",
],
capture_output = True,
timeout = 60,
)
if result.returncode != 0:
# torchao install/resolution failed as expected
pass
else:
# pip dry-run may not catch dependency issues; verify torch is missing
check = subprocess.run(
[no_torch_venv, "-c", "import torch"],
capture_output = True,
)
assert (
check.returncode != 0
), "torch should not be importable -- torchao would fail at runtime"
def test_direct_torch_import_fails(self, no_torch_venv):
"""Direct 'import torch' fails in the no-torch venv."""
result = subprocess.run(
[no_torch_venv, "-c", "import torch; print('torch loaded')"],
capture_output = True,
timeout = 30,
)
assert result.returncode != 0, "import torch should fail in no-torch venv"
assert (
b"ModuleNotFoundError" in result.stderr or b"ImportError" in result.stderr
)

View file

@ -6,11 +6,14 @@ TESTS_DIR="$(cd "$(dirname "$0")" && pwd)"
echo "=== Bash tests ==="
sh "$TESTS_DIR/sh/test_get_torch_index_url.sh"
sh "$TESTS_DIR/sh/test_mac_intel_compat.sh"
echo ""
echo "=== Python tests ==="
python -m pytest "$TESTS_DIR/python/test_install_python_stack.py" -v
python -m pytest "$TESTS_DIR/python/test_cross_platform_parity.py" -v
python -m pytest "$TESTS_DIR/python/test_no_torch_filtering.py" -v
python -m pytest "$TESTS_DIR/python/test_studio_import_no_torch.py" -v
echo ""
echo "All tests passed."

View file

@ -0,0 +1,582 @@
#!/bin/bash
# End-to-end sandbox tests for Mac Intel compatibility and UNSLOTH_NO_TORCH propagation.
# Tests version_ge, arch detection (existing), plus E2E venv creation, torch skip
# via a mock uv shim, and UNSLOTH_NO_TORCH env propagation in install.sh.
set -e
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
INSTALL_SH="$SCRIPT_DIR/../../install.sh"
PASS=0
FAIL=0
assert_eq() {
_label="$1"; _expected="$2"; _actual="$3"
if [ "$_actual" = "$_expected" ]; then
echo " PASS: $_label"
PASS=$((PASS + 1))
else
echo " FAIL: $_label (expected '$_expected', got '$_actual')"
FAIL=$((FAIL + 1))
fi
}
assert_contains() {
_label="$1"; _haystack="$2"; _needle="$3"
if echo "$_haystack" | grep -qF "$_needle"; then
echo " PASS: $_label"
PASS=$((PASS + 1))
else
echo " FAIL: $_label (expected to find '$_needle')"
FAIL=$((FAIL + 1))
fi
}
assert_not_contains() {
_label="$1"; _haystack="$2"; _needle="$3"
if echo "$_haystack" | grep -qF "$_needle"; then
echo " FAIL: $_label (found '$_needle' but should not)"
FAIL=$((FAIL + 1))
else
echo " PASS: $_label"
PASS=$((PASS + 1))
fi
}
# ── Extract version_ge function from install.sh ──
_VGE_FILE=$(mktemp)
sed -n '/^version_ge()/,/^}/p' "$INSTALL_SH" > "$_VGE_FILE"
echo "=== version_ge ==="
# Basic comparisons
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13' '3.12' && echo pass || echo fail")
assert_eq "3.13 >= 3.12" "pass" "$_result"
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.12' '3.13' && echo pass || echo fail")
assert_eq "3.12 >= 3.13" "fail" "$_result"
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13' '3.13' && echo pass || echo fail")
assert_eq "3.13 >= 3.13 (equal)" "pass" "$_result"
# Patch versions
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13.8' '3.13' && echo pass || echo fail")
assert_eq "3.13.8 >= 3.13 (patch > implicit 0)" "pass" "$_result"
_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.12.0' '3.13.0' && echo pass || echo fail")
assert_eq "3.12.0 >= 3.13.0 (minor less)" "fail" "$_result"
# UV_MIN_VERSION edge cases
_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.7.14' '0.7.14' && echo pass || echo fail")
assert_eq "0.7.14 >= 0.7.14 (exact UV_MIN_VERSION)" "pass" "$_result"
_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.7.13' '0.7.14' && echo pass || echo fail")
assert_eq "0.7.13 >= 0.7.14 (below minimum)" "fail" "$_result"
_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.11.1' '0.7.14' && echo pass || echo fail")
assert_eq "0.11.1 >= 0.7.14 (well above)" "pass" "$_result"
# Major jump
_result=$(bash -c ". '$_VGE_FILE'; version_ge '1.0' '0.99.99' && echo pass || echo fail")
assert_eq "1.0 >= 0.99.99 (major jump)" "pass" "$_result"
rm -f "$_VGE_FILE"
echo ""
echo "=== Architecture detection + PYTHON_VERSION ==="
# Self-contained arch detection snippet matching install.sh logic
_ARCH_SNIPPET=$(mktemp)
cat > "$_ARCH_SNIPPET" << 'SNIPPET'
OS="linux"
if [ "$(uname)" = "Darwin" ]; then
OS="macos"
fi
_ARCH=$(uname -m)
MAC_INTEL=false
if [ "$OS" = "macos" ] && [ "$_ARCH" = "x86_64" ]; then
MAC_INTEL=true
fi
_USER_PYTHON=""
if [ -n "$_USER_PYTHON" ]; then
PYTHON_VERSION="$_USER_PYTHON"
elif [ "$MAC_INTEL" = true ]; then
PYTHON_VERSION="3.12"
else
PYTHON_VERSION="3.13"
fi
echo "$OS $MAC_INTEL $PYTHON_VERSION"
SNIPPET
# Test: Darwin x86_64 -> macos true 3.12
_result=$(bash -c '
uname() {
case "$1" in
-m) echo "x86_64" ;;
*) echo "Darwin" ;;
esac
}
export -f uname
'"source '$_ARCH_SNIPPET'")
assert_eq "Darwin x86_64 -> macos true 3.12" "macos true 3.12" "$_result"
# Test: Darwin arm64 -> macos false 3.13
_result=$(bash -c '
uname() {
case "$1" in
-m) echo "arm64" ;;
*) echo "Darwin" ;;
esac
}
export -f uname
'"source '$_ARCH_SNIPPET'")
assert_eq "Darwin arm64 -> macos false 3.13" "macos false 3.13" "$_result"
# Test: Linux x86_64 -> linux false 3.13
_result=$(bash -c '
uname() {
case "$1" in
-m) echo "x86_64" ;;
*) echo "Linux" ;;
esac
}
export -f uname
'"source '$_ARCH_SNIPPET'")
assert_eq "Linux x86_64 -> linux false 3.13" "linux false 3.13" "$_result"
# Test: Linux aarch64 -> linux false 3.13
_result=$(bash -c '
uname() {
case "$1" in
-m) echo "aarch64" ;;
*) echo "Linux" ;;
esac
}
export -f uname
'"source '$_ARCH_SNIPPET'")
assert_eq "Linux aarch64 -> linux false 3.13" "linux false 3.13" "$_result"
rm -f "$_ARCH_SNIPPET"
echo ""
echo "=== get_torch_index_url on Darwin ==="
# Extract get_torch_index_url and replace hardcoded nvidia-smi path
_FUNC_FILE=$(mktemp)
_FAKE_SMI_DIR=$(mktemp -d)
sed -n '/^get_torch_index_url()/,/^}/p' "$INSTALL_SH" \
| sed "s|/usr/bin/nvidia-smi|$_FAKE_SMI_DIR/nvidia-smi-absent|g" \
> "$_FUNC_FILE"
# Build a minimal tools directory
_TOOLS_DIR=$(mktemp -d)
for _cmd in grep sed head sh bash cat; do
_real=$(command -v "$_cmd" 2>/dev/null || true)
[ -n "$_real" ] && ln -sf "$_real" "$_TOOLS_DIR/$_cmd"
done
# Create a mock uname that returns Darwin
_MOCK_UNAME_DIR=$(mktemp -d)
cat > "$_MOCK_UNAME_DIR/uname" << 'MOCK_UNAME'
#!/bin/sh
case "$1" in
-s) echo "Darwin" ;;
-m) echo "arm64" ;;
*) echo "Darwin" ;;
esac
MOCK_UNAME
chmod +x "$_MOCK_UNAME_DIR/uname"
# Mock nvidia-smi that returns CUDA version (to prove macOS ignores it)
_GPU_DIR=$(mktemp -d)
cat > "$_GPU_DIR/nvidia-smi" << 'MOCK_SMI'
#!/bin/sh
cat <<'SMI_OUT'
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.6 |
+-----------------------------------------------------------------------------------------+
SMI_OUT
MOCK_SMI
chmod +x "$_GPU_DIR/nvidia-smi"
# Test: Darwin always returns cpu (even with nvidia-smi present)
_result=$(PATH="$_GPU_DIR:$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null)
assert_eq "Darwin -> cpu (even with nvidia-smi)" "https://download.pytorch.org/whl/cpu" "$_result"
# Test: Darwin without nvidia-smi also returns cpu
_result=$(PATH="$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null)
assert_eq "Darwin -> cpu (no nvidia-smi)" "https://download.pytorch.org/whl/cpu" "$_result"
rm -f "$_FUNC_FILE"
rm -rf "$_FAKE_SMI_DIR" "$_TOOLS_DIR" "$_MOCK_UNAME_DIR" "$_GPU_DIR"
echo ""
echo "=== UNSLOTH_NO_TORCH propagation ==="
# Verify UNSLOTH_NO_TORCH is passed to setup.sh in BOTH the --local and non-local branches.
_local_count=$(grep -c 'UNSLOTH_NO_TORCH=' "$INSTALL_SH" | head -1)
if [ "$_local_count" -ge 2 ]; then
echo " PASS: UNSLOTH_NO_TORCH appears in >= 2 setup.sh invocations ($_local_count found)"
PASS=$((PASS + 1))
else
echo " FAIL: UNSLOTH_NO_TORCH should appear in >= 2 setup.sh invocations (found $_local_count)"
FAIL=$((FAIL + 1))
fi
# Verify the value passed is "$SKIP_TORCH" (the unified variable, not MAC_INTEL)
_skip_torch_count=$(grep 'UNSLOTH_NO_TORCH="\$SKIP_TORCH"' "$INSTALL_SH" | wc -l)
if [ "$_skip_torch_count" -ge 2 ]; then
echo " PASS: UNSLOTH_NO_TORCH=\"\$SKIP_TORCH\" in both branches ($_skip_torch_count found)"
PASS=$((PASS + 1))
else
echo " FAIL: UNSLOTH_NO_TORCH=\"\$SKIP_TORCH\" should appear in >= 2 branches (found $_skip_torch_count)"
FAIL=$((FAIL + 1))
fi
# Verify MAC_INTEL is set to true when Intel Mac is detected
_mac_intel_set=$(grep -c 'MAC_INTEL=true' "$INSTALL_SH")
if [ "$_mac_intel_set" -ge 1 ]; then
echo " PASS: MAC_INTEL=true is set in install.sh"
PASS=$((PASS + 1))
else
echo " FAIL: MAC_INTEL=true not found in install.sh"
FAIL=$((FAIL + 1))
fi
# Verify the PyTorch skip message exists (now covers both --no-torch and Intel Mac)
if grep -q 'Skipping PyTorch' "$INSTALL_SH"; then
echo " PASS: PyTorch skip message found"
PASS=$((PASS + 1))
else
echo " FAIL: PyTorch skip message not found"
FAIL=$((FAIL + 1))
fi
# Verify SKIP_TORCH unified variable exists
if grep -q 'SKIP_TORCH=true' "$INSTALL_SH"; then
echo " PASS: SKIP_TORCH=true assignment found"
PASS=$((PASS + 1))
else
echo " FAIL: SKIP_TORCH=true not found in install.sh"
FAIL=$((FAIL + 1))
fi
echo ""
echo "=== E2E: venv creation at Python 3.12 (simulated Intel Mac) ==="
# Actually create a uv venv at Python 3.12 to verify the path works
if command -v uv >/dev/null 2>&1; then
_VENV_DIR=$(mktemp -d)
_uv_result=$(uv venv "$_VENV_DIR/test_venv" --python 3.12 2>&1) && _uv_rc=0 || _uv_rc=$?
if [ "$_uv_rc" -eq 0 ]; then
echo " PASS: uv venv created at Python 3.12"
PASS=$((PASS + 1))
# Verify Python version inside the venv
_py_ver=$("$_VENV_DIR/test_venv/bin/python" -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
assert_eq "venv Python is 3.12" "3.12" "$_py_ver"
# Verify torch is NOT available (fresh venv has no torch)
if "$_VENV_DIR/test_venv/bin/python" -c "import torch" 2>/dev/null; then
echo " FAIL: torch should NOT be importable in fresh 3.12 venv"
FAIL=$((FAIL + 1))
else
echo " PASS: torch not importable in fresh 3.12 venv (expected for Intel Mac)"
PASS=$((PASS + 1))
fi
else
echo " SKIP: Could not create Python 3.12 venv (python 3.12 not available)"
fi
rm -rf "$_VENV_DIR"
else
echo " SKIP: uv not available, cannot test venv creation"
fi
echo ""
echo "=== E2E: torch install skipped when SKIP_TORCH=true (mock uv shim) ==="
# Create a mock uv that logs all calls instead of running them
_MOCK_UV_DIR=$(mktemp -d)
_UV_LOG="$_MOCK_UV_DIR/uv_calls.log"
touch "$_UV_LOG"
cat > "$_MOCK_UV_DIR/uv" << MOCK_UV_EOF
#!/bin/sh
echo "UV_CALL: \$*" >> "$_UV_LOG"
MOCK_UV_EOF
chmod +x "$_MOCK_UV_DIR/uv"
# Simulates the torch install decision from install.sh using SKIP_TORCH
_TORCH_BLOCK=$(mktemp)
cat > "$_TORCH_BLOCK" << 'TORCH_EOF'
# Simulates the torch install decision from install.sh
TORCH_INDEX_URL="https://download.pytorch.org/whl/cpu"
_VENV_PY="/fake/python"
if [ "$SKIP_TORCH" = true ]; then
echo "==> Skipping PyTorch (--no-torch or Intel Mac x86_64)."
else
echo "==> Installing PyTorch ($TORCH_INDEX_URL)..."
uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" torchvision torchaudio \
--index-url "$TORCH_INDEX_URL"
fi
TORCH_EOF
# Test: SKIP_TORCH=true -> torch install should be SKIPPED (no uv calls)
> "$_UV_LOG" # clear log
_torch_output=$(SKIP_TORCH=true PATH="$_MOCK_UV_DIR:$PATH" bash "$_TORCH_BLOCK" 2>&1)
assert_contains "SKIP_TORCH=true prints skip message" "$_torch_output" "Skipping PyTorch"
if [ -s "$_UV_LOG" ]; then
echo " FAIL: uv was called when SKIP_TORCH=true (should be skipped)"
echo " Log: $(cat "$_UV_LOG")"
FAIL=$((FAIL + 1))
else
echo " PASS: no uv pip install torch when SKIP_TORCH=true"
PASS=$((PASS + 1))
fi
# Test: SKIP_TORCH=false -> torch install should EXECUTE (uv called with torch)
> "$_UV_LOG" # clear log
_torch_output=$(SKIP_TORCH=false PATH="$_MOCK_UV_DIR:$PATH" bash "$_TORCH_BLOCK" 2>&1)
assert_contains "SKIP_TORCH=false prints install message" "$_torch_output" "Installing PyTorch"
if grep -q "torch" "$_UV_LOG"; then
echo " PASS: uv pip install torch called when SKIP_TORCH=false"
PASS=$((PASS + 1))
else
echo " FAIL: uv pip install torch NOT called when SKIP_TORCH=false"
FAIL=$((FAIL + 1))
fi
rm -f "$_TORCH_BLOCK"
rm -rf "$_MOCK_UV_DIR"
echo ""
echo "=== E2E: UNSLOTH_NO_TORCH env propagation (dynamic test) ==="
# Simulates the setup.sh invocation using SKIP_TORCH
_ENV_BLOCK=$(mktemp)
cat > "$_ENV_BLOCK" << 'ENV_EOF'
# Simulates the setup.sh invocation block from install.sh
PACKAGE_NAME="unsloth"
_REPO_ROOT="/fake/repo"
SETUP_SH="/fake/setup.sh"
if [ "$STUDIO_LOCAL_INSTALL" = true ]; then
SKIP_STUDIO_BASE=1 \
STUDIO_PACKAGE_NAME="$PACKAGE_NAME" \
STUDIO_LOCAL_INSTALL=1 \
STUDIO_LOCAL_REPO="$_REPO_ROOT" \
UNSLOTH_NO_TORCH="$SKIP_TORCH" \
env | grep "^UNSLOTH_NO_TORCH="
else
SKIP_STUDIO_BASE=1 \
STUDIO_PACKAGE_NAME="$PACKAGE_NAME" \
UNSLOTH_NO_TORCH="$SKIP_TORCH" \
env | grep "^UNSLOTH_NO_TORCH="
fi
ENV_EOF
# Test: SKIP_TORCH=true -> UNSLOTH_NO_TORCH=true in env
_env_result=$(SKIP_TORCH=true STUDIO_LOCAL_INSTALL=false bash "$_ENV_BLOCK" 2>&1)
assert_eq "non-local: UNSLOTH_NO_TORCH=true when SKIP_TORCH=true" "UNSLOTH_NO_TORCH=true" "$_env_result"
# Test: SKIP_TORCH=false -> UNSLOTH_NO_TORCH=false in env
_env_result=$(SKIP_TORCH=false STUDIO_LOCAL_INSTALL=false bash "$_ENV_BLOCK" 2>&1)
assert_eq "non-local: UNSLOTH_NO_TORCH=false when SKIP_TORCH=false" "UNSLOTH_NO_TORCH=false" "$_env_result"
# Test: local install path also propagates
_env_result=$(SKIP_TORCH=true STUDIO_LOCAL_INSTALL=true bash "$_ENV_BLOCK" 2>&1)
assert_eq "local: UNSLOTH_NO_TORCH=true when SKIP_TORCH=true" "UNSLOTH_NO_TORCH=true" "$_env_result"
_env_result=$(SKIP_TORCH=false STUDIO_LOCAL_INSTALL=true bash "$_ENV_BLOCK" 2>&1)
assert_eq "local: UNSLOTH_NO_TORCH=false when SKIP_TORCH=false" "UNSLOTH_NO_TORCH=false" "$_env_result"
rm -f "$_ENV_BLOCK"
echo ""
echo "=== --python override flag ==="
# Test: flag parsing extracts version correctly
_PARSE_BLOCK=$(mktemp)
cat > "$_PARSE_BLOCK" << 'PARSE_EOF'
_USER_PYTHON=""
_next_is_python=false
_next_is_package=false
STUDIO_LOCAL_INSTALL=false
PACKAGE_NAME="unsloth"
for arg in "$@"; do
if [ "$_next_is_package" = true ]; then PACKAGE_NAME="$arg"; _next_is_package=false; continue; fi
if [ "$_next_is_python" = true ]; then _USER_PYTHON="$arg"; _next_is_python=false; continue; fi
case "$arg" in
--local) STUDIO_LOCAL_INSTALL=true ;;
--package) _next_is_package=true ;;
--python) _next_is_python=true ;;
esac
done
if [ "$_next_is_python" = true ]; then echo "ERROR"; exit 1; fi
echo "$_USER_PYTHON"
PARSE_EOF
_result=$(bash "$_PARSE_BLOCK" --python 3.12)
assert_eq "--python 3.12 parsed" "3.12" "$_result"
_result=$(bash "$_PARSE_BLOCK" --local --python 3.11)
assert_eq "--local --python 3.11 parsed" "3.11" "$_result"
_result=$(bash "$_PARSE_BLOCK" --python 3.12 --local --package foo)
assert_eq "--python with --local --package" "3.12" "$_result"
_result=$(bash "$_PARSE_BLOCK" 2>&1) # no --python
assert_eq "no --python -> empty" "" "$_result"
_rc=0
bash "$_PARSE_BLOCK" --python >/dev/null 2>&1 || _rc=$?
assert_eq "--python without arg -> error" "1" "$_rc"
rm -f "$_PARSE_BLOCK"
# Test: --python overrides auto-detected version in PYTHON_VERSION resolution
_RESOLVE_BLOCK=$(mktemp)
cat > "$_RESOLVE_BLOCK" << 'RESOLVE_EOF'
_USER_PYTHON="$1"
MAC_INTEL="$2"
if [ -n "$_USER_PYTHON" ]; then
PYTHON_VERSION="$_USER_PYTHON"
elif [ "$MAC_INTEL" = true ]; then
PYTHON_VERSION="3.12"
else
PYTHON_VERSION="3.13"
fi
echo "$PYTHON_VERSION"
RESOLVE_EOF
_result=$(bash "$_RESOLVE_BLOCK" "3.11" "true")
assert_eq "--python 3.11 overrides Intel Mac 3.12" "3.11" "$_result"
_result=$(bash "$_RESOLVE_BLOCK" "3.12" "false")
assert_eq "--python 3.12 overrides default 3.13" "3.12" "$_result"
_result=$(bash "$_RESOLVE_BLOCK" "" "true")
assert_eq "no override -> Intel Mac gets 3.12" "3.12" "$_result"
_result=$(bash "$_RESOLVE_BLOCK" "" "false")
assert_eq "no override -> non-Intel gets 3.13" "3.13" "$_result"
rm -f "$_RESOLVE_BLOCK"
# Test: --python flag exists in install.sh
if grep -q '\-\-python)' "$INSTALL_SH"; then
echo " PASS: --python case exists in install.sh"
PASS=$((PASS + 1))
else
echo " FAIL: --python case not found in install.sh"
FAIL=$((FAIL + 1))
fi
# Test: _USER_PYTHON guards exist for stale-venv and 3.13.8 checks
_user_py_guards=$(grep -c '_USER_PYTHON' "$INSTALL_SH")
if [ "$_user_py_guards" -ge 4 ]; then
echo " PASS: _USER_PYTHON referenced >= 4 times in install.sh (flag + resolution + guards)"
PASS=$((PASS + 1))
else
echo " FAIL: _USER_PYTHON should appear >= 4 times (found $_user_py_guards)"
FAIL=$((FAIL + 1))
fi
echo ""
echo "=== --no-torch flag parsing ==="
# Test: --no-torch sets _NO_TORCH_FLAG=true
_FLAG_SNIPPET=$(mktemp)
cat > "$_FLAG_SNIPPET" << 'SNIPPET'
_NO_TORCH_FLAG=false
_next_is_package=false
STUDIO_LOCAL_INSTALL=false
PACKAGE_NAME="unsloth"
for arg in "$@"; do
if [ "$_next_is_package" = true ]; then
PACKAGE_NAME="$arg"
_next_is_package=false
continue
fi
case "$arg" in
--local) STUDIO_LOCAL_INSTALL=true ;;
--package) _next_is_package=true ;;
--no-torch) _NO_TORCH_FLAG=true ;;
esac
done
echo "$_NO_TORCH_FLAG"
SNIPPET
_result=$(bash "$_FLAG_SNIPPET" --no-torch)
assert_eq "--no-torch sets flag to true" "true" "$_result"
_result=$(bash "$_FLAG_SNIPPET")
assert_eq "no flags -> flag is false" "false" "$_result"
_result=$(bash "$_FLAG_SNIPPET" --local --no-torch)
assert_eq "--local --no-torch both work" "true" "$_result"
_result=$(bash "$_FLAG_SNIPPET" --no-torch --package custom-pkg)
assert_eq "--no-torch with --package works" "true" "$_result"
rm -f "$_FLAG_SNIPPET"
echo ""
echo "=== SKIP_TORCH unification ==="
# Test: SKIP_TORCH is set to true when --no-torch flag is set (even without MAC_INTEL)
_SKIP_SNIPPET=$(mktemp)
cat > "$_SKIP_SNIPPET" << 'SNIPPET'
MAC_INTEL=false
_NO_TORCH_FLAG=$1
SKIP_TORCH=false
if [ "$_NO_TORCH_FLAG" = true ] || [ "$MAC_INTEL" = true ]; then
SKIP_TORCH=true
fi
echo "$SKIP_TORCH"
SNIPPET
_result=$(bash "$_SKIP_SNIPPET" true)
assert_eq "--no-torch flag alone sets SKIP_TORCH=true" "true" "$_result"
_result=$(bash "$_SKIP_SNIPPET" false)
assert_eq "no flag, no MAC_INTEL -> SKIP_TORCH=false" "false" "$_result"
# Test: MAC_INTEL=true alone also sets SKIP_TORCH=true
_SKIP_SNIPPET2=$(mktemp)
cat > "$_SKIP_SNIPPET2" << 'SNIPPET'
MAC_INTEL=true
_NO_TORCH_FLAG=false
SKIP_TORCH=false
if [ "$_NO_TORCH_FLAG" = true ] || [ "$MAC_INTEL" = true ]; then
SKIP_TORCH=true
fi
echo "$SKIP_TORCH"
SNIPPET
_result=$(bash "$_SKIP_SNIPPET2")
assert_eq "MAC_INTEL=true alone sets SKIP_TORCH=true" "true" "$_result"
rm -f "$_SKIP_SNIPPET" "$_SKIP_SNIPPET2"
echo ""
echo "=== CPU hint printing ==="
# Verify the CPU hint is present in install.sh source
if grep -q 'No NVIDIA GPU detected' "$INSTALL_SH"; then
echo " PASS: CPU hint message found in install.sh"
PASS=$((PASS + 1))
else
echo " FAIL: CPU hint message not found in install.sh"
FAIL=$((FAIL + 1))
fi
if grep -q '\-\-no-torch' "$INSTALL_SH"; then
echo " PASS: --no-torch appears in install.sh"
PASS=$((PASS + 1))
else
echo " FAIL: --no-torch not found in install.sh"
FAIL=$((FAIL + 1))
fi
echo ""
echo "Results: $PASS passed, $FAIL failed"
[ "$FAIL" -eq 0 ] || exit 1