mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
f05169e56a
commit
3bf8ca7da2
3 changed files with 56 additions and 34 deletions
|
|
@ -78,7 +78,7 @@ def test_raw_text_loader():
|
|||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def tolist(self):
|
||||
return self.data
|
||||
|
||||
|
|
@ -100,21 +100,29 @@ def test_raw_text_loader():
|
|||
loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2)
|
||||
|
||||
# Test loading with text output (legacy mode)
|
||||
text_dataset = loader.load_from_file(test_file, return_tensors=False)
|
||||
text_dataset = loader.load_from_file(test_file, return_tensors = False)
|
||||
assert len(text_dataset) > 0, "Should create at least one chunk"
|
||||
assert "text" in text_dataset.column_names, "Dataset should have 'text' column"
|
||||
|
||||
# Test loading with tokenized output (new efficient mode)
|
||||
tokenized_dataset = loader.load_from_file(test_file, return_tensors=True)
|
||||
tokenized_dataset = loader.load_from_file(test_file, return_tensors = True)
|
||||
assert len(tokenized_dataset) > 0, "Should create at least one tokenized chunk"
|
||||
assert "input_ids" in tokenized_dataset.column_names, "Dataset should have 'input_ids' column"
|
||||
assert "attention_mask" in tokenized_dataset.column_names, "Dataset should have 'attention_mask' column"
|
||||
|
||||
assert (
|
||||
"input_ids" in tokenized_dataset.column_names
|
||||
), "Dataset should have 'input_ids' column"
|
||||
assert (
|
||||
"attention_mask" in tokenized_dataset.column_names
|
||||
), "Dataset should have 'attention_mask' column"
|
||||
|
||||
# Verify tokenized data structure
|
||||
first_sample = tokenized_dataset[0]
|
||||
assert isinstance(first_sample["input_ids"], list), "input_ids should be a list"
|
||||
assert isinstance(first_sample["attention_mask"], list), "attention_mask should be a list"
|
||||
assert len(first_sample["input_ids"]) == len(first_sample["attention_mask"]), "input_ids and attention_mask should have same length"
|
||||
assert isinstance(
|
||||
first_sample["attention_mask"], list
|
||||
), "attention_mask should be a list"
|
||||
assert len(first_sample["input_ids"]) == len(
|
||||
first_sample["attention_mask"]
|
||||
), "input_ids and attention_mask should have same length"
|
||||
|
||||
# Test preprocessor
|
||||
preprocessor = TextPreprocessor()
|
||||
|
|
|
|||
|
|
@ -105,14 +105,14 @@ def run(args):
|
|||
if args.raw_text_file:
|
||||
# Use raw text loader - returns pre-tokenized data
|
||||
loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride)
|
||||
dataset = loader.load_from_file(args.raw_text_file, return_tensors=True)
|
||||
dataset = loader.load_from_file(args.raw_text_file, return_tensors = True)
|
||||
# Mark dataset as pre-tokenized to skip text formatting
|
||||
dataset._is_pretokenized = True
|
||||
return dataset
|
||||
elif args.dataset.endswith((".txt", ".md", ".json", ".jsonl")):
|
||||
# Auto-detect local raw text files - returns pre-tokenized data
|
||||
loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride)
|
||||
dataset = loader.load_from_file(args.dataset, return_tensors=True)
|
||||
dataset = loader.load_from_file(args.dataset, return_tensors = True)
|
||||
# Mark dataset as pre-tokenized to skip text formatting
|
||||
dataset._is_pretokenized = True
|
||||
return dataset
|
||||
|
|
|
|||
|
|
@ -46,16 +46,18 @@ class RawTextDataLoader:
|
|||
extension = Path(file_path).suffix.lower()
|
||||
return SUPPORTED_FORMATS.get(extension, "plain_text")
|
||||
|
||||
def load_from_file(self, file_path, return_tokenized=None):
|
||||
def load_from_file(self, file_path, return_tokenized = None):
|
||||
"""Load raw text and convert to dataset"""
|
||||
if return_tokenized is None:
|
||||
return_tokenized = self.return_tokenized
|
||||
file_format = self.detect_format(file_path)
|
||||
text_content = self._read_file_by_format(file_path, file_format)
|
||||
chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride, return_tokenized)
|
||||
chunks = self.smart_chunk_text(
|
||||
text_content, self.chunk_size, self.stride, return_tokenized
|
||||
)
|
||||
return self.create_causal_dataset(chunks)
|
||||
|
||||
def load_from_files(self, file_paths, return_tokenized=None):
|
||||
def load_from_files(self, file_paths, return_tokenized = None):
|
||||
"""Load multiple text files"""
|
||||
if return_tokenized is None:
|
||||
return_tokenized = self.return_tokenized
|
||||
|
|
@ -63,15 +65,19 @@ class RawTextDataLoader:
|
|||
for file_path in file_paths:
|
||||
file_format = self.detect_format(file_path)
|
||||
text_content = self._read_file_by_format(file_path, file_format)
|
||||
chunks = self.smart_chunk_text(text_content, self.chunk_size, self.stride, return_tokenized)
|
||||
chunks = self.smart_chunk_text(
|
||||
text_content, self.chunk_size, self.stride, return_tokenized
|
||||
)
|
||||
all_chunks.extend(chunks)
|
||||
return self.create_causal_dataset(all_chunks)
|
||||
|
||||
def chunk_text(self, text, return_tokenized=None):
|
||||
def chunk_text(self, text, return_tokenized = None):
|
||||
"""Split text into overlapping chunks"""
|
||||
if return_tokenized is None:
|
||||
return_tokenized = self.return_tokenized
|
||||
return self.smart_chunk_text(text, self.chunk_size, self.stride, return_tokenized)
|
||||
return self.smart_chunk_text(
|
||||
text, self.chunk_size, self.stride, return_tokenized
|
||||
)
|
||||
|
||||
def create_causal_dataset(self, chunks):
|
||||
"""Create dataset for causal language modeling"""
|
||||
|
|
@ -80,15 +86,14 @@ class RawTextDataLoader:
|
|||
# Reorganize the data structure for Dataset.from_dict
|
||||
input_ids = [chunk["input_ids"] for chunk in chunks]
|
||||
attention_mask = [chunk["attention_mask"] for chunk in chunks]
|
||||
return Dataset.from_dict({
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask
|
||||
})
|
||||
return Dataset.from_dict(
|
||||
{"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
)
|
||||
else:
|
||||
# If chunks are text strings (backward compatibility)
|
||||
return Dataset.from_dict({"text": chunks})
|
||||
|
||||
def smart_chunk_text(self, text, chunk_size, stride, return_tokenized=True):
|
||||
def smart_chunk_text(self, text, chunk_size, stride, return_tokenized = True):
|
||||
"""
|
||||
Intelligent chunking that:
|
||||
1. Respects sentence/paragraph boundaries
|
||||
|
|
@ -113,11 +118,13 @@ class RawTextDataLoader:
|
|||
# Text is small enough to fit in one chunk
|
||||
if return_tokenized:
|
||||
# Add EOS token to the tokens if available
|
||||
eos_token_id = getattr(self.tokenizer, 'eos_token_id', None)
|
||||
eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
|
||||
if eos_token_id is not None:
|
||||
tokens = tokens.tolist() if hasattr(tokens, 'tolist') else list(tokens)
|
||||
tokens = (
|
||||
tokens.tolist() if hasattr(tokens, "tolist") else list(tokens)
|
||||
)
|
||||
tokens.append(eos_token_id)
|
||||
|
||||
|
||||
# Create attention mask
|
||||
attention_mask = [1] * len(tokens)
|
||||
return [{"input_ids": tokens, "attention_mask": attention_mask}]
|
||||
|
|
@ -137,28 +144,35 @@ class RawTextDataLoader:
|
|||
|
||||
if return_tokenized:
|
||||
# Convert to list if it's a tensor
|
||||
chunk_tokens_list = chunk_tokens.tolist() if hasattr(chunk_tokens, 'tolist') else list(chunk_tokens)
|
||||
|
||||
chunk_tokens_list = (
|
||||
chunk_tokens.tolist()
|
||||
if hasattr(chunk_tokens, "tolist")
|
||||
else list(chunk_tokens)
|
||||
)
|
||||
|
||||
# Add EOS token if it's the last chunk or chunk is complete
|
||||
if end_idx == len(tokens) or len(chunk_tokens_list) == chunk_size:
|
||||
eos_token_id = getattr(self.tokenizer, 'eos_token_id', None)
|
||||
eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
|
||||
if eos_token_id is not None:
|
||||
chunk_tokens_list.append(eos_token_id)
|
||||
|
||||
# Create attention mask (all tokens are attended to)
|
||||
attention_mask = [1] * len(chunk_tokens_list)
|
||||
|
||||
chunks.append({
|
||||
"input_ids": chunk_tokens_list,
|
||||
"attention_mask": attention_mask
|
||||
})
|
||||
|
||||
chunks.append(
|
||||
{"input_ids": chunk_tokens_list, "attention_mask": attention_mask}
|
||||
)
|
||||
else:
|
||||
# Decode back to text (backward compatibility)
|
||||
chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens = True)
|
||||
chunk_text = self.tokenizer.decode(
|
||||
chunk_tokens, skip_special_tokens = True
|
||||
)
|
||||
|
||||
# Add EOS token if it's the last chunk or chunk is complete
|
||||
if end_idx == len(tokens) or len(chunk_tokens) == chunk_size:
|
||||
eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else ""
|
||||
eos_token = (
|
||||
self.tokenizer.eos_token if self.tokenizer.eos_token else ""
|
||||
)
|
||||
chunk_text += eos_token
|
||||
|
||||
chunks.append(chunk_text)
|
||||
|
|
|
|||
Loading…
Reference in a new issue