Add MLX inference support (#41)

*Issue #, if available:* #28

*Description of changes:* This PR adds MLX inference support.

## Summary of changes
- Update `pyproject.toml` with`mlx` dependencies.
- Create `chronos_mlx` package which will hosts all mlx inference stuff.
- All classes from `main:src/chronos/chronos.py` are copy-pasted into
`mlx:src/chronos_mlx/chronos.py` and modified to use numpy and mlx
arrays instead. Note that the reason for using numpy arrays as input and
output is that mlx doesn't support some operations that are required for
input and output transform.
- MLX implementation of T5 is in `src/chronos_mlx/t5.py`. It has been
adapted from
[ml-explore/mlx-examples](b8a348c1b8/t5/t5.py)
with the following main modifications:
      - Added support for attention mask.
      - Added support for top_k and top_p sampling.
- `src/chronos_mlx/translate.py` translates weights from a torch HF
model to mlx.
- Add `THIRD-PARTY-LICENSES.txt` for third party code from
`mlx-examples`.
- Add tests and CI for `mlx` version. 

## Sample inference code

```py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    dtype="bfloat16",
)

df = pd.read_csv(
    "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
    context, prediction_length
)  # shape [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(
    forecast_index,
    low,
    high,
    color="tomato",
    alpha=0.3,
    label="80% prediction interval",
)
plt.legend()
plt.grid()
plt.show()

```

## Benchmark


![benchmark](https://github.com/amazon-science/chronos-forecasting/assets/4028948/ee5d1b17-d33e-473c-aa7a-55dbe1059b9c)


```py
import timeit

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm

from chronos import ChronosPipeline as ChronosPipelineTorch
from chronos_mlx import ChronosPipeline as ChronosPipelineMLX


def benchmark_torch_model(
    pipeline: ChronosPipelineTorch,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [
            torch.tensor(item["target"])
            for item in test_data_input[idx : idx + batch_size]
        ]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = torch.cat(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst.numpy(), start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df


def benchmark_mlx_model(
    pipeline: ChronosPipelineMLX,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [item["target"] for item in test_data_input[idx : idx + batch_size]]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = np.concatenate(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst, start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df


def main(
    version: str = "cpu",  # cpu, mps, mlx
    dtype: str = "bfloat16",
    gluonts_dataset: str = "australian_electricity_demand",
    model_name: str = "amazon/chronos-t5-small",
    batch_size: int = 4,
):
    if version == "cpu" or version == "mps":
        pipeline = ChronosPipelineTorch.from_pretrained(
            model_name,
            device_map=version,
            torch_dtype=getattr(torch, dtype),
        )
        benchmark_fn = benchmark_torch_model
    else:
        pipeline = ChronosPipelineMLX.from_pretrained(model_name, dtype=dtype)
        benchmark_fn = benchmark_mlx_model

    result_df = benchmark_fn(
        pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size
    )
    result_df["model"] = model_name
    return result_df


if __name__ == "__main__":
    gluonts_dataset: str = "m4_hourly"
    model_name: str = "amazon/chronos-t5-mini"
    batch_size: int = 8
    dfs = []
    for version in ["cpu", "mps", "mlx"]:
        for dtype in ["float32"]:
            try:
                df = main(
                    version=version,
                    dtype=dtype,
                    model_name=model_name,
                    gluonts_dataset=gluonts_dataset,
                    batch_size=batch_size,
                )
                df["version"] = version
                df["dtype"] = dtype
                dfs.append(df)
            except TypeError:
                pass

    result_df = pd.concat(dfs).reset_index(drop=True)
    result_df.to_csv("benchmark.csv", index=False)

    result_df["version"] = result_df["version"].map(
        {"cpu": "Torch (CPU)", "mps": "Torch (MPS)", "mlx": "MLX"}
    )
    fig = plt.figure(figsize=(8, 5))
    g = sns.barplot(
        data=result_df,
        x="dtype",
        y="inference_time",
        hue="version",
        alpha=0.6,
    )
    plt.ylabel("Inference Time (on M1 Pro)")
    plt.title(f"{model_name} inference times on {gluonts_dataset} dataset")
    plt.savefig("benchmark.png", dpi=200)

```

## TODOs:
- [x] Implement `top_p` sampling.
- [x] Add tests.
- [x] Add CI.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
This commit is contained in:
Abdul Fatir 2024-04-08 15:03:44 +02:00 committed by GitHub
parent 2042779efa
commit 159ea36f7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 830 additions and 398 deletions

View file

@ -3,44 +3,23 @@ name: CI
on: [push, pull_request]
jobs:
type-check:
test-mlx:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ["3.11"]
platform: [ubuntu-latest]
python-version: ['3.11']
platform: [macos-14]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[typecheck]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Type checks with mypy
run: mypy src test
test:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ["3.11"]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Test with pytest
run: pytest
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Test with pytest
run: pytest test/

View file

@ -1,4 +1,7 @@
# Chronos: Learning the Language of Time Series
# [🧪 MLX Version] Chronos: Learning the Language of Time Series
> [!IMPORTANT]
> This is the **experimental** MLX version of Chronos for Apple Silicon Macs. Please use the `main` branch for the stable PyTorch version.
Chronos is a family of **pretrained time series forecasting models** based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes.
@ -28,10 +31,10 @@ The models in this repository are based on the [T5 architecture](https://arxiv.o
## Usage
To perform inference with Chronos models, install this package by running:
To perform inference with Chronos models on Apple Silcon devices, install this package by running:
```
pip install git+https://github.com/amazon-science/chronos-forecasting.git
pip install git+https://github.com/amazon-science/chronos-forecasting.git@mlx
```
### Forecasting
@ -43,20 +46,18 @@ A minimal example showing how to perform forecasting using Chronos models:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline
from chronos_mlx import ChronosPipeline
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cuda",
torch_dtype=torch.bfloat16,
dtype="bfloat16",
)
df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(df["#Passengers"])
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
context,
@ -69,7 +70,7 @@ forecast = pipeline.predict(
# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)
plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
@ -86,20 +87,18 @@ A minimal example showing how to extract encoder embeddings from Chronos models:
```python
import pandas as pd
import torch
from chronos import ChronosPipeline
from chronos_mlx import ChronosPipeline
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cuda",
torch_dtype=torch.bfloat16,
dtype="bfloat16",
)
df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(df["#Passengers"])
context = df["#Passengers"].values
embeddings, tokenizer_state = pipeline.embed(context)
```

23
THIRD-PARTY-LICENSES.txt Normal file
View file

@ -0,0 +1,23 @@
** mlx-examples; version b8a348c -- https://github.com/ml-explore/mlx-examples
MIT License
Copyright © 2023 Apple Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,5 +1,5 @@
[project]
name = "chronos"
name = "chronos_mlx"
version = "1.1.0"
requires-python = ">=3.8"
license = { file = "LICENSE" }
@ -7,6 +7,7 @@ dependencies = [
"torch~=2.1", # package was tested on 2.2
"transformers~=4.31",
"accelerate",
"mlx~=0.9.0"
]
[project.optional-dependencies]

View file

@ -3,18 +3,18 @@
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import chronos
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
GenerationConfig,
PreTrainedModel,
)
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import T5Config
import chronos_mlx
from chronos_mlx.t5 import T5
from chronos_mlx.translate import translate_weights
@dataclass
@ -46,13 +46,13 @@ class ChronosConfig:
), f"Special token id's must be smaller than {self.n_special_tokens=}"
def create_tokenizer(self) -> "ChronosTokenizer":
class_ = getattr(chronos, self.tokenizer_class)
class_ = getattr(chronos_mlx, self.tokenizer_class)
return class_(**self.tokenizer_kwargs, config=self)
class ChronosTokenizer:
"""
A ``ChronosTokenizer`` definines how time series are mapped into token IDs
A ``ChronosTokenizer`` defines how time series are mapped into token IDs
and back.
For details, see the ``input_transform`` and ``output_transform`` methods,
@ -60,27 +60,27 @@ class ChronosTokenizer:
"""
def input_transform(
self, context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
self, context: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Any]:
"""
Turn a batch of time series into token IDs, attention map, and scale.
Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
A numpy array shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``np.nan``
to align time series of different lengths.
Returns
-------
token_ids
A tensor of integers, shaped (batch_size, time_length + 1)
A numpy array of integers, shaped (batch_size, time_length + 1)
if ``config.use_eos_token`` and (batch_size, time_length)
otherwise, containing token IDs for the input series.
attention_mask
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
A boolean numpy array, same shape as ``token_ids``, indicating
which input observations are not ``np.nan`` (i.e. not
missing nor padding).
tokenizer_state
An object that will be passed to ``output_transform``.
@ -89,16 +89,14 @@ class ChronosTokenizer:
"""
raise NotImplementedError()
def output_transform(
self, samples: torch.Tensor, tokenizer_state: Any
) -> torch.Tensor:
def output_transform(self, samples: np.ndarray, tokenizer_state: Any) -> np.ndarray:
"""
Turn a batch of sample token IDs into real values.
Parameters
----------
samples
A tensor of integers, shaped (batch_size, num_samples, time_length),
A numpy array of integers, shaped (batch_size, num_samples, time_length),
containing token IDs of sample trajectories.
tokenizer_state
An object returned by ``input_transform`` containing
@ -108,7 +106,7 @@ class ChronosTokenizer:
Returns
-------
forecasts
A real tensor, shaped (batch_size, num_samples, time_length),
A real numpy array, shaped (batch_size, num_samples, time_length),
containing forecasted sample paths.
"""
raise NotImplementedError()
@ -119,70 +117,60 @@ class MeanScaleUniformBins(ChronosTokenizer):
self, low_limit: float, high_limit: float, config: ChronosConfig
) -> None:
self.config = config
self.centers = torch.linspace(
self.centers = np.linspace(
low_limit,
high_limit,
config.n_tokens - config.n_special_tokens - 1,
)
self.boundaries = torch.concat(
self.boundaries = np.concatenate(
(
torch.tensor([-1e20], device=self.centers.device),
np.array([-1e20]),
(self.centers[1:] + self.centers[:-1]) / 2,
torch.tensor([1e20], device=self.centers.device),
np.array([1e20]),
)
)
def input_transform(
self, context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self, context: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
batch_size, length = context.shape
if length > self.config.context_length:
context = context[..., -self.config.context_length :]
attention_mask = ~torch.isnan(context)
scale = torch.nansum(
torch.abs(context) * attention_mask, dim=-1
) / torch.nansum(attention_mask, dim=-1)
attention_mask = ~np.isnan(context)
scale = np.nansum(np.abs(context) * attention_mask, axis=-1) / np.nansum(
attention_mask, axis=-1
)
scale[~(scale > 0)] = 1.0
scaled_context = context / scale.unsqueeze(dim=-1)
scaled_context = context / scale[..., np.newaxis]
token_ids = (
torch.bucketize(
input=scaled_context,
boundaries=self.boundaries,
# buckets are open to the right, see:
# https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize
right=True,
)
np.digitize(scaled_context, bins=self.boundaries)
+ self.config.n_special_tokens
)
token_ids[~attention_mask] = self.config.pad_token_id
if self.config.use_eos_token:
eos_tokens = torch.full(
(batch_size, 1), fill_value=self.config.eos_token_id
)
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
eos_mask = torch.full((batch_size, 1), fill_value=True)
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
eos_tokens = np.full((batch_size, 1), fill_value=self.config.eos_token_id)
token_ids = np.concatenate((token_ids, eos_tokens), axis=1)
eos_mask = np.full((batch_size, 1), fill_value=True)
attention_mask = np.concatenate((attention_mask, eos_mask), axis=1)
return token_ids, attention_mask, scale
def output_transform(
self, samples: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
indices = torch.clamp(
def output_transform(self, samples: np.ndarray, scale: np.ndarray) -> np.ndarray:
scale_unsqueezed = scale[..., np.newaxis, np.newaxis]
indices = np.clip(
samples - self.config.n_special_tokens,
min=0,
max=len(self.centers) - 1,
a_min=0,
a_max=len(self.centers) - 1,
)
return self.centers[indices] * scale_unsqueezed
class ChronosModel(nn.Module):
"""
A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers``
A ``ChronosModel`` wraps a ``T5`` object from ``chronos.mlx.t5``
and uses it to predict sample paths for time series tokens.
Parameters
@ -190,22 +178,21 @@ class ChronosModel(nn.Module):
config
The configuration to use.
model
The pre-trained model to use.
The pretrained T5 model to use.
"""
def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None:
def __init__(self, config: ChronosConfig, model: T5) -> None:
super().__init__()
assert config.model_type == "seq2seq" and isinstance(
model, T5
), "Only the T5 model is currently supported in MLX"
self.config = config
self.model = model
@property
def device(self):
return self.model.device
def encode(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
input_ids: np.ndarray,
attention_mask: np.ndarray,
):
"""
Extract the encoder embedding for the given token sequences.
@ -213,35 +200,33 @@ class ChronosModel(nn.Module):
Parameters
----------
input_ids
Tensor of indices of input sequence tokens in the vocabulary
Array of indices of input sequence tokens in the vocabulary
with shape (batch_size, sequence_length).
attention_mask
A mask tensor of the same shape as input_ids to avoid attending
A mask array of the same shape as input_ids to avoid attending
on padding or missing tokens.
Returns
-------
embedding
A tensor of encoder embeddings with shape
An array of encoder embeddings with shape
(batch_size, sequence_length, d_model).
"""
assert (
self.config.model_type == "seq2seq"
), "Encoder embeddings are only supported for encoder-decoder models"
return self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
return self.model.encode(inputs=input_ids, mask=attention_mask)
def forward(
def __call__(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
input_ids: mx.array,
attention_mask: mx.array,
prediction_length: Optional[int] = None,
num_samples: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> torch.Tensor:
) -> mx.array:
"""
Predict future sample tokens for the given token sequences.
@ -253,7 +238,7 @@ class ChronosModel(nn.Module):
Returns
-------
samples
A tensor of integers, shaped (batch_size, num_samples, time_length),
A numpy array of integers, shaped (batch_size, num_samples, time_length),
containing forecasted sample paths.
"""
if prediction_length is None:
@ -270,40 +255,31 @@ class ChronosModel(nn.Module):
preds = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=GenerationConfig(
min_new_tokens=prediction_length,
max_new_tokens=prediction_length,
do_sample=True,
num_return_sequences=num_samples,
eos_token_id=self.config.eos_token_id,
pad_token_id=self.config.pad_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
),
min_new_tokens=prediction_length,
max_new_tokens=prediction_length,
do_sample=True,
num_return_sequences=num_samples,
eos_token_id=self.config.eos_token_id,
pad_token_id=self.config.pad_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
if self.config.model_type == "seq2seq":
preds = preds[..., 1:] # remove the decoder start token
else:
assert self.config.model_type == "causal"
assert preds.size(-1) == input_ids.size(-1) + prediction_length
preds = preds[..., -prediction_length:]
preds = preds[..., 1:] # remove the decoder start token
return preds.reshape(input_ids.size(0), num_samples, -1)
return preds.reshape(input_ids.shape[0], num_samples, -1)
def left_pad_and_stack_1D(tensors: List[torch.Tensor]):
max_len = max(len(c) for c in tensors)
def left_pad_and_stack_1D(arrays: List[np.ndarray]):
max_len = max(len(c) for c in arrays)
padded = []
for c in tensors:
assert isinstance(c, torch.Tensor)
for c in arrays:
assert isinstance(c, np.ndarray)
assert c.ndim == 1
padding = torch.full(
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
)
padded.append(torch.concat((padding, c), dim=-1))
return torch.stack(padded)
padding = np.full(shape=(max_len - len(c),), fill_value=np.nan)
padded.append(np.concatenate((padding, c), axis=-1))
return np.stack(padded)
class ChronosPipeline:
@ -330,21 +306,20 @@ class ChronosPipeline:
self.model = model
def _prepare_and_validate_context(
self, context: Union[torch.Tensor, List[torch.Tensor]]
):
self, context: Union[np.ndarray, List[np.ndarray]]
) -> np.ndarray:
if isinstance(context, list):
context = left_pad_and_stack_1D(context)
assert isinstance(context, torch.Tensor)
assert isinstance(context, np.ndarray)
if context.ndim == 1:
context = context.unsqueeze(0)
context = context[np.newaxis, ...]
assert context.ndim == 2
return context
@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Any]:
self, context: Union[np.ndarray, List[np.ndarray]]
) -> Tuple[np.ndarray, Any]:
"""
Get encoder embeddings for the given time series.
@ -367,36 +342,36 @@ class ChronosPipeline:
or the length of the longest time series, if a list of 1D tensors was
provided, and the extra 1 is for EOS.
"""
context_tensor = self._prepare_and_validate_context(context=context)
context_array = self._prepare_and_validate_context(context=context)
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
context_tensor
context_array
)
embeddings = self.model.encode(
input_ids=token_ids.to(self.model.device),
attention_mask=attention_mask.to(self.model.device),
).cpu()
return embeddings, tokenizer_state
input_ids=mx.array(token_ids),
attention_mask=mx.array(attention_mask),
)
return np.array(embeddings.astype(mx.float32)), tokenizer_state
def predict(
self,
context: Union[torch.Tensor, List[torch.Tensor]],
context: Union[np.ndarray, List[np.ndarray]],
prediction_length: Optional[int] = None,
num_samples: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
limit_prediction_length: bool = True,
) -> torch.Tensor:
) -> np.ndarray:
"""
Get forecasts for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
Input series. This is either a 1D numpy array, or a list
of 1D numpy arrays, or a 2D numpy array whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
``np.nan`` to align series of different lengths.
prediction_length
Time steps to predict. Defaults to what specified
in ``self.model.config``.
@ -421,10 +396,10 @@ class ChronosPipeline:
Returns
-------
samples
Tensor of sample forecasts, of shape
Numpy array of sample forecasts, of shape
(batch_size, num_samples, prediction_length).
"""
context_tensor = self._prepare_and_validate_context(context=context)
context_array = self._prepare_and_validate_context(context=context)
if prediction_length is None:
prediction_length = self.model.config.prediction_length
@ -432,7 +407,7 @@ class ChronosPipeline:
if prediction_length > self.model.config.prediction_length:
msg = (
f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
"The quality of longer predictions may degrade since the model is not optimized for it. "
f"The quality of longer predictions may degrade since the model is not optimized for it. "
)
if limit_prediction_length:
msg += "You can turn off this check by setting `limit_prediction_length=False`."
@ -444,20 +419,19 @@ class ChronosPipeline:
while remaining > 0:
token_ids, attention_mask, scale = self.tokenizer.input_transform(
context_tensor
context_array
)
token_ids, attention_mask = mx.array(token_ids), mx.array(attention_mask)
samples = self.model(
token_ids.to(self.model.device),
attention_mask.to(self.model.device),
token_ids,
attention_mask,
min(remaining, self.model.config.prediction_length),
num_samples,
temperature,
top_k,
top_p,
)
prediction = self.tokenizer.output_transform(
samples.to(scale.device), scale
)
prediction = self.tokenizer.output_transform(np.array(samples), scale)
predictions.append(prediction)
remaining -= prediction.shape[-1]
@ -465,31 +439,44 @@ class ChronosPipeline:
if remaining <= 0:
break
context_tensor = torch.cat(
[context_tensor, prediction.median(dim=1).values], dim=-1
context_array = np.concatenate(
[context_array, np.median(prediction, axis=1)], axis=-1
)
return torch.cat(predictions, dim=-1)
return np.concatenate(predictions, axis=-1)
@classmethod
def from_pretrained(cls, *args, **kwargs):
def from_pretrained(
cls, model_name_or_path: Union[str, Path], dtype: str = "float32"
):
"""
Load the model, either from a local path or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
from ``transformers``.
Parameters
----------
model_name_or_path
Model ID on HuggingFace Hub or local path.
dtype, optional
String denoting the float dtype of the mlx model,
by default "float32"
Returns
-------
A ChronosPipeline
"""
config = AutoConfig.from_pretrained(*args, **kwargs)
config = T5Config.from_pretrained(model_name_or_path)
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
dtype = getattr(mx, dtype)
chronos_config = ChronosConfig(**config.chronos_config)
if chronos_config.model_type == "seq2seq":
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
else:
assert config.model_type == "causal"
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
inner_model = T5(config=config)
weights = translate_weights(model_name_or_path=model_name_or_path, dtype=dtype)
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
inner_model.update(weights)
mx.eval(inner_model.parameters())
return cls(
tokenizer=chronos_config.create_tokenizer(),

420
src/chronos_mlx/t5.py Normal file
View file

@ -0,0 +1,420 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# Adapted from ml-explore/mlx-examples:
# https://github.com/ml-explore/mlx-examples/blob/b8a348c1b8df4433cfacb9adbeb89b8aa3979ab2/t5/t5.py
# Modifications:
# - Added support for attention mask.
# - Added support for top_k and top_p sampling.
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from transformers import T5Config
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
# Adapted from HuggingFace transformers:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position)
else:
relative_position = -mx.minimum(
relative_position, mx.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins
# in positions up to max_distance
scale = (num_buckets - max_exact) / np.log(max_distance / max_exact)
relative_position_if_large = max_exact + (
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
).astype(mx.int16)
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += mx.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.num_heads
)
def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias"""
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = memory_position - context_position
relative_position_bucket = _relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=3)
values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim]
scores = queries @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(self, x, mask):
y = self.ln1(x)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.dense(y)
return x + y
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])[None]
if mask is not None:
mask = mask[:, None, None, :]
pos_bias += mask
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(
self,
x: mx.array,
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None,
):
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
y, _ = self.cross_attention(y, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.dense(y)
x = x + y
return x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, mask, memory_mask, cache=None):
if cache is not None:
offset = cache[0][0].shape[3]
else:
offset = 0
cache = [None] * len(self.layers)
T = offset + x.shape[1]
pos_bias = self.relative_attention_bias(T, T, offset=offset)
if mask is not None:
mask += pos_bias
else:
mask = pos_bias
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
x = self.ln(x)
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: T5Config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
return self.linear(inputs)
def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1):
assert min_tokens_to_keep <= logits.shape[-1]
logits_dtype = logits.dtype
# FIXME: The following is needed because mlx doesn't have the cumsum
# kernel for bfloat16. Once that is supported natively, this casting
# should be removed. @abdulfatir
logits = logits.astype(mx.float32)
sorted_indices = mx.argsort(logits, axis=-1)
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
cumulative_probs = mx.softmax(sorted_logits, axis=-1).cumsum(axis=-1, reverse=True)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., -min_tokens_to_keep:] = False
masked_sorted_logits = mx.where(sorted_indices_to_remove, -mx.inf, sorted_logits)
unsorted_indices = mx.argsort(sorted_indices, axis=-1)
return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1).astype(
logits_dtype
)
def sample(logits, top_k=1, top_p=1.0, temperature=1.0):
vocab_size = logits.shape[-1]
assert top_p <= 1.0, f"{top_p=} should be <= 1.0"
if temperature == 0 or top_k == 1:
return mx.argmax(logits, axis=-1)
else:
# Apply temperature term
if temperature != 1.0:
logits /= temperature
# Apply top_k
if top_k >= vocab_size:
return mx.random.categorical(
apply_top_p(logits, top_p=top_p) if top_p < 1.0 else logits
)
top_k_indices = mx.argpartition(logits, top_k, axis=-1)[..., -top_k:]
top_k_logits = mx.take_along_axis(logits, top_k_indices, axis=-1)
# Apply top_p
if top_p < 1.0:
top_k_logits = apply_top_p(top_k_logits, top_p=top_p)
return top_k_indices[
mx.arange(top_k_indices.shape[0]), mx.random.categorical(top_k_logits)
]
class T5(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config)
self.tie_word_embeddings = config.tie_word_embeddings
if not self.tie_word_embeddings:
self.lm_head = OutputHead(config)
self.model_dim = config.d_model
def encode(self, inputs: mx.array, mask: mx.array):
return self.encoder(self.wte(inputs), mask)
def decode(
self,
inputs: mx.array,
memory: mx.array,
memory_mask: mx.array,
cache=None,
):
inputs = self.wte(inputs)
T = inputs.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(inputs.dtype)
else:
mask = None
memory_mask = memory_mask[:, None, None, :]
y, cache = self.decoder(
inputs, memory=memory, mask=mask, memory_mask=memory_mask, cache=cache
)
if not self.tie_word_embeddings:
y = self.lm_head(y)
else:
y *= self.model_dim**-0.5
y = y @ self.wte.weight.T
return y, cache
def __call__(
self,
inputs: mx.array,
mask: mx.array,
decoder_inputs: mx.array,
):
memory = self.encode(inputs, mask=mask)
return self.decode(decoder_inputs, memory=memory, memory_mask=mask)[0]
def generate(
self,
input_ids: mx.array,
attention_mask: mx.array,
min_new_tokens: Optional[int] = None,
max_new_tokens: int = 64,
do_sample: bool = True,
num_return_sequences: int = 1,
pad_token_id: int = 0,
eos_token_id: Optional[int] = None,
temperature: Optional[float] = 1.0,
top_k: int = 50,
top_p: float = 1.0,
):
self.eval()
def should_stop(current_token, num_sampled_tokens):
if eos_token_id is not None and (current_token == eos_token_id).all():
return True
if num_sampled_tokens >= max_new_tokens:
return True
return False
top_k = top_k if do_sample else 1
attention_mask = (1.0 - attention_mask.astype(mx.float32)) * -1e9
memory = self.encode(input_ids, mask=attention_mask)
repeated_memory = mx.repeat(memory, num_return_sequences, axis=0)
repeated_attention_mask = mx.repeat(
attention_mask, num_return_sequences, axis=0
)
decoder_start_id = pad_token_id
decoder_inputs = mx.array([decoder_start_id] * len(repeated_attention_mask))[
:, None
]
cache = None
prediction = [decoder_inputs]
num_sampled_tokens = 0
while not should_stop(prediction[-1], num_sampled_tokens):
logits, cache = self.decode(
prediction[-1],
repeated_memory,
memory_mask=repeated_attention_mask,
cache=cache,
)
if (
min_new_tokens is not None
and eos_token_id is not None
and num_sampled_tokens < min_new_tokens
):
logits[..., eos_token_id] = -float("inf")
y = sample(
logits[:, -1, :], top_k=top_k, top_p=top_p, temperature=temperature
)
num_sampled_tokens += 1
prediction.append(y[:, None])
return mx.concatenate(prediction, axis=-1)

View file

@ -0,0 +1,77 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# Adapted from ml-explore/mlx-examples:
# https://github.com/ml-explore/mlx-examples/blob/b8a348c1b8df4433cfacb9adbeb89b8aa3979ab2/t5/convert.py
from pathlib import Path
from typing import Union
import mlx.core as mx
import torch
from transformers import T5ForConditionalGeneration
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in ENCODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in DECODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
return key
def translate_weights(model_name_or_path: Union[str, Path], dtype: mx.Dtype):
"""Translate a HuggingFace transformers T5 model to MLX.
Parameters
----------
model_name
HuggingFace model name or local path.
dtype
mlx dtype for the resulting mlx weights.
Returns
-------
A state dictionary with weights as mlx arrays.
"""
model = T5ForConditionalGeneration.from_pretrained(
model_name_or_path, torch_dtype=torch.float32
)
weights = {
replace_key(k): mx.array(v.numpy(), dtype=dtype)
for k, v in model.state_dict().items()
}
return weights

View file

@ -1,208 +0,0 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Tuple
import torch
import pytest
from chronos import ChronosConfig, ChronosPipeline
@pytest.mark.xfail
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
@pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
@pytest.mark.parametrize("use_eos_token", [False, True])
def test_tokenizer_fixed_data(
n_numerical_tokens: int, n_special_tokens: int, use_eos_token: bool
):
n_tokens = n_numerical_tokens + n_special_tokens
context_length = 3
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=n_special_tokens,
pad_token_id=0,
eos_token_id=1,
use_eos_token=use_eos_token,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
context = torch.tensor(
[
[-3.7, 3.7],
[-42.0, 42.0],
]
)
batch_size, _ = context.shape
token_ids, attention_mask, scale = tokenizer.input_transform(context)
assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
assert all(token_ids[:, 1] == torch.tensor([n_special_tokens]).repeat(batch_size))
assert all(token_ids[:, 2] == torch.tensor([n_tokens - 1]).repeat(batch_size))
if use_eos_token:
assert all(token_ids[:, 3] == torch.tensor([1]).repeat(batch_size))
samples = tokenizer.output_transform(
torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1),
tokenizer_state=scale,
)
assert (samples[:, 0, [0, -1]] == context).all()
@pytest.mark.xfail
@pytest.mark.parametrize("use_eos_token", [False, True])
def test_tokenizer_random_data(use_eos_token: bool):
context_length = 8
n_tokens = 256
n_special_tokens = 2
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=n_special_tokens,
pad_token_id=0,
eos_token_id=1,
use_eos_token=use_eos_token,
model_type="seq2seq",
context_length=context_length,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
context = torch.tensor(
[
[torch.nan, torch.nan, 1.0, 1.1, torch.nan, 2.0],
[3.0, torch.nan, 3.9, 4.0, 4.1, 4.9],
]
)
token_ids, attention_mask, scale = tokenizer.input_transform(context)
assert token_ids.shape == (
*context.shape[:-1],
context_length + 1 * use_eos_token,
)
assert attention_mask.shape == (
*context.shape[:-1],
context_length + 1 * use_eos_token,
)
assert scale.shape == context.shape[:1]
sample_ids = torch.randint(low=n_special_tokens, high=n_tokens, size=(2, 10, 4))
sample_ids[0, 0, 0] = n_special_tokens
sample_ids[-1, -1, -1] = n_tokens - 1
samples = tokenizer.output_transform(sample_ids, scale)
assert samples.shape == (2, 10, 4)
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None:
assert isinstance(samples, torch.Tensor)
assert samples.shape == shape
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_predict(torch_dtype: str):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=torch_dtype,
)
context = 10 * torch.rand(size=(4, 16)) + 10
# input: tensor of shape (batch_size, context_length)
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
validate_tensor(samples, (4, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=False
)
validate_tensor(samples, (4, 7, 65))
# input: batch_size-long list of tensors of shape (context_length,)
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
validate_tensor(samples, (4, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
samples = pipeline.predict(
list(context),
num_samples=7,
prediction_length=65,
limit_prediction_length=False,
)
validate_tensor(samples, (4, 7, 65))
# input: tensor of shape (context_length,)
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
validate_tensor(samples, (1, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
samples = pipeline.predict(
context[0, ...],
num_samples=7,
prediction_length=65,
limit_prediction_length=False,
)
validate_tensor(samples, (1, 7, 65))
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_embed(torch_dtype: str):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=torch_dtype,
)
d_model = pipeline.model.model.config.d_model
context = 10 * torch.rand(size=(4, 16)) + 10
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
# input: tensor of shape (batch_size, context_length)
embedding, scale = pipeline.embed(context)
validate_tensor(embedding, (4, expected_embed_length, d_model))
validate_tensor(scale, (4,))
# input: batch_size-long list of tensors of shape (context_length,)
embedding, scale = pipeline.embed(list(context))
validate_tensor(embedding, (4, expected_embed_length, d_model))
validate_tensor(scale, (4,))
# input: tensor of shape (context_length,)
embedding, scale = pipeline.embed(context[0, ...])
validate_tensor(embedding, (1, expected_embed_length, d_model))
validate_tensor(scale, (1,))

154
test/test_chronos_mlx.py Normal file
View file

@ -0,0 +1,154 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Tuple
import mlx.core as mx
import numpy as np
import pytest
from chronos_mlx.t5 import apply_top_p
from chronos_mlx import ChronosPipeline
def validate_array(samples: np.ndarray, shape: Tuple[int, ...]) -> None:
assert isinstance(samples, np.ndarray)
assert samples.shape == shape
@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
def test_pipeline_predict(dtype: str):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
dtype=dtype,
)
context = 10 * np.random.rand(4, 16) + 10
# input: tensor of shape (batch_size, context_length)
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
validate_array(samples, (4, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=False
)
validate_array(samples, (4, 7, 65))
# input: batch_size-long list of tensors of shape (context_length,)
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
validate_array(samples, (4, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
samples = pipeline.predict(
list(context),
num_samples=7,
prediction_length=65,
limit_prediction_length=False,
)
validate_array(samples, (4, 7, 65))
# input: tensor of shape (context_length,)
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
validate_array(samples, (1, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
samples = pipeline.predict(
context[0, ...],
num_samples=7,
prediction_length=65,
limit_prediction_length=False,
)
validate_array(samples, (1, 7, 65))
@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
def test_pipeline_embed(dtype: str):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
dtype=dtype,
)
d_model = pipeline.model.model.model_dim
context = 10 * np.random.rand(4, 16) + 10
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
# input: tensor of shape (batch_size, context_length)
embedding, scale = pipeline.embed(context)
validate_array(embedding, (4, expected_embed_length, d_model))
validate_array(scale, (4,))
# input: batch_size-long list of tensors of shape (context_length,)
embedding, scale = pipeline.embed(list(context))
validate_array(embedding, (4, expected_embed_length, d_model))
validate_array(scale, (4,))
# input: tensor of shape (context_length,)
embedding, scale = pipeline.embed(context[0, ...])
validate_array(embedding, (1, expected_embed_length, d_model))
validate_array(scale, (1,))
@pytest.mark.parametrize(
"top_p,expected_non_zero_probs",
[
(
0.1,
mx.array(
[
[False, True, False, False],
[False, True, False, False],
[True, False, False, False],
[True, False, False, False],
[False, False, False, True],
]
),
),
(
0.5,
mx.array(
[
[False, True, False, False],
[False, True, False, False],
[True, False, False, False],
[True, False, False, False],
[False, False, True, True],
]
),
),
(
0.95,
mx.array(
[
[False, True, True, True],
[False, True, False, True],
[True, False, False, False],
[True, True, False, False],
[False, True, True, True],
]
),
),
],
)
def test_apply_top_p(top_p: float, expected_non_zero_probs: mx.array):
probs = mx.array(
[
[0.1, 0.4, 0.3, 0.2],
[0.01, 0.39, 0.25, 0.35],
[0.9, 0.01, 0.01, 0.08],
[0.7, 0.2, 0.05, 0.05],
[0.25, 0.25, 0.25, 0.25],
],
)
top_p_probs = mx.softmax(apply_top_p(probs.log(), top_p=top_p), axis=-1)
assert mx.all(mx.not_equal(top_p_probs, 0.0) == expected_non_zero_probs)