diff --git a/README.md b/README.md index b82a181..96e6da2 100644 --- a/README.md +++ b/README.md @@ -68,18 +68,12 @@ pip install git+https://github.com/amazon-science/chronos-forecasting.git > [!TIP] > The recommended way of using Chronos for production use cases is through [AutoGluon](https://auto.gluon.ai), which features ensembling with other statistical and machine learning models for time series forecasting as well as seamless deployments on AWS with SageMaker 🧠. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html). -> [!NOTE] -> We have added 🧪experimental support for [MLX](https://github.com/ml-explore/mlx) inference. If you have an Apple Silicon Mac, check out the [`mlx`](https://github.com/amazon-science/chronos-forecasting/tree/mlx) branch of this repository for instructions on how to install and use the MLX version of Chronos. - ### Forecasting A minimal example showing how to perform forecasting using Chronos models: ```python -# for plotting, run: pip install pandas matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd +import pandas as pd # requires: pip install pandas import torch from chronos import ChronosPipeline @@ -93,19 +87,27 @@ df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnal # context must be either a 1D tensor, a list of 1D tensors, # or a left-padded 2D tensor with batch as the first dimension -context = torch.tensor(df["#Passengers"]) -prediction_length = 12 +# forecast shape: [num_series, num_samples, prediction_length] forecast = pipeline.predict( - context, - prediction_length, + context=torch.tensor(df["#Passengers"]), + prediction_length=12, num_samples=20, - temperature=1.0, - top_k=50, - top_p=1.0, -) # forecast shape: [num_series, num_samples, prediction_length] +) +``` -# visualize the forecast -forecast_index = range(len(df), len(df) + prediction_length) +More options for `pipeline.predict` can be found with: + +```python +print(ChronosPipeline.predict.__doc__) +``` + +We can now visualize the forecast: + +```python +import matplotlib.pyplot as plt # requires: pip install matplotlib +import numpy as np + +forecast_index = range(len(df), len(df) + 12) low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) plt.figure(figsize=(8, 4)) @@ -142,7 +144,7 @@ embeddings, tokenizer_state = pipeline.embed(context) ### Pretraining and fine-tuning -Scripts for pretraining and fine-tuning Chronos models can be found in [this folder](./scripts/training). +Scripts for pretraining and fine-tuning Chronos models can be found in [this folder](./scripts/). ## 🔥 Coverage