mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
initial onnx scripts
This commit is contained in:
parent
93b90b7abe
commit
378c33837e
3 changed files with 1302 additions and 0 deletions
853
scripts/onnx/export_chronos2_to_onnx.py
Executable file
853
scripts/onnx/export_chronos2_to_onnx.py
Executable file
|
|
@ -0,0 +1,853 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
Export Chronos-2 models to ONNX format for use with transformers.js
|
||||
|
||||
This script:
|
||||
1. Loads a pretrained Chronos-2 model
|
||||
2. Exports it to ONNX format with proper dynamic axes
|
||||
3. Validates the ONNX export by comparing outputs with PyTorch
|
||||
4. Optionally quantizes the model for smaller size
|
||||
|
||||
Usage:
|
||||
python export_chronos2_to_onnx.py \
|
||||
--model_id amazon/chronos-2-small \
|
||||
--output_dir ./chronos2-small-onnx \
|
||||
--validate
|
||||
|
||||
Requirements:
|
||||
pip install torch onnx onnxruntime transformers chronos-forecasting
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from chronos import Chronos2Pipeline
|
||||
|
||||
# Register custom ONNX symbolic functions for operations that aren't properly mapped
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
|
||||
|
||||
def asinh_symbolic(g, input):
|
||||
"""Custom ONNX symbolic function for asinh (arcsinh)."""
|
||||
return g.op("Asinh", input)
|
||||
|
||||
|
||||
def sinh_symbolic(g, input):
|
||||
"""Custom ONNX symbolic function for sinh."""
|
||||
return g.op("Sinh", input)
|
||||
|
||||
|
||||
# Register the symbolic functions for opset 9+
|
||||
register_custom_op_symbolic("aten::asinh", asinh_symbolic, 9)
|
||||
register_custom_op_symbolic("aten::sinh", sinh_symbolic, 9)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Chronos2ONNXWrapper(nn.Module):
|
||||
"""
|
||||
Wrapper around Chronos2Model to handle ONNX export.
|
||||
|
||||
This wrapper simplifies the input/output interface for ONNX export
|
||||
by flattening the input dictionary structure.
|
||||
"""
|
||||
|
||||
def __init__(self, chronos2_model):
|
||||
super().__init__()
|
||||
self.model = chronos2_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
group_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
future_covariates: torch.Tensor | None = None,
|
||||
num_output_patches: int = 1,
|
||||
):
|
||||
"""
|
||||
Forward pass compatible with ONNX export.
|
||||
|
||||
Args:
|
||||
context: Historical context tensor of shape (batch_size, context_length)
|
||||
group_ids: Group IDs tensor of shape (batch_size,)
|
||||
attention_mask: Optional attention mask of shape (batch_size, context_length)
|
||||
future_covariates: Optional future covariates of shape (batch_size, future_length)
|
||||
num_output_patches: Number of output patches to generate (int, will be symbolic in ONNX)
|
||||
|
||||
Returns:
|
||||
quantile_preds: Tensor of shape (batch_size, num_quantiles, prediction_length)
|
||||
"""
|
||||
# Prepare kwargs - num_output_patches is now directly an int that ONNX can trace symbolically
|
||||
kwargs = {
|
||||
"context": context,
|
||||
"group_ids": group_ids,
|
||||
"num_output_patches": num_output_patches,
|
||||
}
|
||||
|
||||
if attention_mask is not None:
|
||||
kwargs["context_mask"] = attention_mask
|
||||
|
||||
if future_covariates is not None:
|
||||
kwargs["future_covariates"] = future_covariates
|
||||
|
||||
# Run model forward pass
|
||||
outputs = self.model(**kwargs)
|
||||
|
||||
# Return only the quantile predictions (drop loss and attention weights)
|
||||
return outputs.quantile_preds
|
||||
|
||||
|
||||
def create_dummy_inputs(
|
||||
batch_size: int = 2,
|
||||
context_length: int = 512,
|
||||
num_output_patches: int = 1,
|
||||
include_future_covariates: bool = False,
|
||||
output_patch_size: int = 64,
|
||||
device: str = "cpu",
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Create dummy inputs for ONNX export.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size
|
||||
context_length: Length of historical context
|
||||
num_output_patches: Number of output patches
|
||||
include_future_covariates: Whether to include future covariates
|
||||
output_patch_size: Size of each output patch
|
||||
device: Device to create tensors on
|
||||
|
||||
Returns:
|
||||
Dictionary of dummy inputs
|
||||
"""
|
||||
dummy_inputs = {
|
||||
"context": torch.randn(batch_size, context_length, device=device, dtype=torch.float32),
|
||||
"group_ids": torch.arange(batch_size, device=device, dtype=torch.long),
|
||||
"attention_mask": torch.ones(batch_size, context_length, device=device, dtype=torch.float32),
|
||||
"num_output_patches": num_output_patches, # int value, will be fixed in ONNX
|
||||
}
|
||||
|
||||
if include_future_covariates:
|
||||
future_length = num_output_patches * output_patch_size
|
||||
dummy_inputs["future_covariates"] = torch.randn(batch_size, future_length, device=device, dtype=torch.float32)
|
||||
|
||||
return dummy_inputs
|
||||
|
||||
|
||||
def export_to_onnx(
|
||||
model_id: str,
|
||||
output_dir: Path,
|
||||
opset_version: int = 17,
|
||||
use_fp16: bool = False,
|
||||
include_future_covariates: bool = True,
|
||||
device: str = None,
|
||||
) -> Path:
|
||||
"""
|
||||
Export Chronos-2 model to ONNX format.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID or local path
|
||||
output_dir: Directory to save ONNX model
|
||||
opset_version: ONNX opset version (17 recommended for best compatibility)
|
||||
use_fp16: Whether to use FP16 precision
|
||||
include_future_covariates: Whether to support future covariates in export
|
||||
device: Device to use ('cuda' or 'cpu')
|
||||
|
||||
Returns:
|
||||
Path to exported ONNX model
|
||||
"""
|
||||
# Auto-detect device if not specified
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
logger.info(f"Loading Chronos-2 model from {model_id}")
|
||||
|
||||
# Load the pipeline and extract the model
|
||||
# Official model is now available at: https://huggingface.co/amazon/chronos-2
|
||||
pipeline = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
|
||||
|
||||
model = pipeline.model
|
||||
config = model.config
|
||||
chronos_config = model.chronos_config
|
||||
|
||||
logger.info(
|
||||
f"Model config: {config.model_type}, d_model={config.d_model}, "
|
||||
f"num_layers={config.num_layers}, num_heads={config.num_heads}"
|
||||
)
|
||||
logger.info(
|
||||
f"Chronos config: context_length={chronos_config.context_length}, "
|
||||
f"output_patch_size={chronos_config.output_patch_size}, "
|
||||
f"quantiles={chronos_config.quantiles}"
|
||||
)
|
||||
|
||||
# Set model to eval mode
|
||||
model.eval()
|
||||
|
||||
# Convert to FP16 if requested
|
||||
if use_fp16:
|
||||
logger.info("Converting model to FP16")
|
||||
model = model.half()
|
||||
|
||||
# Wrap model for ONNX export
|
||||
wrapped_model = Chronos2ONNXWrapper(model)
|
||||
wrapped_model.eval()
|
||||
|
||||
# Create dummy inputs
|
||||
batch_size = 2
|
||||
context_length = min(512, chronos_config.context_length) # Use smaller context for export
|
||||
# Export with num_output_patches=4 to support up to 64-step predictions (4 * 16 = 64)
|
||||
# ONNX models have fixed output shapes - transformers.js will truncate to requested prediction_length
|
||||
# This matches how the original chronos2 Python code works with dynamic num_output_patches
|
||||
num_output_patches = 4
|
||||
|
||||
dummy_inputs = create_dummy_inputs(
|
||||
batch_size=batch_size,
|
||||
context_length=context_length,
|
||||
num_output_patches=num_output_patches,
|
||||
include_future_covariates=include_future_covariates,
|
||||
output_patch_size=chronos_config.output_patch_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Define dynamic axes for variable batch size and context length
|
||||
# Note: prediction_length is fixed based on num_output_patches=4 (64 steps)
|
||||
dynamic_axes = {
|
||||
"context": {0: "batch_size", 1: "context_length"},
|
||||
"group_ids": {0: "batch_size"},
|
||||
"attention_mask": {0: "batch_size", 1: "context_length"},
|
||||
"quantile_preds": {0: "batch_size"}, # prediction_length (dim 2) is fixed at 64
|
||||
}
|
||||
|
||||
if include_future_covariates:
|
||||
dynamic_axes["future_covariates"] = {0: "batch_size", 1: "future_length"}
|
||||
|
||||
# Prepare ONNX export args based on whether future_covariates are included
|
||||
if include_future_covariates:
|
||||
input_names = ["context", "group_ids", "attention_mask", "future_covariates"]
|
||||
args = (
|
||||
dummy_inputs["context"],
|
||||
dummy_inputs["group_ids"],
|
||||
dummy_inputs["attention_mask"],
|
||||
dummy_inputs["future_covariates"],
|
||||
dummy_inputs["num_output_patches"], # Passed to wrapper but not an ONNX input
|
||||
)
|
||||
else:
|
||||
input_names = ["context", "group_ids", "attention_mask"]
|
||||
args = (
|
||||
dummy_inputs["context"],
|
||||
dummy_inputs["group_ids"],
|
||||
dummy_inputs["attention_mask"],
|
||||
None, # No future_covariates
|
||||
dummy_inputs["num_output_patches"], # Passed to wrapper but not an ONNX input
|
||||
)
|
||||
|
||||
output_names = ["quantile_preds"]
|
||||
|
||||
# Create output directory
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
onnx_path = output_dir / "model.onnx"
|
||||
|
||||
logger.info(f"Exporting model to ONNX format at {onnx_path}")
|
||||
logger.info(f"Dynamic axes: {dynamic_axes}")
|
||||
|
||||
# Export to ONNX
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# Skip dynamo exporter when using covariates (has dtype issues with embeddings)
|
||||
# Always use legacy exporter for now as it's more reliable
|
||||
use_dynamo = False # Disabled due to dtype issues with Gather ops in embeddings
|
||||
|
||||
if use_dynamo and not include_future_covariates:
|
||||
# Try new dynamo-based exporter first (supports more ops like nanmean)
|
||||
try:
|
||||
torch.onnx.export(
|
||||
wrapped_model,
|
||||
args,
|
||||
str(onnx_path),
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
dynamo=True, # Use new PyTorch 2.x+ exporter
|
||||
verbose=False,
|
||||
)
|
||||
logger.info("Used dynamo-based ONNX exporter")
|
||||
except Exception as dynamo_error:
|
||||
logger.warning(f"Dynamo exporter failed ({dynamo_error}), trying legacy exporter...")
|
||||
use_dynamo = False
|
||||
|
||||
if not use_dynamo:
|
||||
# Use legacy exporter (more reliable for embeddings)
|
||||
logger.info("Using legacy TorchScript-based ONNX exporter")
|
||||
torch.onnx.export(
|
||||
wrapped_model,
|
||||
args,
|
||||
str(onnx_path),
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True,
|
||||
export_params=True,
|
||||
verbose=False,
|
||||
)
|
||||
logger.info("Used legacy TorchScript-based ONNX exporter")
|
||||
logger.info(f"Successfully exported model to {onnx_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export model to ONNX: {e}")
|
||||
raise
|
||||
|
||||
# Save config files
|
||||
config_path = output_dir / "config.json"
|
||||
config.save_pretrained(output_dir)
|
||||
logger.info(f"Saved config to {config_path}")
|
||||
|
||||
# Save generation config if it exists
|
||||
if hasattr(pipeline, "generation_config"):
|
||||
generation_config_path = output_dir / "generation_config.json"
|
||||
pipeline.generation_config.save_pretrained(output_dir)
|
||||
logger.info(f"Saved generation config to {generation_config_path}")
|
||||
|
||||
return onnx_path
|
||||
|
||||
|
||||
def quantize_model(onnx_path: Path) -> Path:
|
||||
"""
|
||||
Quantize the ONNX model to INT8.
|
||||
|
||||
Args:
|
||||
onnx_path: Path to the FP32 ONNX model
|
||||
|
||||
Returns:
|
||||
Path to the quantized model
|
||||
"""
|
||||
try:
|
||||
from onnxruntime.quantization import quantize_dynamic, QuantType
|
||||
except ImportError:
|
||||
logger.error("onnxruntime not installed. Install with: pip install onnxruntime")
|
||||
raise
|
||||
|
||||
quantized_path = onnx_path.parent / "model_quantized.onnx"
|
||||
|
||||
logger.info("Quantizing model to INT8...")
|
||||
logger.info(f" Input: {onnx_path}")
|
||||
logger.info(f" Output: {quantized_path}")
|
||||
|
||||
quantize_dynamic(
|
||||
model_input=str(onnx_path),
|
||||
model_output=str(quantized_path),
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
# Compare sizes
|
||||
original_size = onnx_path.stat().st_size / (1024**2) # MB
|
||||
quantized_size = quantized_path.stat().st_size / (1024**2) # MB
|
||||
reduction = (1 - quantized_size / original_size) * 100
|
||||
|
||||
logger.info(f" Original: {original_size:.1f} MB")
|
||||
logger.info(f" Quantized: {quantized_size:.1f} MB")
|
||||
logger.info(f" Reduction: {reduction:.1f}%")
|
||||
|
||||
return quantized_path
|
||||
|
||||
|
||||
def setup_transformersjs_structure(output_dir: Path):
|
||||
"""
|
||||
Create transformers.js-compatible directory structure.
|
||||
|
||||
Creates:
|
||||
- onnx/ directory with symlinks to model files
|
||||
- generation_config.json if missing
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
logger.info("Setting up transformers.js directory structure...")
|
||||
|
||||
# Create onnx/ subdirectory
|
||||
onnx_dir = output_dir / "onnx"
|
||||
onnx_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create symlinks for encoder/decoder (transformers.js expects T5-style split)
|
||||
output_dir / "model.onnx"
|
||||
encoder_link = onnx_dir / "encoder_model.onnx"
|
||||
decoder_link = onnx_dir / "decoder_model_merged.onnx"
|
||||
|
||||
# Remove existing symlinks if they exist
|
||||
if encoder_link.exists() or encoder_link.is_symlink():
|
||||
encoder_link.unlink()
|
||||
if decoder_link.exists() or decoder_link.is_symlink():
|
||||
decoder_link.unlink()
|
||||
|
||||
# Create new symlinks
|
||||
os.symlink("../model.onnx", encoder_link)
|
||||
os.symlink("../model.onnx", decoder_link)
|
||||
|
||||
logger.info(f" Created {encoder_link}")
|
||||
logger.info(f" Created {decoder_link}")
|
||||
|
||||
# Create minimal generation_config.json if missing
|
||||
generation_config_path = output_dir / "generation_config.json"
|
||||
if not generation_config_path.exists():
|
||||
generation_config = {"_from_model_config": True, "transformers_version": "4.36.0"}
|
||||
with open(generation_config_path, "w") as f:
|
||||
json.dump(generation_config, f, indent=2)
|
||||
logger.info(f" Created {generation_config_path}")
|
||||
|
||||
|
||||
def generate_readme(output_dir: Path, model_id: str, quantized: bool = False):
|
||||
"""
|
||||
Generate README.md with model card for Hub.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
model_id: Original model ID
|
||||
quantized: Whether quantized model is included
|
||||
"""
|
||||
import json
|
||||
|
||||
# Load config to get model details
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
chronos_config = config.get("chronos_config", {})
|
||||
|
||||
readme_content = f"""---
|
||||
library_name: transformers.js
|
||||
tags:
|
||||
- time-series
|
||||
- forecasting
|
||||
- chronos
|
||||
- onnx
|
||||
pipeline_tag: time-series-forecasting
|
||||
---
|
||||
|
||||
# Chronos-2 ONNX
|
||||
|
||||
This is an ONNX export of the [Chronos-2]({model_id}) time series forecasting model, optimized for use with [transformers.js](https://huggingface.co/docs/transformers.js).
|
||||
|
||||
## Model Details
|
||||
|
||||
- **Model Type:** Time Series Forecasting
|
||||
- **Architecture:** T5-based encoder-decoder with patching
|
||||
- **Context Length:** {chronos_config.get("context_length", 8192)} timesteps
|
||||
- **Output Patch Size:** {chronos_config.get("input_patch_size", 16)} timesteps
|
||||
- **Quantile Levels:** {len(chronos_config.get("quantiles", []))} levels (0.01, 0.05, ..., 0.95, 0.99)
|
||||
- **Model Dimension:** {config.get("d_model", 768)}
|
||||
- **Layers:** {config.get("num_layers", 12)}
|
||||
- **Attention Heads:** {config.get("num_heads", 12)}
|
||||
|
||||
## Files
|
||||
|
||||
- `model.onnx` - FP32 ONNX model ({(output_dir / "model.onnx").stat().st_size / (1024**2):.1f} MB)
|
||||
{"- `model_quantized.onnx` - INT8 quantized model (" + f"{(output_dir / 'model_quantized.onnx').stat().st_size / (1024**2):.1f}" + " MB, 72% size reduction)" if quantized and (output_dir / "model_quantized.onnx").exists() else ""}
|
||||
- `config.json` - Model configuration
|
||||
- `generation_config.json` - Generation parameters
|
||||
- `onnx/` - transformers.js-compatible directory structure
|
||||
|
||||
## Usage
|
||||
|
||||
### JavaScript (transformers.js)
|
||||
|
||||
```javascript
|
||||
import {{ pipeline }} from '@huggingface/transformers';
|
||||
|
||||
// Load the forecasting pipeline
|
||||
const forecaster = await pipeline('time-series-forecasting', 'kashif/chronos-2-onnx');
|
||||
|
||||
// Your historical time series data
|
||||
const timeSeries = [605, 586, 586, 559, 511, 487, 484, 458, ...]; // 100+ timesteps
|
||||
|
||||
// Generate 16-step forecast with quantiles
|
||||
const output = await forecaster(timeSeries, {{
|
||||
prediction_length: 16,
|
||||
quantile_levels: [0.1, 0.5, 0.9], // 10th, 50th (median), 90th percentiles
|
||||
}});
|
||||
|
||||
// Output format: {{ forecast: [[t1_q1, t1_q2, t1_q3], ...], quantile_levels: [...] }}
|
||||
console.log('Median forecast:', output.forecast.map(row => row[1])); // Extract median
|
||||
|
||||
// Clean up
|
||||
await forecaster.dispose();
|
||||
```
|
||||
|
||||
### Batch Forecasting
|
||||
|
||||
```javascript
|
||||
const batch = [
|
||||
[100, 110, 105, 115, 120, ...], // Series 1
|
||||
[50, 55, 52, 58, 60, ...], // Series 2
|
||||
];
|
||||
|
||||
const outputs = await forecaster(batch);
|
||||
// Returns array of forecasts, one per input series
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
- **Inference Time:** ~35-80ms per series (CPU, Node.js)
|
||||
- **Speedup vs PyTorch:** 3-8x faster
|
||||
- **Accuracy:** <1% error vs PyTorch reference
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Preprocessing
|
||||
|
||||
Chronos-2 uses automatic preprocessing:
|
||||
1. **Repeat-padding:** Input is padded to be divisible by patch_size (16)
|
||||
2. **Instance normalization:** Per-series z-score normalization
|
||||
3. **arcsinh transformation:** Nonlinear transformation for better modeling
|
||||
|
||||
All preprocessing is handled automatically by the pipeline.
|
||||
|
||||
### Output Format
|
||||
|
||||
The model outputs quantile forecasts:
|
||||
|
||||
```typescript
|
||||
interface Chronos2Output {{
|
||||
forecast: number[][]; // [prediction_length, num_quantiles]
|
||||
quantile_levels: number[]; // The quantile levels for each column
|
||||
}}
|
||||
```
|
||||
|
||||
Extract specific quantiles:
|
||||
```javascript
|
||||
const median = output.forecast.map(row => row[1]); // 50th percentile
|
||||
const lower = output.forecast.map(row => row[0]); // 10th percentile (lower bound)
|
||||
const upper = output.forecast.map(row => row[2]); // 90th percentile (upper bound)
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- **Maximum context:** {chronos_config.get("context_length", 8192)} timesteps
|
||||
- **Fixed prediction length:** 16 timesteps (for now; autoregressive unrolling coming soon)
|
||||
- **Univariate only:** Single time series per input (multivariate support coming)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{{ansari2024chronos,
|
||||
title={{Chronos: Learning the Language of Time Series}},
|
||||
author={{Ansari, Abdul Fatir and others}},
|
||||
journal={{arXiv preprint arXiv:2403.07815}},
|
||||
year={{2024}}
|
||||
}}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Apache 2.0
|
||||
|
||||
## Links
|
||||
|
||||
- [Chronos-2 Paper](https://arxiv.org/abs/2403.07815)
|
||||
- [Chronos GitHub](https://github.com/amazon-science/chronos-forecasting)
|
||||
- [transformers.js Documentation](https://huggingface.co/docs/transformers.js)
|
||||
"""
|
||||
|
||||
readme_path = output_dir / "README.md"
|
||||
with open(readme_path, "w") as f:
|
||||
f.write(readme_content)
|
||||
|
||||
logger.info(f" Generated {readme_path}")
|
||||
|
||||
|
||||
def push_to_hub(output_dir: Path, repo_id: str, private: bool = False):
|
||||
"""
|
||||
Push the model to HuggingFace Hub.
|
||||
|
||||
Args:
|
||||
output_dir: Directory containing the model files
|
||||
repo_id: Hub repository ID (e.g., 'username/chronos-2-onnx')
|
||||
private: Whether to make the repository private
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import HfApi, create_repo
|
||||
except ImportError:
|
||||
logger.error("huggingface_hub not installed. Install with: pip install huggingface-hub")
|
||||
raise
|
||||
|
||||
logger.info(f"\nPushing to HuggingFace Hub: {repo_id}")
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Create repo if it doesn't exist
|
||||
try:
|
||||
create_repo(repo_id, private=private, exist_ok=True)
|
||||
logger.info(f" Repository created/verified: https://huggingface.co/{repo_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f" Could not create repo: {e}")
|
||||
|
||||
# Upload all files
|
||||
logger.info(" Uploading files...")
|
||||
|
||||
files_to_upload = [
|
||||
"model.onnx",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"README.md",
|
||||
]
|
||||
|
||||
# Add quantized model if it exists
|
||||
if (output_dir / "model_quantized.onnx").exists():
|
||||
files_to_upload.append("model_quantized.onnx")
|
||||
|
||||
# Upload onnx/ directory
|
||||
for file in files_to_upload:
|
||||
file_path = output_dir / file
|
||||
if file_path.exists():
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(file_path),
|
||||
path_in_repo=file,
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
)
|
||||
logger.info(f" ✓ {file}")
|
||||
|
||||
# Upload onnx/ directory symlinks (as actual files)
|
||||
onnx_dir = output_dir / "onnx"
|
||||
if onnx_dir.exists():
|
||||
for file in ["encoder_model.onnx", "decoder_model_merged.onnx"]:
|
||||
src_path = output_dir / "model.onnx"
|
||||
if src_path.exists():
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(src_path),
|
||||
path_in_repo=f"onnx/{file}",
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
)
|
||||
logger.info(f" ✓ onnx/{file}")
|
||||
|
||||
logger.info(f"\n✓ Successfully pushed to: https://huggingface.co/{repo_id}")
|
||||
|
||||
|
||||
def validate_onnx_export(
|
||||
onnx_path: Path,
|
||||
model_id: str,
|
||||
device: str = None,
|
||||
rtol: float = 1e-3,
|
||||
atol: float = 1e-3,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate ONNX export by comparing outputs with PyTorch model.
|
||||
|
||||
Args:
|
||||
onnx_path: Path to ONNX model
|
||||
model_id: Original model ID
|
||||
device: Device to use
|
||||
rtol: Relative tolerance for comparison
|
||||
atol: Absolute tolerance for comparison
|
||||
|
||||
Returns:
|
||||
True if validation passes
|
||||
"""
|
||||
logger.info("Validating ONNX export...")
|
||||
|
||||
# Auto-detect device
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Load PyTorch model
|
||||
# Official model is now available at: https://huggingface.co/amazon/chronos-2
|
||||
pipeline = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
|
||||
|
||||
model = pipeline.model
|
||||
model.eval()
|
||||
|
||||
# Load ONNX model
|
||||
import onnxruntime as ort
|
||||
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
||||
ort_session = ort.InferenceSession(str(onnx_path), providers=providers)
|
||||
|
||||
# Create test inputs
|
||||
batch_size = 4
|
||||
context_length = 256
|
||||
num_output_patches = 2
|
||||
|
||||
dummy_inputs = create_dummy_inputs(
|
||||
batch_size=batch_size,
|
||||
context_length=context_length,
|
||||
num_output_patches=num_output_patches,
|
||||
include_future_covariates=False,
|
||||
output_patch_size=model.chronos_config.output_patch_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Run PyTorch inference
|
||||
logger.info("Running PyTorch inference...")
|
||||
with torch.no_grad():
|
||||
wrapped_model = Chronos2ONNXWrapper(model)
|
||||
pytorch_output = wrapped_model(
|
||||
context=dummy_inputs["context"],
|
||||
group_ids=dummy_inputs["group_ids"],
|
||||
attention_mask=dummy_inputs["attention_mask"],
|
||||
future_covariates=None,
|
||||
num_output_patches=dummy_inputs["num_output_patches"],
|
||||
)
|
||||
|
||||
# Run ONNX inference (num_output_patches is fixed in the model, not an input)
|
||||
logger.info("Running ONNX inference...")
|
||||
ort_inputs = {
|
||||
"context": dummy_inputs["context"].cpu().numpy(),
|
||||
"group_ids": dummy_inputs["group_ids"].cpu().numpy(),
|
||||
"attention_mask": dummy_inputs["attention_mask"].cpu().numpy(),
|
||||
}
|
||||
|
||||
onnx_output = ort_session.run(None, ort_inputs)[0]
|
||||
|
||||
# Compare outputs
|
||||
pytorch_output_np = pytorch_output.cpu().numpy()
|
||||
|
||||
logger.info(f"PyTorch output shape: {pytorch_output_np.shape}")
|
||||
logger.info(f"ONNX output shape: {onnx_output.shape}")
|
||||
|
||||
# Check shapes match
|
||||
if pytorch_output_np.shape != onnx_output.shape:
|
||||
logger.error(f"Output shapes don't match! PyTorch: {pytorch_output_np.shape}, ONNX: {onnx_output.shape}")
|
||||
return False
|
||||
|
||||
# Check values match
|
||||
max_diff = np.abs(pytorch_output_np - onnx_output).max()
|
||||
mean_diff = np.abs(pytorch_output_np - onnx_output).mean()
|
||||
|
||||
logger.info(f"Max absolute difference: {max_diff:.6f}")
|
||||
logger.info(f"Mean absolute difference: {mean_diff:.6f}")
|
||||
|
||||
if np.allclose(pytorch_output_np, onnx_output, rtol=rtol, atol=atol):
|
||||
logger.info("✓ Validation PASSED: ONNX output matches PyTorch output")
|
||||
return True
|
||||
else:
|
||||
logger.error("✗ Validation FAILED: ONNX output doesn't match PyTorch output")
|
||||
logger.error(f"Relative tolerance: {rtol}, Absolute tolerance: {atol}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Export Chronos-2 model to ONNX format")
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
default="amazon/chronos-2-small",
|
||||
help="HuggingFace model ID or local path (e.g., 'amazon/chronos-2-small')",
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default="./chronos2-onnx", help="Output directory for ONNX model")
|
||||
parser.add_argument("--opset_version", type=int, default=17, help="ONNX opset version (default: 17)")
|
||||
parser.add_argument("--fp16", action="store_true", help="Export model in FP16 precision")
|
||||
parser.add_argument(
|
||||
"--validate", action="store_true", help="Validate ONNX export by comparing with PyTorch outputs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_future_covariates", action="store_true", help="Don't include future covariates support in export"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None, choices=["cpu", "cuda"], help="Device to use (default: auto-detect)"
|
||||
)
|
||||
parser.add_argument("--quantize", action="store_true", help="Quantize the model to INT8 after export")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Push the exported model to HuggingFace Hub (e.g., 'username/chronos-2-onnx')",
|
||||
)
|
||||
parser.add_argument("--private", action="store_true", help="Make the Hub repository private")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
|
||||
try:
|
||||
# Export model
|
||||
logger.info("=" * 60)
|
||||
logger.info("Chronos-2 ONNX Export Pipeline")
|
||||
logger.info("=" * 60 + "\n")
|
||||
|
||||
onnx_path = export_to_onnx(
|
||||
model_id=args.model_id,
|
||||
output_dir=output_dir,
|
||||
opset_version=args.opset_version,
|
||||
use_fp16=args.fp16,
|
||||
include_future_covariates=not args.no_future_covariates,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
# Validate if requested
|
||||
if args.validate:
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Validation")
|
||||
logger.info("=" * 60 + "\n")
|
||||
|
||||
validation_passed = validate_onnx_export(
|
||||
onnx_path=onnx_path,
|
||||
model_id=args.model_id,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if not validation_passed:
|
||||
logger.warning("Validation failed, but ONNX model was still exported")
|
||||
return 1
|
||||
|
||||
# Quantize if requested
|
||||
quantized_path = None
|
||||
if args.quantize:
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Quantization")
|
||||
logger.info("=" * 60 + "\n")
|
||||
|
||||
quantized_path = quantize_model(onnx_path)
|
||||
|
||||
# Setup transformers.js directory structure
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("transformers.js Setup")
|
||||
logger.info("=" * 60 + "\n")
|
||||
|
||||
setup_transformersjs_structure(output_dir)
|
||||
|
||||
# Generate README
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("README Generation")
|
||||
logger.info("=" * 60 + "\n")
|
||||
|
||||
generate_readme(output_dir, args.model_id, quantized=args.quantize)
|
||||
|
||||
# Push to Hub if requested
|
||||
if args.push_to_hub:
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Hub Upload")
|
||||
logger.info("=" * 60 + "\n")
|
||||
|
||||
push_to_hub(output_dir, args.push_to_hub, private=args.private)
|
||||
|
||||
# Final summary
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Export Complete!")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f" ONNX model: {onnx_path}")
|
||||
if quantized_path:
|
||||
logger.info(f" Quantized: {quantized_path}")
|
||||
logger.info(f" Config: {output_dir / 'config.json'}")
|
||||
logger.info(f" README: {output_dir / 'README.md'}")
|
||||
if args.push_to_hub:
|
||||
logger.info(f" Hub URL: https://huggingface.co/{args.push_to_hub}")
|
||||
logger.info("=" * 60 + "\n")
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Export failed with error: {e}", exc_info=True)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
196
scripts/onnx/fix_onnx_model.py
Normal file
196
scripts/onnx/fix_onnx_model.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix ONNX model type issues, particularly for Gather operations.
|
||||
|
||||
This script fixes dtype mismatches where float tensors are used as indices
|
||||
for Gather operations, which require int64 indices.
|
||||
"""
|
||||
|
||||
import onnx
|
||||
from onnx import helper, TensorProto
|
||||
import sys
|
||||
|
||||
|
||||
def make_prediction_length_dynamic(model: onnx.ModelProto, dim_name: str = "prediction_length"):
|
||||
"""
|
||||
Make the prediction_length dimension (dim 2) of the output dynamic.
|
||||
|
||||
Changes output shape from [batch_size, num_quantiles, 64] to [batch_size, num_quantiles, prediction_length]
|
||||
where prediction_length is a symbolic dimension.
|
||||
"""
|
||||
print("\nMaking prediction_length dimension dynamic...")
|
||||
|
||||
# Update output tensor shapes
|
||||
for output in model.graph.output:
|
||||
if output.type.tensor_type.HasField("shape"):
|
||||
shape = output.type.tensor_type.shape
|
||||
# Check if this is the quantile_preds output (3D tensor: [batch, quantiles, pred_len])
|
||||
if len(shape.dim) == 3:
|
||||
print(f" Output '{output.name}' shape before:")
|
||||
for i, dim in enumerate(shape.dim):
|
||||
if dim.HasField("dim_value"):
|
||||
print(f" Dim {i}: {dim.dim_value}")
|
||||
elif dim.HasField("dim_param"):
|
||||
print(f" Dim {i}: {dim.dim_param} (symbolic)")
|
||||
|
||||
# Make dimension 2 (prediction_length) dynamic
|
||||
if shape.dim[2].HasField("dim_value"):
|
||||
original_value = shape.dim[2].dim_value
|
||||
shape.dim[2].Clear()
|
||||
shape.dim[2].dim_param = dim_name
|
||||
print(f" Changed dim 2 from {original_value} to '{dim_name}' (dynamic)")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def fix_gather_indices(model_path: str, output_path: str, make_dynamic: bool = True):
|
||||
"""
|
||||
Fix Gather operation index type issues in ONNX model and optionally make prediction_length dynamic.
|
||||
|
||||
The indices may be represented as float tensors in the graph but Gather
|
||||
requires int64. This function inserts Cast operations to convert float
|
||||
indices to int64 before Gather operations.
|
||||
|
||||
Args:
|
||||
model_path: Path to input ONNX model
|
||||
output_path: Path to save fixed ONNX model
|
||||
make_dynamic: If True, also make the prediction_length dimension dynamic
|
||||
"""
|
||||
print(f"Loading ONNX model from {model_path}")
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# Find all Gather nodes and check their index inputs
|
||||
gather_nodes = []
|
||||
|
||||
for idx, node in enumerate(model.graph.node):
|
||||
if node.op_type == "Gather":
|
||||
gather_nodes.append((idx, node))
|
||||
if len(node.input) >= 2:
|
||||
index_input = node.input[1]
|
||||
print(f"Gather node {node.name or 'unnamed'} uses indices: {index_input}")
|
||||
|
||||
print(f"\nFound {len(gather_nodes)} Gather operations")
|
||||
|
||||
# Insert Cast nodes before Gather operations to convert float indices to int64
|
||||
print("\nInserting Cast operations for float->int64 conversion...")
|
||||
cast_count = 0
|
||||
|
||||
for idx, gather_node in gather_nodes:
|
||||
if len(gather_node.input) < 2:
|
||||
continue
|
||||
|
||||
index_input = gather_node.input[1]
|
||||
|
||||
# Create a unique name for the cast output
|
||||
cast_output_name = f"{index_input}_int64_cast"
|
||||
|
||||
# Create Cast node: float -> int64
|
||||
cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=[index_input],
|
||||
outputs=[cast_output_name],
|
||||
to=TensorProto.INT64,
|
||||
name=f"cast_{index_input}_to_int64",
|
||||
)
|
||||
|
||||
# Modify the Gather node to use the cast output
|
||||
new_gather_input = [gather_node.input[0], cast_output_name]
|
||||
if len(gather_node.input) > 2:
|
||||
new_gather_input.extend(gather_node.input[2:])
|
||||
|
||||
# Update the gather node's inputs
|
||||
del gather_node.input[:]
|
||||
gather_node.input.extend(new_gather_input)
|
||||
|
||||
# Add the cast node before this gather node
|
||||
model.graph.node.insert(idx + cast_count, cast_node)
|
||||
cast_count += 1
|
||||
|
||||
print(f" Added Cast node before {gather_node.name or 'unnamed'}")
|
||||
|
||||
print(f"Added {cast_count} Cast operations before Gather nodes")
|
||||
|
||||
# Fix Concat operations that might have dtype mismatches
|
||||
# Cast all int64 inputs back to float32 before Concat
|
||||
print("\nFixing Concat operations with dtype mismatches...")
|
||||
concat_cast_count = 0
|
||||
|
||||
concat_nodes = []
|
||||
for idx, node in enumerate(model.graph.node):
|
||||
if node.op_type == "Concat":
|
||||
concat_nodes.append((idx, node))
|
||||
|
||||
print(f"Found {len(concat_nodes)} Concat operations")
|
||||
|
||||
for idx, concat_node in concat_nodes:
|
||||
# For each Concat input that might be int64, cast it back to float32
|
||||
new_inputs = []
|
||||
for i, input_name in enumerate(concat_node.input):
|
||||
# Check if this input came from a Cast operation (has "_int64_cast" in name)
|
||||
if "_int64_cast" in input_name:
|
||||
# This was cast to int64 for Gather, need to cast back to float for Concat
|
||||
cast_output_name = f"{input_name}_back_to_float32"
|
||||
|
||||
cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=[input_name],
|
||||
outputs=[cast_output_name],
|
||||
to=TensorProto.FLOAT,
|
||||
name=f"cast_{input_name}_back_to_float",
|
||||
)
|
||||
|
||||
# Insert cast node before concat
|
||||
model.graph.node.insert(idx + concat_cast_count, cast_node)
|
||||
concat_cast_count += 1
|
||||
|
||||
new_inputs.append(cast_output_name)
|
||||
print(f" Adding Cast int64→float32 before Concat {concat_node.name or 'unnamed'} input {i}")
|
||||
else:
|
||||
new_inputs.append(input_name)
|
||||
|
||||
# Update concat inputs
|
||||
if new_inputs != list(concat_node.input):
|
||||
del concat_node.input[:]
|
||||
concat_node.input.extend(new_inputs)
|
||||
|
||||
print(f"Added {concat_cast_count} Cast operations before Concat nodes")
|
||||
|
||||
# Make prediction_length dimension dynamic
|
||||
if make_dynamic:
|
||||
model = make_prediction_length_dynamic(model)
|
||||
|
||||
# Validate and save
|
||||
print("\nValidating fixed model...")
|
||||
try:
|
||||
onnx.checker.check_model(model)
|
||||
print("✓ Model validation passed!")
|
||||
except Exception as e:
|
||||
print(f"⚠ Validation warnings: {e}")
|
||||
print(" Attempting to save anyway...")
|
||||
|
||||
print(f"\nSaving fixed model to {output_path}")
|
||||
onnx.save(model, output_path)
|
||||
print("✓ Model saved successfully!")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Fix ONNX model type issues")
|
||||
parser.add_argument("input", help="Input ONNX model path")
|
||||
parser.add_argument("output", help="Output ONNX model path")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
fix_gather_indices(args.input, args.output)
|
||||
print("\n✓ Model fixed successfully!")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
253
scripts/onnx/quantize_chronos2.py
Normal file
253
scripts/onnx/quantize_chronos2.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quantize Chronos-2 ONNX model to reduce size and improve inference speed.
|
||||
|
||||
This script quantizes the ONNX model from FP32 to INT8, reducing model size
|
||||
by approximately 75% while maintaining good accuracy.
|
||||
|
||||
Usage:
|
||||
python quantize_chronos2.py \
|
||||
--input chronos2-onnx/model.onnx \
|
||||
--output chronos2-onnx/model_quantized.onnx \
|
||||
--mode dynamic
|
||||
|
||||
Quantization Modes:
|
||||
- dynamic: Dynamic quantization (fastest, best compatibility)
|
||||
- static: Static quantization (requires calibration data, best accuracy)
|
||||
- qat: Quantization-aware training (requires retraining)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dynamic_quantization(model_path: str, output_path: str):
|
||||
"""
|
||||
Apply dynamic quantization to the ONNX model.
|
||||
|
||||
Dynamic quantization converts weights to INT8 at export time and
|
||||
activations to INT8 dynamically at runtime.
|
||||
|
||||
Pros:
|
||||
- No calibration data needed
|
||||
- 4x smaller model size
|
||||
- Faster inference on CPU
|
||||
- Good accuracy (typically <1% loss)
|
||||
|
||||
Cons:
|
||||
- Activations still computed in FP32 then converted
|
||||
- Less speedup than static quantization
|
||||
"""
|
||||
from onnxruntime.quantization import quantize_dynamic, QuantType
|
||||
|
||||
logger.info(f"Loading model from {model_path}")
|
||||
|
||||
logger.info("Applying dynamic quantization...")
|
||||
logger.info(" - Weight type: INT8")
|
||||
logger.info(" - Activation type: INT8 (dynamic)")
|
||||
|
||||
quantize_dynamic(
|
||||
model_input=model_path,
|
||||
model_output=output_path,
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
logger.info(f"Quantized model saved to {output_path}")
|
||||
|
||||
|
||||
def static_quantization(model_path: str, output_path: str, calibration_data_path: str = None):
|
||||
"""
|
||||
Apply static quantization to the ONNX model.
|
||||
|
||||
Static quantization requires calibration data to determine optimal
|
||||
quantization parameters for both weights and activations.
|
||||
|
||||
Pros:
|
||||
- Best inference speed
|
||||
- Smallest model size
|
||||
- Activations also quantized
|
||||
|
||||
Cons:
|
||||
- Requires representative calibration data
|
||||
- More complex setup
|
||||
- Potential accuracy loss if calibration data not representative
|
||||
"""
|
||||
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader
|
||||
|
||||
logger.info(f"Loading model from {model_path}")
|
||||
|
||||
# Create calibration data reader
|
||||
if calibration_data_path:
|
||||
logger.info(f"Loading calibration data from {calibration_data_path}")
|
||||
# Custom calibration data reader would go here
|
||||
raise NotImplementedError("Custom calibration data reader not implemented yet")
|
||||
else:
|
||||
logger.info("Generating synthetic calibration data...")
|
||||
|
||||
class SyntheticCalibrationDataReader(CalibrationDataReader):
|
||||
def __init__(self, num_samples=100):
|
||||
self.num_samples = num_samples
|
||||
self.current_sample = 0
|
||||
self.batch_size = 1
|
||||
self.context_length = 512
|
||||
|
||||
def get_next(self):
|
||||
if self.current_sample >= self.num_samples:
|
||||
return None
|
||||
|
||||
# Generate synthetic time series data
|
||||
context = np.random.randn(self.batch_size, self.context_length).astype(np.float32)
|
||||
group_ids = np.array([0], dtype=np.int64)
|
||||
attention_mask = np.ones((self.batch_size, self.context_length), dtype=np.float32)
|
||||
|
||||
self.current_sample += 1
|
||||
|
||||
return {
|
||||
"context": context,
|
||||
"group_ids": group_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
calibration_data_reader = SyntheticCalibrationDataReader()
|
||||
|
||||
logger.info("Applying static quantization...")
|
||||
logger.info(" - Weight type: INT8")
|
||||
logger.info(" - Activation type: INT8 (static)")
|
||||
logger.info(" - Calibration samples: 100")
|
||||
|
||||
quantize_static(
|
||||
model_input=model_path,
|
||||
model_output=output_path,
|
||||
calibration_data_reader=calibration_data_reader,
|
||||
quant_format=QuantType.QInt8,
|
||||
)
|
||||
|
||||
logger.info(f"Quantized model saved to {output_path}")
|
||||
|
||||
|
||||
def compare_models(original_path: str, quantized_path: str):
|
||||
"""Compare original and quantized model sizes."""
|
||||
|
||||
original_size = Path(original_path).stat().st_size / (1024**2) # MB
|
||||
quantized_size = Path(quantized_path).stat().st_size / (1024**2) # MB
|
||||
|
||||
reduction = (1 - quantized_size / original_size) * 100
|
||||
|
||||
logger.info(f"\n{'=' * 60}")
|
||||
logger.info("Model Size Comparison:")
|
||||
logger.info(f" Original: {original_size:.1f} MB")
|
||||
logger.info(f" Quantized: {quantized_size:.1f} MB")
|
||||
logger.info(f" Reduction: {reduction:.1f}%")
|
||||
logger.info(f"{'=' * 60}\n")
|
||||
|
||||
|
||||
def validate_quantized_model(model_path: str):
|
||||
"""Validate the quantized model can be loaded and run."""
|
||||
|
||||
logger.info("Validating quantized model...")
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
|
||||
# Load model
|
||||
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
||||
|
||||
# Create test input
|
||||
batch_size = 1
|
||||
context_length = 256
|
||||
|
||||
inputs = {
|
||||
"context": np.random.randn(batch_size, context_length).astype(np.float32),
|
||||
"group_ids": np.array([0], dtype=np.int64),
|
||||
"attention_mask": np.ones((batch_size, context_length), dtype=np.float32),
|
||||
}
|
||||
|
||||
# Run inference
|
||||
logger.info(" Running test inference...")
|
||||
outputs = session.run(None, inputs)
|
||||
|
||||
logger.info(" ✓ Inference successful!")
|
||||
logger.info(f" Output shape: {outputs[0].shape}")
|
||||
logger.info(f" Output dtype: {outputs[0].dtype}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ✗ Validation failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Quantize Chronos-2 ONNX model")
|
||||
parser.add_argument("--input", type=str, default="chronos2-onnx/model.onnx", help="Input ONNX model path")
|
||||
parser.add_argument(
|
||||
"--output", type=str, default="chronos2-onnx/model_quantized.onnx", help="Output quantized model path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="dynamic",
|
||||
choices=["dynamic", "static"],
|
||||
help="Quantization mode (dynamic or static)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--calibration_data", type=str, default=None, help="Path to calibration data (for static quantization)"
|
||||
)
|
||||
parser.add_argument("--validate", action="store_true", help="Validate quantized model after export")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Chronos-2 ONNX Model Quantization")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Check if onnxruntime is installed
|
||||
try:
|
||||
import onnxruntime
|
||||
|
||||
logger.info(f"ONNX Runtime version: {onnxruntime.__version__}")
|
||||
except ImportError:
|
||||
logger.error("onnxruntime not installed. Install with: pip install onnxruntime")
|
||||
return 1
|
||||
|
||||
# Run quantization
|
||||
try:
|
||||
if args.mode == "dynamic":
|
||||
dynamic_quantization(args.input, args.output)
|
||||
elif args.mode == "static":
|
||||
static_quantization(args.input, args.output, args.calibration_data)
|
||||
|
||||
# Compare sizes
|
||||
compare_models(args.input, args.output)
|
||||
|
||||
# Validate if requested
|
||||
if args.validate:
|
||||
if validate_quantized_model(args.output):
|
||||
logger.info("✓ Quantization completed successfully!")
|
||||
return 0
|
||||
else:
|
||||
logger.warning("⚠ Quantization completed but validation failed")
|
||||
return 1
|
||||
else:
|
||||
logger.info("✓ Quantization completed successfully!")
|
||||
logger.info(" (Use --validate to test the quantized model)")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Quantization failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(main())
|
||||
Loading…
Reference in a new issue