diff --git a/pyproject.toml b/pyproject.toml index f6de45c..d25661d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ dependencies = [ "torch~=2.1", # package was tested on 2.2 "transformers~=4.31", "accelerate", - "mlx~=0.9.0" + "mlx~=0.10.0" ] [project.optional-dependencies] diff --git a/src/chronos_mlx/t5.py b/src/chronos_mlx/t5.py index 2ec7659..95ea4e9 100644 --- a/src/chronos_mlx/t5.py +++ b/src/chronos_mlx/t5.py @@ -264,11 +264,6 @@ class OutputHead(nn.Module): def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1): assert min_tokens_to_keep <= logits.shape[-1] - logits_dtype = logits.dtype - # FIXME: The following is needed because mlx doesn't have the cumsum - # kernel for bfloat16. Once that is supported natively, this casting - # should be removed. @abdulfatir - logits = logits.astype(mx.float32) sorted_indices = mx.argsort(logits, axis=-1) sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1) cumulative_probs = mx.softmax(sorted_logits, axis=-1).cumsum(axis=-1, reverse=True) @@ -276,9 +271,7 @@ def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1): sorted_indices_to_remove[..., -min_tokens_to_keep:] = False masked_sorted_logits = mx.where(sorted_indices_to_remove, -mx.inf, sorted_logits) unsorted_indices = mx.argsort(sorted_indices, axis=-1) - return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1).astype( - logits_dtype - ) + return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1) def sample(logits, top_k=1, top_p=1.0, temperature=1.0): diff --git a/test/test_chronos_mlx.py b/test/test_chronos_mlx.py index 4f718a5..603a18a 100644 --- a/test/test_chronos_mlx.py +++ b/test/test_chronos_mlx.py @@ -70,6 +70,17 @@ def test_pipeline_predict(dtype: str): ) validate_array(samples, (1, 7, 65)) + # test non-default inference params + samples = pipeline.predict( + context, + num_samples=12, + prediction_length=3, + top_p=0.7, + top_k=32, + temperature=0.9, + ) + validate_array(samples, (4, 12, 3)) + @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) def test_pipeline_embed(dtype: str):