chronos-forecasting/README.md
Abdul Fatir 159ea36f7f
Add MLX inference support (#41)
*Issue #, if available:* #28

*Description of changes:* This PR adds MLX inference support.

## Summary of changes
- Update `pyproject.toml` with`mlx` dependencies.
- Create `chronos_mlx` package which will hosts all mlx inference stuff.
- All classes from `main:src/chronos/chronos.py` are copy-pasted into
`mlx:src/chronos_mlx/chronos.py` and modified to use numpy and mlx
arrays instead. Note that the reason for using numpy arrays as input and
output is that mlx doesn't support some operations that are required for
input and output transform.
- MLX implementation of T5 is in `src/chronos_mlx/t5.py`. It has been
adapted from
[ml-explore/mlx-examples](b8a348c1b8/t5/t5.py)
with the following main modifications:
      - Added support for attention mask.
      - Added support for top_k and top_p sampling.
- `src/chronos_mlx/translate.py` translates weights from a torch HF
model to mlx.
- Add `THIRD-PARTY-LICENSES.txt` for third party code from
`mlx-examples`.
- Add tests and CI for `mlx` version. 

## Sample inference code

```py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    dtype="bfloat16",
)

df = pd.read_csv(
    "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
    context, prediction_length
)  # shape [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(
    forecast_index,
    low,
    high,
    color="tomato",
    alpha=0.3,
    label="80% prediction interval",
)
plt.legend()
plt.grid()
plt.show()

```

## Benchmark


![benchmark](https://github.com/amazon-science/chronos-forecasting/assets/4028948/ee5d1b17-d33e-473c-aa7a-55dbe1059b9c)


```py
import timeit

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm

from chronos import ChronosPipeline as ChronosPipelineTorch
from chronos_mlx import ChronosPipeline as ChronosPipelineMLX


def benchmark_torch_model(
    pipeline: ChronosPipelineTorch,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [
            torch.tensor(item["target"])
            for item in test_data_input[idx : idx + batch_size]
        ]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = torch.cat(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst.numpy(), start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df


def benchmark_mlx_model(
    pipeline: ChronosPipelineMLX,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [item["target"] for item in test_data_input[idx : idx + batch_size]]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = np.concatenate(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst, start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df


def main(
    version: str = "cpu",  # cpu, mps, mlx
    dtype: str = "bfloat16",
    gluonts_dataset: str = "australian_electricity_demand",
    model_name: str = "amazon/chronos-t5-small",
    batch_size: int = 4,
):
    if version == "cpu" or version == "mps":
        pipeline = ChronosPipelineTorch.from_pretrained(
            model_name,
            device_map=version,
            torch_dtype=getattr(torch, dtype),
        )
        benchmark_fn = benchmark_torch_model
    else:
        pipeline = ChronosPipelineMLX.from_pretrained(model_name, dtype=dtype)
        benchmark_fn = benchmark_mlx_model

    result_df = benchmark_fn(
        pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size
    )
    result_df["model"] = model_name
    return result_df


if __name__ == "__main__":
    gluonts_dataset: str = "m4_hourly"
    model_name: str = "amazon/chronos-t5-mini"
    batch_size: int = 8
    dfs = []
    for version in ["cpu", "mps", "mlx"]:
        for dtype in ["float32"]:
            try:
                df = main(
                    version=version,
                    dtype=dtype,
                    model_name=model_name,
                    gluonts_dataset=gluonts_dataset,
                    batch_size=batch_size,
                )
                df["version"] = version
                df["dtype"] = dtype
                dfs.append(df)
            except TypeError:
                pass

    result_df = pd.concat(dfs).reset_index(drop=True)
    result_df.to_csv("benchmark.csv", index=False)

    result_df["version"] = result_df["version"].map(
        {"cpu": "Torch (CPU)", "mps": "Torch (MPS)", "mlx": "MLX"}
    )
    fig = plt.figure(figsize=(8, 5))
    g = sns.barplot(
        data=result_df,
        x="dtype",
        y="inference_time",
        hue="version",
        alpha=0.6,
    )
    plt.ylabel("Inference Time (on M1 Pro)")
    plt.title(f"{model_name} inference times on {gluonts_dataset} dataset")
    plt.savefig("benchmark.png", dpi=200)

```

## TODOs:
- [x] Implement `top_p` sampling.
- [x] Add tests.
- [x] Add CI.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
2024-04-08 15:03:44 +02:00

5.9 KiB

[🧪 MLX Version] Chronos: Learning the Language of Time Series

Important

This is the experimental MLX version of Chronos for Apple Silicon Macs. Please use the main branch for the stable PyTorch version.

Chronos is a family of pretrained time series forecasting models based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes.

For details on Chronos models, training data and procedures, and experimental results, please refer to the paper Chronos: Learning the Language of Time Series.


Fig. 1: High-level depiction of Chronos. (Left) The input time series is scaled and quantized to obtain a sequence of tokens. (Center) The tokens are fed into a language model which may either be an encoder-decoder or a decoder-only model. The model is trained using the cross-entropy loss. (Right) During inference, we autoregressively sample tokens from the model and map them back to numerical values. Multiple trajectories are sampled to obtain a predictive distribution.


Architecture

The models in this repository are based on the T5 architecture. The only difference is in the vocabulary size: Chronos-T5 models use 4096 different tokens, compared to 32128 of the original T5 models, resulting in fewer parameters.

Model Parameters Based on
chronos-t5-tiny 8M t5-efficient-tiny
chronos-t5-mini 20M t5-efficient-mini
chronos-t5-small 46M t5-efficient-small
chronos-t5-base 200M t5-efficient-base
chronos-t5-large 710M t5-efficient-large

Usage

To perform inference with Chronos models on Apple Silcon devices, install this package by running:

pip install git+https://github.com/amazon-science/chronos-forecasting.git@mlx

Forecasting

A minimal example showing how to perform forecasting using Chronos models:

# for plotting, run: pip install pandas matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    dtype="bfloat16",
)

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
    context,
    prediction_length,
    num_samples=20,
    temperature=1.0,
    top_k=50,
    top_p=1.0,
) # forecast shape: [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
plt.legend()
plt.grid()
plt.show()

Extracting Encoder Embeddings

A minimal example showing how to extract encoder embeddings from Chronos models:

import pandas as pd
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    dtype="bfloat16",
)

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
embeddings, tokenizer_state = pipeline.embed(context)

Citation

If you find Chronos models useful for your research, please consider citing the associated paper:

@article{ansari2024chronos,
  author  = {Ansari, Abdul Fatir and Stella, Lorenzo and Turkmen, Caner and Zhang, Xiyuan, and Mercado, Pedro and Shen, Huibin and Shchur, Oleksandr and Rangapuram, Syama Syndar and Pineda Arango, Sebastian and Kapoor, Shubham and Zschiegner, Jasper and Maddix, Danielle C. and Mahoney, Michael W. and Torkkola, Kari and Gordon Wilson, Andrew and Bohlke-Schneider, Michael and Wang, Yuyang},
  title   = {Chronos: Learning the Language of Time Series},
  journal = {arXiv preprint arXiv:2403.07815},
  year    = {2024}
}

Security

See CONTRIBUTING for more information.

License

This project is licensed under the Apache-2.0 License.