From face46d188b7f2e03444b7ed68e206d5d02c4bc2 Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 18 Nov 2025 21:46:41 +0800 Subject: [PATCH 01/21] Write file and template for raw_text dataprep --- unsloth/dataprep/raw_text.py | 76 ++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 unsloth/dataprep/raw_text.py diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py new file mode 100644 index 000000000..93c8c6855 --- /dev/null +++ b/unsloth/dataprep/raw_text.py @@ -0,0 +1,76 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import json +import csv +from typing import List, Dict, Any, Union, Optional +from datasets import Dataset +from pathlib import Path + +class RawTextDataLoader: + def __init__(self, tokenizer, chunk_size=2048, stride=512): + self.tokenizer = tokenizer + self.chunk_size = chunk_size + self.stride = stride + + def load_from_file(self, file_path): + """Load raw text and convert to dataset""" + + def load_from_files(self, file_paths): + """Load multiple text files""" + + def chunk_text(self, text): + """Split text into overlapping chunks""" + + def create_causal_dataset(self, chunks): + """Create dataset for causal language modeling""" + + def smart_chunk_text(self, text, chunk_size, stride): + """ + Intelligent chunking that: + 1. Respects sentence/paragraph boundaries + 2. Handles various text formats (.txt, .md, .json, etc.) + 3. Maintains context with stride overlap + 4. Adds proper EOS tokens + """ + + def tokenize_and_chunk(self, text): + """ + Tokenize first, then chunk by token count: + 1. More precise length control + 2. Avoids mid-token splits + 3. Handles different languages better + """ + +class TextPreprocessor: + def clean_text(self, text): + """Remove unwanted characters, normalize whitespace""" + + def extract_sections(self, text, patterns): + """Extract specific sections (e.g., code blocks, quotes)""" + + def add_structure_tokens(self, text): + """Add special tokens for structure (chapters, sections)""" + +def validate_dataset(self, dataset): + """ + Check for: + - Minimum/maximum sequence lengths + - Character encoding issues + - Repeated content + - Empty chunks + """ + From d75fbb5d0aba4c796b8580a6bd73655e88f5d086 Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 18 Nov 2025 21:53:20 +0800 Subject: [PATCH 02/21] Add implementation to cli --- unsloth-cli.py | 48 ++++++++++++++++++++++++++++++++++++ unsloth/dataprep/raw_text.py | 25 +++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/unsloth-cli.py b/unsloth-cli.py index fb6e39266..aac0e7f7e 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -42,6 +42,7 @@ def run(args): from transformers import TrainingArguments from unsloth import is_bfloat16_supported import logging + from unsloth import RawTextDataLoader logging.getLogger("hf-to-gguf").setLevel(logging.WARNING) @@ -98,6 +99,21 @@ def run(args): texts.append(text) return {"text": texts} + def load_dataset_smart(args): + if args.raw_text_file: + # Use raw text loader + loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) + dataset = loader.load_from_file(args.raw_text_file) + elif args.dataset.endswith(('.txt', '.md', '.json', '.jsonl')): + # Auto-detect local raw text files + loader = RawTextDataLoader(tokenizer) + dataset = loader.load_from_file(args.dataset) + else: + # Existing HuggingFace dataset logic + dataset = load_dataset(args.dataset, split="train") + dataset = dataset.map(formatting_prompts_func, batched=True) + return dataset + use_modelscope = strtobool(os.environ.get("UNSLOTH_USE_MODELSCOPE", "False")) if use_modelscope: from modelscope import MsDataset @@ -389,5 +405,37 @@ if __name__ == "__main__": "--hub_token", type = str, help = "Token for pushing the model to Hugging Face hub" ) + parser.add_argument( + "--raw_text_file", + type=str, + help="Path to raw text file for training" + ) + parser.add_argument( + "--chunk_size", + type=int, + default=2048, + help="Size of text chunks for training" + ) + parser.add_argument( + "--stride", + type=int, + default=512, + help="Overlap between chunks" + ) + + TRAINING_MODES = { + 'instruction': 'Standard instruction-following', + 'causal': 'Causal language modeling (raw text)', + 'completion': 'Text completion tasks' + } + + parser.add_argument( + "--training_mode", + type=str, + default="instruction", + choices=list(TRAINING_MODES.keys()), + help="Training mode for the model" + ) + args = parser.parse_args() run(args) diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index 93c8c6855..e36978bf1 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -20,12 +20,28 @@ from typing import List, Dict, Any, Union, Optional from datasets import Dataset from pathlib import Path +__all__ = [ + "RawTextDataLoader", + "TextPreprocessor", +] + +SUPPORTED_FORMATS = { + '.txt': 'plain_text', + '.md': 'markdown', + '.json': 'json_lines', + '.jsonl': 'json_lines', + '.csv': 'csv_text_column' +} + class RawTextDataLoader: def __init__(self, tokenizer, chunk_size=2048, stride=512): self.tokenizer = tokenizer self.chunk_size = chunk_size self.stride = stride + def detect_format(self, file_path): + """Auto-detect file format and parse accordingly""" + def load_from_file(self, file_path): """Load raw text and convert to dataset""" @@ -64,6 +80,15 @@ class TextPreprocessor: def add_structure_tokens(self, text): """Add special tokens for structure (chapters, sections)""" + + def validate_dataset(self, dataset): + """ + Check for: + - Minimum/maximum sequence lengths + - Character encoding issues + - Repeated content + - Empty chunks + """ def validate_dataset(self, dataset): """ From aecfbe1fff782b95c00280baee264bf61385bd00 Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 18 Nov 2025 21:59:01 +0800 Subject: [PATCH 03/21] Add support for multiple files --- unsloth/dataprep/raw_text.py | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index e36978bf1..8b705b3a9 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -41,12 +41,26 @@ class RawTextDataLoader: def detect_format(self, file_path): """Auto-detect file format and parse accordingly""" + extension = Path(file_path).suffix.lower() + return SUPPORTED_FORMATS.get(extension, 'plain_text') def load_from_file(self, file_path): """Load raw text and convert to dataset""" + file_format = self.detect_format(file_path) + text_content = self._read_file_by_format(file_path, file_format) + chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride) + return self.create_causal_dataset(chunks) def load_from_files(self, file_paths): """Load multiple text files""" + all_chunks = [] + for file_path in file_paths: + file_format = self.detect_format(file_path) + text_content = self._read_file_by_format(file_path, file_format) + chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride) + all_chunks.extend(chunks) + return self.create_causal_dataset(all_chunks) + def chunk_text(self, text): """Split text into overlapping chunks""" @@ -71,6 +85,48 @@ class RawTextDataLoader: 3. Handles different languages better """ + def _read_file_by_format(self, file_path, file_format): + """Read file content based on detected format.""" + with open(file_path, 'r', encoding='utf-8') as f: + if file_format == 'plain_text' or file_format == 'markdown': + return f.read() + elif file_format == 'json_lines': + lines = [] + for line in f: + try: + data = json.loads(line.strip()) + text = self._extract_text_from_json(data) + if text: + lines.append(text) + except json.JSONDecodeError: + continue + return '\n\n'.join(lines) + elif file_format == 'csv_text_column': + reader = csv.DictReader(f) + texts = [] + for row in reader: + text = self._extract_text_from_csv_row(row) + if text: + texts.append(text) + return '\n\n'.join(texts) + return "" + + def _extract_text_from_json(self, data): + """Extract text from JSON object using common field names.""" + text_fields = ['text', 'content', 'message', 'body', 'description', 'prompt'] + for field in text_fields: + if field in data and isinstance(data[field], str): + return data[field] + return "" + + def _extract_text_from_csv_row(self, row): + """Extract text from CSV row using common column names.""" + text_columns = ['text', 'content', 'message', 'body', 'description', 'prompt'] + for column in text_columns: + if column in row and row[column]: + return row[column] + return "" + class TextPreprocessor: def clean_text(self, text): """Remove unwanted characters, normalize whitespace""" From ed5820e667f98ca02cad25db827c2868ad56b994 Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 18 Nov 2025 22:00:07 +0800 Subject: [PATCH 04/21] Write chunking logic --- unsloth/dataprep/raw_text.py | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index 8b705b3a9..44582c76a 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -64,9 +64,11 @@ class RawTextDataLoader: def chunk_text(self, text): """Split text into overlapping chunks""" + return self.smart_chunk_text(text, self.chunk_size, self.stride) def create_causal_dataset(self, chunks): """Create dataset for causal language modeling""" + return Dataset.from_dict({"text": chunks}) def smart_chunk_text(self, text, chunk_size, stride): """ @@ -76,6 +78,51 @@ class RawTextDataLoader: 3. Maintains context with stride overlap 4. Adds proper EOS tokens """ + # First pass: tokenize the entire text to get accurate token counts + tokenized = self.tokenizer(text, return_tensors="pt", add_special_tokens=False) + tokens = tokenized["input_ids"] + + # Handle different tokenizer return formats + if hasattr(tokens, '__len__') and len(tokens) > 0: + # If it's a nested structure, get the first element + if hasattr(tokens[0], '__len__'): + tokens = tokens[0] + elif isinstance(tokens, int): + # If tokenizer returns just a count, create a simple range + tokens = list(range(tokens)) + + if len(tokens) <= chunk_size: + # Text is small enough to fit in one chunk + eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" + return [text + eos_token] + + chunks = [] + start_idx = 0 + + while start_idx < len(tokens): + # Calculate end index for this chunk + end_idx = min(start_idx + chunk_size, len(tokens)) + + # Extract tokens for this chunk + chunk_tokens = tokens[start_idx:end_idx] + + # Decode back to text + chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens=True) + + # Add EOS token if it's the last chunk or chunk is complete + if end_idx == len(tokens) or len(chunk_tokens) == chunk_size: + eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" + chunk_text += eos_token + + chunks.append(chunk_text) + + # Move to next chunk with stride overlap + if end_idx == len(tokens): + break + start_idx += chunk_size - stride + + return chunks + def tokenize_and_chunk(self, text): """ From 6014bb4dd21597c7579711f95a895f045622fb97 Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 18 Nov 2025 22:01:36 +0800 Subject: [PATCH 05/21] Add logic to clean and extract text sections --- unsloth/dataprep/raw_text.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index 44582c76a..b31120cc5 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -177,12 +177,27 @@ class RawTextDataLoader: class TextPreprocessor: def clean_text(self, text): """Remove unwanted characters, normalize whitespace""" + text = re.sub(r'\s+', ' ', text) + text = re.sub(r'[^\x20-\x7E\n\t]', '', text) + text = text.replace('\r\n', '\n').replace('\r', '\n') + text = re.sub(r'\n{3,}', '\n\n', text) + return text.strip() def extract_sections(self, text, patterns): """Extract specific sections (e.g., code blocks, quotes)""" - + sections = [] + for pattern in patterns: + matches = re.findall(pattern, text, re.MULTILINE | re.DOTALL) + sections.extend(matches) + return sections + def add_structure_tokens(self, text): """Add special tokens for structure (chapters, sections)""" + text = re.sub(r'^# (.+)$', r'<|chapter|>\1<|/chapter|>', text, flags=re.MULTILINE) + text = re.sub(r'^## (.+)$', r'<|section|>\1<|/section|>', text, flags=re.MULTILINE) + text = re.sub(r'^### (.+)$', r'<|subsection|>\1<|/subsection|>', text, flags=re.MULTILINE) + text = re.sub(r'```(\w*)\n(.*?)\n```', r'<|code|\1|>\2<|/code|>', text, flags=re.DOTALL) + return text def validate_dataset(self, dataset): """ From 8d482c212987da6e6ce445ae2138033fd3237ddc Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 18 Nov 2025 22:02:35 +0800 Subject: [PATCH 06/21] Add validation code --- unsloth/dataprep/raw_text.py | 67 +++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index b31120cc5..dd6d6b96d 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -207,13 +207,62 @@ class TextPreprocessor: - Repeated content - Empty chunks """ - -def validate_dataset(self, dataset): - """ - Check for: - - Minimum/maximum sequence lengths - - Character encoding issues - - Repeated content - - Empty chunks - """ + stats = { + 'total_samples': len(dataset), + 'empty_samples': 0, + 'min_length': float('inf'), + 'max_length': 0, + 'avg_length': 0, + 'repeated_content': 0, + 'encoding_issues': 0, + 'warnings': [] + } + + texts = dataset['text'] + text_lengths = [] + seen_texts = set() + + for i, text in enumerate(texts): + if not text or len(text.strip()) == 0: + stats['empty_samples'] += 1 + continue + + # Check for encoding issues + try: + text.encode('utf-8') + except UnicodeEncodeError: + stats['encoding_issues'] += 1 + + # Calculate lengths + length = len(text) + text_lengths.append(length) + stats['min_length'] = min(stats['min_length'], length) + stats['max_length'] = max(stats['max_length'], length) + + # Check for repeated content + text_hash = hash(text.strip()) + if text_hash in seen_texts: + stats['repeated_content'] += 1 + else: + seen_texts.add(text_hash) + + # Calculate average length + if text_lengths: + stats['avg_length'] = sum(text_lengths) / len(text_lengths) + stats['min_length'] = stats['min_length'] if stats['min_length'] != float('inf') else 0 + + # Generate warnings + if stats['empty_samples'] > 0: + stats['warnings'].append(f"Found {stats['empty_samples']} empty samples") + + if stats['repeated_content'] > 0: + stats['warnings'].append(f"Found {stats['repeated_content']} repeated samples") + + if stats['encoding_issues'] > 0: + stats['warnings'].append(f"Found {stats['encoding_issues']} encoding issues") + + if stats['min_length'] < 10: + stats['warnings'].append("Some samples are very short (< 10 characters)") + + return stats From ee37dd9f92144a8622a8b91b20808d2348924dac Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 18 Nov 2025 22:36:38 +0800 Subject: [PATCH 07/21] Write simple test --- tests/test_raw_text.py | 121 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 tests/test_raw_text.py diff --git a/tests/test_raw_text.py b/tests/test_raw_text.py new file mode 100644 index 000000000..503fbeb4c --- /dev/null +++ b/tests/test_raw_text.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +""" +Minimal test for raw text training implementation. +Tests basic functionality without heavy dependencies. +""" + +import sys +import os +import tempfile +from pathlib import Path +import importlib.util + +# Mock the datasets module since it's not installed +class MockDataset: + def __init__(self, data_dict): + self.data = data_dict + self.column_names = list(data_dict.keys()) + + def __len__(self): + return len(next(iter(self.data.values()))) + + def __getitem__(self, idx): + if isinstance(idx, str): + # Allow accessing columns by name like dataset['text'] + return self.data[idx] + elif isinstance(idx, int): + # Allow accessing individual rows by index + return {key: values[idx] for key, values in self.data.items()} + else: + raise TypeError(f"Invalid index type: {type(idx)}") + + @classmethod + def from_dict(cls, data_dict): + return cls(data_dict) + +# Mock datasets module +datasets_mock = type(sys)('datasets') +datasets_mock.Dataset = MockDataset +sys.modules['datasets'] = datasets_mock + +# Import the raw_text module directly to avoid unsloth/__init__.py dependencies +current_dir = os.path.dirname(__file__) +raw_text_path = os.path.join(os.path.dirname(current_dir), 'unsloth', 'dataprep', 'raw_text.py') + +spec = importlib.util.spec_from_file_location("raw_text", raw_text_path) +raw_text_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(raw_text_module) + +RawTextDataLoader = raw_text_module.RawTextDataLoader +TextPreprocessor = raw_text_module.TextPreprocessor + +def test_raw_text_loader(): + """Test basic RawTextDataLoader functionality.""" + + # Mock tokenizer for testing + class MockTokenizer: + def __init__(self): + self.eos_token = "" + + def __call__(self, text, return_tensors=None, add_special_tokens=False): + words = text.split() + token_ids = list(range(len(words))) + + if return_tensors == "pt": + # Mock tensor-like object + class MockTensor: + def __init__(self, data): + self.data = data + def __getitem__(self, idx): + return self.data + def __len__(self): + return len(self.data) + return {"input_ids": [MockTensor(token_ids)]} + return {"input_ids": token_ids} + + def decode(self, token_ids, skip_special_tokens=False): + return " ".join([f"word_{i}" for i in token_ids]) + + + + + # Create test file + test_content = "This is a test file for raw text training. " * 10 + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write(test_content) + test_file = f.name + + try: + # Test loader + tokenizer = MockTokenizer() + loader = RawTextDataLoader(tokenizer, chunk_size=5, stride=2) + + # Test loading + dataset = loader.load_from_file(test_file) + assert len(dataset) > 0, "Should create at least one chunk" + assert 'text' in dataset.column_names, "Dataset should have 'text' column" + + # Test preprocessor + preprocessor = TextPreprocessor() + clean_text = preprocessor.clean_text(" messy text \n\n\n ") + assert "messy text" in clean_text, "Should clean text properly" + + # Test validation + stats = preprocessor.validate_dataset(dataset) + assert stats['total_samples'] > 0, "Should count samples" + assert 'warnings' in stats, "Should include warnings" + + print("✅ All tests passed!") + return True + + except Exception as e: + print(f"❌ Test failed: {e}") + return False + + finally: + # Cleanup + os.unlink(test_file) + +if __name__ == "__main__": + success = test_raw_text_loader() + sys.exit(0 if success else 1) \ No newline at end of file From 171fb1257390349f6434f81f97051cda0af7e7d4 Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 18 Nov 2025 22:44:48 +0800 Subject: [PATCH 08/21] Add module to init --- unsloth/dataprep/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/__init__.py b/unsloth/dataprep/__init__.py index b36122eb7..b6840f247 100644 --- a/unsloth/dataprep/__init__.py +++ b/unsloth/dataprep/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .synthetic import * +from raw_text import * From d429363c23dda17da5623849c63cc7e7019f308f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 12:51:17 +0000 Subject: [PATCH 09/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_raw_text.py | 64 ++++++------ unsloth-cli.py | 34 +++---- unsloth/dataprep/raw_text.py | 183 +++++++++++++++++++---------------- 3 files changed, 146 insertions(+), 135 deletions(-) diff --git a/tests/test_raw_text.py b/tests/test_raw_text.py index 503fbeb4c..9bbfee92a 100644 --- a/tests/test_raw_text.py +++ b/tests/test_raw_text.py @@ -10,15 +10,16 @@ import tempfile from pathlib import Path import importlib.util + # Mock the datasets module since it's not installed class MockDataset: def __init__(self, data_dict): self.data = data_dict self.column_names = list(data_dict.keys()) - + def __len__(self): return len(next(iter(self.data.values()))) - + def __getitem__(self, idx): if isinstance(idx, str): # Allow accessing columns by name like dataset['text'] @@ -28,19 +29,22 @@ class MockDataset: return {key: values[idx] for key, values in self.data.items()} else: raise TypeError(f"Invalid index type: {type(idx)}") - + @classmethod def from_dict(cls, data_dict): return cls(data_dict) + # Mock datasets module -datasets_mock = type(sys)('datasets') +datasets_mock = type(sys)("datasets") datasets_mock.Dataset = MockDataset -sys.modules['datasets'] = datasets_mock +sys.modules["datasets"] = datasets_mock # Import the raw_text module directly to avoid unsloth/__init__.py dependencies current_dir = os.path.dirname(__file__) -raw_text_path = os.path.join(os.path.dirname(current_dir), 'unsloth', 'dataprep', 'raw_text.py') +raw_text_path = os.path.join( + os.path.dirname(current_dir), "unsloth", "dataprep", "raw_text.py" +) spec = importlib.util.spec_from_file_location("raw_text", raw_text_path) raw_text_module = importlib.util.module_from_spec(spec) @@ -49,73 +53,75 @@ spec.loader.exec_module(raw_text_module) RawTextDataLoader = raw_text_module.RawTextDataLoader TextPreprocessor = raw_text_module.TextPreprocessor + def test_raw_text_loader(): """Test basic RawTextDataLoader functionality.""" - + # Mock tokenizer for testing class MockTokenizer: def __init__(self): self.eos_token = "" - - def __call__(self, text, return_tensors=None, add_special_tokens=False): + + def __call__(self, text, return_tensors = None, add_special_tokens = False): words = text.split() token_ids = list(range(len(words))) - + if return_tensors == "pt": # Mock tensor-like object class MockTensor: def __init__(self, data): self.data = data + def __getitem__(self, idx): return self.data + def __len__(self): return len(self.data) + return {"input_ids": [MockTensor(token_ids)]} return {"input_ids": token_ids} - - def decode(self, token_ids, skip_special_tokens=False): + + def decode(self, token_ids, skip_special_tokens = False): return " ".join([f"word_{i}" for i in token_ids]) - - - - + # Create test file test_content = "This is a test file for raw text training. " * 10 - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + with tempfile.NamedTemporaryFile(mode = "w", suffix = ".txt", delete = False) as f: f.write(test_content) test_file = f.name - + try: # Test loader tokenizer = MockTokenizer() - loader = RawTextDataLoader(tokenizer, chunk_size=5, stride=2) - + loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2) + # Test loading dataset = loader.load_from_file(test_file) assert len(dataset) > 0, "Should create at least one chunk" - assert 'text' in dataset.column_names, "Dataset should have 'text' column" - + assert "text" in dataset.column_names, "Dataset should have 'text' column" + # Test preprocessor preprocessor = TextPreprocessor() clean_text = preprocessor.clean_text(" messy text \n\n\n ") assert "messy text" in clean_text, "Should clean text properly" - + # Test validation stats = preprocessor.validate_dataset(dataset) - assert stats['total_samples'] > 0, "Should count samples" - assert 'warnings' in stats, "Should include warnings" - + assert stats["total_samples"] > 0, "Should count samples" + assert "warnings" in stats, "Should include warnings" + print("✅ All tests passed!") return True - + except Exception as e: print(f"❌ Test failed: {e}") return False - + finally: # Cleanup os.unlink(test_file) + if __name__ == "__main__": success = test_raw_text_loader() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/unsloth-cli.py b/unsloth-cli.py index aac0e7f7e..e454a47c4 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -104,14 +104,14 @@ def run(args): # Use raw text loader loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) dataset = loader.load_from_file(args.raw_text_file) - elif args.dataset.endswith(('.txt', '.md', '.json', '.jsonl')): + elif args.dataset.endswith((".txt", ".md", ".json", ".jsonl")): # Auto-detect local raw text files loader = RawTextDataLoader(tokenizer) dataset = loader.load_from_file(args.dataset) else: # Existing HuggingFace dataset logic - dataset = load_dataset(args.dataset, split="train") - dataset = dataset.map(formatting_prompts_func, batched=True) + dataset = load_dataset(args.dataset, split = "train") + dataset = dataset.map(formatting_prompts_func, batched = True) return dataset use_modelscope = strtobool(os.environ.get("UNSLOTH_USE_MODELSCOPE", "False")) @@ -406,35 +406,27 @@ if __name__ == "__main__": ) parser.add_argument( - "--raw_text_file", - type=str, - help="Path to raw text file for training" + "--raw_text_file", type = str, help = "Path to raw text file for training" ) parser.add_argument( - "--chunk_size", - type=int, - default=2048, - help="Size of text chunks for training" + "--chunk_size", type = int, default = 2048, help = "Size of text chunks for training" ) parser.add_argument( - "--stride", - type=int, - default=512, - help="Overlap between chunks" + "--stride", type = int, default = 512, help = "Overlap between chunks" ) TRAINING_MODES = { - 'instruction': 'Standard instruction-following', - 'causal': 'Causal language modeling (raw text)', - 'completion': 'Text completion tasks' + "instruction": "Standard instruction-following", + "causal": "Causal language modeling (raw text)", + "completion": "Text completion tasks", } parser.add_argument( "--training_mode", - type=str, - default="instruction", - choices=list(TRAINING_MODES.keys()), - help="Training mode for the model" + type = str, + default = "instruction", + choices = list(TRAINING_MODES.keys()), + help = "Training mode for the model", ) args = parser.parse_args() diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index dd6d6b96d..d809eb312 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -26,23 +26,24 @@ __all__ = [ ] SUPPORTED_FORMATS = { - '.txt': 'plain_text', - '.md': 'markdown', - '.json': 'json_lines', - '.jsonl': 'json_lines', - '.csv': 'csv_text_column' + ".txt": "plain_text", + ".md": "markdown", + ".json": "json_lines", + ".jsonl": "json_lines", + ".csv": "csv_text_column", } + class RawTextDataLoader: - def __init__(self, tokenizer, chunk_size=2048, stride=512): + def __init__(self, tokenizer, chunk_size = 2048, stride = 512): self.tokenizer = tokenizer - self.chunk_size = chunk_size + self.chunk_size = chunk_size self.stride = stride - + def detect_format(self, file_path): """Auto-detect file format and parse accordingly""" extension = Path(file_path).suffix.lower() - return SUPPORTED_FORMATS.get(extension, 'plain_text') + return SUPPORTED_FORMATS.get(extension, "plain_text") def load_from_file(self, file_path): """Load raw text and convert to dataset""" @@ -50,7 +51,7 @@ class RawTextDataLoader: text_content = self._read_file_by_format(file_path, file_format) chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride) return self.create_causal_dataset(chunks) - + def load_from_files(self, file_paths): """Load multiple text files""" all_chunks = [] @@ -61,11 +62,10 @@ class RawTextDataLoader: all_chunks.extend(chunks) return self.create_causal_dataset(all_chunks) - def chunk_text(self, text): """Split text into overlapping chunks""" return self.smart_chunk_text(text, self.chunk_size, self.stride) - + def create_causal_dataset(self, chunks): """Create dataset for causal language modeling""" return Dataset.from_dict({"text": chunks}) @@ -79,51 +79,50 @@ class RawTextDataLoader: 4. Adds proper EOS tokens """ # First pass: tokenize the entire text to get accurate token counts - tokenized = self.tokenizer(text, return_tensors="pt", add_special_tokens=False) + tokenized = self.tokenizer(text, return_tensors = "pt", add_special_tokens = False) tokens = tokenized["input_ids"] - + # Handle different tokenizer return formats - if hasattr(tokens, '__len__') and len(tokens) > 0: + if hasattr(tokens, "__len__") and len(tokens) > 0: # If it's a nested structure, get the first element - if hasattr(tokens[0], '__len__'): + if hasattr(tokens[0], "__len__"): tokens = tokens[0] elif isinstance(tokens, int): # If tokenizer returns just a count, create a simple range tokens = list(range(tokens)) - + if len(tokens) <= chunk_size: # Text is small enough to fit in one chunk eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" return [text + eos_token] - + chunks = [] start_idx = 0 - + while start_idx < len(tokens): # Calculate end index for this chunk end_idx = min(start_idx + chunk_size, len(tokens)) - + # Extract tokens for this chunk chunk_tokens = tokens[start_idx:end_idx] - + # Decode back to text - chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens=True) - + chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens = True) + # Add EOS token if it's the last chunk or chunk is complete if end_idx == len(tokens) or len(chunk_tokens) == chunk_size: eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" chunk_text += eos_token - + chunks.append(chunk_text) - + # Move to next chunk with stride overlap if end_idx == len(tokens): break start_idx += chunk_size - stride - + return chunks - def tokenize_and_chunk(self, text): """ Tokenize first, then chunk by token count: @@ -134,10 +133,10 @@ class RawTextDataLoader: def _read_file_by_format(self, file_path, file_format): """Read file content based on detected format.""" - with open(file_path, 'r', encoding='utf-8') as f: - if file_format == 'plain_text' or file_format == 'markdown': + with open(file_path, "r", encoding = "utf-8") as f: + if file_format == "plain_text" or file_format == "markdown": return f.read() - elif file_format == 'json_lines': + elif file_format == "json_lines": lines = [] for line in f: try: @@ -147,42 +146,43 @@ class RawTextDataLoader: lines.append(text) except json.JSONDecodeError: continue - return '\n\n'.join(lines) - elif file_format == 'csv_text_column': + return "\n\n".join(lines) + elif file_format == "csv_text_column": reader = csv.DictReader(f) texts = [] for row in reader: text = self._extract_text_from_csv_row(row) if text: texts.append(text) - return '\n\n'.join(texts) + return "\n\n".join(texts) return "" - + def _extract_text_from_json(self, data): """Extract text from JSON object using common field names.""" - text_fields = ['text', 'content', 'message', 'body', 'description', 'prompt'] + text_fields = ["text", "content", "message", "body", "description", "prompt"] for field in text_fields: if field in data and isinstance(data[field], str): return data[field] return "" - + def _extract_text_from_csv_row(self, row): """Extract text from CSV row using common column names.""" - text_columns = ['text', 'content', 'message', 'body', 'description', 'prompt'] + text_columns = ["text", "content", "message", "body", "description", "prompt"] for column in text_columns: if column in row and row[column]: return row[column] return "" + class TextPreprocessor: def clean_text(self, text): """Remove unwanted characters, normalize whitespace""" - text = re.sub(r'\s+', ' ', text) - text = re.sub(r'[^\x20-\x7E\n\t]', '', text) - text = text.replace('\r\n', '\n').replace('\r', '\n') - text = re.sub(r'\n{3,}', '\n\n', text) + text = re.sub(r"\s+", " ", text) + text = re.sub(r"[^\x20-\x7E\n\t]", "", text) + text = text.replace("\r\n", "\n").replace("\r", "\n") + text = re.sub(r"\n{3,}", "\n\n", text) return text.strip() - + def extract_sections(self, text, patterns): """Extract specific sections (e.g., code blocks, quotes)""" sections = [] @@ -193,12 +193,20 @@ class TextPreprocessor: def add_structure_tokens(self, text): """Add special tokens for structure (chapters, sections)""" - text = re.sub(r'^# (.+)$', r'<|chapter|>\1<|/chapter|>', text, flags=re.MULTILINE) - text = re.sub(r'^## (.+)$', r'<|section|>\1<|/section|>', text, flags=re.MULTILINE) - text = re.sub(r'^### (.+)$', r'<|subsection|>\1<|/subsection|>', text, flags=re.MULTILINE) - text = re.sub(r'```(\w*)\n(.*?)\n```', r'<|code|\1|>\2<|/code|>', text, flags=re.DOTALL) + text = re.sub( + r"^# (.+)$", r"<|chapter|>\1<|/chapter|>", text, flags = re.MULTILINE + ) + text = re.sub( + r"^## (.+)$", r"<|section|>\1<|/section|>", text, flags = re.MULTILINE + ) + text = re.sub( + r"^### (.+)$", r"<|subsection|>\1<|/subsection|>", text, flags = re.MULTILINE + ) + text = re.sub( + r"```(\w*)\n(.*?)\n```", r"<|code|\1|>\2<|/code|>", text, flags = re.DOTALL + ) return text - + def validate_dataset(self, dataset): """ Check for: @@ -208,61 +216,66 @@ class TextPreprocessor: - Empty chunks """ stats = { - 'total_samples': len(dataset), - 'empty_samples': 0, - 'min_length': float('inf'), - 'max_length': 0, - 'avg_length': 0, - 'repeated_content': 0, - 'encoding_issues': 0, - 'warnings': [] + "total_samples": len(dataset), + "empty_samples": 0, + "min_length": float("inf"), + "max_length": 0, + "avg_length": 0, + "repeated_content": 0, + "encoding_issues": 0, + "warnings": [], } - - texts = dataset['text'] + + texts = dataset["text"] text_lengths = [] seen_texts = set() - + for i, text in enumerate(texts): if not text or len(text.strip()) == 0: - stats['empty_samples'] += 1 + stats["empty_samples"] += 1 continue - + # Check for encoding issues try: - text.encode('utf-8') + text.encode("utf-8") except UnicodeEncodeError: - stats['encoding_issues'] += 1 - + stats["encoding_issues"] += 1 + # Calculate lengths length = len(text) text_lengths.append(length) - stats['min_length'] = min(stats['min_length'], length) - stats['max_length'] = max(stats['max_length'], length) - + stats["min_length"] = min(stats["min_length"], length) + stats["max_length"] = max(stats["max_length"], length) + # Check for repeated content text_hash = hash(text.strip()) if text_hash in seen_texts: - stats['repeated_content'] += 1 + stats["repeated_content"] += 1 else: seen_texts.add(text_hash) - + # Calculate average length if text_lengths: - stats['avg_length'] = sum(text_lengths) / len(text_lengths) - stats['min_length'] = stats['min_length'] if stats['min_length'] != float('inf') else 0 - - # Generate warnings - if stats['empty_samples'] > 0: - stats['warnings'].append(f"Found {stats['empty_samples']} empty samples") - - if stats['repeated_content'] > 0: - stats['warnings'].append(f"Found {stats['repeated_content']} repeated samples") - - if stats['encoding_issues'] > 0: - stats['warnings'].append(f"Found {stats['encoding_issues']} encoding issues") - - if stats['min_length'] < 10: - stats['warnings'].append("Some samples are very short (< 10 characters)") - - return stats + stats["avg_length"] = sum(text_lengths) / len(text_lengths) + stats["min_length"] = ( + stats["min_length"] if stats["min_length"] != float("inf") else 0 + ) + # Generate warnings + if stats["empty_samples"] > 0: + stats["warnings"].append(f"Found {stats['empty_samples']} empty samples") + + if stats["repeated_content"] > 0: + stats["warnings"].append( + f"Found {stats['repeated_content']} repeated samples" + ) + + if stats["encoding_issues"] > 0: + stats["warnings"].append( + f"Found {stats['encoding_issues']} encoding issues" + ) + + if stats["min_length"] < 10: + stats["warnings"].append("Some samples are very short (< 10 characters)") + + return stats From c20a3b40ee5b51786e54206cd340ec43368d1d17 Mon Sep 17 00:00:00 2001 From: vangmay Date: Thu, 20 Nov 2025 20:53:22 +0800 Subject: [PATCH 10/21] Integrate smart dataset loader --- unsloth-cli.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/unsloth-cli.py b/unsloth-cli.py index aac0e7f7e..044ad93f3 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -100,6 +100,8 @@ def run(args): return {"text": texts} def load_dataset_smart(args): + from transformers.utils import strtobool + if args.raw_text_file: # Use raw text loader loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) @@ -109,20 +111,21 @@ def run(args): loader = RawTextDataLoader(tokenizer) dataset = loader.load_from_file(args.dataset) else: - # Existing HuggingFace dataset logic - dataset = load_dataset(args.dataset, split="train") + # Check for modelscope usage + use_modelscope = strtobool(os.environ.get("UNSLOTH_USE_MODELSCOPE", "False")) + if use_modelscope: + from modelscope import MsDataset + dataset = MsDataset.load(args.dataset, split="train") + else: + # Existing HuggingFace dataset logic + dataset = load_dataset(args.dataset, split="train") + + # Apply formatting for structured datasets dataset = dataset.map(formatting_prompts_func, batched=True) return dataset - use_modelscope = strtobool(os.environ.get("UNSLOTH_USE_MODELSCOPE", "False")) - if use_modelscope: - from modelscope import MsDataset - - dataset = MsDataset.load(args.dataset, split = "train") - else: - # Load and format dataset - dataset = load_dataset(args.dataset, split = "train") - dataset = dataset.map(formatting_prompts_func, batched = True) + # Load dataset using smart loader + dataset = load_dataset_smart(args) print("Data is formatted and ready!") # Configure training arguments From 25e69f2d36ca4212a26ae9f828e338bc15c1e9e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 12:57:53 +0000 Subject: [PATCH 11/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth-cli.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/unsloth-cli.py b/unsloth-cli.py index 44232e19d..79b3ef529 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -101,7 +101,7 @@ def run(args): def load_dataset_smart(args): from transformers.utils import strtobool - + if args.raw_text_file: # Use raw text loader loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) @@ -112,16 +112,19 @@ def run(args): dataset = loader.load_from_file(args.dataset) else: # Check for modelscope usage - use_modelscope = strtobool(os.environ.get("UNSLOTH_USE_MODELSCOPE", "False")) + use_modelscope = strtobool( + os.environ.get("UNSLOTH_USE_MODELSCOPE", "False") + ) if use_modelscope: from modelscope import MsDataset - dataset = MsDataset.load(args.dataset, split="train") + + dataset = MsDataset.load(args.dataset, split = "train") else: # Existing HuggingFace dataset logic - dataset = load_dataset(args.dataset, split="train") - + dataset = load_dataset(args.dataset, split = "train") + # Apply formatting for structured datasets - dataset = dataset.map(formatting_prompts_func, batched=True) + dataset = dataset.map(formatting_prompts_func, batched = True) return dataset # Load dataset using smart loader From f05169e56a7ec5e5af5cbeba7b4781db972054c9 Mon Sep 17 00:00:00 2001 From: vangmay Date: Thu, 20 Nov 2025 21:08:33 +0800 Subject: [PATCH 12/21] Make the chunk function efficient --- tests/test_raw_text.py | 24 ++++++++-- unsloth-cli.py | 20 ++++++--- unsloth/dataprep/raw_text.py | 85 ++++++++++++++++++++++++++++-------- 3 files changed, 99 insertions(+), 30 deletions(-) diff --git a/tests/test_raw_text.py b/tests/test_raw_text.py index 9bbfee92a..88dac2604 100644 --- a/tests/test_raw_text.py +++ b/tests/test_raw_text.py @@ -61,6 +61,7 @@ def test_raw_text_loader(): class MockTokenizer: def __init__(self): self.eos_token = "" + self.eos_token_id = 2 # Mock EOS token ID def __call__(self, text, return_tensors = None, add_special_tokens = False): words = text.split() @@ -77,6 +78,9 @@ def test_raw_text_loader(): def __len__(self): return len(self.data) + + def tolist(self): + return self.data return {"input_ids": [MockTensor(token_ids)]} return {"input_ids": token_ids} @@ -95,10 +99,22 @@ def test_raw_text_loader(): tokenizer = MockTokenizer() loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2) - # Test loading - dataset = loader.load_from_file(test_file) - assert len(dataset) > 0, "Should create at least one chunk" - assert "text" in dataset.column_names, "Dataset should have 'text' column" + # Test loading with text output (legacy mode) + text_dataset = loader.load_from_file(test_file, return_tensors=False) + assert len(text_dataset) > 0, "Should create at least one chunk" + assert "text" in text_dataset.column_names, "Dataset should have 'text' column" + + # Test loading with tokenized output (new efficient mode) + tokenized_dataset = loader.load_from_file(test_file, return_tensors=True) + assert len(tokenized_dataset) > 0, "Should create at least one tokenized chunk" + assert "input_ids" in tokenized_dataset.column_names, "Dataset should have 'input_ids' column" + assert "attention_mask" in tokenized_dataset.column_names, "Dataset should have 'attention_mask' column" + + # Verify tokenized data structure + first_sample = tokenized_dataset[0] + assert isinstance(first_sample["input_ids"], list), "input_ids should be a list" + assert isinstance(first_sample["attention_mask"], list), "attention_mask should be a list" + assert len(first_sample["input_ids"]) == len(first_sample["attention_mask"]), "input_ids and attention_mask should have same length" # Test preprocessor preprocessor = TextPreprocessor() diff --git a/unsloth-cli.py b/unsloth-cli.py index 79b3ef529..efdc0b3f4 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -103,13 +103,19 @@ def run(args): from transformers.utils import strtobool if args.raw_text_file: - # Use raw text loader + # Use raw text loader - returns pre-tokenized data loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) - dataset = loader.load_from_file(args.raw_text_file) + dataset = loader.load_from_file(args.raw_text_file, return_tensors=True) + # Mark dataset as pre-tokenized to skip text formatting + dataset._is_pretokenized = True + return dataset elif args.dataset.endswith((".txt", ".md", ".json", ".jsonl")): - # Auto-detect local raw text files - loader = RawTextDataLoader(tokenizer) - dataset = loader.load_from_file(args.dataset) + # Auto-detect local raw text files - returns pre-tokenized data + loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) + dataset = loader.load_from_file(args.dataset, return_tensors=True) + # Mark dataset as pre-tokenized to skip text formatting + dataset._is_pretokenized = True + return dataset else: # Check for modelscope usage use_modelscope = strtobool( @@ -123,9 +129,9 @@ def run(args): # Existing HuggingFace dataset logic dataset = load_dataset(args.dataset, split = "train") - # Apply formatting for structured datasets + # Apply formatting for structured datasets (text-based) dataset = dataset.map(formatting_prompts_func, batched = True) - return dataset + return dataset # Load dataset using smart loader dataset = load_dataset_smart(args) diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index d809eb312..50612c1b8 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -35,48 +35,66 @@ SUPPORTED_FORMATS = { class RawTextDataLoader: - def __init__(self, tokenizer, chunk_size = 2048, stride = 512): + def __init__(self, tokenizer, chunk_size = 2048, stride = 512, return_tokenized = True): self.tokenizer = tokenizer self.chunk_size = chunk_size self.stride = stride + self.return_tokenized = return_tokenized def detect_format(self, file_path): """Auto-detect file format and parse accordingly""" extension = Path(file_path).suffix.lower() return SUPPORTED_FORMATS.get(extension, "plain_text") - def load_from_file(self, file_path): + def load_from_file(self, file_path, return_tokenized=None): """Load raw text and convert to dataset""" + if return_tokenized is None: + return_tokenized = self.return_tokenized file_format = self.detect_format(file_path) text_content = self._read_file_by_format(file_path, file_format) - chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride) + chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride, return_tokenized) return self.create_causal_dataset(chunks) - def load_from_files(self, file_paths): + def load_from_files(self, file_paths, return_tokenized=None): """Load multiple text files""" + if return_tokenized is None: + return_tokenized = self.return_tokenized all_chunks = [] for file_path in file_paths: file_format = self.detect_format(file_path) text_content = self._read_file_by_format(file_path, file_format) - chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride) + chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride, return_tokenized) all_chunks.extend(chunks) return self.create_causal_dataset(all_chunks) - def chunk_text(self, text): + def chunk_text(self, text, return_tokenized=None): """Split text into overlapping chunks""" - return self.smart_chunk_text(text, self.chunk_size, self.stride) + if return_tokenized is None: + return_tokenized = self.return_tokenized + return self.smart_chunk_text(text, self.chunk_size, self.stride, return_tokenized) def create_causal_dataset(self, chunks): """Create dataset for causal language modeling""" - return Dataset.from_dict({"text": chunks}) + if chunks and isinstance(chunks[0], dict): + # If chunks are already tokenized (dict with input_ids, attention_mask) + # Reorganize the data structure for Dataset.from_dict + input_ids = [chunk["input_ids"] for chunk in chunks] + attention_mask = [chunk["attention_mask"] for chunk in chunks] + return Dataset.from_dict({ + "input_ids": input_ids, + "attention_mask": attention_mask + }) + else: + # If chunks are text strings (backward compatibility) + return Dataset.from_dict({"text": chunks}) - def smart_chunk_text(self, text, chunk_size, stride): + def smart_chunk_text(self, text, chunk_size, stride, return_tokenized=True): """ Intelligent chunking that: 1. Respects sentence/paragraph boundaries 2. Handles various text formats (.txt, .md, .json, etc.) 3. Maintains context with stride overlap - 4. Adds proper EOS tokens + 4. Returns tokenized chunks directly (more efficient) or text chunks """ # First pass: tokenize the entire text to get accurate token counts tokenized = self.tokenizer(text, return_tensors = "pt", add_special_tokens = False) @@ -93,8 +111,19 @@ class RawTextDataLoader: if len(tokens) <= chunk_size: # Text is small enough to fit in one chunk - eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" - return [text + eos_token] + if return_tokenized: + # Add EOS token to the tokens if available + eos_token_id = getattr(self.tokenizer, 'eos_token_id', None) + if eos_token_id is not None: + tokens = tokens.tolist() if hasattr(tokens, 'tolist') else list(tokens) + tokens.append(eos_token_id) + + # Create attention mask + attention_mask = [1] * len(tokens) + return [{"input_ids": tokens, "attention_mask": attention_mask}] + else: + eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" + return [text + eos_token] chunks = [] start_idx = 0 @@ -106,15 +135,33 @@ class RawTextDataLoader: # Extract tokens for this chunk chunk_tokens = tokens[start_idx:end_idx] - # Decode back to text - chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens = True) + if return_tokenized: + # Convert to list if it's a tensor + chunk_tokens_list = chunk_tokens.tolist() if hasattr(chunk_tokens, 'tolist') else list(chunk_tokens) + + # Add EOS token if it's the last chunk or chunk is complete + if end_idx == len(tokens) or len(chunk_tokens_list) == chunk_size: + eos_token_id = getattr(self.tokenizer, 'eos_token_id', None) + if eos_token_id is not None: + chunk_tokens_list.append(eos_token_id) - # Add EOS token if it's the last chunk or chunk is complete - if end_idx == len(tokens) or len(chunk_tokens) == chunk_size: - eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" - chunk_text += eos_token + # Create attention mask (all tokens are attended to) + attention_mask = [1] * len(chunk_tokens_list) + + chunks.append({ + "input_ids": chunk_tokens_list, + "attention_mask": attention_mask + }) + else: + # Decode back to text (backward compatibility) + chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens = True) - chunks.append(chunk_text) + # Add EOS token if it's the last chunk or chunk is complete + if end_idx == len(tokens) or len(chunk_tokens) == chunk_size: + eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" + chunk_text += eos_token + + chunks.append(chunk_text) # Move to next chunk with stride overlap if end_idx == len(tokens): From 3bf8ca7da21d31f4ea2294b0d00f2f690f5c8de3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 13:09:07 +0000 Subject: [PATCH 13/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_raw_text.py | 24 +++++++++----- unsloth-cli.py | 4 +-- unsloth/dataprep/raw_text.py | 62 ++++++++++++++++++++++-------------- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/tests/test_raw_text.py b/tests/test_raw_text.py index 88dac2604..5306c68fa 100644 --- a/tests/test_raw_text.py +++ b/tests/test_raw_text.py @@ -78,7 +78,7 @@ def test_raw_text_loader(): def __len__(self): return len(self.data) - + def tolist(self): return self.data @@ -100,21 +100,29 @@ def test_raw_text_loader(): loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2) # Test loading with text output (legacy mode) - text_dataset = loader.load_from_file(test_file, return_tensors=False) + text_dataset = loader.load_from_file(test_file, return_tensors = False) assert len(text_dataset) > 0, "Should create at least one chunk" assert "text" in text_dataset.column_names, "Dataset should have 'text' column" # Test loading with tokenized output (new efficient mode) - tokenized_dataset = loader.load_from_file(test_file, return_tensors=True) + tokenized_dataset = loader.load_from_file(test_file, return_tensors = True) assert len(tokenized_dataset) > 0, "Should create at least one tokenized chunk" - assert "input_ids" in tokenized_dataset.column_names, "Dataset should have 'input_ids' column" - assert "attention_mask" in tokenized_dataset.column_names, "Dataset should have 'attention_mask' column" - + assert ( + "input_ids" in tokenized_dataset.column_names + ), "Dataset should have 'input_ids' column" + assert ( + "attention_mask" in tokenized_dataset.column_names + ), "Dataset should have 'attention_mask' column" + # Verify tokenized data structure first_sample = tokenized_dataset[0] assert isinstance(first_sample["input_ids"], list), "input_ids should be a list" - assert isinstance(first_sample["attention_mask"], list), "attention_mask should be a list" - assert len(first_sample["input_ids"]) == len(first_sample["attention_mask"]), "input_ids and attention_mask should have same length" + assert isinstance( + first_sample["attention_mask"], list + ), "attention_mask should be a list" + assert len(first_sample["input_ids"]) == len( + first_sample["attention_mask"] + ), "input_ids and attention_mask should have same length" # Test preprocessor preprocessor = TextPreprocessor() diff --git a/unsloth-cli.py b/unsloth-cli.py index efdc0b3f4..5135d0049 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -105,14 +105,14 @@ def run(args): if args.raw_text_file: # Use raw text loader - returns pre-tokenized data loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) - dataset = loader.load_from_file(args.raw_text_file, return_tensors=True) + dataset = loader.load_from_file(args.raw_text_file, return_tensors = True) # Mark dataset as pre-tokenized to skip text formatting dataset._is_pretokenized = True return dataset elif args.dataset.endswith((".txt", ".md", ".json", ".jsonl")): # Auto-detect local raw text files - returns pre-tokenized data loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) - dataset = loader.load_from_file(args.dataset, return_tensors=True) + dataset = loader.load_from_file(args.dataset, return_tensors = True) # Mark dataset as pre-tokenized to skip text formatting dataset._is_pretokenized = True return dataset diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index 50612c1b8..d6a9a6a04 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -46,16 +46,18 @@ class RawTextDataLoader: extension = Path(file_path).suffix.lower() return SUPPORTED_FORMATS.get(extension, "plain_text") - def load_from_file(self, file_path, return_tokenized=None): + def load_from_file(self, file_path, return_tokenized = None): """Load raw text and convert to dataset""" if return_tokenized is None: return_tokenized = self.return_tokenized file_format = self.detect_format(file_path) text_content = self._read_file_by_format(file_path, file_format) - chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride, return_tokenized) + chunks = self.smart_chunk_text( + text_content, self.chunk_size, self.stride, return_tokenized + ) return self.create_causal_dataset(chunks) - def load_from_files(self, file_paths, return_tokenized=None): + def load_from_files(self, file_paths, return_tokenized = None): """Load multiple text files""" if return_tokenized is None: return_tokenized = self.return_tokenized @@ -63,15 +65,19 @@ class RawTextDataLoader: for file_path in file_paths: file_format = self.detect_format(file_path) text_content = self._read_file_by_format(file_path, file_format) - chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride, return_tokenized) + chunks = self.smart_chunk_text( + text_content, self.chunk_size, self.stride, return_tokenized + ) all_chunks.extend(chunks) return self.create_causal_dataset(all_chunks) - def chunk_text(self, text, return_tokenized=None): + def chunk_text(self, text, return_tokenized = None): """Split text into overlapping chunks""" if return_tokenized is None: return_tokenized = self.return_tokenized - return self.smart_chunk_text(text, self.chunk_size, self.stride, return_tokenized) + return self.smart_chunk_text( + text, self.chunk_size, self.stride, return_tokenized + ) def create_causal_dataset(self, chunks): """Create dataset for causal language modeling""" @@ -80,15 +86,14 @@ class RawTextDataLoader: # Reorganize the data structure for Dataset.from_dict input_ids = [chunk["input_ids"] for chunk in chunks] attention_mask = [chunk["attention_mask"] for chunk in chunks] - return Dataset.from_dict({ - "input_ids": input_ids, - "attention_mask": attention_mask - }) + return Dataset.from_dict( + {"input_ids": input_ids, "attention_mask": attention_mask} + ) else: # If chunks are text strings (backward compatibility) return Dataset.from_dict({"text": chunks}) - def smart_chunk_text(self, text, chunk_size, stride, return_tokenized=True): + def smart_chunk_text(self, text, chunk_size, stride, return_tokenized = True): """ Intelligent chunking that: 1. Respects sentence/paragraph boundaries @@ -113,11 +118,13 @@ class RawTextDataLoader: # Text is small enough to fit in one chunk if return_tokenized: # Add EOS token to the tokens if available - eos_token_id = getattr(self.tokenizer, 'eos_token_id', None) + eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if eos_token_id is not None: - tokens = tokens.tolist() if hasattr(tokens, 'tolist') else list(tokens) + tokens = ( + tokens.tolist() if hasattr(tokens, "tolist") else list(tokens) + ) tokens.append(eos_token_id) - + # Create attention mask attention_mask = [1] * len(tokens) return [{"input_ids": tokens, "attention_mask": attention_mask}] @@ -137,28 +144,35 @@ class RawTextDataLoader: if return_tokenized: # Convert to list if it's a tensor - chunk_tokens_list = chunk_tokens.tolist() if hasattr(chunk_tokens, 'tolist') else list(chunk_tokens) - + chunk_tokens_list = ( + chunk_tokens.tolist() + if hasattr(chunk_tokens, "tolist") + else list(chunk_tokens) + ) + # Add EOS token if it's the last chunk or chunk is complete if end_idx == len(tokens) or len(chunk_tokens_list) == chunk_size: - eos_token_id = getattr(self.tokenizer, 'eos_token_id', None) + eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if eos_token_id is not None: chunk_tokens_list.append(eos_token_id) # Create attention mask (all tokens are attended to) attention_mask = [1] * len(chunk_tokens_list) - - chunks.append({ - "input_ids": chunk_tokens_list, - "attention_mask": attention_mask - }) + + chunks.append( + {"input_ids": chunk_tokens_list, "attention_mask": attention_mask} + ) else: # Decode back to text (backward compatibility) - chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens = True) + chunk_text = self.tokenizer.decode( + chunk_tokens, skip_special_tokens = True + ) # Add EOS token if it's the last chunk or chunk is complete if end_idx == len(tokens) or len(chunk_tokens) == chunk_size: - eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "" + eos_token = ( + self.tokenizer.eos_token if self.tokenizer.eos_token else "" + ) chunk_text += eos_token chunks.append(chunk_text) From 082da69cc4f39637567ac69ac78d482bf1666f60 Mon Sep 17 00:00:00 2001 From: vangmay Date: Thu, 20 Nov 2025 21:40:45 +0800 Subject: [PATCH 14/21] remove old function --- unsloth/dataprep/raw_text.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index d6a9a6a04..b880e9733 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -184,14 +184,6 @@ class RawTextDataLoader: return chunks - def tokenize_and_chunk(self, text): - """ - Tokenize first, then chunk by token count: - 1. More precise length control - 2. Avoids mid-token splits - 3. Handles different languages better - """ - def _read_file_by_format(self, file_path, file_format): """Read file content based on detected format.""" with open(file_path, "r", encoding = "utf-8") as f: From 646629884b58655fc56bc0d597960be13bef36f6 Mon Sep 17 00:00:00 2001 From: vangmay Date: Tue, 25 Nov 2025 21:01:43 +0800 Subject: [PATCH 15/21] Remove training mode arg --- unsloth-cli.py | 34 +++++++--------------------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/unsloth-cli.py b/unsloth-cli.py index 5135d0049..ef8167ad3 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -103,19 +103,13 @@ def run(args): from transformers.utils import strtobool if args.raw_text_file: - # Use raw text loader - returns pre-tokenized data + # Use raw text loader loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) - dataset = loader.load_from_file(args.raw_text_file, return_tensors = True) - # Mark dataset as pre-tokenized to skip text formatting - dataset._is_pretokenized = True - return dataset + dataset = loader.load_from_file(args.raw_text_file) elif args.dataset.endswith((".txt", ".md", ".json", ".jsonl")): - # Auto-detect local raw text files - returns pre-tokenized data - loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride) - dataset = loader.load_from_file(args.dataset, return_tensors = True) - # Mark dataset as pre-tokenized to skip text formatting - dataset._is_pretokenized = True - return dataset + # Auto-detect local raw text files + loader = RawTextDataLoader(tokenizer) + dataset = loader.load_from_file(args.dataset) else: # Check for modelscope usage use_modelscope = strtobool( @@ -129,9 +123,9 @@ def run(args): # Existing HuggingFace dataset logic dataset = load_dataset(args.dataset, split = "train") - # Apply formatting for structured datasets (text-based) + # Apply formatting for structured datasets dataset = dataset.map(formatting_prompts_func, batched = True) - return dataset + return dataset # Load dataset using smart loader dataset = load_dataset_smart(args) @@ -427,19 +421,5 @@ if __name__ == "__main__": "--stride", type = int, default = 512, help = "Overlap between chunks" ) - TRAINING_MODES = { - "instruction": "Standard instruction-following", - "causal": "Causal language modeling (raw text)", - "completion": "Text completion tasks", - } - - parser.add_argument( - "--training_mode", - type = str, - default = "instruction", - choices = list(TRAINING_MODES.keys()), - help = "Training mode for the model", - ) - args = parser.parse_args() run(args) From fe36643c66439b0492e9cc547b151b337f553f56 Mon Sep 17 00:00:00 2001 From: vangmay Date: Wed, 10 Dec 2025 10:15:56 +0530 Subject: [PATCH 16/21] Fix RawTextDataLoader import issue --- unsloth/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 8b48ce3ba..b77af7c24 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -247,6 +247,8 @@ from .save import * from .chat_templates import * from .tokenizer_utils import * from .trainer import * +# Export dataprep utilities for CLI and downstream users +from .dataprep.raw_text import RawTextDataLoader, TextPreprocessor from unsloth_zoo.rl_environments import ( check_python_modules, create_locked_down_function, From 07966659d816b934f2048db05c25e41384687160 Mon Sep 17 00:00:00 2001 From: vangmay Date: Wed, 10 Dec 2025 10:17:23 +0530 Subject: [PATCH 17/21] Fix Incorrect non-relative import in dataprep package --- unsloth/dataprep/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/__init__.py b/unsloth/dataprep/__init__.py index b6840f247..048f9b801 100644 --- a/unsloth/dataprep/__init__.py +++ b/unsloth/dataprep/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. from .synthetic import * -from raw_text import * +from .raw_text import * From 96eba88c90ac46bfe60bf7c07a20051408eb39be Mon Sep 17 00:00:00 2001 From: vangmay Date: Wed, 10 Dec 2025 10:46:29 +0530 Subject: [PATCH 18/21] =?UTF-8?q?Fix=20Chunking=20loop=20can=20hang=20when?= =?UTF-8?q?=20stride=20=E2=89=A5=20chunk=5Fsize?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- unsloth/dataprep/raw_text.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index b880e9733..f7c1bf785 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -101,6 +101,13 @@ class RawTextDataLoader: 3. Maintains context with stride overlap 4. Returns tokenized chunks directly (more efficient) or text chunks """ + if chunk_size <= 0: + raise ValueError(f"chunk_size must be positive, got {chunk_size}") + if stride >= chunk_size: + raise ValueError( + f"stride ({stride}) must be smaller than chunk_size ({chunk_size}) to progress the chunking loop" + ) + # First pass: tokenize the entire text to get accurate token counts tokenized = self.tokenizer(text, return_tensors = "pt", add_special_tokens = False) tokens = tokenized["input_ids"] From fb565d52f01d68c9b234214e79cfbd7ab890b820 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 05:17:01 +0000 Subject: [PATCH 19/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index b77af7c24..26e43eec8 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -247,6 +247,7 @@ from .save import * from .chat_templates import * from .tokenizer_utils import * from .trainer import * + # Export dataprep utilities for CLI and downstream users from .dataprep.raw_text import RawTextDataLoader, TextPreprocessor from unsloth_zoo.rl_environments import ( From 16a2d901fae32e9eb570cb9cddf0b78bc09861d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 8 Jan 2026 11:35:00 +0000 Subject: [PATCH 20/21] Fix bugs and add improvements to RawTextDataLoader - Fix test file: use return_tokenized instead of return_tensors - Fix test file: use text_dataset instead of undefined dataset variable - Move parameter validation to constructor (fail fast on invalid params) - Add labels field in tokenized output for causal LM training - Add empty file handling with clear error message - Add tests for constructor validation and labels field --- tests/test_raw_text.py | 23 ++++++++++++++++++++--- unsloth/dataprep/raw_text.py | 19 +++++++++++-------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/tests/test_raw_text.py b/tests/test_raw_text.py index 5306c68fa..7c7272a55 100644 --- a/tests/test_raw_text.py +++ b/tests/test_raw_text.py @@ -100,12 +100,12 @@ def test_raw_text_loader(): loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2) # Test loading with text output (legacy mode) - text_dataset = loader.load_from_file(test_file, return_tensors = False) + text_dataset = loader.load_from_file(test_file, return_tokenized = False) assert len(text_dataset) > 0, "Should create at least one chunk" assert "text" in text_dataset.column_names, "Dataset should have 'text' column" # Test loading with tokenized output (new efficient mode) - tokenized_dataset = loader.load_from_file(test_file, return_tensors = True) + tokenized_dataset = loader.load_from_file(test_file, return_tokenized = True) assert len(tokenized_dataset) > 0, "Should create at least one tokenized chunk" assert ( "input_ids" in tokenized_dataset.column_names @@ -124,13 +124,30 @@ def test_raw_text_loader(): first_sample["attention_mask"] ), "input_ids and attention_mask should have same length" + # Verify labels field exists (for causal LM training) + assert "labels" in tokenized_dataset.column_names, "Dataset should have 'labels' column" + assert first_sample["labels"] == first_sample["input_ids"], "labels should match input_ids" + + # Test constructor validation + try: + bad_loader = RawTextDataLoader(tokenizer, chunk_size = 0, stride = 2) + assert False, "Should raise ValueError for chunk_size=0" + except ValueError as e: + assert "chunk_size must be positive" in str(e) + + try: + bad_loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 10) + assert False, "Should raise ValueError for stride >= chunk_size" + except ValueError as e: + assert "stride" in str(e) and "chunk_size" in str(e) + # Test preprocessor preprocessor = TextPreprocessor() clean_text = preprocessor.clean_text(" messy text \n\n\n ") assert "messy text" in clean_text, "Should clean text properly" # Test validation - stats = preprocessor.validate_dataset(dataset) + stats = preprocessor.validate_dataset(text_dataset) assert stats["total_samples"] > 0, "Should count samples" assert "warnings" in stats, "Should include warnings" diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index f7c1bf785..da64565bb 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -36,6 +36,12 @@ SUPPORTED_FORMATS = { class RawTextDataLoader: def __init__(self, tokenizer, chunk_size = 2048, stride = 512, return_tokenized = True): + if chunk_size <= 0: + raise ValueError(f"chunk_size must be positive, got {chunk_size}") + if stride >= chunk_size: + raise ValueError( + f"stride ({stride}) must be smaller than chunk_size ({chunk_size})" + ) self.tokenizer = tokenizer self.chunk_size = chunk_size self.stride = stride @@ -52,6 +58,8 @@ class RawTextDataLoader: return_tokenized = self.return_tokenized file_format = self.detect_format(file_path) text_content = self._read_file_by_format(file_path, file_format) + if not text_content or not text_content.strip(): + raise ValueError(f"File '{file_path}' is empty or contains only whitespace") chunks = self.smart_chunk_text( text_content, self.chunk_size, self.stride, return_tokenized ) @@ -86,8 +94,10 @@ class RawTextDataLoader: # Reorganize the data structure for Dataset.from_dict input_ids = [chunk["input_ids"] for chunk in chunks] attention_mask = [chunk["attention_mask"] for chunk in chunks] + # Labels are same as input_ids for causal LM training + labels = [list(ids) for ids in input_ids] return Dataset.from_dict( - {"input_ids": input_ids, "attention_mask": attention_mask} + {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} ) else: # If chunks are text strings (backward compatibility) @@ -101,13 +111,6 @@ class RawTextDataLoader: 3. Maintains context with stride overlap 4. Returns tokenized chunks directly (more efficient) or text chunks """ - if chunk_size <= 0: - raise ValueError(f"chunk_size must be positive, got {chunk_size}") - if stride >= chunk_size: - raise ValueError( - f"stride ({stride}) must be smaller than chunk_size ({chunk_size}) to progress the chunking loop" - ) - # First pass: tokenize the entire text to get accurate token counts tokenized = self.tokenizer(text, return_tensors = "pt", add_special_tokens = False) tokens = tokenized["input_ids"] From 362056402517da0277849dd83e9b60ca0641945f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 11:35:21 +0000 Subject: [PATCH 21/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_raw_text.py | 8 ++++++-- unsloth/dataprep/raw_text.py | 6 +++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/test_raw_text.py b/tests/test_raw_text.py index 7c7272a55..9f2e8cda4 100644 --- a/tests/test_raw_text.py +++ b/tests/test_raw_text.py @@ -125,8 +125,12 @@ def test_raw_text_loader(): ), "input_ids and attention_mask should have same length" # Verify labels field exists (for causal LM training) - assert "labels" in tokenized_dataset.column_names, "Dataset should have 'labels' column" - assert first_sample["labels"] == first_sample["input_ids"], "labels should match input_ids" + assert ( + "labels" in tokenized_dataset.column_names + ), "Dataset should have 'labels' column" + assert ( + first_sample["labels"] == first_sample["input_ids"] + ), "labels should match input_ids" # Test constructor validation try: diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py index da64565bb..ba010edab 100644 --- a/unsloth/dataprep/raw_text.py +++ b/unsloth/dataprep/raw_text.py @@ -97,7 +97,11 @@ class RawTextDataLoader: # Labels are same as input_ids for causal LM training labels = [list(ids) for ids in input_ids] return Dataset.from_dict( - {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } ) else: # If chunks are text strings (backward compatibility)