mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Merge 9844366902 into 32111085d8
This commit is contained in:
commit
10e2dee6d5
2 changed files with 65 additions and 8 deletions
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import os
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
|
@ -351,19 +352,50 @@ class GroupSelfAttention(nn.Module):
|
|||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False
|
||||
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | tuple, output_attentions: bool = False
|
||||
) -> AttentionOutput:
|
||||
# flip time and batch axes because attention operates along dim=-2
|
||||
hidden_states = rearrange(hidden_states, "batch time d -> time batch d")
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output: AttentionOutput = self.self_attention(
|
||||
normed_hidden_states, mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = hidden_states + self.dropout(attention_output[0])
|
||||
|
||||
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
|
||||
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = attention_mask
|
||||
fast_normed_hidden_states = torch.zeros(
|
||||
(group_num * normed_hidden_states.size(0), max_group_len, normed_hidden_states.size(2)),
|
||||
dtype=normed_hidden_states.dtype,
|
||||
device=normed_hidden_states.device,
|
||||
)
|
||||
for i in range(group_num):
|
||||
start_index = group_start_index[i].item()
|
||||
group_len = counts[i].item()
|
||||
end_index = start_index + group_len
|
||||
fast_normed_hidden_states[i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :] = normed_hidden_states[
|
||||
:, start_index:end_index, :
|
||||
]
|
||||
fast_attention_output: AttentionOutput = self.self_attention(
|
||||
fast_normed_hidden_states, mask=flast_group_time_mask, output_attentions=output_attentions
|
||||
)
|
||||
attn_weights = fast_attention_output.attn_weights
|
||||
attention_output = torch.empty_like(normed_hidden_states)
|
||||
for i in range(group_num):
|
||||
start_index = group_start_index[i].item()
|
||||
group_len = counts[i].item()
|
||||
end_index = start_index + group_len
|
||||
attention_output[:, start_index:end_index, :] = fast_attention_output[0][
|
||||
i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :
|
||||
]
|
||||
else:
|
||||
attention_output: AttentionOutput = self.self_attention(
|
||||
normed_hidden_states, mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
attn_weights = attention_output.attn_weights
|
||||
attention_output = attention_output[0]
|
||||
|
||||
hidden_states = hidden_states + self.dropout(attention_output)
|
||||
# flip time and batch axes back to their original position
|
||||
hidden_states = rearrange(hidden_states, "time batch d -> batch time d")
|
||||
|
||||
return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights)
|
||||
return AttentionOutput(hidden_states=hidden_states, attn_weights=attn_weights)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
|
||||
|
||||
import copy
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
|
|
@ -51,7 +52,7 @@ class Chronos2EncoderBlock(nn.Module):
|
|||
*,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
group_time_mask: torch.Tensor,
|
||||
group_time_mask: torch.Tensor | tuple,
|
||||
output_attentions: bool = False,
|
||||
) -> Chronos2EncoderBlockOutput:
|
||||
# apply time attention
|
||||
|
|
@ -129,6 +130,24 @@ class Chronos2Encoder(nn.Module):
|
|||
group_time_mask = rearrange(group_time_mask, "q b t -> t 1 q b")
|
||||
group_time_mask = (1.0 - group_time_mask) * torch.finfo(floating_type).min
|
||||
|
||||
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
|
||||
# pad group_time_mask to max group size for less useless computation in self-attention
|
||||
unique_groups, counts = torch.unique(group_ids, return_counts=True)
|
||||
max_group_len = counts.max().item()
|
||||
group_num = unique_groups.size(0)
|
||||
fast_group_time_mask = torch.full(
|
||||
(group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), torch.finfo(floating_type).min, device=group_time_mask.device
|
||||
)
|
||||
group_start_index = counts.cumsum(dim=0) - counts
|
||||
for i in range(group_num):
|
||||
group_len = counts[i].item()
|
||||
start_index = group_start_index[i].item()
|
||||
end_index = start_index + group_len
|
||||
fast_group_time_mask[
|
||||
i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), :, :group_len, :group_len
|
||||
] = group_time_mask[:, :, start_index : end_index, start_index : end_index]
|
||||
return fast_group_time_mask, group_start_index, counts, max_group_len, group_num
|
||||
|
||||
return group_time_mask
|
||||
|
||||
def forward(
|
||||
|
|
@ -152,7 +171,13 @@ class Chronos2Encoder(nn.Module):
|
|||
extended_attention_mask = self._expand_and_invert_time_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
# construct group time mask
|
||||
group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype)
|
||||
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
|
||||
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = self._construct_and_invert_group_time_mask(
|
||||
group_ids, attention_mask, inputs_embeds.dtype
|
||||
)
|
||||
group_time_mask = (flast_group_time_mask, group_start_index, counts, max_group_len, group_num)
|
||||
else:
|
||||
group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype)
|
||||
|
||||
all_time_self_attentions: tuple[torch.Tensor, ...] = ()
|
||||
all_group_self_attentions: tuple[torch.Tensor, ...] = ()
|
||||
|
|
|
|||
Loading…
Reference in a new issue