From e3bbda7207468acda35185fe36e79b1c4c8fbfc0 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Fri, 29 Nov 2024 16:54:48 +0100 Subject: [PATCH] Fix README example to use `predict_quantiles` (#220) *Issue #, if available:* *Description of changes:* `predict` returns different things based on model type. This fixes the example to use `predict_quantiles` which will give correct quantiles. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Abdul Fatir Ansari --- README.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 9489608..e921489 100644 --- a/README.md +++ b/README.md @@ -114,16 +114,16 @@ df = pd.read_csv( # context must be either a 1D tensor, a list of 1D tensors, # or a left-padded 2D tensor with batch as the first dimension -# The original Chronos models generate forecast samples, so forecast has shape -# [num_series, num_samples, prediction_length]. -# Chronos-Bolt models generate quantile forecasts, so forecast has shape -# [num_series, num_quantiles, prediction_length]. -forecast = pipeline.predict( - context=torch.tensor(df["#Passengers"]), prediction_length=12 +# quantiles is an fp32 tensor with shape [batch_size, prediction_length, num_quantile_levels] +# mean is an fp32 tensor with shape [batch_size, prediction_length] +quantiles, mean = pipeline.predict_quantiles( + context=torch.tensor(df["#Passengers"]), + prediction_length=12, + quantile_levels=[0.1, 0.5, 0.9], ) ``` -More options for `pipeline.predict` can be found with: +For the original Chronos models, `pipeline.predict` can be used to draw forecast samples. More options for `predict_kwargs` in `pipeline.predict_quantiles` can be found with: ```python from chronos import ChronosPipeline, ChronosBoltPipeline @@ -136,10 +136,9 @@ We can now visualize the forecast: ```python import matplotlib.pyplot as plt # requires: pip install matplotlib -import numpy as np forecast_index = range(len(df), len(df) + 12) -low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) +low, median, high = quantiles[0, :, 0], quantiles[0, :, 1], quantiles[0, :, 2] plt.figure(figsize=(8, 4)) plt.plot(df["#Passengers"], color="royalblue", label="historical data")