Simplify tokenizer creation (#44)

*Description of changes:* Minor simplification to how the tokenizer is
constructed from the config


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Lorenzo Stella 2024-04-05 17:15:33 +02:00 committed by GitHub
parent b4423b8c4d
commit 2042779efa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -5,6 +5,7 @@ import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import chronos
import torch
import torch.nn as nn
from transformers import (
@ -45,9 +46,8 @@ class ChronosConfig:
), f"Special token id's must be smaller than {self.n_special_tokens=}"
def create_tokenizer(self) -> "ChronosTokenizer":
if self.tokenizer_class == "MeanScaleUniformBins":
return MeanScaleUniformBins(**self.tokenizer_kwargs, config=self)
raise ValueError
class_ = getattr(chronos, self.tokenizer_class)
return class_(**self.tokenizer_kwargs, config=self)
class ChronosTokenizer: