Speed up predict_df (#437)

*Issue #, if available:*

*Description of changes:*
- Remove for-loop with numpy operations + single pd.DataFrame
construction


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Oleksandr Shchur 2025-12-17 13:44:43 +01:00 committed by GitHub
parent efb86e02c2
commit c50fed93df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 38 additions and 32 deletions

View file

@ -218,22 +218,26 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
quantiles_np = quantiles.numpy() # [n_series, horizon, num_quantiles]
mean_np = mean.numpy() # [n_series, horizon]
results_dfs = []
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
q_pred = quantiles_np[i] # (horizon, num_quantiles)
point_pred = mean_np[i] # (horizon)
series_ids = list(prediction_timestamps.keys())
future_ts = list(prediction_timestamps.values())
series_forecast_data = {id_column: series_id, timestamp_column: future_ts, "target_name": target}
series_forecast_data["predictions"] = point_pred
for q_idx, q_level in enumerate(quantile_levels):
series_forecast_data[str(q_level)] = q_pred[:, q_idx]
data = {
id_column: np.repeat(series_ids, prediction_length),
timestamp_column: np.concatenate(future_ts),
"target_name": target,
"predictions": mean_np.ravel(),
}
results_dfs.append(pd.DataFrame(series_forecast_data))
quantiles_flat = quantiles_np.reshape(-1, len(quantile_levels))
for q_idx, q_level in enumerate(quantile_levels):
data[str(q_level)] = quantiles_flat[:, q_idx]
predictions_df = pd.concat(results_dfs, ignore_index=True)
predictions_df.set_index(id_column, inplace=True)
predictions_df = predictions_df.loc[original_order]
predictions_df.reset_index(inplace=True)
predictions_df = pd.DataFrame(data)
# If validate_inputs=False, the df is used as-is without sorting by item_id, no reordering required
if validate_inputs:
predictions_df.set_index(id_column, inplace=True)
predictions_df = predictions_df.loc[original_order]
predictions_df.reset_index(inplace=True)
return predictions_df

View file

@ -9,7 +9,7 @@ import time
import warnings
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Callable
from typing import TYPE_CHECKING, Callable, Literal, Mapping, Sequence
import numpy as np
import torch
@ -914,27 +914,29 @@ class Chronos2Pipeline(BaseChronosPipeline):
quantiles_np = torch.stack(quantiles).numpy() # [n_tasks, n_variates, horizon, num_quantiles]
mean_np = torch.stack(mean).numpy() # [n_tasks, n_variates, horizon]
results_dfs = []
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
q_pred = quantiles_np[i] # (n_variates, prediction_length, len(quantile_levels))
point_pred = mean_np[i] # (n_variates, prediction_length)
n_tasks = len(prediction_timestamps)
n_variates = len(target)
for target_idx, target_col in enumerate(target):
series_forecast_data: dict[str | tuple[str, str], Any] = {
id_column: series_id,
timestamp_column: future_ts,
"target_name": target_col,
}
series_forecast_data["predictions"] = point_pred[target_idx]
for q_idx, q_level in enumerate(quantile_levels):
series_forecast_data[str(q_level)] = q_pred[target_idx, :, q_idx]
series_ids = list(prediction_timestamps.keys())
future_ts = list(prediction_timestamps.values())
results_dfs.append(pd.DataFrame(series_forecast_data))
data = {
id_column: np.repeat(series_ids, n_variates * prediction_length),
timestamp_column: np.concatenate([np.tile(ts, n_variates) for ts in future_ts]),
"target_name": np.tile(np.repeat(target, prediction_length), n_tasks),
"predictions": mean_np.ravel(),
}
predictions_df = pd.concat(results_dfs, ignore_index=True)
predictions_df.set_index(id_column, inplace=True)
predictions_df = predictions_df.loc[original_order]
predictions_df.reset_index(inplace=True)
quantiles_flat = quantiles_np.reshape(-1, len(quantile_levels))
for q_idx, q_level in enumerate(quantile_levels):
data[str(q_level)] = quantiles_flat[:, q_idx]
predictions_df = pd.DataFrame(data)
# If validate_inputs=False, the df is used as-is without sorting by item_id, no reordering required
if validate_inputs:
predictions_df.set_index(id_column, inplace=True)
predictions_df = predictions_df.loc[original_order]
predictions_df.reset_index(inplace=True)
return predictions_df