Fix README example to use predict_quantiles (#220)

*Issue #, if available:*

*Description of changes:* `predict` returns different things based on
model type. This fixes the example to use `predict_quantiles` which will
give correct quantiles.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
Abdul Fatir 2024-11-29 16:54:48 +01:00 committed by GitHub
parent 4c43cfbdac
commit e3bbda7207
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -114,16 +114,16 @@ df = pd.read_csv(
# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
# The original Chronos models generate forecast samples, so forecast has shape
# [num_series, num_samples, prediction_length].
# Chronos-Bolt models generate quantile forecasts, so forecast has shape
# [num_series, num_quantiles, prediction_length].
forecast = pipeline.predict(
context=torch.tensor(df["#Passengers"]), prediction_length=12
# quantiles is an fp32 tensor with shape [batch_size, prediction_length, num_quantile_levels]
# mean is an fp32 tensor with shape [batch_size, prediction_length]
quantiles, mean = pipeline.predict_quantiles(
context=torch.tensor(df["#Passengers"]),
prediction_length=12,
quantile_levels=[0.1, 0.5, 0.9],
)
```
More options for `pipeline.predict` can be found with:
For the original Chronos models, `pipeline.predict` can be used to draw forecast samples. More options for `predict_kwargs` in `pipeline.predict_quantiles` can be found with:
```python
from chronos import ChronosPipeline, ChronosBoltPipeline
@ -136,10 +136,9 @@ We can now visualize the forecast:
```python
import matplotlib.pyplot as plt # requires: pip install matplotlib
import numpy as np
forecast_index = range(len(df), len(df) + 12)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
low, median, high = quantiles[0, :, 0], quantiles[0, :, 1], quantiles[0, :, 2]
plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")