Remove float32 casting for cumsum (#53)

*Description of changes:* This PR removes casting to `fp32` for the
`cumsum` operation and upgrades `mlx` to `~=0.10.0` which adds `bf16`
support for `cumsum`.

Related: https://github.com/ml-explore/mlx/issues/959

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
This commit is contained in:
Abdul Fatir 2024-04-12 20:41:12 +02:00 committed by GitHub
parent 159ea36f7f
commit 5242d986f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 13 additions and 9 deletions

View file

@ -7,7 +7,7 @@ dependencies = [
"torch~=2.1", # package was tested on 2.2
"transformers~=4.31",
"accelerate",
"mlx~=0.9.0"
"mlx~=0.10.0"
]
[project.optional-dependencies]

View file

@ -264,11 +264,6 @@ class OutputHead(nn.Module):
def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1):
assert min_tokens_to_keep <= logits.shape[-1]
logits_dtype = logits.dtype
# FIXME: The following is needed because mlx doesn't have the cumsum
# kernel for bfloat16. Once that is supported natively, this casting
# should be removed. @abdulfatir
logits = logits.astype(mx.float32)
sorted_indices = mx.argsort(logits, axis=-1)
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
cumulative_probs = mx.softmax(sorted_logits, axis=-1).cumsum(axis=-1, reverse=True)
@ -276,9 +271,7 @@ def apply_top_p(logits: mx.array, top_p: float, min_tokens_to_keep=1):
sorted_indices_to_remove[..., -min_tokens_to_keep:] = False
masked_sorted_logits = mx.where(sorted_indices_to_remove, -mx.inf, sorted_logits)
unsorted_indices = mx.argsort(sorted_indices, axis=-1)
return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1).astype(
logits_dtype
)
return mx.take_along_axis(masked_sorted_logits, unsorted_indices, axis=-1)
def sample(logits, top_k=1, top_p=1.0, temperature=1.0):

View file

@ -70,6 +70,17 @@ def test_pipeline_predict(dtype: str):
)
validate_array(samples, (1, 7, 65))
# test non-default inference params
samples = pipeline.predict(
context,
num_samples=12,
prediction_length=3,
top_p=0.7,
top_k=32,
temperature=0.9,
)
validate_array(samples, (4, 12, 3))
@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
def test_pipeline_embed(dtype: str):