unsloth/studio/backend/utils/datasets/format_conversion.py
Daniel Han e9ac785346
fix: install.sh Mac Intel compatibility + Studio no-torch support (#4624)
* fix: install.sh Mac Intel compatibility + Studio no-torch support (#4621)

On Intel Macs (x86_64), PyTorch has no wheels for torch >= 2.3, so the
installer crashes. Even when torch is absent, Studio crashes on startup
because two files have bare top-level torch imports.

Studio's GGUF inference (llama.cpp) does not need PyTorch. Training and
HF-inference already isolate torch to subprocesses. Only 2 files in the
server startup chain had top-level torch imports preventing startup.

Changes:
- install.sh: detect architecture, default to Python 3.12 on Intel Mac,
  skip torch install, add Python 3.13.8 guard for arm64, pass
  UNSLOTH_NO_TORCH env var to setup.sh
- data_collators.py: remove unused `import torch` (no torch.* refs)
- chat_templates.py: lazy-import IterableDataset into function bodies
- install_python_stack.py: add IS_MACOS/NO_TORCH constants, skip
  torch-dependent packages, skip overrides.txt, skip triton on macOS

No existing working flow changes. Linux/WSL and macOS arm64 behavior is
identical.

* tests: add test suite for Mac Intel compat + no-torch mode

Shell tests (test_mac_intel_compat.sh):
- version_ge edge cases (9 tests)
- Architecture detection for Darwin x86_64/arm64, Linux x86_64/aarch64
- get_torch_index_url returns cpu on simulated Darwin
- UNSLOTH_NO_TORCH propagation to both setup.sh branches

Python unit tests (test_no_torch_filtering.py):
- _filter_requirements with NO_TORCH_SKIP_PACKAGES
- NO_TORCH env var parsing (true/1/TRUE/false/0/unset)
- IS_MACOS constant check
- Overrides skip and triton macOS skip guards

Python import tests (test_studio_import_no_torch.py):
- data_collators.py loads in isolated no-torch venv
- chat_templates.py has no top-level torch imports
- Negative control confirms import torch fails without torch

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests: add E2E sandbox tests for Mac Intel no-torch mode

Replace static/synthetic test stubs with real sandbox tests:

- Shell: E2E uv venv creation at Python 3.12, mock uv shim to verify
  torch install is skipped when MAC_INTEL=true, dynamic env propagation
  test for UNSLOTH_NO_TORCH in both local and non-local install paths
- Python filtering: test real extras.txt and extras-no-deps.txt with
  NO_TORCH_SKIP_PACKAGES, subprocess mock of install_python_stack() for
  5 platform configs (NO_TORCH+macOS, Windows+NO_TORCH, normal Linux,
  Windows-only, macOS-only), VCS URL and env marker edge cases
- Python imports: parametrized Python 3.12+3.13 venv fixture, dataclass
  instantiation for all 3 collator classes, chat_templates.py exec with
  stubs, negative controls proving import torch and torchao install fail
  in no-torch venvs

91 total tests, all passing.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: address reviewer findings for Intel Mac no-torch mode

P1 fixes:
- Auto-infer NO_TORCH in install_python_stack.py via platform.machine()
  so `unsloth studio update` preserves GGUF-only mode without needing
  the UNSLOTH_NO_TORCH env var (6/10 reviewers)
- Add openai-whisper and transformers-cfg to NO_TORCH_SKIP_PACKAGES
  since both have unconditional torch dependencies (4/10 reviewers)
- Skip unsloth-zoo on Intel Mac --local installs (depends on torch)
  in both migrated and fresh install paths (1/10)
- Recreate stale 3.13 venvs as 3.12 on Intel Mac re-runs (1/10)
- Detect Apple Silicon under Rosetta via sysctl hw.optional.arm64
  and warn user to use native arm64 terminal (1/10)

P2 fixes:
- Wire new test files into tests/run_all.sh (4/10 reviewers)
- Add update-path tests (skip_base=False) for Intel Mac
- Add _infer_no_torch tests for platform auto-detection

P3 fixes:
- Fix macOS progress bar total (triton step skipped but was counted)
- Fix temp file leak when Windows + NO_TORCH filters stack

All tests pass: 30 shell, 66 Python (96 total).

* feat: add --python override flag to install.sh

Lets users force a specific Python version, e.g. ./install.sh --python 3.12.
Addresses M2 Mac users whose systems resolve to a problematic 3.13.x patch.
When --python is set, the Intel Mac stale-venv guard and 3.13.8 auto-downgrade
are skipped so the user's choice is respected.

* tests: add comprehensive E2E sandbox tests for no-torch mode

Add test_e2e_no_torch_sandbox.py with 7 test groups (43 tests total)
covering the full no-torch import chain, edge cases, and install logic:

- Group 1: BEFORE vs AFTER import chain comparison (proves the bug
  existed and the fix works by synthetically prepending top-level torch
  imports)
- Group 2: Dataclass instantiation without torch
- Group 3: Edge cases with broken/fake torch modules on sys.path
- Group 4: Hardware detection fallback to CPU without torch
- Group 5: install.sh flag parsing, version resolution, arch detection
- Group 6: install_python_stack.py NO_TORCH filtering
- Group 7: Live server startup without torch (marked @server, skipped
  when studio venv is unavailable)

All 43 tests pass on both Python 3.12 and 3.13 isolated venvs.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: add --no-torch flag to install.sh/ps1, fix lazy import bug in dataset formatting

- Fix chat_templates.py: narrow torch IterableDataset import into inner
  try/except ImportError so dataset.map() works without torch installed
- Fix format_conversion.py: same lazy import fix for convert_chatml_to_alpaca
  and convert_alpaca_to_chatml
- Add --no-torch flag to install.sh with unified SKIP_TORCH variable
  (driven by --no-torch flag OR MAC_INTEL auto-detection)
- Add --no-torch flag to install.ps1 with $SkipTorch variable
- Print CPU hint when no GPU detected and --no-torch not set
- Replace MAC_INTEL guards with SKIP_TORCH in torch install sections
- Update shell tests (40 pass) and Python tests (90 pass)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: address reviewer findings for --no-torch installer paths

- Fix migrated-env branch in install.sh and install.ps1: check
  SKIP_TORCH first, then branch on STUDIO_LOCAL_INSTALL. Previously
  SKIP_TORCH+non-local fell into else and installed unsloth-zoo (which
  depends on torch), defeating --no-torch mode.
- Fix $env:UNSLOTH_NO_TORCH leak in install.ps1: always set to "true"
  or "false" instead of only setting on the true branch. Prevents stale
  no-torch state from leaking across runs in the same PS session.
- Fix install_python_stack.py update path: add NO_TORCH guard around
  base.txt install so unsloth studio update does not reinstall
  unsloth-zoo (which depends on torch) in no-torch mode.

* fix: install unsloth + unsloth-zoo with --no-deps in no-torch mode

Instead of skipping unsloth-zoo entirely (which breaks unsloth's
dependency on it), install both packages with --no-deps so they are
present but torch is not pulled in transitively. Applied consistently
across all no-torch paths: migrated-env, fresh-local, fresh-non-local
in install.sh, install.ps1, and install_python_stack.py.

* chore: temporarily remove test files (will be added in a follow-up)

* refactor: deduplicate SKIP_TORCH conditional branches in installers

Collapse if/else blocks that differ only by --no-deps into a single
branch with a conditional flag variable. Applied to migrated-env and
fresh-local paths in install.sh, install.ps1, and install_python_stack.py.

* fix: apply --no-deps to fresh non-local --no-torch install path

The non-local else branch was missing $_no_deps_arg/$noDepsArg, so
uv pip install unsloth would resolve torch from PyPI metadata (the
published unsloth package still declares torch as a hard dep). Now
--no-deps is applied consistently to all SKIP_TORCH code paths.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 02:09:21 -07:00

902 lines
32 KiB
Python

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
Format conversion utilities for dataset processing.
This module contains functions for converting between dataset formats
(Alpaca, ShareGPT, ChatML) and standardizing chat formats.
"""
import os
from datasets import IterableDataset
from loggers import get_logger
logger = get_logger(__name__)
def standardize_chat_format(
dataset,
tokenizer = None,
aliases_for_system = [
"system",
],
aliases_for_user = [
"user",
"human",
"input",
],
aliases_for_assistant = [
"gpt",
"assistant",
"output",
],
batch_size = 1000,
num_proc = None,
):
"""
Our own standardization function that handles BOTH messages and conversations.
Converts non-standard role names and keys to standard format.
"""
import collections
import itertools
from datasets import IterableDataset
# Check if vision tokenizer is used
is_vlm = False
if tokenizer is not None:
if hasattr(tokenizer, "image_processor") or hasattr(tokenizer, "tokenizer"):
is_vlm = True
column_names = set(next(iter(dataset)).keys())
# Check for both 'conversations' and 'messages'
chat_column = None
if "conversations" in column_names:
chat_column = "conversations"
elif "messages" in column_names:
chat_column = "messages"
elif "texts" in column_names:
chat_column = "texts"
else:
return dataset # No chat column found
# Inspect structure
examples = itertools.islice(dataset, 10)
uniques = collections.defaultdict(list)
for example in examples:
for message in example[chat_column]:
for key, value in message.items():
if type(value) is not str:
continue # Skip non-string values
uniques[key].append(value)
if len(uniques.keys()) != 2:
return dataset # Unexpected structure
keys = list(uniques.keys())
length_first = len(set(uniques[keys[0]]))
length_second = len(set(uniques[keys[1]]))
# Determine which is role and which is content
if length_first < length_second:
role_key = keys[0]
content_key = keys[1]
else:
role_key = keys[1]
content_key = keys[0]
# Mapping for aliases
aliases_mapping = {}
for x in aliases_for_system:
aliases_mapping[x] = "system"
for x in aliases_for_user:
aliases_mapping[x] = "user"
for x in aliases_for_assistant:
aliases_mapping[x] = "assistant"
def _standardize_dataset(examples):
convos = examples[chat_column]
all_convos = []
for convo in convos:
new_convo = []
for message in convo:
# Get original role and content
original_role = message.get(role_key, "")
original_content = message.get(content_key, "")
# Map to standard role name
standard_role = aliases_mapping.get(original_role, original_role)
# Handle VLM format
if is_vlm:
original_content = [{"type": "text", "text": original_content}]
# Create dict with EXPLICIT ORDER
new_message = {"role": standard_role, "content": original_content}
new_convo.append(new_message)
all_convos.append(new_convo)
return {chat_column: all_convos}
dataset_map_kwargs = {
"batched": True,
"batch_size": batch_size,
}
if not isinstance(dataset, IterableDataset):
from utils.hardware import dataset_map_num_proc
if num_proc is None or type(num_proc) is not int:
num_proc = dataset_map_num_proc()
else:
num_proc = dataset_map_num_proc(num_proc)
dataset_map_kwargs["num_proc"] = num_proc
dataset_map_kwargs["desc"] = "Standardizing chat format"
return dataset.map(_standardize_dataset, **dataset_map_kwargs)
def convert_chatml_to_alpaca(dataset, batch_size = 1000, num_proc = None):
"""
Converts ChatML format (messages OR conversations) to Alpaca format.
Handles both standardized and ShareGPT formats.
Supports:
- "messages" or "conversations" column
- "role"/"content" (standard) or "from"/"value" (ShareGPT)
"""
try:
from torch.utils.data import IterableDataset
_is_torch_iterable = isinstance(dataset, IterableDataset)
except ImportError:
_is_torch_iterable = False
def _convert(examples):
# Auto-detect which column name is used
chatml_data = (
examples.get("messages")
or examples.get("conversations")
or examples.get("texts")
)
if chatml_data is None:
raise ValueError(
"No 'messages' or 'conversations' or 'texts' column found."
)
instructions = []
outputs = []
inputs = []
for convo in chatml_data:
instruction = ""
output = ""
for msg in convo:
# Handle both standard and ShareGPT formats
role = msg.get("role") or msg.get("from")
content = msg.get("content") or msg.get("value")
# Get first user message as instruction
if role in ["user", "human", "input"] and not instruction:
instruction = content
# Get first assistant message as output
elif role in ["assistant", "gpt", "output"] and not output:
output = content
break # Stop after first assistant response
instructions.append(instruction)
inputs.append("") # Alpaca typically has empty input
outputs.append(output)
return {"instruction": instructions, "input": inputs, "output": outputs}
dataset_map_kwargs = {
"batched": True,
"batch_size": batch_size,
}
if not _is_torch_iterable:
from utils.hardware import dataset_map_num_proc
if num_proc is None or type(num_proc) is not int:
num_proc = dataset_map_num_proc()
else:
num_proc = dataset_map_num_proc(num_proc)
dataset_map_kwargs["num_proc"] = num_proc
dataset_map_kwargs["desc"] = "Converting ChatML to Alpaca format"
return dataset.map(_convert, **dataset_map_kwargs)
def convert_alpaca_to_chatml(dataset, batch_size = 1000, num_proc = None):
"""
Converts Alpaca format to ChatML format.
Output format: Uses 'conversations' column with standard 'role'/'content' structure.
"""
try:
from torch.utils.data import IterableDataset
_is_torch_iterable = isinstance(dataset, IterableDataset)
except ImportError:
_is_torch_iterable = False
def _convert(examples):
conversations = []
for i in range(len(examples["instruction"])):
instruction = examples["instruction"][i]
input_text = examples.get("input", [""] * len(examples["instruction"]))[i]
output = examples["output"][i]
# Combine instruction and input (if exists) for user message
if input_text and input_text.strip():
user_content = f"{instruction}\n\n{input_text}".strip()
else:
user_content = instruction
# Build conversation in standard ChatML format
convo = [
{"role": "user", "content": user_content},
{"role": "assistant", "content": output},
]
conversations.append(convo)
return {"conversations": conversations}
dataset_map_kwargs = {
"batched": True,
"batch_size": batch_size,
}
if not _is_torch_iterable:
from utils.hardware import dataset_map_num_proc
if num_proc is None or type(num_proc) is not int:
num_proc = dataset_map_num_proc()
else:
num_proc = dataset_map_num_proc(num_proc)
dataset_map_kwargs["num_proc"] = num_proc
dataset_map_kwargs["desc"] = "Converting Alpaca to ChatML format"
return dataset.map(_convert, **dataset_map_kwargs)
def _format_eta(seconds):
"""Format seconds into a human-readable ETA string."""
if seconds < 60:
return f"{seconds:.0f}s"
elif seconds < 3600:
m, s = divmod(int(seconds), 60)
return f"{m}m {s}s"
else:
h, remainder = divmod(int(seconds), 3600)
m, _ = divmod(remainder, 60)
return f"{h}h {m}m"
def convert_to_vlm_format(
dataset,
instruction = None,
text_column = "text",
image_column = "image",
dataset_name = None,
progress_callback = None,
):
"""
Converts simple {image, text} format to VLM messages format.
Returns a LIST, not a HuggingFace Dataset (to preserve PIL Images).
For URL-based image datasets, runs a 200-sample parallel probe first to
estimate download speed and failure rate, then reports time estimate or
warning through progress_callback before proceeding with the full conversion.
Args:
progress_callback: Optional callable(status_message=str) to report
progress to the training overlay.
Returns:
list: List of dicts with 'messages' field
"""
from PIL import Image
from .vlm_processing import generate_smart_vlm_instruction
def _notify(msg):
"""Send status update to the training overlay if callback is available."""
if progress_callback:
progress_callback(status_message = msg)
# Generate smart instruction if not provided
if instruction is None:
instruction_info = generate_smart_vlm_instruction(
dataset,
text_column = text_column,
image_column = image_column,
dataset_name = dataset_name,
)
instruction = instruction_info["instruction"]
instruction_column = instruction_info.get("instruction_column")
uses_dynamic = instruction_info["uses_dynamic_instruction"]
logger.info(
f"📝 Auto-detected instruction type: {instruction_info['instruction_type']}"
)
logger.info(f"📝 Confidence: {instruction_info['confidence']:.2f}")
if not uses_dynamic:
logger.info(f"📝 Using instruction: '{instruction}'")
else:
logger.info(
f"📝 Using dynamic instructions from column: '{instruction_column}'"
)
else:
instruction_column = None
uses_dynamic = False
def _convert_single_sample(sample):
"""Convert a single sample to VLM format."""
# Get image (might be PIL Image, local path, URL, or bare filename)
image_data = sample[image_column]
if isinstance(image_data, str):
if image_data.startswith(("http://", "https://")):
import fsspec
from io import BytesIO
with fsspec.open(image_data, "rb", expand = True) as f:
image_data = Image.open(BytesIO(f.read())).convert("RGB")
elif _image_lookup is not None and image_data in _image_lookup:
# Bare filename → resolve via HF repo lookup
from huggingface_hub import hf_hub_download
local_path = hf_hub_download(
dataset_name,
_image_lookup[image_data],
repo_type = "dataset",
)
image_data = Image.open(local_path).convert("RGB")
else:
image_data = Image.open(image_data).convert("RGB")
# Get text (if list of strings, pick a random one — e.g. multiple captions)
text_data = sample[text_column]
if isinstance(text_data, list) and len(text_data) > 0:
import random
text_data = random.choice(text_data)
# Get instruction (static or dynamic)
if uses_dynamic and instruction_column:
current_instruction = sample[instruction_column]
else:
current_instruction = instruction
# Build VLM messages - simple structure
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": current_instruction},
{"type": "image", "image": image_data}, # PIL object
],
},
{"role": "assistant", "content": [{"type": "text", "text": text_data}]},
]
# Return dict with messages
return {"messages": messages}
total = len(dataset)
first_image = next(iter(dataset))[image_column]
has_urls = isinstance(first_image, str) and first_image.startswith(
("http://", "https://")
)
# ── Bare-filename detection: images stored as filenames (e.g. "img_001.png")
# that don't exist locally. Build a basename→repo_path lookup so we can
# resolve them via hf_hub_download during conversion.
_image_lookup = None
_IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tiff")
if (
not has_urls
and isinstance(first_image, str)
and not os.path.exists(first_image)
and dataset_name
):
try:
from huggingface_hub import HfApi
_notify("Resolving image filenames from HF repo...")
logger.info(
f"🔍 Image column contains bare filenames (e.g. '{first_image}') — building repo lookup..."
)
repo_files = HfApi().list_repo_files(dataset_name, repo_type = "dataset")
_image_lookup = {
os.path.basename(f): f
for f in repo_files
if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS)
}
if first_image in _image_lookup:
logger.info(
f"✅ Matched {len(_image_lookup)} image files in repo (e.g. '{first_image}''{_image_lookup[first_image]}')"
)
else:
logger.info(
f"⚠️ Built lookup with {len(_image_lookup)} images but '{first_image}' not found — falling back to local open"
)
_image_lookup = None
except Exception as e:
logger.info(f"⚠️ Failed to build HF repo image lookup: {e}")
_image_lookup = None
# ── URL probe: 200 samples with parallel workers to estimate speed + failure rate ──
PROBE_SIZE = 200
MAX_FAIL_RATE = 0.3
if has_urls and total > PROBE_SIZE:
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from utils.hardware import safe_thread_num_proc
num_workers = safe_thread_num_proc()
_notify(f"Probing {PROBE_SIZE} image URLs with {num_workers} workers...")
logger.info(
f"🔍 Probing {PROBE_SIZE}/{total} image URLs with {num_workers} workers..."
)
probe_samples = [dataset[i] for i in range(PROBE_SIZE)]
probe_ok = 0
probe_fail = 0
probe_start = time.time()
with ThreadPoolExecutor(max_workers = num_workers) as executor:
futures = {
executor.submit(_convert_single_sample, s): s for s in probe_samples
}
for future in as_completed(futures):
try:
future.result()
probe_ok += 1
except Exception:
probe_fail += 1
probe_elapsed = time.time() - probe_start
probe_total = probe_ok + probe_fail
fail_rate = probe_fail / probe_total if probe_total > 0 else 0
throughput = probe_total / probe_elapsed if probe_elapsed > 0 else 0
if fail_rate >= MAX_FAIL_RATE:
issues = [
f"{fail_rate:.0%} of the first {PROBE_SIZE} image URLs failed to download ({probe_fail}/{probe_total})",
"Images are external URLs, not embedded in the dataset",
]
# Try LLM-friendly warning
friendly = None
try:
from .llm_assist import llm_generate_dataset_warning
friendly = llm_generate_dataset_warning(
issues,
dataset_name = dataset_name,
modality = "vision",
column_names = [image_column, text_column],
)
except Exception:
pass
msg = friendly or (
f"⚠️ {fail_rate:.0%} of the first {PROBE_SIZE} images failed to download "
f"({probe_fail}/{probe_total}). "
"This dataset has too many broken or unreachable image URLs. "
"Consider using a dataset with embedded images instead."
)
logger.info(msg)
_notify(msg)
raise ValueError(msg)
# Estimate total time for remaining samples
remaining = total - PROBE_SIZE
estimated_seconds = remaining / throughput if throughput > 0 else 0
eta_str = _format_eta(estimated_seconds)
info_msg = (
f"Downloading {total:,} images ({num_workers} workers, ~{throughput:.1f} img/s). "
f"Estimated time: ~{eta_str}"
)
if probe_fail > 0:
info_msg += f" | {fail_rate:.0%} broken URLs will be skipped"
logger.info(
f"✅ Probe passed: {probe_ok}/{probe_total} ok, {probe_fail} failed ({fail_rate:.0%}), {throughput:.1f} img/s"
)
logger.info(f"⏱️ Estimated time for {total:,} samples: ~{eta_str}")
_notify(info_msg)
# ── Full conversion with progress ──
from tqdm import tqdm
logger.info(f"🔄 Converting {total} samples to VLM format...")
converted_list = []
failed_count = 0
if has_urls:
# Parallel conversion for URL-based datasets
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from utils.hardware import safe_thread_num_proc
num_workers = safe_thread_num_proc()
batch_size = 500
start_time = time.time()
for batch_start in range(0, total, batch_size):
batch_end = min(batch_start + batch_size, total)
batch_samples = [dataset[i] for i in range(batch_start, batch_end)]
with ThreadPoolExecutor(max_workers = num_workers) as executor:
futures = {
executor.submit(_convert_single_sample, s): i
for i, s in enumerate(batch_samples)
}
batch_results = [None] * len(batch_samples)
for future in as_completed(futures):
idx = futures[future]
try:
batch_results[idx] = future.result()
except Exception as e:
failed_count += 1
if failed_count == 1:
print(
f"⚠️ First VLM conversion failure: {type(e).__name__}: {e}"
)
if failed_count == 1:
logger.info(
f"⚠️ First VLM conversion failure: {type(e).__name__}: {e}"
)
converted_list.extend(r for r in batch_results if r is not None)
# Progress update every batch
elapsed = time.time() - start_time
done = batch_end
rate = done / elapsed if elapsed > 0 else 0
remaining_time = (total - done) / rate if rate > 0 else 0
eta_str = _format_eta(remaining_time)
progress_msg = f"Downloading images: {done:,}/{total:,} ({done*100//total}%) | ~{eta_str} remaining | {failed_count} skipped"
logger.info(
f" [{done}/{total}] {rate:.1f} img/s, {failed_count} failed, ETA {eta_str}"
)
_notify(progress_msg)
else:
# Sequential conversion for local/embedded images (fast, no I/O bottleneck)
pbar = tqdm(dataset, total = total, desc = "Converting VLM samples", unit = "sample")
for sample in pbar:
try:
converted_list.append(_convert_single_sample(sample))
except Exception as e:
failed_count += 1
if failed_count == 1:
# Log the first failure to aid debugging
print(f"⚠️ First VLM conversion failure: {type(e).__name__}: {e}")
if failed_count == 1:
# Log the first failure to aid debugging
logger.info(
f"⚠️ First VLM conversion failure: {type(e).__name__}: {e}"
)
pbar.set_postfix(ok = len(converted_list), failed = failed_count, refresh = False)
pbar.close()
if failed_count > 0:
fail_rate = failed_count / total
logger.info(
f"⚠️ Skipped {failed_count}/{total} ({fail_rate:.0%}) samples with broken/unreachable images"
)
# For datasets that skipped the probe (small URL datasets), check fail rate now
if has_urls and fail_rate >= MAX_FAIL_RATE:
issues = [
f"{fail_rate:.0%} of images failed to download ({failed_count}/{total})",
"Images are external URLs, not embedded in the dataset",
]
friendly = None
try:
from .llm_assist import llm_generate_dataset_warning
friendly = llm_generate_dataset_warning(
issues,
dataset_name = dataset_name,
modality = "vision",
column_names = [image_column, text_column],
)
except Exception:
pass
msg = friendly or (
f"⚠️ {fail_rate:.0%} of images failed to download ({failed_count}/{total}). "
"This dataset has too many broken or unreachable image URLs. "
"Consider using a dataset with embedded images instead."
)
_notify(msg)
raise ValueError(msg)
if len(converted_list) == 0:
issues = [
f"All {total} samples failed during VLM conversion — no usable images found",
f"Image column '{image_column}' may contain URLs that are no longer accessible, "
"or local file paths that don't exist",
]
friendly = None
try:
from .llm_assist import llm_generate_dataset_warning
friendly = llm_generate_dataset_warning(
issues,
dataset_name = dataset_name,
modality = "vision",
column_names = [image_column, text_column],
)
except Exception:
pass
raise ValueError(
friendly
or (
f"All {total} samples failed during VLM conversion — no usable images found. "
"This dataset may contain only image URLs that are no longer accessible."
)
)
logger.info(f"✅ Converted {len(converted_list)}/{total} samples")
_notify(f"Converted {len(converted_list):,}/{total:,} images successfully")
# Return list, NOT Dataset
return converted_list
def convert_sharegpt_with_images_to_vlm_format(
dataset,
image_column = "image",
messages_column = "conversations",
dataset_name = None,
progress_callback = None,
):
"""
Converts ShareGPT/ChatML datasets that have a separate image column and
``<image>`` placeholders inside the conversation text.
Example input::
{
"image": "sam/images/sa_545504.jpg",
"conversations": [
{"from": "human", "value": "<image>\\nWhat is this photo about?"},
{"from": "gpt", "value": "The image captures..."}
]
}
Returns a list of dicts in standard VLM messages format (PIL Images inline).
"""
from PIL import Image
from tqdm import tqdm
_IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tiff")
_ROLE_MAP = {
"human": "user",
"user": "user",
"gpt": "assistant",
"assistant": "assistant",
"system": "system",
}
def _notify(msg):
if progress_callback:
progress_callback(status_message = msg)
# ── Resolve image loading strategy (same 3-tier as convert_to_vlm_format) ──
total = len(dataset)
first_image = next(iter(dataset))[image_column]
_image_lookup = None
if (
isinstance(first_image, str)
and not first_image.startswith(("http://", "https://"))
and not os.path.exists(first_image)
and dataset_name
):
try:
from huggingface_hub import HfApi
_notify("Resolving image filenames from HF repo...")
logger.info(
f"🔍 Image column contains bare filenames (e.g. '{first_image}') — building repo lookup..."
)
repo_files = HfApi().list_repo_files(dataset_name, repo_type = "dataset")
_image_lookup = {
os.path.basename(f): f
for f in repo_files
if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS)
}
# Also add the full relative paths as keys (for paths like "sam/images/sa_545504.jpg")
for f in repo_files:
if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS):
_image_lookup[f] = f
if first_image in _image_lookup:
logger.info(
f"✅ Matched {len(_image_lookup)} image files in repo (e.g. '{first_image}''{_image_lookup[first_image]}')"
)
else:
logger.info(
f"⚠️ Built lookup with {len(_image_lookup)} images but '{first_image}' not found — falling back to local open"
)
_image_lookup = None
except Exception as e:
logger.info(f"⚠️ Failed to build HF repo image lookup: {e}")
_image_lookup = None
def _resolve_image(image_data):
"""Resolve image data to a PIL Image object."""
if hasattr(image_data, "size") and hasattr(image_data, "mode"):
return image_data # Already PIL
if isinstance(image_data, str):
if image_data.startswith(("http://", "https://")):
import fsspec
from io import BytesIO
with fsspec.open(image_data, "rb", expand = True) as f:
return Image.open(BytesIO(f.read())).convert("RGB")
elif _image_lookup is not None and image_data in _image_lookup:
from huggingface_hub import hf_hub_download
local_path = hf_hub_download(
dataset_name,
_image_lookup[image_data],
repo_type = "dataset",
)
return Image.open(local_path).convert("RGB")
else:
return Image.open(image_data).convert("RGB")
if isinstance(image_data, dict) and (
"bytes" in image_data or "path" in image_data
):
if image_data.get("bytes"):
from io import BytesIO
return Image.open(BytesIO(image_data["bytes"])).convert("RGB")
if image_data.get("path"):
return Image.open(image_data["path"]).convert("RGB")
raise ValueError(f"Cannot resolve image: {type(image_data)}")
def _convert_single_sample(sample):
"""Convert a single ShareGPT+image sample to standard VLM format."""
pil_image = _resolve_image(sample[image_column])
conversation = sample[messages_column]
new_messages = []
for msg in conversation:
role_raw = msg.get("from") or msg.get("role", "user")
role = _ROLE_MAP.get(role_raw.lower(), role_raw.lower())
text = msg.get("value") or msg.get("content") or ""
# Split on <image> to interleave text and image content blocks
if "<image>" in text:
parts = text.split("<image>")
content = []
for i, part in enumerate(parts):
part = part.strip()
if part:
content.append({"type": "text", "text": part})
if i < len(parts) - 1:
content.append({"type": "image", "image": pil_image})
# If <image> was the entire text, content might just be the image
if not content:
content.append({"type": "image", "image": pil_image})
else:
content = [{"type": "text", "text": text}]
new_messages.append({"role": role, "content": content})
return {"messages": new_messages}
# ── Full conversion with progress ──
logger.info(f"🔄 Converting {total} samples from ShareGPT+image format...")
converted_list = []
failed_count = 0
pbar = tqdm(dataset, total = total, desc = "Converting ShareGPT+image", unit = "sample")
for sample in pbar:
try:
converted_list.append(_convert_single_sample(sample))
except Exception as e:
failed_count += 1
if failed_count == 1:
logger.info(f"⚠️ First conversion failure: {type(e).__name__}: {e}")
pbar.set_postfix(ok = len(converted_list), failed = failed_count, refresh = False)
pbar.close()
if failed_count > 0:
logger.info(
f"⚠️ Skipped {failed_count}/{total} ({failed_count*100//total}%) samples"
)
if len(converted_list) == 0:
raise ValueError(
f"All {total} samples failed during ShareGPT+image conversion — "
"no usable samples found."
)
logger.info(f"✅ Converted {len(converted_list)}/{total} samples")
_notify(f"Converted {len(converted_list):,}/{total:,} samples successfully")
return converted_list
def convert_llava_to_vlm_format(dataset):
"""
Converts Llava format to standard VLM format.
Llava format:
- messages: [{'content': [{'type': 'image', 'index': 0}, {'type': 'text', 'text': '...'}]}]
- images: [PIL_Image1, PIL_Image2, ...]
Standard VLM format:
- messages: [{'content': [{'type': 'image', 'image': PIL_Image}, {'type': 'text', 'text': '...'}]}]
"""
from PIL import Image
logger.info(
f"🔄 Converting {len(dataset)} samples from Llava format to standard VLM format..."
)
def _convert_single_sample(sample):
"""Convert a single llava sample to standard VLM format."""
messages = sample["messages"]
images = sample.get("images", [])
# Process each message
new_messages = []
for msg in messages:
new_content = []
for item in msg["content"]:
if item["type"] == "image":
# Replace index with actual PIL image
if "index" in item and item["index"] is not None:
img_idx = item["index"]
if img_idx < len(images):
pil_image = images[img_idx]
# Ensure it's PIL
if isinstance(pil_image, str):
pil_image = Image.open(pil_image).convert("RGB")
new_content.append(
{
"type": "image",
"image": pil_image, # Actual PIL object
}
)
else:
# No index, try to use first image
if len(images) > 0:
pil_image = images[0]
if isinstance(pil_image, str):
pil_image = Image.open(pil_image).convert("RGB")
new_content.append({"type": "image", "image": pil_image})
elif item["type"] == "text":
# Keep text as-is (only type + text)
new_content.append({"type": "text", "text": item.get("text", "")})
new_messages.append({"role": msg["role"], "content": new_content})
return {"messages": new_messages}
# Convert using list comprehension
converted_list = [_convert_single_sample(sample) for sample in dataset]
logger.info(f"✅ Converted {len(converted_list)} samples")
return converted_list