mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Add MLX inference support (#41)
*Issue #, if available:* #28
*Description of changes:* This PR adds MLX inference support.
## Summary of changes
- Update `pyproject.toml` with`mlx` dependencies.
- Create `chronos_mlx` package which will hosts all mlx inference stuff.
- All classes from `main:src/chronos/chronos.py` are copy-pasted into
`mlx:src/chronos_mlx/chronos.py` and modified to use numpy and mlx
arrays instead. Note that the reason for using numpy arrays as input and
output is that mlx doesn't support some operations that are required for
input and output transform.
- MLX implementation of T5 is in `src/chronos_mlx/t5.py`. It has been
adapted from
[ml-explore/mlx-examples](b8a348c1b8/t5/t5.py)
with the following main modifications:
- Added support for attention mask.
- Added support for top_k and top_p sampling.
- `src/chronos_mlx/translate.py` translates weights from a torch HF
model to mlx.
- Add `THIRD-PARTY-LICENSES.txt` for third party code from
`mlx-examples`.
- Add tests and CI for `mlx` version.
## Sample inference code
```py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from chronos_mlx import ChronosPipeline
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
dtype="bfloat16",
)
df = pd.read_csv(
"https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)
# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
context, prediction_length
) # shape [num_series, num_samples, prediction_length]
# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)
plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(
forecast_index,
low,
high,
color="tomato",
alpha=0.3,
label="80% prediction interval",
)
plt.legend()
plt.grid()
plt.show()
```
## Benchmark

```py
import timeit
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm
from chronos import ChronosPipeline as ChronosPipelineTorch
from chronos_mlx import ChronosPipeline as ChronosPipelineMLX
def benchmark_torch_model(
pipeline: ChronosPipelineTorch,
gluonts_dataset: str = "m4_hourly",
batch_size: int = 32,
):
dataset = get_dataset(gluonts_dataset)
prediction_length = dataset.metadata.prediction_length
_, test_template = split(dataset.test, offset=-prediction_length)
test_data = test_template.generate_instances(prediction_length)
test_data_input = list(test_data.input)
start_time = timeit.default_timer()
forecasts = []
for idx in tqdm(range(0, len(test_data_input), batch_size)):
batch = [
torch.tensor(item["target"])
for item in test_data_input[idx : idx + batch_size]
]
batch_forecasts = pipeline.predict(batch, prediction_length)
forecasts.append(batch_forecasts)
forecasts = torch.cat(forecasts)
end_time = timeit.default_timer()
print(f"Inference time: {end_time-start_time:.2f}s")
results_df = evaluate_forecasts(
forecasts=[
SampleForecast(fcst.numpy(), start_date=label["start"])
for fcst, label in zip(forecasts, test_data.label)
],
test_data=test_data,
metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
)
results_df["inference_time"] = end_time - start_time
return results_df
def benchmark_mlx_model(
pipeline: ChronosPipelineMLX,
gluonts_dataset: str = "m4_hourly",
batch_size: int = 32,
):
dataset = get_dataset(gluonts_dataset)
prediction_length = dataset.metadata.prediction_length
_, test_template = split(dataset.test, offset=-prediction_length)
test_data = test_template.generate_instances(prediction_length)
test_data_input = list(test_data.input)
start_time = timeit.default_timer()
forecasts = []
for idx in tqdm(range(0, len(test_data_input), batch_size)):
batch = [item["target"] for item in test_data_input[idx : idx + batch_size]]
batch_forecasts = pipeline.predict(batch, prediction_length)
forecasts.append(batch_forecasts)
forecasts = np.concatenate(forecasts)
end_time = timeit.default_timer()
print(f"Inference time: {end_time-start_time:.2f}s")
results_df = evaluate_forecasts(
forecasts=[
SampleForecast(fcst, start_date=label["start"])
for fcst, label in zip(forecasts, test_data.label)
],
test_data=test_data,
metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
)
results_df["inference_time"] = end_time - start_time
return results_df
def main(
version: str = "cpu", # cpu, mps, mlx
dtype: str = "bfloat16",
gluonts_dataset: str = "australian_electricity_demand",
model_name: str = "amazon/chronos-t5-small",
batch_size: int = 4,
):
if version == "cpu" or version == "mps":
pipeline = ChronosPipelineTorch.from_pretrained(
model_name,
device_map=version,
torch_dtype=getattr(torch, dtype),
)
benchmark_fn = benchmark_torch_model
else:
pipeline = ChronosPipelineMLX.from_pretrained(model_name, dtype=dtype)
benchmark_fn = benchmark_mlx_model
result_df = benchmark_fn(
pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size
)
result_df["model"] = model_name
return result_df
if __name__ == "__main__":
gluonts_dataset: str = "m4_hourly"
model_name: str = "amazon/chronos-t5-mini"
batch_size: int = 8
dfs = []
for version in ["cpu", "mps", "mlx"]:
for dtype in ["float32"]:
try:
df = main(
version=version,
dtype=dtype,
model_name=model_name,
gluonts_dataset=gluonts_dataset,
batch_size=batch_size,
)
df["version"] = version
df["dtype"] = dtype
dfs.append(df)
except TypeError:
pass
result_df = pd.concat(dfs).reset_index(drop=True)
result_df.to_csv("benchmark.csv", index=False)
result_df["version"] = result_df["version"].map(
{"cpu": "Torch (CPU)", "mps": "Torch (MPS)", "mlx": "MLX"}
)
fig = plt.figure(figsize=(8, 5))
g = sns.barplot(
data=result_df,
x="dtype",
y="inference_time",
hue="version",
alpha=0.6,
)
plt.ylabel("Inference Time (on M1 Pro)")
plt.title(f"{model_name} inference times on {gluonts_dataset} dataset")
plt.savefig("benchmark.png", dpi=200)
```
## TODOs:
- [x] Implement `top_p` sampling.
- [x] Add tests.
- [x] Add CI.
By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
---------
Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
This commit is contained in:
parent
2042779efa
commit
159ea36f7f
10 changed files with 830 additions and 398 deletions
45
.github/workflows/ci.yml
vendored
45
.github/workflows/ci.yml
vendored
|
|
@ -3,44 +3,23 @@ name: CI
|
|||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
type-check:
|
||||
test-mlx:
|
||||
strategy:
|
||||
max-parallel: 4
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
platform: [ubuntu-latest]
|
||||
python-version: ['3.11']
|
||||
platform: [macos-14]
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: pip install ".[typecheck]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
- name: Type checks with mypy
|
||||
run: mypy src test
|
||||
|
||||
test:
|
||||
strategy:
|
||||
max-parallel: 4
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
platform: [ubuntu-latest]
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
- name: Test with pytest
|
||||
run: pytest
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: pip install ".[test]" -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
- name: Test with pytest
|
||||
run: pytest test/
|
||||
|
|
|
|||
27
README.md
27
README.md
|
|
@ -1,4 +1,7 @@
|
|||
# Chronos: Learning the Language of Time Series
|
||||
# [🧪 MLX Version] Chronos: Learning the Language of Time Series
|
||||
|
||||
> [!IMPORTANT]
|
||||
> This is the **experimental** MLX version of Chronos for Apple Silicon Macs. Please use the `main` branch for the stable PyTorch version.
|
||||
|
||||
Chronos is a family of **pretrained time series forecasting models** based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes.
|
||||
|
||||
|
|
@ -28,10 +31,10 @@ The models in this repository are based on the [T5 architecture](https://arxiv.o
|
|||
|
||||
## Usage
|
||||
|
||||
To perform inference with Chronos models, install this package by running:
|
||||
To perform inference with Chronos models on Apple Silcon devices, install this package by running:
|
||||
|
||||
```
|
||||
pip install git+https://github.com/amazon-science/chronos-forecasting.git
|
||||
pip install git+https://github.com/amazon-science/chronos-forecasting.git@mlx
|
||||
```
|
||||
|
||||
### Forecasting
|
||||
|
|
@ -43,20 +46,18 @@ A minimal example showing how to perform forecasting using Chronos models:
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from chronos import ChronosPipeline
|
||||
from chronos_mlx import ChronosPipeline
|
||||
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
"amazon/chronos-t5-small",
|
||||
device_map="cuda",
|
||||
torch_dtype=torch.bfloat16,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
|
||||
|
||||
# context must be either a 1D tensor, a list of 1D tensors,
|
||||
# or a left-padded 2D tensor with batch as the first dimension
|
||||
context = torch.tensor(df["#Passengers"])
|
||||
context = df["#Passengers"].values
|
||||
prediction_length = 12
|
||||
forecast = pipeline.predict(
|
||||
context,
|
||||
|
|
@ -69,7 +70,7 @@ forecast = pipeline.predict(
|
|||
|
||||
# visualize the forecast
|
||||
forecast_index = range(len(df), len(df) + prediction_length)
|
||||
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
|
||||
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)
|
||||
|
||||
plt.figure(figsize=(8, 4))
|
||||
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
|
||||
|
|
@ -86,20 +87,18 @@ A minimal example showing how to extract encoder embeddings from Chronos models:
|
|||
|
||||
```python
|
||||
import pandas as pd
|
||||
import torch
|
||||
from chronos import ChronosPipeline
|
||||
from chronos_mlx import ChronosPipeline
|
||||
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
"amazon/chronos-t5-small",
|
||||
device_map="cuda",
|
||||
torch_dtype=torch.bfloat16,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
|
||||
|
||||
# context must be either a 1D tensor, a list of 1D tensors,
|
||||
# or a left-padded 2D tensor with batch as the first dimension
|
||||
context = torch.tensor(df["#Passengers"])
|
||||
context = df["#Passengers"].values
|
||||
embeddings, tokenizer_state = pipeline.embed(context)
|
||||
```
|
||||
|
||||
|
|
|
|||
23
THIRD-PARTY-LICENSES.txt
Normal file
23
THIRD-PARTY-LICENSES.txt
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
** mlx-examples; version b8a348c -- https://github.com/ml-explore/mlx-examples
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright © 2023 Apple Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
[project]
|
||||
name = "chronos"
|
||||
name = "chronos_mlx"
|
||||
version = "1.1.0"
|
||||
requires-python = ">=3.8"
|
||||
license = { file = "LICENSE" }
|
||||
|
|
@ -7,6 +7,7 @@ dependencies = [
|
|||
"torch~=2.1", # package was tested on 2.2
|
||||
"transformers~=4.31",
|
||||
"accelerate",
|
||||
"mlx~=0.9.0"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -3,18 +3,18 @@
|
|||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import chronos
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from transformers import T5Config
|
||||
|
||||
import chronos_mlx
|
||||
from chronos_mlx.t5 import T5
|
||||
from chronos_mlx.translate import translate_weights
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -46,13 +46,13 @@ class ChronosConfig:
|
|||
), f"Special token id's must be smaller than {self.n_special_tokens=}"
|
||||
|
||||
def create_tokenizer(self) -> "ChronosTokenizer":
|
||||
class_ = getattr(chronos, self.tokenizer_class)
|
||||
class_ = getattr(chronos_mlx, self.tokenizer_class)
|
||||
return class_(**self.tokenizer_kwargs, config=self)
|
||||
|
||||
|
||||
class ChronosTokenizer:
|
||||
"""
|
||||
A ``ChronosTokenizer`` definines how time series are mapped into token IDs
|
||||
A ``ChronosTokenizer`` defines how time series are mapped into token IDs
|
||||
and back.
|
||||
|
||||
For details, see the ``input_transform`` and ``output_transform`` methods,
|
||||
|
|
@ -60,27 +60,27 @@ class ChronosTokenizer:
|
|||
"""
|
||||
|
||||
def input_transform(
|
||||
self, context: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
|
||||
self, context: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, Any]:
|
||||
"""
|
||||
Turn a batch of time series into token IDs, attention map, and scale.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context
|
||||
A tensor shaped (batch_size, time_length), containing the
|
||||
timeseries to forecast. Use left-padding with ``torch.nan``
|
||||
A numpy array shaped (batch_size, time_length), containing the
|
||||
timeseries to forecast. Use left-padding with ``np.nan``
|
||||
to align time series of different lengths.
|
||||
|
||||
Returns
|
||||
-------
|
||||
token_ids
|
||||
A tensor of integers, shaped (batch_size, time_length + 1)
|
||||
A numpy array of integers, shaped (batch_size, time_length + 1)
|
||||
if ``config.use_eos_token`` and (batch_size, time_length)
|
||||
otherwise, containing token IDs for the input series.
|
||||
attention_mask
|
||||
A boolean tensor, same shape as ``token_ids``, indicating
|
||||
which input observations are not ``torch.nan`` (i.e. not
|
||||
A boolean numpy array, same shape as ``token_ids``, indicating
|
||||
which input observations are not ``np.nan`` (i.e. not
|
||||
missing nor padding).
|
||||
tokenizer_state
|
||||
An object that will be passed to ``output_transform``.
|
||||
|
|
@ -89,16 +89,14 @@ class ChronosTokenizer:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def output_transform(
|
||||
self, samples: torch.Tensor, tokenizer_state: Any
|
||||
) -> torch.Tensor:
|
||||
def output_transform(self, samples: np.ndarray, tokenizer_state: Any) -> np.ndarray:
|
||||
"""
|
||||
Turn a batch of sample token IDs into real values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
samples
|
||||
A tensor of integers, shaped (batch_size, num_samples, time_length),
|
||||
A numpy array of integers, shaped (batch_size, num_samples, time_length),
|
||||
containing token IDs of sample trajectories.
|
||||
tokenizer_state
|
||||
An object returned by ``input_transform`` containing
|
||||
|
|
@ -108,7 +106,7 @@ class ChronosTokenizer:
|
|||
Returns
|
||||
-------
|
||||
forecasts
|
||||
A real tensor, shaped (batch_size, num_samples, time_length),
|
||||
A real numpy array, shaped (batch_size, num_samples, time_length),
|
||||
containing forecasted sample paths.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
|
@ -119,70 +117,60 @@ class MeanScaleUniformBins(ChronosTokenizer):
|
|||
self, low_limit: float, high_limit: float, config: ChronosConfig
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.centers = torch.linspace(
|
||||
self.centers = np.linspace(
|
||||
low_limit,
|
||||
high_limit,
|
||||
config.n_tokens - config.n_special_tokens - 1,
|
||||
)
|
||||
self.boundaries = torch.concat(
|
||||
self.boundaries = np.concatenate(
|
||||
(
|
||||
torch.tensor([-1e20], device=self.centers.device),
|
||||
np.array([-1e20]),
|
||||
(self.centers[1:] + self.centers[:-1]) / 2,
|
||||
torch.tensor([1e20], device=self.centers.device),
|
||||
np.array([1e20]),
|
||||
)
|
||||
)
|
||||
|
||||
def input_transform(
|
||||
self, context: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
self, context: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
batch_size, length = context.shape
|
||||
|
||||
if length > self.config.context_length:
|
||||
context = context[..., -self.config.context_length :]
|
||||
|
||||
attention_mask = ~torch.isnan(context)
|
||||
scale = torch.nansum(
|
||||
torch.abs(context) * attention_mask, dim=-1
|
||||
) / torch.nansum(attention_mask, dim=-1)
|
||||
attention_mask = ~np.isnan(context)
|
||||
scale = np.nansum(np.abs(context) * attention_mask, axis=-1) / np.nansum(
|
||||
attention_mask, axis=-1
|
||||
)
|
||||
scale[~(scale > 0)] = 1.0
|
||||
scaled_context = context / scale.unsqueeze(dim=-1)
|
||||
scaled_context = context / scale[..., np.newaxis]
|
||||
token_ids = (
|
||||
torch.bucketize(
|
||||
input=scaled_context,
|
||||
boundaries=self.boundaries,
|
||||
# buckets are open to the right, see:
|
||||
# https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize
|
||||
right=True,
|
||||
)
|
||||
np.digitize(scaled_context, bins=self.boundaries)
|
||||
+ self.config.n_special_tokens
|
||||
)
|
||||
token_ids[~attention_mask] = self.config.pad_token_id
|
||||
|
||||
if self.config.use_eos_token:
|
||||
eos_tokens = torch.full(
|
||||
(batch_size, 1), fill_value=self.config.eos_token_id
|
||||
)
|
||||
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
|
||||
eos_mask = torch.full((batch_size, 1), fill_value=True)
|
||||
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
|
||||
eos_tokens = np.full((batch_size, 1), fill_value=self.config.eos_token_id)
|
||||
token_ids = np.concatenate((token_ids, eos_tokens), axis=1)
|
||||
eos_mask = np.full((batch_size, 1), fill_value=True)
|
||||
attention_mask = np.concatenate((attention_mask, eos_mask), axis=1)
|
||||
|
||||
return token_ids, attention_mask, scale
|
||||
|
||||
def output_transform(
|
||||
self, samples: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
|
||||
indices = torch.clamp(
|
||||
def output_transform(self, samples: np.ndarray, scale: np.ndarray) -> np.ndarray:
|
||||
scale_unsqueezed = scale[..., np.newaxis, np.newaxis]
|
||||
indices = np.clip(
|
||||
samples - self.config.n_special_tokens,
|
||||
min=0,
|
||||
max=len(self.centers) - 1,
|
||||
a_min=0,
|
||||
a_max=len(self.centers) - 1,
|
||||
)
|
||||
return self.centers[indices] * scale_unsqueezed
|
||||
|
||||
|
||||
class ChronosModel(nn.Module):
|
||||
"""
|
||||
A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers``
|
||||
A ``ChronosModel`` wraps a ``T5`` object from ``chronos.mlx.t5``
|
||||
and uses it to predict sample paths for time series tokens.
|
||||
|
||||
Parameters
|
||||
|
|
@ -190,22 +178,21 @@ class ChronosModel(nn.Module):
|
|||
config
|
||||
The configuration to use.
|
||||
model
|
||||
The pre-trained model to use.
|
||||
The pretrained T5 model to use.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None:
|
||||
def __init__(self, config: ChronosConfig, model: T5) -> None:
|
||||
super().__init__()
|
||||
assert config.model_type == "seq2seq" and isinstance(
|
||||
model, T5
|
||||
), "Only the T5 model is currently supported in MLX"
|
||||
self.config = config
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.model.device
|
||||
|
||||
def encode(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
input_ids: np.ndarray,
|
||||
attention_mask: np.ndarray,
|
||||
):
|
||||
"""
|
||||
Extract the encoder embedding for the given token sequences.
|
||||
|
|
@ -213,35 +200,33 @@ class ChronosModel(nn.Module):
|
|||
Parameters
|
||||
----------
|
||||
input_ids
|
||||
Tensor of indices of input sequence tokens in the vocabulary
|
||||
Array of indices of input sequence tokens in the vocabulary
|
||||
with shape (batch_size, sequence_length).
|
||||
attention_mask
|
||||
A mask tensor of the same shape as input_ids to avoid attending
|
||||
A mask array of the same shape as input_ids to avoid attending
|
||||
on padding or missing tokens.
|
||||
|
||||
Returns
|
||||
-------
|
||||
embedding
|
||||
A tensor of encoder embeddings with shape
|
||||
An array of encoder embeddings with shape
|
||||
(batch_size, sequence_length, d_model).
|
||||
"""
|
||||
assert (
|
||||
self.config.model_type == "seq2seq"
|
||||
), "Encoder embeddings are only supported for encoder-decoder models"
|
||||
return self.model.encoder(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
return self.model.encode(inputs=input_ids, mask=attention_mask)
|
||||
|
||||
def forward(
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
input_ids: mx.array,
|
||||
attention_mask: mx.array,
|
||||
prediction_length: Optional[int] = None,
|
||||
num_samples: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> mx.array:
|
||||
"""
|
||||
Predict future sample tokens for the given token sequences.
|
||||
|
||||
|
|
@ -253,7 +238,7 @@ class ChronosModel(nn.Module):
|
|||
Returns
|
||||
-------
|
||||
samples
|
||||
A tensor of integers, shaped (batch_size, num_samples, time_length),
|
||||
A numpy array of integers, shaped (batch_size, num_samples, time_length),
|
||||
containing forecasted sample paths.
|
||||
"""
|
||||
if prediction_length is None:
|
||||
|
|
@ -270,40 +255,31 @@ class ChronosModel(nn.Module):
|
|||
preds = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=GenerationConfig(
|
||||
min_new_tokens=prediction_length,
|
||||
max_new_tokens=prediction_length,
|
||||
do_sample=True,
|
||||
num_return_sequences=num_samples,
|
||||
eos_token_id=self.config.eos_token_id,
|
||||
pad_token_id=self.config.pad_token_id,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
),
|
||||
min_new_tokens=prediction_length,
|
||||
max_new_tokens=prediction_length,
|
||||
do_sample=True,
|
||||
num_return_sequences=num_samples,
|
||||
eos_token_id=self.config.eos_token_id,
|
||||
pad_token_id=self.config.pad_token_id,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
if self.config.model_type == "seq2seq":
|
||||
preds = preds[..., 1:] # remove the decoder start token
|
||||
else:
|
||||
assert self.config.model_type == "causal"
|
||||
assert preds.size(-1) == input_ids.size(-1) + prediction_length
|
||||
preds = preds[..., -prediction_length:]
|
||||
preds = preds[..., 1:] # remove the decoder start token
|
||||
|
||||
return preds.reshape(input_ids.size(0), num_samples, -1)
|
||||
return preds.reshape(input_ids.shape[0], num_samples, -1)
|
||||
|
||||
|
||||
def left_pad_and_stack_1D(tensors: List[torch.Tensor]):
|
||||
max_len = max(len(c) for c in tensors)
|
||||
def left_pad_and_stack_1D(arrays: List[np.ndarray]):
|
||||
max_len = max(len(c) for c in arrays)
|
||||
padded = []
|
||||
for c in tensors:
|
||||
assert isinstance(c, torch.Tensor)
|
||||
for c in arrays:
|
||||
assert isinstance(c, np.ndarray)
|
||||
assert c.ndim == 1
|
||||
padding = torch.full(
|
||||
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
|
||||
)
|
||||
padded.append(torch.concat((padding, c), dim=-1))
|
||||
return torch.stack(padded)
|
||||
padding = np.full(shape=(max_len - len(c),), fill_value=np.nan)
|
||||
padded.append(np.concatenate((padding, c), axis=-1))
|
||||
return np.stack(padded)
|
||||
|
||||
|
||||
class ChronosPipeline:
|
||||
|
|
@ -330,21 +306,20 @@ class ChronosPipeline:
|
|||
self.model = model
|
||||
|
||||
def _prepare_and_validate_context(
|
||||
self, context: Union[torch.Tensor, List[torch.Tensor]]
|
||||
):
|
||||
self, context: Union[np.ndarray, List[np.ndarray]]
|
||||
) -> np.ndarray:
|
||||
if isinstance(context, list):
|
||||
context = left_pad_and_stack_1D(context)
|
||||
assert isinstance(context, torch.Tensor)
|
||||
assert isinstance(context, np.ndarray)
|
||||
if context.ndim == 1:
|
||||
context = context.unsqueeze(0)
|
||||
context = context[np.newaxis, ...]
|
||||
assert context.ndim == 2
|
||||
|
||||
return context
|
||||
|
||||
@torch.no_grad()
|
||||
def embed(
|
||||
self, context: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
self, context: Union[np.ndarray, List[np.ndarray]]
|
||||
) -> Tuple[np.ndarray, Any]:
|
||||
"""
|
||||
Get encoder embeddings for the given time series.
|
||||
|
||||
|
|
@ -367,36 +342,36 @@ class ChronosPipeline:
|
|||
or the length of the longest time series, if a list of 1D tensors was
|
||||
provided, and the extra 1 is for EOS.
|
||||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=context)
|
||||
context_array = self._prepare_and_validate_context(context=context)
|
||||
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
|
||||
context_tensor
|
||||
context_array
|
||||
)
|
||||
embeddings = self.model.encode(
|
||||
input_ids=token_ids.to(self.model.device),
|
||||
attention_mask=attention_mask.to(self.model.device),
|
||||
).cpu()
|
||||
return embeddings, tokenizer_state
|
||||
input_ids=mx.array(token_ids),
|
||||
attention_mask=mx.array(attention_mask),
|
||||
)
|
||||
return np.array(embeddings.astype(mx.float32)), tokenizer_state
|
||||
|
||||
def predict(
|
||||
self,
|
||||
context: Union[torch.Tensor, List[torch.Tensor]],
|
||||
context: Union[np.ndarray, List[np.ndarray]],
|
||||
prediction_length: Optional[int] = None,
|
||||
num_samples: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
limit_prediction_length: bool = True,
|
||||
) -> torch.Tensor:
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get forecasts for the given time series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context
|
||||
Input series. This is either a 1D tensor, or a list
|
||||
of 1D tensors, or a 2D tensor whose first dimension
|
||||
Input series. This is either a 1D numpy array, or a list
|
||||
of 1D numpy arrays, or a 2D numpy array whose first dimension
|
||||
is batch. In the latter case, use left-padding with
|
||||
``torch.nan`` to align series of different lengths.
|
||||
``np.nan`` to align series of different lengths.
|
||||
prediction_length
|
||||
Time steps to predict. Defaults to what specified
|
||||
in ``self.model.config``.
|
||||
|
|
@ -421,10 +396,10 @@ class ChronosPipeline:
|
|||
Returns
|
||||
-------
|
||||
samples
|
||||
Tensor of sample forecasts, of shape
|
||||
Numpy array of sample forecasts, of shape
|
||||
(batch_size, num_samples, prediction_length).
|
||||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=context)
|
||||
context_array = self._prepare_and_validate_context(context=context)
|
||||
|
||||
if prediction_length is None:
|
||||
prediction_length = self.model.config.prediction_length
|
||||
|
|
@ -432,7 +407,7 @@ class ChronosPipeline:
|
|||
if prediction_length > self.model.config.prediction_length:
|
||||
msg = (
|
||||
f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
|
||||
"The quality of longer predictions may degrade since the model is not optimized for it. "
|
||||
f"The quality of longer predictions may degrade since the model is not optimized for it. "
|
||||
)
|
||||
if limit_prediction_length:
|
||||
msg += "You can turn off this check by setting `limit_prediction_length=False`."
|
||||
|
|
@ -444,20 +419,19 @@ class ChronosPipeline:
|
|||
|
||||
while remaining > 0:
|
||||
token_ids, attention_mask, scale = self.tokenizer.input_transform(
|
||||
context_tensor
|
||||
context_array
|
||||
)
|
||||
token_ids, attention_mask = mx.array(token_ids), mx.array(attention_mask)
|
||||
samples = self.model(
|
||||
token_ids.to(self.model.device),
|
||||
attention_mask.to(self.model.device),
|
||||
token_ids,
|
||||
attention_mask,
|
||||
min(remaining, self.model.config.prediction_length),
|
||||
num_samples,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
)
|
||||
prediction = self.tokenizer.output_transform(
|
||||
samples.to(scale.device), scale
|
||||
)
|
||||
prediction = self.tokenizer.output_transform(np.array(samples), scale)
|
||||
|
||||
predictions.append(prediction)
|
||||
remaining -= prediction.shape[-1]
|
||||
|
|
@ -465,31 +439,44 @@ class ChronosPipeline:
|
|||
if remaining <= 0:
|
||||
break
|
||||
|
||||
context_tensor = torch.cat(
|
||||
[context_tensor, prediction.median(dim=1).values], dim=-1
|
||||
context_array = np.concatenate(
|
||||
[context_array, np.median(prediction, axis=1)], axis=-1
|
||||
)
|
||||
|
||||
return torch.cat(predictions, dim=-1)
|
||||
return np.concatenate(predictions, axis=-1)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, Path], dtype: str = "float32"
|
||||
):
|
||||
"""
|
||||
Load the model, either from a local path or from the HuggingFace Hub.
|
||||
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
|
||||
from ``transformers``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name_or_path
|
||||
Model ID on HuggingFace Hub or local path.
|
||||
dtype, optional
|
||||
String denoting the float dtype of the mlx model,
|
||||
by default "float32"
|
||||
|
||||
Returns
|
||||
-------
|
||||
A ChronosPipeline
|
||||
"""
|
||||
|
||||
config = AutoConfig.from_pretrained(*args, **kwargs)
|
||||
config = T5Config.from_pretrained(model_name_or_path)
|
||||
|
||||
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
||||
|
||||
dtype = getattr(mx, dtype)
|
||||
chronos_config = ChronosConfig(**config.chronos_config)
|
||||
|
||||
if chronos_config.model_type == "seq2seq":
|
||||
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
|
||||
else:
|
||||
assert config.model_type == "causal"
|
||||
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
|
||||
inner_model = T5(config=config)
|
||||
weights = translate_weights(model_name_or_path=model_name_or_path, dtype=dtype)
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
inner_model.update(weights)
|
||||
mx.eval(inner_model.parameters())
|
||||
|
||||
return cls(
|
||||
tokenizer=chronos_config.create_tokenizer(),
|
||||
420
src/chronos_mlx/t5.py
Normal file
420
src/chronos_mlx/t5.py
Normal file
|
|
@ -0,0 +1,420 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from ml-explore/mlx-examples:
|
||||
# https://github.com/ml-explore/mlx-examples/blob/b8a348c1b8df4433cfacb9adbeb89b8aa3979ab2/t5/t5.py
|
||||
# Modifications:
|
||||
# - Added support for attention mask.
|
||||
# - Added support for top_k and top_p sampling.
|
||||
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from transformers import T5Config
|
||||
|
||||
|
||||
def _relative_position_bucket(
|
||||
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||
):
|
||||
# Adapted from HuggingFace transformers:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||
relative_position = mx.abs(relative_position)
|
||||
else:
|
||||
relative_position = -mx.minimum(
|
||||
relative_position, mx.zeros_like(relative_position)
|
||||
)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# The other half of the buckets are for logarithmically bigger bins
|
||||
# in positions up to max_distance
|
||||
scale = (num_buckets - max_exact) / np.log(max_distance / max_exact)
|
||||
relative_position_if_large = max_exact + (
|
||||
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
|
||||
).astype(mx.int16)
|
||||
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
|
||||
relative_buckets += mx.where(
|
||||
is_small, relative_position, relative_position_if_large
|
||||
)
|
||||
return relative_buckets
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, config: T5Config, bidirectional: bool):
|
||||
self.bidirectional = bidirectional
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(
|
||||
config.relative_attention_num_buckets, config.num_heads
|
||||
)
|
||||
|
||||
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = mx.arange(offset, query_length)[:, None]
|
||||
memory_position = mx.arange(key_length)[None, :]
|
||||
|
||||
# shape (query_length, key_length)
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = _relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
|
||||
# shape (query_length, key_length, num_heads)
|
||||
values = self.embeddings(relative_position_bucket)
|
||||
|
||||
# shape (num_heads, query_length, key_length)
|
||||
return values.transpose(2, 0, 1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
inner_dim = config.d_kv * config.num_heads
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
queries: mx.array,
|
||||
keys: mx.array,
|
||||
values: mx.array,
|
||||
mask: Optional[mx.array],
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
keys = mx.concatenate([key_cache, keys], axis=3)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scores = queries @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.gated = config.feed_forward_proj.startswith("gated")
|
||||
if self.gated:
|
||||
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
else:
|
||||
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||
if activation == "relu":
|
||||
self.act = nn.relu
|
||||
elif activation == "gelu":
|
||||
self.act = nn.gelu
|
||||
elif activation == "silu":
|
||||
self.act = nn.silu
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
|
||||
def __call__(self, x):
|
||||
if self.gated:
|
||||
hidden_act = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_act * hidden_linear
|
||||
else:
|
||||
x = self.act(self.wi(x))
|
||||
return self.wo(x)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y, _ = self.attention(y, y, y, mask=mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.dense(y)
|
||||
return x + y
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x: mx.array, mask: Optional[mx.array] = None):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])[None]
|
||||
if mask is not None:
|
||||
mask = mask[:, None, None, :]
|
||||
pos_bias += mask
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=pos_bias)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
memory: mx.array,
|
||||
mask: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
):
|
||||
y = self.ln1(x)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y, _ = self.cross_attention(y, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln3(x)
|
||||
y = self.dense(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
|
||||
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||
if cache is not None:
|
||||
offset = cache[0][0].shape[3]
|
||||
else:
|
||||
offset = 0
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
T = offset + x.shape[1]
|
||||
pos_bias = self.relative_attention_bias(T, T, offset=offset)
|
||||
if mask is not None:
|
||||
mask += pos_bias
|
||||
else:
|
||||
mask = pos_bias
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
|
||||
x = self.ln(x)
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(inputs)
|
||||
|
||||
|
||||
def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1):
|
||||
assert min_tokens_to_keep <= logits.shape[-1]
|
||||
logits_dtype = logits.dtype
|
||||
# FIXME: The following is needed because mlx doesn't have the cumsum
|
||||
# kernel for bfloat16. Once that is supported natively, this casting
|
||||
# should be removed. @abdulfatir
|
||||
logits = logits.astype(mx.float32)
|
||||
sorted_indices = mx.argsort(logits, axis=-1)
|
||||
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
|
||||
cumulative_probs = mx.softmax(sorted_logits, axis=-1).cumsum(axis=-1, reverse=True)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., -min_tokens_to_keep:] = False
|
||||
masked_sorted_logits = mx.where(sorted_indices_to_remove, -mx.inf, sorted_logits)
|
||||
unsorted_indices = mx.argsort(sorted_indices, axis=-1)
|
||||
return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1).astype(
|
||||
logits_dtype
|
||||
)
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=1.0, temperature=1.0):
|
||||
vocab_size = logits.shape[-1]
|
||||
assert top_p <= 1.0, f"{top_p=} should be <= 1.0"
|
||||
|
||||
if temperature == 0 or top_k == 1:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
# Apply temperature term
|
||||
if temperature != 1.0:
|
||||
logits /= temperature
|
||||
|
||||
# Apply top_k
|
||||
if top_k >= vocab_size:
|
||||
return mx.random.categorical(
|
||||
apply_top_p(logits, top_p=top_p) if top_p < 1.0 else logits
|
||||
)
|
||||
|
||||
top_k_indices = mx.argpartition(logits, top_k, axis=-1)[..., -top_k:]
|
||||
top_k_logits = mx.take_along_axis(logits, top_k_indices, axis=-1)
|
||||
|
||||
# Apply top_p
|
||||
if top_p < 1.0:
|
||||
top_k_logits = apply_top_p(top_k_logits, top_p=top_p)
|
||||
|
||||
return top_k_indices[
|
||||
mx.arange(top_k_indices.shape[0]), mx.random.categorical(top_k_logits)
|
||||
]
|
||||
|
||||
|
||||
class T5(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
self.decoder = TransformerDecoder(config)
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
if not self.tie_word_embeddings:
|
||||
self.lm_head = OutputHead(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
def encode(self, inputs: mx.array, mask: mx.array):
|
||||
return self.encoder(self.wte(inputs), mask)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
memory: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
inputs = self.wte(inputs)
|
||||
T = inputs.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(inputs.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
memory_mask = memory_mask[:, None, None, :]
|
||||
y, cache = self.decoder(
|
||||
inputs, memory=memory, mask=mask, memory_mask=memory_mask, cache=cache
|
||||
)
|
||||
if not self.tie_word_embeddings:
|
||||
y = self.lm_head(y)
|
||||
else:
|
||||
y *= self.model_dim**-0.5
|
||||
y = y @ self.wte.weight.T
|
||||
return y, cache
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array,
|
||||
decoder_inputs: mx.array,
|
||||
):
|
||||
memory = self.encode(inputs, mask=mask)
|
||||
return self.decode(decoder_inputs, memory=memory, memory_mask=mask)[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
attention_mask: mx.array,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
max_new_tokens: int = 64,
|
||||
do_sample: bool = True,
|
||||
num_return_sequences: int = 1,
|
||||
pad_token_id: int = 0,
|
||||
eos_token_id: Optional[int] = None,
|
||||
temperature: Optional[float] = 1.0,
|
||||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
):
|
||||
self.eval()
|
||||
|
||||
def should_stop(current_token, num_sampled_tokens):
|
||||
if eos_token_id is not None and (current_token == eos_token_id).all():
|
||||
return True
|
||||
if num_sampled_tokens >= max_new_tokens:
|
||||
return True
|
||||
return False
|
||||
|
||||
top_k = top_k if do_sample else 1
|
||||
attention_mask = (1.0 - attention_mask.astype(mx.float32)) * -1e9
|
||||
memory = self.encode(input_ids, mask=attention_mask)
|
||||
|
||||
repeated_memory = mx.repeat(memory, num_return_sequences, axis=0)
|
||||
repeated_attention_mask = mx.repeat(
|
||||
attention_mask, num_return_sequences, axis=0
|
||||
)
|
||||
decoder_start_id = pad_token_id
|
||||
decoder_inputs = mx.array([decoder_start_id] * len(repeated_attention_mask))[
|
||||
:, None
|
||||
]
|
||||
|
||||
cache = None
|
||||
prediction = [decoder_inputs]
|
||||
num_sampled_tokens = 0
|
||||
while not should_stop(prediction[-1], num_sampled_tokens):
|
||||
logits, cache = self.decode(
|
||||
prediction[-1],
|
||||
repeated_memory,
|
||||
memory_mask=repeated_attention_mask,
|
||||
cache=cache,
|
||||
)
|
||||
if (
|
||||
min_new_tokens is not None
|
||||
and eos_token_id is not None
|
||||
and num_sampled_tokens < min_new_tokens
|
||||
):
|
||||
logits[..., eos_token_id] = -float("inf")
|
||||
|
||||
y = sample(
|
||||
logits[:, -1, :], top_k=top_k, top_p=top_p, temperature=temperature
|
||||
)
|
||||
num_sampled_tokens += 1
|
||||
prediction.append(y[:, None])
|
||||
|
||||
return mx.concatenate(prediction, axis=-1)
|
||||
77
src/chronos_mlx/translate.py
Normal file
77
src/chronos_mlx/translate.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from ml-explore/mlx-examples:
|
||||
# https://github.com/ml-explore/mlx-examples/blob/b8a348c1b8df4433cfacb9adbeb89b8aa3979ab2/t5/convert.py
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
DECODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".self_attention."),
|
||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||
(".layer.2.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
if key.startswith("encoder."):
|
||||
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
elif key.startswith("decoder."):
|
||||
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
return key
|
||||
|
||||
|
||||
def translate_weights(model_name_or_path: Union[str, Path], dtype: mx.Dtype):
|
||||
"""Translate a HuggingFace transformers T5 model to MLX.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name
|
||||
HuggingFace model name or local path.
|
||||
dtype
|
||||
mlx dtype for the resulting mlx weights.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A state dictionary with weights as mlx arrays.
|
||||
"""
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, torch_dtype=torch.float32
|
||||
)
|
||||
weights = {
|
||||
replace_key(k): mx.array(v.numpy(), dtype=dtype)
|
||||
for k, v in model.state_dict().items()
|
||||
}
|
||||
return weights
|
||||
|
|
@ -1,208 +0,0 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from chronos import ChronosConfig, ChronosPipeline
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
|
||||
@pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
|
||||
@pytest.mark.parametrize("use_eos_token", [False, True])
|
||||
def test_tokenizer_fixed_data(
|
||||
n_numerical_tokens: int, n_special_tokens: int, use_eos_token: bool
|
||||
):
|
||||
n_tokens = n_numerical_tokens + n_special_tokens
|
||||
context_length = 3
|
||||
|
||||
config = ChronosConfig(
|
||||
tokenizer_class="MeanScaleUniformBins",
|
||||
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
|
||||
n_tokens=n_tokens,
|
||||
n_special_tokens=n_special_tokens,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
use_eos_token=use_eos_token,
|
||||
model_type="seq2seq",
|
||||
context_length=512,
|
||||
prediction_length=64,
|
||||
num_samples=20,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
)
|
||||
|
||||
tokenizer = config.create_tokenizer()
|
||||
|
||||
context = torch.tensor(
|
||||
[
|
||||
[-3.7, 3.7],
|
||||
[-42.0, 42.0],
|
||||
]
|
||||
)
|
||||
batch_size, _ = context.shape
|
||||
|
||||
token_ids, attention_mask, scale = tokenizer.input_transform(context)
|
||||
|
||||
assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
|
||||
assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
|
||||
assert all(token_ids[:, 1] == torch.tensor([n_special_tokens]).repeat(batch_size))
|
||||
assert all(token_ids[:, 2] == torch.tensor([n_tokens - 1]).repeat(batch_size))
|
||||
|
||||
if use_eos_token:
|
||||
assert all(token_ids[:, 3] == torch.tensor([1]).repeat(batch_size))
|
||||
|
||||
samples = tokenizer.output_transform(
|
||||
torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1),
|
||||
tokenizer_state=scale,
|
||||
)
|
||||
|
||||
assert (samples[:, 0, [0, -1]] == context).all()
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize("use_eos_token", [False, True])
|
||||
def test_tokenizer_random_data(use_eos_token: bool):
|
||||
context_length = 8
|
||||
n_tokens = 256
|
||||
n_special_tokens = 2
|
||||
|
||||
config = ChronosConfig(
|
||||
tokenizer_class="MeanScaleUniformBins",
|
||||
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
|
||||
n_tokens=n_tokens,
|
||||
n_special_tokens=n_special_tokens,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
use_eos_token=use_eos_token,
|
||||
model_type="seq2seq",
|
||||
context_length=context_length,
|
||||
prediction_length=64,
|
||||
num_samples=20,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
)
|
||||
|
||||
tokenizer = config.create_tokenizer()
|
||||
|
||||
context = torch.tensor(
|
||||
[
|
||||
[torch.nan, torch.nan, 1.0, 1.1, torch.nan, 2.0],
|
||||
[3.0, torch.nan, 3.9, 4.0, 4.1, 4.9],
|
||||
]
|
||||
)
|
||||
|
||||
token_ids, attention_mask, scale = tokenizer.input_transform(context)
|
||||
|
||||
assert token_ids.shape == (
|
||||
*context.shape[:-1],
|
||||
context_length + 1 * use_eos_token,
|
||||
)
|
||||
assert attention_mask.shape == (
|
||||
*context.shape[:-1],
|
||||
context_length + 1 * use_eos_token,
|
||||
)
|
||||
assert scale.shape == context.shape[:1]
|
||||
|
||||
sample_ids = torch.randint(low=n_special_tokens, high=n_tokens, size=(2, 10, 4))
|
||||
sample_ids[0, 0, 0] = n_special_tokens
|
||||
sample_ids[-1, -1, -1] = n_tokens - 1
|
||||
|
||||
samples = tokenizer.output_transform(sample_ids, scale)
|
||||
|
||||
assert samples.shape == (2, 10, 4)
|
||||
|
||||
|
||||
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None:
|
||||
assert isinstance(samples, torch.Tensor)
|
||||
assert samples.shape == shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_predict(torch_dtype: str):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, (4, 12, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
|
||||
|
||||
samples = pipeline.predict(
|
||||
context, num_samples=7, prediction_length=65, limit_prediction_length=False
|
||||
)
|
||||
validate_tensor(samples, (4, 7, 65))
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, (4, 12, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
|
||||
|
||||
samples = pipeline.predict(
|
||||
list(context),
|
||||
num_samples=7,
|
||||
prediction_length=65,
|
||||
limit_prediction_length=False,
|
||||
)
|
||||
validate_tensor(samples, (4, 7, 65))
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, (1, 12, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
|
||||
|
||||
samples = pipeline.predict(
|
||||
context[0, ...],
|
||||
num_samples=7,
|
||||
prediction_length=65,
|
||||
limit_prediction_length=False,
|
||||
)
|
||||
validate_tensor(samples, (1, 7, 65))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_embed(torch_dtype: str):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
d_model = pipeline.model.model.config.d_model
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
embedding, scale = pipeline.embed(context)
|
||||
validate_tensor(embedding, (4, expected_embed_length, d_model))
|
||||
validate_tensor(scale, (4,))
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
embedding, scale = pipeline.embed(list(context))
|
||||
validate_tensor(embedding, (4, expected_embed_length, d_model))
|
||||
validate_tensor(scale, (4,))
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
embedding, scale = pipeline.embed(context[0, ...])
|
||||
validate_tensor(embedding, (1, expected_embed_length, d_model))
|
||||
validate_tensor(scale, (1,))
|
||||
154
test/test_chronos_mlx.py
Normal file
154
test/test_chronos_mlx.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from chronos_mlx.t5 import apply_top_p
|
||||
from chronos_mlx import ChronosPipeline
|
||||
|
||||
|
||||
def validate_array(samples: np.ndarray, shape: Tuple[int, ...]) -> None:
|
||||
assert isinstance(samples, np.ndarray)
|
||||
assert samples.shape == shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
|
||||
def test_pipeline_predict(dtype: str):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
dtype=dtype,
|
||||
)
|
||||
context = 10 * np.random.rand(4, 16) + 10
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
|
||||
validate_array(samples, (4, 12, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
|
||||
|
||||
samples = pipeline.predict(
|
||||
context, num_samples=7, prediction_length=65, limit_prediction_length=False
|
||||
)
|
||||
validate_array(samples, (4, 7, 65))
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
|
||||
validate_array(samples, (4, 12, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
|
||||
|
||||
samples = pipeline.predict(
|
||||
list(context),
|
||||
num_samples=7,
|
||||
prediction_length=65,
|
||||
limit_prediction_length=False,
|
||||
)
|
||||
validate_array(samples, (4, 7, 65))
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
|
||||
validate_array(samples, (1, 12, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
|
||||
|
||||
samples = pipeline.predict(
|
||||
context[0, ...],
|
||||
num_samples=7,
|
||||
prediction_length=65,
|
||||
limit_prediction_length=False,
|
||||
)
|
||||
validate_array(samples, (1, 7, 65))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
|
||||
def test_pipeline_embed(dtype: str):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
dtype=dtype,
|
||||
)
|
||||
d_model = pipeline.model.model.model_dim
|
||||
context = 10 * np.random.rand(4, 16) + 10
|
||||
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
embedding, scale = pipeline.embed(context)
|
||||
validate_array(embedding, (4, expected_embed_length, d_model))
|
||||
validate_array(scale, (4,))
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
embedding, scale = pipeline.embed(list(context))
|
||||
validate_array(embedding, (4, expected_embed_length, d_model))
|
||||
validate_array(scale, (4,))
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
embedding, scale = pipeline.embed(context[0, ...])
|
||||
validate_array(embedding, (1, expected_embed_length, d_model))
|
||||
validate_array(scale, (1,))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"top_p,expected_non_zero_probs",
|
||||
[
|
||||
(
|
||||
0.1,
|
||||
mx.array(
|
||||
[
|
||||
[False, True, False, False],
|
||||
[False, True, False, False],
|
||||
[True, False, False, False],
|
||||
[True, False, False, False],
|
||||
[False, False, False, True],
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
0.5,
|
||||
mx.array(
|
||||
[
|
||||
[False, True, False, False],
|
||||
[False, True, False, False],
|
||||
[True, False, False, False],
|
||||
[True, False, False, False],
|
||||
[False, False, True, True],
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
0.95,
|
||||
mx.array(
|
||||
[
|
||||
[False, True, True, True],
|
||||
[False, True, False, True],
|
||||
[True, False, False, False],
|
||||
[True, True, False, False],
|
||||
[False, True, True, True],
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_apply_top_p(top_p: float, expected_non_zero_probs: mx.array):
|
||||
probs = mx.array(
|
||||
[
|
||||
[0.1, 0.4, 0.3, 0.2],
|
||||
[0.01, 0.39, 0.25, 0.35],
|
||||
[0.9, 0.01, 0.01, 0.08],
|
||||
[0.7, 0.2, 0.05, 0.05],
|
||||
[0.25, 0.25, 0.25, 0.25],
|
||||
],
|
||||
)
|
||||
top_p_probs = mx.softmax(apply_top_p(probs.log(), top_p=top_p), axis=-1)
|
||||
assert mx.all(mx.not_equal(top_p_probs, 0.0) == expected_non_zero_probs)
|
||||
Loading…
Reference in a new issue