This commit is contained in:
li-jinpeng 2026-04-21 13:24:10 -04:00 committed by GitHub
commit 10e2dee6d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 65 additions and 8 deletions

View file

@ -5,6 +5,7 @@
from dataclasses import dataclass
import os
import torch
from einops import rearrange
from torch import nn
@ -351,19 +352,50 @@ class GroupSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | tuple, output_attentions: bool = False
) -> AttentionOutput:
# flip time and batch axes because attention operates along dim=-2
hidden_states = rearrange(hidden_states, "batch time d -> time batch d")
normed_hidden_states = self.layer_norm(hidden_states)
attention_output: AttentionOutput = self.self_attention(
normed_hidden_states, mask=attention_mask, output_attentions=output_attentions
)
hidden_states = hidden_states + self.dropout(attention_output[0])
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = attention_mask
fast_normed_hidden_states = torch.zeros(
(group_num * normed_hidden_states.size(0), max_group_len, normed_hidden_states.size(2)),
dtype=normed_hidden_states.dtype,
device=normed_hidden_states.device,
)
for i in range(group_num):
start_index = group_start_index[i].item()
group_len = counts[i].item()
end_index = start_index + group_len
fast_normed_hidden_states[i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :] = normed_hidden_states[
:, start_index:end_index, :
]
fast_attention_output: AttentionOutput = self.self_attention(
fast_normed_hidden_states, mask=flast_group_time_mask, output_attentions=output_attentions
)
attn_weights = fast_attention_output.attn_weights
attention_output = torch.empty_like(normed_hidden_states)
for i in range(group_num):
start_index = group_start_index[i].item()
group_len = counts[i].item()
end_index = start_index + group_len
attention_output[:, start_index:end_index, :] = fast_attention_output[0][
i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :
]
else:
attention_output: AttentionOutput = self.self_attention(
normed_hidden_states, mask=attention_mask, output_attentions=output_attentions
)
attn_weights = attention_output.attn_weights
attention_output = attention_output[0]
hidden_states = hidden_states + self.dropout(attention_output)
# flip time and batch axes back to their original position
hidden_states = rearrange(hidden_states, "time batch d -> batch time d")
return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights)
return AttentionOutput(hidden_states=hidden_states, attn_weights=attn_weights)
class ResidualBlock(nn.Module):

View file

@ -4,6 +4,7 @@
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
import copy
import os
from dataclasses import dataclass
from typing import cast
@ -51,7 +52,7 @@ class Chronos2EncoderBlock(nn.Module):
*,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
group_time_mask: torch.Tensor,
group_time_mask: torch.Tensor | tuple,
output_attentions: bool = False,
) -> Chronos2EncoderBlockOutput:
# apply time attention
@ -129,6 +130,24 @@ class Chronos2Encoder(nn.Module):
group_time_mask = rearrange(group_time_mask, "q b t -> t 1 q b")
group_time_mask = (1.0 - group_time_mask) * torch.finfo(floating_type).min
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
# pad group_time_mask to max group size for less useless computation in self-attention
unique_groups, counts = torch.unique(group_ids, return_counts=True)
max_group_len = counts.max().item()
group_num = unique_groups.size(0)
fast_group_time_mask = torch.full(
(group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), torch.finfo(floating_type).min, device=group_time_mask.device
)
group_start_index = counts.cumsum(dim=0) - counts
for i in range(group_num):
group_len = counts[i].item()
start_index = group_start_index[i].item()
end_index = start_index + group_len
fast_group_time_mask[
i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), :, :group_len, :group_len
] = group_time_mask[:, :, start_index : end_index, start_index : end_index]
return fast_group_time_mask, group_start_index, counts, max_group_len, group_num
return group_time_mask
def forward(
@ -152,7 +171,13 @@ class Chronos2Encoder(nn.Module):
extended_attention_mask = self._expand_and_invert_time_attention_mask(attention_mask, inputs_embeds.dtype)
# construct group time mask
group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype)
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = self._construct_and_invert_group_time_mask(
group_ids, attention_mask, inputs_embeds.dtype
)
group_time_mask = (flast_group_time_mask, group_start_index, counts, max_group_len, group_num)
else:
group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype)
all_time_self_attentions: tuple[torch.Tensor, ...] = ()
all_group_self_attentions: tuple[torch.Tensor, ...] = ()