mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
Align code formatting
Signed-off-by: xymli <xymli@tencent.com>
This commit is contained in:
parent
49edd9b876
commit
9844366902
2 changed files with 90 additions and 337 deletions
|
|
@ -28,35 +28,22 @@ class RoPE(nn.Module):
|
|||
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
|
||||
)
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||
self.inv_freq: torch.Tensor # type hint for type checker
|
||||
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self, x: torch.Tensor, position_ids: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
self.inv_freq.to(x.device)
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = (
|
||||
device_type
|
||||
if isinstance(device_type, str) and device_type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (
|
||||
inv_freq_expanded.float() @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
|
@ -71,11 +58,7 @@ class RoPE(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim: int = 1,
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
|
|
@ -147,9 +130,7 @@ class FeedForward(nn.Module):
|
|||
|
||||
assert not config.is_gated_act, "gated activations are unsupported"
|
||||
self.mlp: nn.Module = MLP(config)
|
||||
self.layer_norm = Chronos2LayerNorm(
|
||||
config.d_model, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -206,14 +187,10 @@ class MHA(nn.Module):
|
|||
attn_weights: [batch, n_heads, q_len, kv_len]
|
||||
"""
|
||||
# Compute attention weights (no scaling - this is the original Chronos-2 implementation)
|
||||
scores = torch.matmul(
|
||||
query_states, key_states.transpose(3, 2)
|
||||
) # "bnqd,bnkd->bnqk"
|
||||
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk"
|
||||
scores += mask
|
||||
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
|
@ -271,9 +248,7 @@ class MHA(nn.Module):
|
|||
- attn_weights : Attention weights if output_attentions=True
|
||||
"""
|
||||
if self.use_rope:
|
||||
assert (
|
||||
position_ids is not None
|
||||
), "position_ids must be provided when self.use_rope=True"
|
||||
assert position_ids is not None, "position_ids must be provided when self.use_rope=True"
|
||||
|
||||
# Force eager attention if output_attentions is True (only eager returns weights)
|
||||
attn_implementation = self.config._attn_implementation
|
||||
|
|
@ -284,23 +259,11 @@ class MHA(nn.Module):
|
|||
|
||||
def shape(states: torch.Tensor) -> torch.Tensor:
|
||||
"""(batch, seq_len, inner_dim) -> (batch, n_heads, seq_len, kv_proj_dim)"""
|
||||
return rearrange(
|
||||
states,
|
||||
"b s (h d) -> b h s d",
|
||||
h=self.n_heads,
|
||||
s=seq_length,
|
||||
d=self.kv_proj_dim,
|
||||
)
|
||||
return rearrange(states, "b s (h d) -> b h s d", h=self.n_heads, s=seq_length, d=self.kv_proj_dim)
|
||||
|
||||
def unshape(states: torch.Tensor) -> torch.Tensor:
|
||||
"""(batch, n_heads, seq_len, kv_proj_dim) -> (batch, seq_len, inner_dim)"""
|
||||
return rearrange(
|
||||
states,
|
||||
"b h s d -> b s (h d)",
|
||||
h=self.n_heads,
|
||||
s=seq_length,
|
||||
d=self.kv_proj_dim,
|
||||
)
|
||||
return rearrange(states, "b h s d -> b s (h d)", h=self.n_heads, s=seq_length, d=self.kv_proj_dim)
|
||||
|
||||
# Construct query states
|
||||
query_states = shape(self.q(hidden_states))
|
||||
|
|
@ -315,36 +278,25 @@ class MHA(nn.Module):
|
|||
value_states = shape(self.v(hidden_states))
|
||||
if self.use_rope:
|
||||
cos, sin = self.rope_embed(value_states, position_ids)
|
||||
query_states, key_states = RoPE.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if attn_implementation == "sdpa":
|
||||
attn_output, attn_weights = self._sdpa_attention(
|
||||
query_states, key_states, value_states, mask
|
||||
)
|
||||
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)
|
||||
else: # eager
|
||||
attn_output, attn_weights = self._eager_attention(
|
||||
query_states, key_states, value_states, mask
|
||||
)
|
||||
attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask)
|
||||
|
||||
# Project attention output
|
||||
attn_output = unshape(attn_output)
|
||||
attn_output = self.o(attn_output)
|
||||
|
||||
return AttentionOutput(
|
||||
hidden_states=attn_output,
|
||||
attn_weights=attn_weights if output_attentions else None,
|
||||
)
|
||||
return AttentionOutput(hidden_states=attn_output, attn_weights=attn_weights if output_attentions else None)
|
||||
|
||||
|
||||
class TimeSelfAttention(nn.Module):
|
||||
def __init__(self, config: Chronos2CoreConfig):
|
||||
super().__init__()
|
||||
self.self_attention = MHA(config, use_rope=True)
|
||||
self.layer_norm = Chronos2LayerNorm(
|
||||
config.d_model, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(
|
||||
|
|
@ -356,25 +308,18 @@ class TimeSelfAttention(nn.Module):
|
|||
) -> AttentionOutput:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output: AttentionOutput = self.self_attention(
|
||||
normed_hidden_states,
|
||||
position_ids=position_ids,
|
||||
mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
normed_hidden_states, position_ids=position_ids, mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = hidden_states + self.dropout(attention_output[0])
|
||||
|
||||
return AttentionOutput(
|
||||
hidden_states=hidden_states, attn_weights=attention_output.attn_weights
|
||||
)
|
||||
return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights)
|
||||
|
||||
|
||||
class TimeCrossAttention(nn.Module):
|
||||
def __init__(self, config: Chronos2CoreConfig):
|
||||
super().__init__()
|
||||
self.cross_attention = MHA(config, use_rope=False)
|
||||
self.layer_norm = Chronos2LayerNorm(
|
||||
config.d_model, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(
|
||||
|
|
@ -393,9 +338,7 @@ class TimeCrossAttention(nn.Module):
|
|||
)
|
||||
hidden_states = hidden_states + self.dropout(attention_output[0])
|
||||
|
||||
return AttentionOutput(
|
||||
hidden_states=hidden_states, attn_weights=attention_output.attn_weights
|
||||
)
|
||||
return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights)
|
||||
|
||||
|
||||
class GroupSelfAttention(nn.Module):
|
||||
|
|
@ -405,35 +348,20 @@ class GroupSelfAttention(nn.Module):
|
|||
super().__init__()
|
||||
# we don't use RoPE here because there's no natural ordering along the batch axis
|
||||
self.self_attention = MHA(config, use_rope=False)
|
||||
self.layer_norm = Chronos2LayerNorm(
|
||||
config.d_model, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | tuple,
|
||||
output_attentions: bool = False,
|
||||
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | tuple, output_attentions: bool = False
|
||||
) -> AttentionOutput:
|
||||
# flip time and batch axes because attention operates along dim=-2
|
||||
hidden_states = rearrange(hidden_states, "batch time d -> time batch d")
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
|
||||
(
|
||||
flast_group_time_mask,
|
||||
group_start_index,
|
||||
counts,
|
||||
max_group_len,
|
||||
group_num,
|
||||
) = attention_mask
|
||||
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = attention_mask
|
||||
fast_normed_hidden_states = torch.zeros(
|
||||
(
|
||||
group_num * normed_hidden_states.size(0),
|
||||
max_group_len,
|
||||
normed_hidden_states.size(2),
|
||||
),
|
||||
(group_num * normed_hidden_states.size(0), max_group_len, normed_hidden_states.size(2)),
|
||||
dtype=normed_hidden_states.dtype,
|
||||
device=normed_hidden_states.device,
|
||||
)
|
||||
|
|
@ -441,17 +369,11 @@ class GroupSelfAttention(nn.Module):
|
|||
start_index = group_start_index[i].item()
|
||||
group_len = counts[i].item()
|
||||
end_index = start_index + group_len
|
||||
fast_normed_hidden_states[
|
||||
i
|
||||
* normed_hidden_states.size(0) : (i + 1)
|
||||
* normed_hidden_states.size(0),
|
||||
:group_len,
|
||||
:,
|
||||
] = normed_hidden_states[:, start_index:end_index, :]
|
||||
fast_normed_hidden_states[i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :] = normed_hidden_states[
|
||||
:, start_index:end_index, :
|
||||
]
|
||||
fast_attention_output: AttentionOutput = self.self_attention(
|
||||
fast_normed_hidden_states,
|
||||
mask=flast_group_time_mask,
|
||||
output_attentions=output_attentions,
|
||||
fast_normed_hidden_states, mask=flast_group_time_mask, output_attentions=output_attentions
|
||||
)
|
||||
attn_weights = fast_attention_output.attn_weights
|
||||
attention_output = torch.empty_like(normed_hidden_states)
|
||||
|
|
@ -459,20 +381,12 @@ class GroupSelfAttention(nn.Module):
|
|||
start_index = group_start_index[i].item()
|
||||
group_len = counts[i].item()
|
||||
end_index = start_index + group_len
|
||||
attention_output[:, start_index:end_index, :] = fast_attention_output[
|
||||
0
|
||||
][
|
||||
i
|
||||
* normed_hidden_states.size(0) : (i + 1)
|
||||
* normed_hidden_states.size(0),
|
||||
:group_len,
|
||||
:,
|
||||
attention_output[:, start_index:end_index, :] = fast_attention_output[0][
|
||||
i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :
|
||||
]
|
||||
else:
|
||||
attention_output: AttentionOutput = self.self_attention(
|
||||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
normed_hidden_states, mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
attn_weights = attention_output.attn_weights
|
||||
attention_output = attention_output[0]
|
||||
|
|
|
|||
|
|
@ -66,9 +66,7 @@ class Chronos2EncoderBlock(nn.Module):
|
|||
|
||||
# apply group attention
|
||||
group_self_attn_outputs: AttentionOutput = self.layer[1](
|
||||
hidden_states,
|
||||
attention_mask=group_time_mask,
|
||||
output_attentions=output_attentions,
|
||||
hidden_states, attention_mask=group_time_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = group_self_attn_outputs[0]
|
||||
|
||||
|
|
@ -94,21 +92,15 @@ class Chronos2Encoder(nn.Module):
|
|||
super().__init__()
|
||||
assert not config.is_decoder
|
||||
|
||||
self.block = nn.ModuleList(
|
||||
[Chronos2EncoderBlock(config) for i in range(config.num_layers)]
|
||||
)
|
||||
self.final_layer_norm = Chronos2LayerNorm(
|
||||
config.d_model, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.block = nn.ModuleList([Chronos2EncoderBlock(config) for i in range(config.num_layers)])
|
||||
self.final_layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
@staticmethod
|
||||
def _expand_and_invert_time_attention_mask(
|
||||
attention_mask: torch.Tensor, floating_type: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
attention_mask.ndim == 2
|
||||
), "attention_mask must have shape (batch, seq_len)"
|
||||
assert attention_mask.ndim == 2, "attention_mask must have shape (batch, seq_len)"
|
||||
|
||||
# Add new dims for attention heads and q_len
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
|
@ -120,9 +112,7 @@ class Chronos2Encoder(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _construct_and_invert_group_time_mask(
|
||||
group_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
floating_type: torch.dtype,
|
||||
group_ids: torch.Tensor, attention_mask: torch.Tensor, floating_type: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
# construct group_mask (batch, batch) from group ids
|
||||
# a cell is True if both row and col had the same group id
|
||||
|
|
@ -146,9 +136,7 @@ class Chronos2Encoder(nn.Module):
|
|||
max_group_len = counts.max().item()
|
||||
group_num = unique_groups.size(0)
|
||||
fast_group_time_mask = torch.full(
|
||||
(group_num * group_time_mask.size(0), 1, max_group_len, max_group_len),
|
||||
torch.finfo(floating_type).min,
|
||||
device=group_time_mask.device,
|
||||
(group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), torch.finfo(floating_type).min, device=group_time_mask.device
|
||||
)
|
||||
group_start_index = counts.cumsum(dim=0) - counts
|
||||
for i in range(group_num):
|
||||
|
|
@ -156,18 +144,9 @@ class Chronos2Encoder(nn.Module):
|
|||
start_index = group_start_index[i].item()
|
||||
end_index = start_index + group_len
|
||||
fast_group_time_mask[
|
||||
i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0),
|
||||
:,
|
||||
:group_len,
|
||||
:group_len,
|
||||
] = group_time_mask[:, :, start_index:end_index, start_index:end_index]
|
||||
return (
|
||||
fast_group_time_mask,
|
||||
group_start_index,
|
||||
counts,
|
||||
max_group_len,
|
||||
group_num,
|
||||
)
|
||||
i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), :, :group_len, :group_len
|
||||
] = group_time_mask[:, :, start_index : end_index, start_index : end_index]
|
||||
return fast_group_time_mask, group_start_index, counts, max_group_len, group_num
|
||||
|
||||
return group_time_mask
|
||||
|
||||
|
|
@ -183,45 +162,22 @@ class Chronos2Encoder(nn.Module):
|
|||
batch_size, seq_length = inputs_embeds.size()[:-1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
0, seq_length, dtype=torch.long, device=inputs_embeds.device
|
||||
).unsqueeze(0)
|
||||
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=inputs_embeds.device).unsqueeze(0)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
batch_size,
|
||||
seq_length,
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
||||
|
||||
# make the time attention mask broadcastable to attention scores (batch, n_heads, q_len, kv_len) and invert
|
||||
extended_attention_mask = self._expand_and_invert_time_attention_mask(
|
||||
attention_mask, inputs_embeds.dtype
|
||||
)
|
||||
extended_attention_mask = self._expand_and_invert_time_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
# construct group time mask
|
||||
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
|
||||
(
|
||||
flast_group_time_mask,
|
||||
group_start_index,
|
||||
counts,
|
||||
max_group_len,
|
||||
group_num,
|
||||
) = self._construct_and_invert_group_time_mask(
|
||||
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = self._construct_and_invert_group_time_mask(
|
||||
group_ids, attention_mask, inputs_embeds.dtype
|
||||
)
|
||||
group_time_mask = (
|
||||
flast_group_time_mask,
|
||||
group_start_index,
|
||||
counts,
|
||||
max_group_len,
|
||||
group_num,
|
||||
)
|
||||
group_time_mask = (flast_group_time_mask, group_start_index, counts, max_group_len, group_num)
|
||||
else:
|
||||
group_time_mask = self._construct_and_invert_group_time_mask(
|
||||
group_ids, attention_mask, inputs_embeds.dtype
|
||||
)
|
||||
group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype)
|
||||
|
||||
all_time_self_attentions: tuple[torch.Tensor, ...] = ()
|
||||
all_group_self_attentions: tuple[torch.Tensor, ...] = ()
|
||||
|
|
@ -243,14 +199,8 @@ class Chronos2Encoder(nn.Module):
|
|||
assert layer_outputs.time_self_attn_weights is not None
|
||||
assert layer_outputs.group_self_attn_weights is not None
|
||||
|
||||
all_time_self_attentions = (
|
||||
*all_time_self_attentions,
|
||||
layer_outputs.time_self_attn_weights,
|
||||
)
|
||||
all_group_self_attentions = (
|
||||
*all_group_self_attentions,
|
||||
layer_outputs.group_self_attn_weights,
|
||||
)
|
||||
all_time_self_attentions = (*all_time_self_attentions, layer_outputs.time_self_attn_weights)
|
||||
all_group_self_attentions = (*all_group_self_attentions, layer_outputs.group_self_attn_weights)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
|
@ -288,10 +238,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
)
|
||||
self.chronos_config = Chronos2ForecastingConfig(**config.chronos_config)
|
||||
|
||||
assert (
|
||||
self.chronos_config.input_patch_size
|
||||
== self.chronos_config.output_patch_size
|
||||
), (
|
||||
assert self.chronos_config.input_patch_size == self.chronos_config.output_patch_size, (
|
||||
"input_patch_size and output_patch_size sizes must be equal, "
|
||||
f"but found {self.chronos_config.input_patch_size} and {self.chronos_config.output_patch_size}"
|
||||
)
|
||||
|
|
@ -315,8 +262,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
|
||||
# patching layer
|
||||
self.patch = Patch(
|
||||
patch_size=self.chronos_config.input_patch_size,
|
||||
patch_stride=self.chronos_config.input_patch_stride,
|
||||
patch_size=self.chronos_config.input_patch_size, patch_stride=self.chronos_config.input_patch_stride
|
||||
)
|
||||
|
||||
# instance normalization, also referred to as "scaling" in Chronos and GluonTS
|
||||
|
|
@ -352,14 +298,10 @@ class Chronos2Model(PreTrainedModel):
|
|||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
|
||||
module.wi.weight.data.normal_(
|
||||
mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
|
||||
)
|
||||
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.wi, "bias") and module.wi.bias is not None:
|
||||
module.wi.bias.data.zero_()
|
||||
module.wo.weight.data.normal_(
|
||||
mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
|
||||
)
|
||||
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
||||
module.wo.bias.data.zero_()
|
||||
elif isinstance(module, MHA):
|
||||
|
|
@ -368,14 +310,10 @@ class Chronos2Model(PreTrainedModel):
|
|||
d_model = self.config.d_model
|
||||
kv_proj_dim = self.config.d_kv
|
||||
n_heads = self.config.num_heads
|
||||
module.q.weight.data.normal_(
|
||||
mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5)
|
||||
)
|
||||
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
|
||||
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
||||
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
||||
module.o.weight.data.normal_(
|
||||
mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)
|
||||
)
|
||||
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
|
||||
elif isinstance(module, (Chronos2Model)):
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
elif isinstance(module, ResidualBlock):
|
||||
|
|
@ -383,29 +321,20 @@ class Chronos2Model(PreTrainedModel):
|
|||
mean=0.0,
|
||||
std=factor * (module.hidden_layer.weight.size(-1) ** -0.5),
|
||||
)
|
||||
if (
|
||||
hasattr(module.hidden_layer, "bias")
|
||||
and module.hidden_layer.bias is not None
|
||||
):
|
||||
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
|
||||
module.hidden_layer.bias.data.zero_()
|
||||
|
||||
module.residual_layer.weight.data.normal_(
|
||||
mean=0.0,
|
||||
std=factor * (module.residual_layer.weight.size(-1) ** -0.5),
|
||||
)
|
||||
if (
|
||||
hasattr(module.residual_layer, "bias")
|
||||
and module.residual_layer.bias is not None
|
||||
):
|
||||
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
|
||||
module.residual_layer.bias.data.zero_()
|
||||
|
||||
module.output_layer.weight.data.normal_(
|
||||
mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5)
|
||||
)
|
||||
if (
|
||||
hasattr(module.output_layer, "bias")
|
||||
and module.output_layer.bias is not None
|
||||
):
|
||||
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
|
||||
module.output_layer.bias.data.zero_()
|
||||
|
||||
def _validate_input(
|
||||
|
|
@ -421,18 +350,11 @@ class Chronos2Model(PreTrainedModel):
|
|||
):
|
||||
output_patch_size = self.chronos_config.output_patch_size
|
||||
if context.ndim != 2:
|
||||
raise ValueError(
|
||||
f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}"
|
||||
)
|
||||
raise ValueError(f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}")
|
||||
if context_mask is not None and context_mask.shape != context.shape:
|
||||
raise ValueError(
|
||||
f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}"
|
||||
)
|
||||
raise ValueError(f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}")
|
||||
if future_covariates is not None:
|
||||
if (
|
||||
future_covariates.shape[0] != context.shape[0]
|
||||
or future_covariates.ndim != 2
|
||||
):
|
||||
if future_covariates.shape[0] != context.shape[0] or future_covariates.ndim != 2:
|
||||
raise ValueError(
|
||||
f"future_covariates must have shape (batch_size={context.shape[0]}, future_length), found: {tuple(future_covariates.shape)}"
|
||||
)
|
||||
|
|
@ -441,27 +363,20 @@ class Chronos2Model(PreTrainedModel):
|
|||
f"{num_output_patches=} must be large enough to accommodate the length of future_covariates, "
|
||||
f"found: {future_covariates.shape[-1]} > {num_output_patches} * {output_patch_size}"
|
||||
)
|
||||
if (
|
||||
future_target is not None
|
||||
and future_target.shape != future_covariates.shape
|
||||
):
|
||||
if future_target is not None and future_target.shape != future_covariates.shape:
|
||||
raise ValueError(
|
||||
f"future_target must have the same shape as future_covariates, found: {tuple(future_target.shape)} and {tuple(future_covariates.shape)}"
|
||||
)
|
||||
if future_covariates_mask is not None:
|
||||
if future_covariates is None:
|
||||
raise ValueError(
|
||||
"future_covariates must be provided if future_covariates_mask is provided"
|
||||
)
|
||||
raise ValueError("future_covariates must be provided if future_covariates_mask is provided")
|
||||
if future_covariates_mask.shape != future_covariates.shape:
|
||||
raise ValueError(
|
||||
f"future_covariates_mask must have the same shape as future_covariates, "
|
||||
f"found: {tuple(future_covariates_mask.shape)} and {tuple(future_covariates.shape)}"
|
||||
)
|
||||
if group_ids is not None and group_ids.shape != (context.shape[0],):
|
||||
raise ValueError(
|
||||
f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}"
|
||||
)
|
||||
raise ValueError(f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}")
|
||||
if future_target is not None:
|
||||
if future_target.shape[0] != context.shape[0] or future_target.ndim != 2:
|
||||
raise ValueError(
|
||||
|
|
@ -474,9 +389,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
)
|
||||
if future_target_mask is not None:
|
||||
if future_target is None:
|
||||
raise ValueError(
|
||||
"future_target must be provided if future_target_mask is provided"
|
||||
)
|
||||
raise ValueError("future_target must be provided if future_target_mask is provided")
|
||||
if future_target_mask.shape != future_target.shape:
|
||||
raise ValueError(
|
||||
f"future_target_mask must have the same shape as future_target, found: {tuple(future_target_mask.shape)} and {tuple(future_target.shape)}"
|
||||
|
|
@ -515,12 +428,8 @@ class Chronos2Model(PreTrainedModel):
|
|||
|
||||
# context time encoding: every observation is assigned a sequential time index,
|
||||
# scaled by model's context length = [-C, -(C-1), ..., -1] / context_length
|
||||
final_context_length = (
|
||||
num_context_patches * self.chronos_config.input_patch_size
|
||||
)
|
||||
context_time_enc = torch.arange(
|
||||
start=-final_context_length, end=0, device=self.device, dtype=torch.float32
|
||||
)
|
||||
final_context_length = num_context_patches * self.chronos_config.input_patch_size
|
||||
context_time_enc = torch.arange(start=-final_context_length, end=0, device=self.device, dtype=torch.float32)
|
||||
context_time_enc = (
|
||||
repeat(
|
||||
context_time_enc,
|
||||
|
|
@ -534,9 +443,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
)
|
||||
|
||||
# concat time encoding, context and mask along the last (feature) dim
|
||||
patched_context = torch.cat(
|
||||
[context_time_enc, patched_context, patched_mask], dim=-1
|
||||
)
|
||||
patched_context = torch.cat([context_time_enc, patched_context, patched_mask], dim=-1)
|
||||
|
||||
return patched_context, attention_mask, loc_scale
|
||||
|
||||
|
|
@ -555,15 +462,9 @@ class Chronos2Model(PreTrainedModel):
|
|||
future_covariates = future_covariates.to(self.dtype)
|
||||
|
||||
if future_covariates_mask is None:
|
||||
future_covariates_mask = (
|
||||
torch.isnan(future_covariates)
|
||||
.logical_not()
|
||||
.to(future_covariates.dtype)
|
||||
)
|
||||
future_covariates_mask = torch.isnan(future_covariates).logical_not().to(future_covariates.dtype)
|
||||
|
||||
future_covariates = torch.where(
|
||||
future_covariates_mask > 0.0, future_covariates, 0.0
|
||||
)
|
||||
future_covariates = torch.where(future_covariates_mask > 0.0, future_covariates, 0.0)
|
||||
|
||||
if torch.isnan(future_covariates).any():
|
||||
raise ValueError(
|
||||
|
|
@ -575,58 +476,33 @@ class Chronos2Model(PreTrainedModel):
|
|||
if num_output_patches * output_patch_size > future_covariates.shape[-1]:
|
||||
padding_shape = (
|
||||
*future_covariates.shape[:-1],
|
||||
num_output_patches * output_patch_size
|
||||
- future_covariates.shape[-1],
|
||||
num_output_patches * output_patch_size - future_covariates.shape[-1],
|
||||
)
|
||||
future_covariates = torch.cat(
|
||||
[
|
||||
future_covariates,
|
||||
torch.zeros(padding_shape).to(future_covariates),
|
||||
],
|
||||
dim=-1,
|
||||
[future_covariates, torch.zeros(padding_shape).to(future_covariates)], dim=-1
|
||||
)
|
||||
future_covariates_mask = torch.cat(
|
||||
[
|
||||
future_covariates_mask,
|
||||
torch.zeros(padding_shape).to(future_covariates_mask),
|
||||
],
|
||||
dim=-1,
|
||||
[future_covariates_mask, torch.zeros(padding_shape).to(future_covariates_mask)], dim=-1
|
||||
)
|
||||
|
||||
patched_future_covariates = rearrange(
|
||||
future_covariates,
|
||||
"b (n p) -> b n p",
|
||||
n=num_output_patches,
|
||||
p=output_patch_size,
|
||||
future_covariates, "b (n p) -> b n p", n=num_output_patches, p=output_patch_size
|
||||
)
|
||||
patched_future_covariates_mask = rearrange(
|
||||
future_covariates_mask,
|
||||
"b (n p) -> b n p",
|
||||
n=num_output_patches,
|
||||
p=output_patch_size,
|
||||
future_covariates_mask, "b (n p) -> b n p", n=num_output_patches, p=output_patch_size
|
||||
)
|
||||
else:
|
||||
patched_future_covariates = torch.zeros(
|
||||
batch_size,
|
||||
num_output_patches,
|
||||
output_patch_size,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
batch_size, num_output_patches, output_patch_size, device=self.device, dtype=self.dtype
|
||||
)
|
||||
patched_future_covariates_mask = torch.zeros(
|
||||
batch_size,
|
||||
num_output_patches,
|
||||
output_patch_size,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
batch_size, num_output_patches, output_patch_size, device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
# future time encoding: every future timestep is assigned a sequential time index,
|
||||
# scaled by model's context length = [0, 1, ..., h-1] / context_length
|
||||
final_future_length = num_output_patches * output_patch_size
|
||||
future_time_enc = torch.arange(
|
||||
start=0, end=final_future_length, device=self.device, dtype=torch.float32
|
||||
)
|
||||
future_time_enc = torch.arange(start=0, end=final_future_length, device=self.device, dtype=torch.float32)
|
||||
future_time_enc = (
|
||||
repeat(
|
||||
future_time_enc,
|
||||
|
|
@ -640,12 +516,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
)
|
||||
|
||||
patched_future = torch.cat(
|
||||
[
|
||||
future_time_enc,
|
||||
patched_future_covariates,
|
||||
patched_future_covariates_mask,
|
||||
],
|
||||
dim=-1,
|
||||
[future_time_enc, patched_future_covariates, patched_future_covariates_mask], dim=-1
|
||||
)
|
||||
|
||||
return patched_future, patched_future_covariates_mask
|
||||
|
|
@ -661,10 +532,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
) -> torch.Tensor:
|
||||
batch_size = future_target.shape[0]
|
||||
output_patch_size = self.chronos_config.output_patch_size
|
||||
assert (
|
||||
quantile_preds.shape[0] == batch_size
|
||||
and quantile_preds.shape[-1] >= future_target.shape[-1]
|
||||
)
|
||||
assert quantile_preds.shape[0] == batch_size and quantile_preds.shape[-1] >= future_target.shape[-1]
|
||||
|
||||
# normalize target and mask
|
||||
future_target, _ = self.instance_norm(future_target, loc_scale)
|
||||
|
|
@ -679,22 +547,15 @@ class Chronos2Model(PreTrainedModel):
|
|||
|
||||
# pad target and target_mask if they are shorter than model's prediction
|
||||
if quantile_preds.shape[-1] > future_target.shape[-1]:
|
||||
padding_shape = (
|
||||
*future_target.shape[:-1],
|
||||
quantile_preds.shape[-1] - future_target.shape[-1],
|
||||
)
|
||||
future_target = torch.cat(
|
||||
[future_target, torch.zeros(padding_shape).to(future_target)], dim=-1
|
||||
)
|
||||
padding_shape = (*future_target.shape[:-1], quantile_preds.shape[-1] - future_target.shape[-1])
|
||||
future_target = torch.cat([future_target, torch.zeros(padding_shape).to(future_target)], dim=-1)
|
||||
future_target_mask = torch.cat(
|
||||
[future_target_mask, torch.zeros(padding_shape).to(future_target_mask)],
|
||||
dim=-1,
|
||||
[future_target_mask, torch.zeros(padding_shape).to(future_target_mask)], dim=-1
|
||||
)
|
||||
|
||||
quantiles = rearrange(self.quantiles, "num_quantiles -> 1 num_quantiles 1")
|
||||
quantile_loss = 2 * torch.abs(
|
||||
(future_target - quantile_preds)
|
||||
* ((future_target <= quantile_preds).float() - quantiles)
|
||||
(future_target - quantile_preds) * ((future_target <= quantile_preds).float() - quantiles)
|
||||
)
|
||||
inv_future_covariate_mask = 1 - rearrange(
|
||||
patched_future_covariates_mask,
|
||||
|
|
@ -744,17 +605,11 @@ class Chronos2Model(PreTrainedModel):
|
|||
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
|
||||
# append [REG] special token embedding, if needed
|
||||
if self.chronos_config.use_reg_token:
|
||||
reg_input_ids = torch.full(
|
||||
(batch_size, 1), self.config.reg_token_id, device=input_embeds.device
|
||||
)
|
||||
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
|
||||
reg_embeds = self.shared(reg_input_ids)
|
||||
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
attention_mask.to(self.dtype),
|
||||
torch.ones_like(reg_input_ids).to(self.dtype),
|
||||
],
|
||||
dim=-1,
|
||||
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
|
||||
)
|
||||
|
||||
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
|
||||
|
|
@ -764,9 +619,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
num_output_patches=num_output_patches,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
future_attention_mask = torch.ones(
|
||||
batch_size, num_output_patches, dtype=self.dtype, device=self.device
|
||||
)
|
||||
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
|
||||
|
||||
# get future embeddings of shape (batch, num_output_patches, d_model)
|
||||
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
|
||||
|
|
@ -785,12 +638,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
group_ids=group_ids,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
return (
|
||||
encoder_outputs,
|
||||
loc_scale,
|
||||
patched_future_covariates_mask,
|
||||
num_context_patches,
|
||||
)
|
||||
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -871,12 +719,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
|
||||
"""
|
||||
batch_size = context.shape[0]
|
||||
(
|
||||
encoder_outputs,
|
||||
loc_scale,
|
||||
patched_future_covariates_mask,
|
||||
num_context_patches,
|
||||
) = self.encode(
|
||||
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
group_ids=group_ids,
|
||||
|
|
@ -888,11 +731,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states: torch.Tensor = encoder_outputs[0]
|
||||
assert hidden_states.shape == (
|
||||
batch_size,
|
||||
num_context_patches + 1 + num_output_patches,
|
||||
self.model_dim,
|
||||
)
|
||||
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
|
||||
|
||||
# slice the last num_output_patches hidden states to be input into the output_patch_embedding
|
||||
forecast_embeds = hidden_states[:, -num_output_patches:]
|
||||
|
|
|
|||
Loading…
Reference in a new issue