mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
Fix README example to use predict_quantiles (#220)
*Issue #, if available:* *Description of changes:* `predict` returns different things based on model type. This fixes the example to use `predict_quantiles` which will give correct quantiles. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
parent
4c43cfbdac
commit
e3bbda7207
1 changed files with 8 additions and 9 deletions
17
README.md
17
README.md
|
|
@ -114,16 +114,16 @@ df = pd.read_csv(
|
|||
|
||||
# context must be either a 1D tensor, a list of 1D tensors,
|
||||
# or a left-padded 2D tensor with batch as the first dimension
|
||||
# The original Chronos models generate forecast samples, so forecast has shape
|
||||
# [num_series, num_samples, prediction_length].
|
||||
# Chronos-Bolt models generate quantile forecasts, so forecast has shape
|
||||
# [num_series, num_quantiles, prediction_length].
|
||||
forecast = pipeline.predict(
|
||||
context=torch.tensor(df["#Passengers"]), prediction_length=12
|
||||
# quantiles is an fp32 tensor with shape [batch_size, prediction_length, num_quantile_levels]
|
||||
# mean is an fp32 tensor with shape [batch_size, prediction_length]
|
||||
quantiles, mean = pipeline.predict_quantiles(
|
||||
context=torch.tensor(df["#Passengers"]),
|
||||
prediction_length=12,
|
||||
quantile_levels=[0.1, 0.5, 0.9],
|
||||
)
|
||||
```
|
||||
|
||||
More options for `pipeline.predict` can be found with:
|
||||
For the original Chronos models, `pipeline.predict` can be used to draw forecast samples. More options for `predict_kwargs` in `pipeline.predict_quantiles` can be found with:
|
||||
|
||||
```python
|
||||
from chronos import ChronosPipeline, ChronosBoltPipeline
|
||||
|
|
@ -136,10 +136,9 @@ We can now visualize the forecast:
|
|||
|
||||
```python
|
||||
import matplotlib.pyplot as plt # requires: pip install matplotlib
|
||||
import numpy as np
|
||||
|
||||
forecast_index = range(len(df), len(df) + 12)
|
||||
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
|
||||
low, median, high = quantiles[0, :, 0], quantiles[0, :, 1], quantiles[0, :, 2]
|
||||
|
||||
plt.figure(figsize=(8, 4))
|
||||
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
|
||||
|
|
|
|||
Loading…
Reference in a new issue