Merge pull request #3612 from Vangmay/feature/raw-text-dataprep

Feature/raw text dataprep
This commit is contained in:
Daniel Han 2026-01-08 03:38:15 -08:00 committed by GitHub
commit 0f07e36813
5 changed files with 564 additions and 8 deletions

172
tests/test_raw_text.py Normal file
View file

@ -0,0 +1,172 @@
#!/usr/bin/env python3
"""
Minimal test for raw text training implementation.
Tests basic functionality without heavy dependencies.
"""
import sys
import os
import tempfile
from pathlib import Path
import importlib.util
# Mock the datasets module since it's not installed
class MockDataset:
def __init__(self, data_dict):
self.data = data_dict
self.column_names = list(data_dict.keys())
def __len__(self):
return len(next(iter(self.data.values())))
def __getitem__(self, idx):
if isinstance(idx, str):
# Allow accessing columns by name like dataset['text']
return self.data[idx]
elif isinstance(idx, int):
# Allow accessing individual rows by index
return {key: values[idx] for key, values in self.data.items()}
else:
raise TypeError(f"Invalid index type: {type(idx)}")
@classmethod
def from_dict(cls, data_dict):
return cls(data_dict)
# Mock datasets module
datasets_mock = type(sys)("datasets")
datasets_mock.Dataset = MockDataset
sys.modules["datasets"] = datasets_mock
# Import the raw_text module directly to avoid unsloth/__init__.py dependencies
current_dir = os.path.dirname(__file__)
raw_text_path = os.path.join(
os.path.dirname(current_dir), "unsloth", "dataprep", "raw_text.py"
)
spec = importlib.util.spec_from_file_location("raw_text", raw_text_path)
raw_text_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(raw_text_module)
RawTextDataLoader = raw_text_module.RawTextDataLoader
TextPreprocessor = raw_text_module.TextPreprocessor
def test_raw_text_loader():
"""Test basic RawTextDataLoader functionality."""
# Mock tokenizer for testing
class MockTokenizer:
def __init__(self):
self.eos_token = "</s>"
self.eos_token_id = 2 # Mock EOS token ID
def __call__(self, text, return_tensors = None, add_special_tokens = False):
words = text.split()
token_ids = list(range(len(words)))
if return_tensors == "pt":
# Mock tensor-like object
class MockTensor:
def __init__(self, data):
self.data = data
def __getitem__(self, idx):
return self.data
def __len__(self):
return len(self.data)
def tolist(self):
return self.data
return {"input_ids": [MockTensor(token_ids)]}
return {"input_ids": token_ids}
def decode(self, token_ids, skip_special_tokens = False):
return " ".join([f"word_{i}" for i in token_ids])
# Create test file
test_content = "This is a test file for raw text training. " * 10
with tempfile.NamedTemporaryFile(mode = "w", suffix = ".txt", delete = False) as f:
f.write(test_content)
test_file = f.name
try:
# Test loader
tokenizer = MockTokenizer()
loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2)
# Test loading with text output (legacy mode)
text_dataset = loader.load_from_file(test_file, return_tokenized = False)
assert len(text_dataset) > 0, "Should create at least one chunk"
assert "text" in text_dataset.column_names, "Dataset should have 'text' column"
# Test loading with tokenized output (new efficient mode)
tokenized_dataset = loader.load_from_file(test_file, return_tokenized = True)
assert len(tokenized_dataset) > 0, "Should create at least one tokenized chunk"
assert (
"input_ids" in tokenized_dataset.column_names
), "Dataset should have 'input_ids' column"
assert (
"attention_mask" in tokenized_dataset.column_names
), "Dataset should have 'attention_mask' column"
# Verify tokenized data structure
first_sample = tokenized_dataset[0]
assert isinstance(first_sample["input_ids"], list), "input_ids should be a list"
assert isinstance(
first_sample["attention_mask"], list
), "attention_mask should be a list"
assert len(first_sample["input_ids"]) == len(
first_sample["attention_mask"]
), "input_ids and attention_mask should have same length"
# Verify labels field exists (for causal LM training)
assert (
"labels" in tokenized_dataset.column_names
), "Dataset should have 'labels' column"
assert (
first_sample["labels"] == first_sample["input_ids"]
), "labels should match input_ids"
# Test constructor validation
try:
bad_loader = RawTextDataLoader(tokenizer, chunk_size = 0, stride = 2)
assert False, "Should raise ValueError for chunk_size=0"
except ValueError as e:
assert "chunk_size must be positive" in str(e)
try:
bad_loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 10)
assert False, "Should raise ValueError for stride >= chunk_size"
except ValueError as e:
assert "stride" in str(e) and "chunk_size" in str(e)
# Test preprocessor
preprocessor = TextPreprocessor()
clean_text = preprocessor.clean_text(" messy text \n\n\n ")
assert "messy text" in clean_text, "Should clean text properly"
# Test validation
stats = preprocessor.validate_dataset(text_dataset)
assert stats["total_samples"] > 0, "Should count samples"
assert "warnings" in stats, "Should include warnings"
print("✅ All tests passed!")
return True
except Exception as e:
print(f"❌ Test failed: {e}")
return False
finally:
# Cleanup
os.unlink(test_file)
if __name__ == "__main__":
success = test_raw_text_loader()
sys.exit(0 if success else 1)

View file

@ -41,6 +41,7 @@ def run(args):
from unsloth import is_bfloat16_supported
from unsloth.models.loader_utils import prepare_device_map
import logging
from unsloth import RawTextDataLoader
logging.getLogger("hf-to-gguf").setLevel(logging.WARNING)
@ -99,15 +100,36 @@ def run(args):
texts.append(text)
return {"text": texts}
use_modelscope = strtobool(os.environ.get("UNSLOTH_USE_MODELSCOPE", "False"))
if use_modelscope:
from modelscope import MsDataset
def load_dataset_smart(args):
from transformers.utils import strtobool
dataset = MsDataset.load(args.dataset, split = "train")
else:
# Load and format dataset
dataset = load_dataset(args.dataset, split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True)
if args.raw_text_file:
# Use raw text loader
loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride)
dataset = loader.load_from_file(args.raw_text_file)
elif args.dataset.endswith((".txt", ".md", ".json", ".jsonl")):
# Auto-detect local raw text files
loader = RawTextDataLoader(tokenizer)
dataset = loader.load_from_file(args.dataset)
else:
# Check for modelscope usage
use_modelscope = strtobool(
os.environ.get("UNSLOTH_USE_MODELSCOPE", "False")
)
if use_modelscope:
from modelscope import MsDataset
dataset = MsDataset.load(args.dataset, split = "train")
else:
# Existing HuggingFace dataset logic
dataset = load_dataset(args.dataset, split = "train")
# Apply formatting for structured datasets
dataset = dataset.map(formatting_prompts_func, batched = True)
return dataset
# Load dataset using smart loader
dataset = load_dataset_smart(args)
print("Data is formatted and ready!")
# Configure training arguments
@ -437,5 +459,15 @@ if __name__ == "__main__":
help = "Token for pushing the model to Hugging Face hub",
)
parser.add_argument(
"--raw_text_file", type = str, help = "Path to raw text file for training"
)
parser.add_argument(
"--chunk_size", type = int, default = 2048, help = "Size of text chunks for training"
)
parser.add_argument(
"--stride", type = int, default = 512, help = "Overlap between chunks"
)
args = parser.parse_args()
run(args)

View file

@ -279,6 +279,9 @@ from .save import *
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *
# Export dataprep utilities for CLI and downstream users
from .dataprep.raw_text import RawTextDataLoader, TextPreprocessor
from unsloth_zoo.rl_environments import (
check_python_modules,
create_locked_down_function,

View file

@ -13,3 +13,4 @@
# limitations under the License.
from .synthetic import *
from .raw_text import *

View file

@ -0,0 +1,348 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import json
import csv
from typing import List, Dict, Any, Union, Optional
from datasets import Dataset
from pathlib import Path
__all__ = [
"RawTextDataLoader",
"TextPreprocessor",
]
SUPPORTED_FORMATS = {
".txt": "plain_text",
".md": "markdown",
".json": "json_lines",
".jsonl": "json_lines",
".csv": "csv_text_column",
}
class RawTextDataLoader:
def __init__(self, tokenizer, chunk_size = 2048, stride = 512, return_tokenized = True):
if chunk_size <= 0:
raise ValueError(f"chunk_size must be positive, got {chunk_size}")
if stride >= chunk_size:
raise ValueError(
f"stride ({stride}) must be smaller than chunk_size ({chunk_size})"
)
self.tokenizer = tokenizer
self.chunk_size = chunk_size
self.stride = stride
self.return_tokenized = return_tokenized
def detect_format(self, file_path):
"""Auto-detect file format and parse accordingly"""
extension = Path(file_path).suffix.lower()
return SUPPORTED_FORMATS.get(extension, "plain_text")
def load_from_file(self, file_path, return_tokenized = None):
"""Load raw text and convert to dataset"""
if return_tokenized is None:
return_tokenized = self.return_tokenized
file_format = self.detect_format(file_path)
text_content = self._read_file_by_format(file_path, file_format)
if not text_content or not text_content.strip():
raise ValueError(f"File '{file_path}' is empty or contains only whitespace")
chunks = self.smart_chunk_text(
text_content, self.chunk_size, self.stride, return_tokenized
)
return self.create_causal_dataset(chunks)
def load_from_files(self, file_paths, return_tokenized = None):
"""Load multiple text files"""
if return_tokenized is None:
return_tokenized = self.return_tokenized
all_chunks = []
for file_path in file_paths:
file_format = self.detect_format(file_path)
text_content = self._read_file_by_format(file_path, file_format)
chunks = self.smart_chunk_text(
text_content, self.chunk_size, self.stride, return_tokenized
)
all_chunks.extend(chunks)
return self.create_causal_dataset(all_chunks)
def chunk_text(self, text, return_tokenized = None):
"""Split text into overlapping chunks"""
if return_tokenized is None:
return_tokenized = self.return_tokenized
return self.smart_chunk_text(
text, self.chunk_size, self.stride, return_tokenized
)
def create_causal_dataset(self, chunks):
"""Create dataset for causal language modeling"""
if chunks and isinstance(chunks[0], dict):
# If chunks are already tokenized (dict with input_ids, attention_mask)
# Reorganize the data structure for Dataset.from_dict
input_ids = [chunk["input_ids"] for chunk in chunks]
attention_mask = [chunk["attention_mask"] for chunk in chunks]
# Labels are same as input_ids for causal LM training
labels = [list(ids) for ids in input_ids]
return Dataset.from_dict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)
else:
# If chunks are text strings (backward compatibility)
return Dataset.from_dict({"text": chunks})
def smart_chunk_text(self, text, chunk_size, stride, return_tokenized = True):
"""
Intelligent chunking that:
1. Respects sentence/paragraph boundaries
2. Handles various text formats (.txt, .md, .json, etc.)
3. Maintains context with stride overlap
4. Returns tokenized chunks directly (more efficient) or text chunks
"""
# First pass: tokenize the entire text to get accurate token counts
tokenized = self.tokenizer(text, return_tensors = "pt", add_special_tokens = False)
tokens = tokenized["input_ids"]
# Handle different tokenizer return formats
if hasattr(tokens, "__len__") and len(tokens) > 0:
# If it's a nested structure, get the first element
if hasattr(tokens[0], "__len__"):
tokens = tokens[0]
elif isinstance(tokens, int):
# If tokenizer returns just a count, create a simple range
tokens = list(range(tokens))
if len(tokens) <= chunk_size:
# Text is small enough to fit in one chunk
if return_tokenized:
# Add EOS token to the tokens if available
eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
if eos_token_id is not None:
tokens = (
tokens.tolist() if hasattr(tokens, "tolist") else list(tokens)
)
tokens.append(eos_token_id)
# Create attention mask
attention_mask = [1] * len(tokens)
return [{"input_ids": tokens, "attention_mask": attention_mask}]
else:
eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else ""
return [text + eos_token]
chunks = []
start_idx = 0
while start_idx < len(tokens):
# Calculate end index for this chunk
end_idx = min(start_idx + chunk_size, len(tokens))
# Extract tokens for this chunk
chunk_tokens = tokens[start_idx:end_idx]
if return_tokenized:
# Convert to list if it's a tensor
chunk_tokens_list = (
chunk_tokens.tolist()
if hasattr(chunk_tokens, "tolist")
else list(chunk_tokens)
)
# Add EOS token if it's the last chunk or chunk is complete
if end_idx == len(tokens) or len(chunk_tokens_list) == chunk_size:
eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
if eos_token_id is not None:
chunk_tokens_list.append(eos_token_id)
# Create attention mask (all tokens are attended to)
attention_mask = [1] * len(chunk_tokens_list)
chunks.append(
{"input_ids": chunk_tokens_list, "attention_mask": attention_mask}
)
else:
# Decode back to text (backward compatibility)
chunk_text = self.tokenizer.decode(
chunk_tokens, skip_special_tokens = True
)
# Add EOS token if it's the last chunk or chunk is complete
if end_idx == len(tokens) or len(chunk_tokens) == chunk_size:
eos_token = (
self.tokenizer.eos_token if self.tokenizer.eos_token else ""
)
chunk_text += eos_token
chunks.append(chunk_text)
# Move to next chunk with stride overlap
if end_idx == len(tokens):
break
start_idx += chunk_size - stride
return chunks
def _read_file_by_format(self, file_path, file_format):
"""Read file content based on detected format."""
with open(file_path, "r", encoding = "utf-8") as f:
if file_format == "plain_text" or file_format == "markdown":
return f.read()
elif file_format == "json_lines":
lines = []
for line in f:
try:
data = json.loads(line.strip())
text = self._extract_text_from_json(data)
if text:
lines.append(text)
except json.JSONDecodeError:
continue
return "\n\n".join(lines)
elif file_format == "csv_text_column":
reader = csv.DictReader(f)
texts = []
for row in reader:
text = self._extract_text_from_csv_row(row)
if text:
texts.append(text)
return "\n\n".join(texts)
return ""
def _extract_text_from_json(self, data):
"""Extract text from JSON object using common field names."""
text_fields = ["text", "content", "message", "body", "description", "prompt"]
for field in text_fields:
if field in data and isinstance(data[field], str):
return data[field]
return ""
def _extract_text_from_csv_row(self, row):
"""Extract text from CSV row using common column names."""
text_columns = ["text", "content", "message", "body", "description", "prompt"]
for column in text_columns:
if column in row and row[column]:
return row[column]
return ""
class TextPreprocessor:
def clean_text(self, text):
"""Remove unwanted characters, normalize whitespace"""
text = re.sub(r"\s+", " ", text)
text = re.sub(r"[^\x20-\x7E\n\t]", "", text)
text = text.replace("\r\n", "\n").replace("\r", "\n")
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
def extract_sections(self, text, patterns):
"""Extract specific sections (e.g., code blocks, quotes)"""
sections = []
for pattern in patterns:
matches = re.findall(pattern, text, re.MULTILINE | re.DOTALL)
sections.extend(matches)
return sections
def add_structure_tokens(self, text):
"""Add special tokens for structure (chapters, sections)"""
text = re.sub(
r"^# (.+)$", r"<|chapter|>\1<|/chapter|>", text, flags = re.MULTILINE
)
text = re.sub(
r"^## (.+)$", r"<|section|>\1<|/section|>", text, flags = re.MULTILINE
)
text = re.sub(
r"^### (.+)$", r"<|subsection|>\1<|/subsection|>", text, flags = re.MULTILINE
)
text = re.sub(
r"```(\w*)\n(.*?)\n```", r"<|code|\1|>\2<|/code|>", text, flags = re.DOTALL
)
return text
def validate_dataset(self, dataset):
"""
Check for:
- Minimum/maximum sequence lengths
- Character encoding issues
- Repeated content
- Empty chunks
"""
stats = {
"total_samples": len(dataset),
"empty_samples": 0,
"min_length": float("inf"),
"max_length": 0,
"avg_length": 0,
"repeated_content": 0,
"encoding_issues": 0,
"warnings": [],
}
texts = dataset["text"]
text_lengths = []
seen_texts = set()
for i, text in enumerate(texts):
if not text or len(text.strip()) == 0:
stats["empty_samples"] += 1
continue
# Check for encoding issues
try:
text.encode("utf-8")
except UnicodeEncodeError:
stats["encoding_issues"] += 1
# Calculate lengths
length = len(text)
text_lengths.append(length)
stats["min_length"] = min(stats["min_length"], length)
stats["max_length"] = max(stats["max_length"], length)
# Check for repeated content
text_hash = hash(text.strip())
if text_hash in seen_texts:
stats["repeated_content"] += 1
else:
seen_texts.add(text_hash)
# Calculate average length
if text_lengths:
stats["avg_length"] = sum(text_lengths) / len(text_lengths)
stats["min_length"] = (
stats["min_length"] if stats["min_length"] != float("inf") else 0
)
# Generate warnings
if stats["empty_samples"] > 0:
stats["warnings"].append(f"Found {stats['empty_samples']} empty samples")
if stats["repeated_content"] > 0:
stats["warnings"].append(
f"Found {stats['repeated_content']} repeated samples"
)
if stats["encoding_issues"] > 0:
stats["warnings"].append(
f"Found {stats['encoding_issues']} encoding issues"
)
if stats["min_length"] < 10:
stats["warnings"].append("Some samples are very short (< 10 characters)")
return stats