mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Speed up predict_df (#437)
*Issue #, if available:* *Description of changes:* - Remove for-loop with numpy operations + single pd.DataFrame construction By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
efb86e02c2
commit
c50fed93df
2 changed files with 38 additions and 32 deletions
|
|
@ -218,22 +218,26 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
quantiles_np = quantiles.numpy() # [n_series, horizon, num_quantiles]
|
||||
mean_np = mean.numpy() # [n_series, horizon]
|
||||
|
||||
results_dfs = []
|
||||
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
|
||||
q_pred = quantiles_np[i] # (horizon, num_quantiles)
|
||||
point_pred = mean_np[i] # (horizon)
|
||||
series_ids = list(prediction_timestamps.keys())
|
||||
future_ts = list(prediction_timestamps.values())
|
||||
|
||||
series_forecast_data = {id_column: series_id, timestamp_column: future_ts, "target_name": target}
|
||||
series_forecast_data["predictions"] = point_pred
|
||||
for q_idx, q_level in enumerate(quantile_levels):
|
||||
series_forecast_data[str(q_level)] = q_pred[:, q_idx]
|
||||
data = {
|
||||
id_column: np.repeat(series_ids, prediction_length),
|
||||
timestamp_column: np.concatenate(future_ts),
|
||||
"target_name": target,
|
||||
"predictions": mean_np.ravel(),
|
||||
}
|
||||
|
||||
results_dfs.append(pd.DataFrame(series_forecast_data))
|
||||
quantiles_flat = quantiles_np.reshape(-1, len(quantile_levels))
|
||||
for q_idx, q_level in enumerate(quantile_levels):
|
||||
data[str(q_level)] = quantiles_flat[:, q_idx]
|
||||
|
||||
predictions_df = pd.concat(results_dfs, ignore_index=True)
|
||||
predictions_df.set_index(id_column, inplace=True)
|
||||
predictions_df = predictions_df.loc[original_order]
|
||||
predictions_df.reset_index(inplace=True)
|
||||
predictions_df = pd.DataFrame(data)
|
||||
# If validate_inputs=False, the df is used as-is without sorting by item_id, no reordering required
|
||||
if validate_inputs:
|
||||
predictions_df.set_index(id_column, inplace=True)
|
||||
predictions_df = predictions_df.loc[original_order]
|
||||
predictions_df.reset_index(inplace=True)
|
||||
|
||||
return predictions_df
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import time
|
|||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Callable
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Mapping, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -914,27 +914,29 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
quantiles_np = torch.stack(quantiles).numpy() # [n_tasks, n_variates, horizon, num_quantiles]
|
||||
mean_np = torch.stack(mean).numpy() # [n_tasks, n_variates, horizon]
|
||||
|
||||
results_dfs = []
|
||||
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
|
||||
q_pred = quantiles_np[i] # (n_variates, prediction_length, len(quantile_levels))
|
||||
point_pred = mean_np[i] # (n_variates, prediction_length)
|
||||
n_tasks = len(prediction_timestamps)
|
||||
n_variates = len(target)
|
||||
|
||||
for target_idx, target_col in enumerate(target):
|
||||
series_forecast_data: dict[str | tuple[str, str], Any] = {
|
||||
id_column: series_id,
|
||||
timestamp_column: future_ts,
|
||||
"target_name": target_col,
|
||||
}
|
||||
series_forecast_data["predictions"] = point_pred[target_idx]
|
||||
for q_idx, q_level in enumerate(quantile_levels):
|
||||
series_forecast_data[str(q_level)] = q_pred[target_idx, :, q_idx]
|
||||
series_ids = list(prediction_timestamps.keys())
|
||||
future_ts = list(prediction_timestamps.values())
|
||||
|
||||
results_dfs.append(pd.DataFrame(series_forecast_data))
|
||||
data = {
|
||||
id_column: np.repeat(series_ids, n_variates * prediction_length),
|
||||
timestamp_column: np.concatenate([np.tile(ts, n_variates) for ts in future_ts]),
|
||||
"target_name": np.tile(np.repeat(target, prediction_length), n_tasks),
|
||||
"predictions": mean_np.ravel(),
|
||||
}
|
||||
|
||||
predictions_df = pd.concat(results_dfs, ignore_index=True)
|
||||
predictions_df.set_index(id_column, inplace=True)
|
||||
predictions_df = predictions_df.loc[original_order]
|
||||
predictions_df.reset_index(inplace=True)
|
||||
quantiles_flat = quantiles_np.reshape(-1, len(quantile_levels))
|
||||
for q_idx, q_level in enumerate(quantile_levels):
|
||||
data[str(q_level)] = quantiles_flat[:, q_idx]
|
||||
|
||||
predictions_df = pd.DataFrame(data)
|
||||
# If validate_inputs=False, the df is used as-is without sorting by item_id, no reordering required
|
||||
if validate_inputs:
|
||||
predictions_df.set_index(id_column, inplace=True)
|
||||
predictions_df = predictions_df.loc[original_order]
|
||||
predictions_df.reset_index(inplace=True)
|
||||
|
||||
return predictions_df
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue