mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Remove float32 casting for cumsum (#53)
*Description of changes:* This PR removes casting to `fp32` for the `cumsum` operation and upgrades `mlx` to `~=0.10.0` which adds `bf16` support for `cumsum`. Related: https://github.com/ml-explore/mlx/issues/959 By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
This commit is contained in:
parent
159ea36f7f
commit
5242d986f4
3 changed files with 13 additions and 9 deletions
|
|
@ -7,7 +7,7 @@ dependencies = [
|
|||
"torch~=2.1", # package was tested on 2.2
|
||||
"transformers~=4.31",
|
||||
"accelerate",
|
||||
"mlx~=0.9.0"
|
||||
"mlx~=0.10.0"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -264,11 +264,6 @@ class OutputHead(nn.Module):
|
|||
|
||||
def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1):
|
||||
assert min_tokens_to_keep <= logits.shape[-1]
|
||||
logits_dtype = logits.dtype
|
||||
# FIXME: The following is needed because mlx doesn't have the cumsum
|
||||
# kernel for bfloat16. Once that is supported natively, this casting
|
||||
# should be removed. @abdulfatir
|
||||
logits = logits.astype(mx.float32)
|
||||
sorted_indices = mx.argsort(logits, axis=-1)
|
||||
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
|
||||
cumulative_probs = mx.softmax(sorted_logits, axis=-1).cumsum(axis=-1, reverse=True)
|
||||
|
|
@ -276,9 +271,7 @@ def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1):
|
|||
sorted_indices_to_remove[..., -min_tokens_to_keep:] = False
|
||||
masked_sorted_logits = mx.where(sorted_indices_to_remove, -mx.inf, sorted_logits)
|
||||
unsorted_indices = mx.argsort(sorted_indices, axis=-1)
|
||||
return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1).astype(
|
||||
logits_dtype
|
||||
)
|
||||
return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1)
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=1.0, temperature=1.0):
|
||||
|
|
|
|||
|
|
@ -70,6 +70,17 @@ def test_pipeline_predict(dtype: str):
|
|||
)
|
||||
validate_array(samples, (1, 7, 65))
|
||||
|
||||
# test non-default inference params
|
||||
samples = pipeline.predict(
|
||||
context,
|
||||
num_samples=12,
|
||||
prediction_length=3,
|
||||
top_p=0.7,
|
||||
top_k=32,
|
||||
temperature=0.9,
|
||||
)
|
||||
validate_array(samples, (4, 12, 3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
|
||||
def test_pipeline_embed(dtype: str):
|
||||
|
|
|
|||
Loading…
Reference in a new issue