mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Address PR comments
This commit is contained in:
parent
2907643528
commit
a8837cfd88
1 changed files with 5 additions and 12 deletions
|
|
@ -117,7 +117,7 @@ def from_dataframe(
|
|||
...
|
||||
|
||||
|
||||
def from_dict_list(
|
||||
def from_list_of_dicts(
|
||||
data: list[dict],
|
||||
prediction_length: int,
|
||||
use_target_encoding: bool = True,
|
||||
|
|
@ -256,13 +256,12 @@ def _target_encode(
|
|||
-----------
|
||||
- id_codes and cat_codes are non-negative integers in [0, n_items) and [0, n_categories)
|
||||
- future_id_codes (if provided) are valid item IDs that appear in id_codes
|
||||
- future_cat_codes may contain -1 for unseen categories (encoded as NaN)
|
||||
- future_cat_codes are non-negative integers in [0, n_categories)
|
||||
|
||||
Edge cases
|
||||
----------
|
||||
- NaN values in target are excluded from sum/count computations
|
||||
- Unseen (item, category) pairs get the item mean as fallback (via smoothing formula)
|
||||
- Completely unseen categories in future (cat_code=-1) get the item mean
|
||||
- Unseen (item, category) pairs naturally get item_mean via the smoothing formula
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -280,7 +279,6 @@ def _target_encode(
|
|||
Item ID for each future row, shape: (n_future_rows,). Optional.
|
||||
future_cat_codes
|
||||
Category codes for future rows, shape: (n_future_rows,). Optional.
|
||||
Use -1 for categories not seen in past (will be encoded as NaN).
|
||||
smooth
|
||||
Smoothing parameter. Higher values give more weight to item mean vs category mean.
|
||||
|
||||
|
|
@ -308,12 +306,7 @@ def _target_encode(
|
|||
|
||||
encoded_future = None
|
||||
if future_id_codes is not None and future_cat_codes is not None:
|
||||
valid_future = future_cat_codes >= 0
|
||||
future_combined = np.where(valid_future, future_id_codes * n_categories + future_cat_codes, 0)
|
||||
encoded_future = np.where(
|
||||
valid_future,
|
||||
lookup[future_combined],
|
||||
item_means[future_id_codes]
|
||||
).astype(np.float32)
|
||||
future_combined = future_id_codes * n_categories + future_cat_codes
|
||||
encoded_future = lookup[future_combined].astype(np.float32)
|
||||
|
||||
return encoded_past, encoded_future
|
||||
|
|
|
|||
Loading…
Reference in a new issue