Align code formatting

Signed-off-by: xymli <xymli@tencent.com>
This commit is contained in:
xymli 2025-12-22 21:05:50 +08:00
parent 49edd9b876
commit 9844366902
2 changed files with 90 additions and 337 deletions

View file

@ -28,35 +28,22 @@ class RoPE(nn.Module):
self.dim = dim
self.base = base
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.inv_freq: torch.Tensor # type hint for type checker
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(
self, x: torch.Tensor, position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
@ -71,11 +58,7 @@ class RoPE(nn.Module):
@staticmethod
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
) -> tuple[torch.Tensor, torch.Tensor]:
"""Applies Rotary Position Embedding to the query and key tensors.
@ -147,9 +130,7 @@ class FeedForward(nn.Module):
assert not config.is_gated_act, "gated activations are unsupported"
self.mlp: nn.Module = MLP(config)
self.layer_norm = Chronos2LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -206,14 +187,10 @@ class MHA(nn.Module):
attn_weights: [batch, n_heads, q_len, kv_len]
"""
# Compute attention weights (no scaling - this is the original Chronos-2 implementation)
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # "bnqd,bnkd->bnqk"
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk"
scores += mask
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
return attn_output, attn_weights
@ -271,9 +248,7 @@ class MHA(nn.Module):
- attn_weights : Attention weights if output_attentions=True
"""
if self.use_rope:
assert (
position_ids is not None
), "position_ids must be provided when self.use_rope=True"
assert position_ids is not None, "position_ids must be provided when self.use_rope=True"
# Force eager attention if output_attentions is True (only eager returns weights)
attn_implementation = self.config._attn_implementation
@ -284,23 +259,11 @@ class MHA(nn.Module):
def shape(states: torch.Tensor) -> torch.Tensor:
"""(batch, seq_len, inner_dim) -> (batch, n_heads, seq_len, kv_proj_dim)"""
return rearrange(
states,
"b s (h d) -> b h s d",
h=self.n_heads,
s=seq_length,
d=self.kv_proj_dim,
)
return rearrange(states, "b s (h d) -> b h s d", h=self.n_heads, s=seq_length, d=self.kv_proj_dim)
def unshape(states: torch.Tensor) -> torch.Tensor:
"""(batch, n_heads, seq_len, kv_proj_dim) -> (batch, seq_len, inner_dim)"""
return rearrange(
states,
"b h s d -> b s (h d)",
h=self.n_heads,
s=seq_length,
d=self.kv_proj_dim,
)
return rearrange(states, "b h s d -> b s (h d)", h=self.n_heads, s=seq_length, d=self.kv_proj_dim)
# Construct query states
query_states = shape(self.q(hidden_states))
@ -315,36 +278,25 @@ class MHA(nn.Module):
value_states = shape(self.v(hidden_states))
if self.use_rope:
cos, sin = self.rope_embed(value_states, position_ids)
query_states, key_states = RoPE.apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
if attn_implementation == "sdpa":
attn_output, attn_weights = self._sdpa_attention(
query_states, key_states, value_states, mask
)
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)
else: # eager
attn_output, attn_weights = self._eager_attention(
query_states, key_states, value_states, mask
)
attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask)
# Project attention output
attn_output = unshape(attn_output)
attn_output = self.o(attn_output)
return AttentionOutput(
hidden_states=attn_output,
attn_weights=attn_weights if output_attentions else None,
)
return AttentionOutput(hidden_states=attn_output, attn_weights=attn_weights if output_attentions else None)
class TimeSelfAttention(nn.Module):
def __init__(self, config: Chronos2CoreConfig):
super().__init__()
self.self_attention = MHA(config, use_rope=True)
self.layer_norm = Chronos2LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
@ -356,25 +308,18 @@ class TimeSelfAttention(nn.Module):
) -> AttentionOutput:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output: AttentionOutput = self.self_attention(
normed_hidden_states,
position_ids=position_ids,
mask=attention_mask,
output_attentions=output_attentions,
normed_hidden_states, position_ids=position_ids, mask=attention_mask, output_attentions=output_attentions
)
hidden_states = hidden_states + self.dropout(attention_output[0])
return AttentionOutput(
hidden_states=hidden_states, attn_weights=attention_output.attn_weights
)
return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights)
class TimeCrossAttention(nn.Module):
def __init__(self, config: Chronos2CoreConfig):
super().__init__()
self.cross_attention = MHA(config, use_rope=False)
self.layer_norm = Chronos2LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
@ -393,9 +338,7 @@ class TimeCrossAttention(nn.Module):
)
hidden_states = hidden_states + self.dropout(attention_output[0])
return AttentionOutput(
hidden_states=hidden_states, attn_weights=attention_output.attn_weights
)
return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights)
class GroupSelfAttention(nn.Module):
@ -405,35 +348,20 @@ class GroupSelfAttention(nn.Module):
super().__init__()
# we don't use RoPE here because there's no natural ordering along the batch axis
self.self_attention = MHA(config, use_rope=False)
self.layer_norm = Chronos2LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | tuple,
output_attentions: bool = False,
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | tuple, output_attentions: bool = False
) -> AttentionOutput:
# flip time and batch axes because attention operates along dim=-2
hidden_states = rearrange(hidden_states, "batch time d -> time batch d")
normed_hidden_states = self.layer_norm(hidden_states)
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
(
flast_group_time_mask,
group_start_index,
counts,
max_group_len,
group_num,
) = attention_mask
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = attention_mask
fast_normed_hidden_states = torch.zeros(
(
group_num * normed_hidden_states.size(0),
max_group_len,
normed_hidden_states.size(2),
),
(group_num * normed_hidden_states.size(0), max_group_len, normed_hidden_states.size(2)),
dtype=normed_hidden_states.dtype,
device=normed_hidden_states.device,
)
@ -441,17 +369,11 @@ class GroupSelfAttention(nn.Module):
start_index = group_start_index[i].item()
group_len = counts[i].item()
end_index = start_index + group_len
fast_normed_hidden_states[
i
* normed_hidden_states.size(0) : (i + 1)
* normed_hidden_states.size(0),
:group_len,
:,
] = normed_hidden_states[:, start_index:end_index, :]
fast_normed_hidden_states[i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :] = normed_hidden_states[
:, start_index:end_index, :
]
fast_attention_output: AttentionOutput = self.self_attention(
fast_normed_hidden_states,
mask=flast_group_time_mask,
output_attentions=output_attentions,
fast_normed_hidden_states, mask=flast_group_time_mask, output_attentions=output_attentions
)
attn_weights = fast_attention_output.attn_weights
attention_output = torch.empty_like(normed_hidden_states)
@ -459,20 +381,12 @@ class GroupSelfAttention(nn.Module):
start_index = group_start_index[i].item()
group_len = counts[i].item()
end_index = start_index + group_len
attention_output[:, start_index:end_index, :] = fast_attention_output[
0
][
i
* normed_hidden_states.size(0) : (i + 1)
* normed_hidden_states.size(0),
:group_len,
:,
attention_output[:, start_index:end_index, :] = fast_attention_output[0][
i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :
]
else:
attention_output: AttentionOutput = self.self_attention(
normed_hidden_states,
mask=attention_mask,
output_attentions=output_attentions,
normed_hidden_states, mask=attention_mask, output_attentions=output_attentions
)
attn_weights = attention_output.attn_weights
attention_output = attention_output[0]

View file

@ -66,9 +66,7 @@ class Chronos2EncoderBlock(nn.Module):
# apply group attention
group_self_attn_outputs: AttentionOutput = self.layer[1](
hidden_states,
attention_mask=group_time_mask,
output_attentions=output_attentions,
hidden_states, attention_mask=group_time_mask, output_attentions=output_attentions
)
hidden_states = group_self_attn_outputs[0]
@ -94,21 +92,15 @@ class Chronos2Encoder(nn.Module):
super().__init__()
assert not config.is_decoder
self.block = nn.ModuleList(
[Chronos2EncoderBlock(config) for i in range(config.num_layers)]
)
self.final_layer_norm = Chronos2LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
self.block = nn.ModuleList([Chronos2EncoderBlock(config) for i in range(config.num_layers)])
self.final_layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
@staticmethod
def _expand_and_invert_time_attention_mask(
attention_mask: torch.Tensor, floating_type: torch.dtype
) -> torch.Tensor:
assert (
attention_mask.ndim == 2
), "attention_mask must have shape (batch, seq_len)"
assert attention_mask.ndim == 2, "attention_mask must have shape (batch, seq_len)"
# Add new dims for attention heads and q_len
attention_mask = attention_mask[:, None, None, :]
@ -120,9 +112,7 @@ class Chronos2Encoder(nn.Module):
@staticmethod
def _construct_and_invert_group_time_mask(
group_ids: torch.Tensor,
attention_mask: torch.Tensor,
floating_type: torch.dtype,
group_ids: torch.Tensor, attention_mask: torch.Tensor, floating_type: torch.dtype
) -> torch.Tensor:
# construct group_mask (batch, batch) from group ids
# a cell is True if both row and col had the same group id
@ -146,9 +136,7 @@ class Chronos2Encoder(nn.Module):
max_group_len = counts.max().item()
group_num = unique_groups.size(0)
fast_group_time_mask = torch.full(
(group_num * group_time_mask.size(0), 1, max_group_len, max_group_len),
torch.finfo(floating_type).min,
device=group_time_mask.device,
(group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), torch.finfo(floating_type).min, device=group_time_mask.device
)
group_start_index = counts.cumsum(dim=0) - counts
for i in range(group_num):
@ -156,18 +144,9 @@ class Chronos2Encoder(nn.Module):
start_index = group_start_index[i].item()
end_index = start_index + group_len
fast_group_time_mask[
i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0),
:,
:group_len,
:group_len,
] = group_time_mask[:, :, start_index:end_index, start_index:end_index]
return (
fast_group_time_mask,
group_start_index,
counts,
max_group_len,
group_num,
)
i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), :, :group_len, :group_len
] = group_time_mask[:, :, start_index : end_index, start_index : end_index]
return fast_group_time_mask, group_start_index, counts, max_group_len, group_num
return group_time_mask
@ -183,45 +162,22 @@ class Chronos2Encoder(nn.Module):
batch_size, seq_length = inputs_embeds.size()[:-1]
if position_ids is None:
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=inputs_embeds.device
).unsqueeze(0)
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=inputs_embeds.device).unsqueeze(0)
if attention_mask is None:
attention_mask = torch.ones(
batch_size,
seq_length,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)
attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
# make the time attention mask broadcastable to attention scores (batch, n_heads, q_len, kv_len) and invert
extended_attention_mask = self._expand_and_invert_time_attention_mask(
attention_mask, inputs_embeds.dtype
)
extended_attention_mask = self._expand_and_invert_time_attention_mask(attention_mask, inputs_embeds.dtype)
# construct group time mask
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
(
flast_group_time_mask,
group_start_index,
counts,
max_group_len,
group_num,
) = self._construct_and_invert_group_time_mask(
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = self._construct_and_invert_group_time_mask(
group_ids, attention_mask, inputs_embeds.dtype
)
group_time_mask = (
flast_group_time_mask,
group_start_index,
counts,
max_group_len,
group_num,
)
group_time_mask = (flast_group_time_mask, group_start_index, counts, max_group_len, group_num)
else:
group_time_mask = self._construct_and_invert_group_time_mask(
group_ids, attention_mask, inputs_embeds.dtype
)
group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype)
all_time_self_attentions: tuple[torch.Tensor, ...] = ()
all_group_self_attentions: tuple[torch.Tensor, ...] = ()
@ -243,14 +199,8 @@ class Chronos2Encoder(nn.Module):
assert layer_outputs.time_self_attn_weights is not None
assert layer_outputs.group_self_attn_weights is not None
all_time_self_attentions = (
*all_time_self_attentions,
layer_outputs.time_self_attn_weights,
)
all_group_self_attentions = (
*all_group_self_attentions,
layer_outputs.group_self_attn_weights,
)
all_time_self_attentions = (*all_time_self_attentions, layer_outputs.time_self_attn_weights)
all_group_self_attentions = (*all_group_self_attentions, layer_outputs.group_self_attn_weights)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
@ -288,10 +238,7 @@ class Chronos2Model(PreTrainedModel):
)
self.chronos_config = Chronos2ForecastingConfig(**config.chronos_config)
assert (
self.chronos_config.input_patch_size
== self.chronos_config.output_patch_size
), (
assert self.chronos_config.input_patch_size == self.chronos_config.output_patch_size, (
"input_patch_size and output_patch_size sizes must be equal, "
f"but found {self.chronos_config.input_patch_size} and {self.chronos_config.output_patch_size}"
)
@ -315,8 +262,7 @@ class Chronos2Model(PreTrainedModel):
# patching layer
self.patch = Patch(
patch_size=self.chronos_config.input_patch_size,
patch_stride=self.chronos_config.input_patch_stride,
patch_size=self.chronos_config.input_patch_size, patch_stride=self.chronos_config.input_patch_stride
)
# instance normalization, also referred to as "scaling" in Chronos and GluonTS
@ -352,14 +298,10 @@ class Chronos2Model(PreTrainedModel):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
module.wi.weight.data.normal_(
mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
)
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi, "bias") and module.wi.bias is not None:
module.wi.bias.data.zero_()
module.wo.weight.data.normal_(
mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
)
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, MHA):
@ -368,14 +310,10 @@ class Chronos2Model(PreTrainedModel):
d_model = self.config.d_model
kv_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(
mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5)
)
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.o.weight.data.normal_(
mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)
)
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
elif isinstance(module, (Chronos2Model)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, ResidualBlock):
@ -383,29 +321,20 @@ class Chronos2Model(PreTrainedModel):
mean=0.0,
std=factor * (module.hidden_layer.weight.size(-1) ** -0.5),
)
if (
hasattr(module.hidden_layer, "bias")
and module.hidden_layer.bias is not None
):
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
module.hidden_layer.bias.data.zero_()
module.residual_layer.weight.data.normal_(
mean=0.0,
std=factor * (module.residual_layer.weight.size(-1) ** -0.5),
)
if (
hasattr(module.residual_layer, "bias")
and module.residual_layer.bias is not None
):
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
module.residual_layer.bias.data.zero_()
module.output_layer.weight.data.normal_(
mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5)
)
if (
hasattr(module.output_layer, "bias")
and module.output_layer.bias is not None
):
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
module.output_layer.bias.data.zero_()
def _validate_input(
@ -421,18 +350,11 @@ class Chronos2Model(PreTrainedModel):
):
output_patch_size = self.chronos_config.output_patch_size
if context.ndim != 2:
raise ValueError(
f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}"
)
raise ValueError(f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}")
if context_mask is not None and context_mask.shape != context.shape:
raise ValueError(
f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}"
)
raise ValueError(f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}")
if future_covariates is not None:
if (
future_covariates.shape[0] != context.shape[0]
or future_covariates.ndim != 2
):
if future_covariates.shape[0] != context.shape[0] or future_covariates.ndim != 2:
raise ValueError(
f"future_covariates must have shape (batch_size={context.shape[0]}, future_length), found: {tuple(future_covariates.shape)}"
)
@ -441,27 +363,20 @@ class Chronos2Model(PreTrainedModel):
f"{num_output_patches=} must be large enough to accommodate the length of future_covariates, "
f"found: {future_covariates.shape[-1]} > {num_output_patches} * {output_patch_size}"
)
if (
future_target is not None
and future_target.shape != future_covariates.shape
):
if future_target is not None and future_target.shape != future_covariates.shape:
raise ValueError(
f"future_target must have the same shape as future_covariates, found: {tuple(future_target.shape)} and {tuple(future_covariates.shape)}"
)
if future_covariates_mask is not None:
if future_covariates is None:
raise ValueError(
"future_covariates must be provided if future_covariates_mask is provided"
)
raise ValueError("future_covariates must be provided if future_covariates_mask is provided")
if future_covariates_mask.shape != future_covariates.shape:
raise ValueError(
f"future_covariates_mask must have the same shape as future_covariates, "
f"found: {tuple(future_covariates_mask.shape)} and {tuple(future_covariates.shape)}"
)
if group_ids is not None and group_ids.shape != (context.shape[0],):
raise ValueError(
f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}"
)
raise ValueError(f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}")
if future_target is not None:
if future_target.shape[0] != context.shape[0] or future_target.ndim != 2:
raise ValueError(
@ -474,9 +389,7 @@ class Chronos2Model(PreTrainedModel):
)
if future_target_mask is not None:
if future_target is None:
raise ValueError(
"future_target must be provided if future_target_mask is provided"
)
raise ValueError("future_target must be provided if future_target_mask is provided")
if future_target_mask.shape != future_target.shape:
raise ValueError(
f"future_target_mask must have the same shape as future_target, found: {tuple(future_target_mask.shape)} and {tuple(future_target.shape)}"
@ -515,12 +428,8 @@ class Chronos2Model(PreTrainedModel):
# context time encoding: every observation is assigned a sequential time index,
# scaled by model's context length = [-C, -(C-1), ..., -1] / context_length
final_context_length = (
num_context_patches * self.chronos_config.input_patch_size
)
context_time_enc = torch.arange(
start=-final_context_length, end=0, device=self.device, dtype=torch.float32
)
final_context_length = num_context_patches * self.chronos_config.input_patch_size
context_time_enc = torch.arange(start=-final_context_length, end=0, device=self.device, dtype=torch.float32)
context_time_enc = (
repeat(
context_time_enc,
@ -534,9 +443,7 @@ class Chronos2Model(PreTrainedModel):
)
# concat time encoding, context and mask along the last (feature) dim
patched_context = torch.cat(
[context_time_enc, patched_context, patched_mask], dim=-1
)
patched_context = torch.cat([context_time_enc, patched_context, patched_mask], dim=-1)
return patched_context, attention_mask, loc_scale
@ -555,15 +462,9 @@ class Chronos2Model(PreTrainedModel):
future_covariates = future_covariates.to(self.dtype)
if future_covariates_mask is None:
future_covariates_mask = (
torch.isnan(future_covariates)
.logical_not()
.to(future_covariates.dtype)
)
future_covariates_mask = torch.isnan(future_covariates).logical_not().to(future_covariates.dtype)
future_covariates = torch.where(
future_covariates_mask > 0.0, future_covariates, 0.0
)
future_covariates = torch.where(future_covariates_mask > 0.0, future_covariates, 0.0)
if torch.isnan(future_covariates).any():
raise ValueError(
@ -575,58 +476,33 @@ class Chronos2Model(PreTrainedModel):
if num_output_patches * output_patch_size > future_covariates.shape[-1]:
padding_shape = (
*future_covariates.shape[:-1],
num_output_patches * output_patch_size
- future_covariates.shape[-1],
num_output_patches * output_patch_size - future_covariates.shape[-1],
)
future_covariates = torch.cat(
[
future_covariates,
torch.zeros(padding_shape).to(future_covariates),
],
dim=-1,
[future_covariates, torch.zeros(padding_shape).to(future_covariates)], dim=-1
)
future_covariates_mask = torch.cat(
[
future_covariates_mask,
torch.zeros(padding_shape).to(future_covariates_mask),
],
dim=-1,
[future_covariates_mask, torch.zeros(padding_shape).to(future_covariates_mask)], dim=-1
)
patched_future_covariates = rearrange(
future_covariates,
"b (n p) -> b n p",
n=num_output_patches,
p=output_patch_size,
future_covariates, "b (n p) -> b n p", n=num_output_patches, p=output_patch_size
)
patched_future_covariates_mask = rearrange(
future_covariates_mask,
"b (n p) -> b n p",
n=num_output_patches,
p=output_patch_size,
future_covariates_mask, "b (n p) -> b n p", n=num_output_patches, p=output_patch_size
)
else:
patched_future_covariates = torch.zeros(
batch_size,
num_output_patches,
output_patch_size,
device=self.device,
dtype=self.dtype,
batch_size, num_output_patches, output_patch_size, device=self.device, dtype=self.dtype
)
patched_future_covariates_mask = torch.zeros(
batch_size,
num_output_patches,
output_patch_size,
device=self.device,
dtype=self.dtype,
batch_size, num_output_patches, output_patch_size, device=self.device, dtype=self.dtype
)
# future time encoding: every future timestep is assigned a sequential time index,
# scaled by model's context length = [0, 1, ..., h-1] / context_length
final_future_length = num_output_patches * output_patch_size
future_time_enc = torch.arange(
start=0, end=final_future_length, device=self.device, dtype=torch.float32
)
future_time_enc = torch.arange(start=0, end=final_future_length, device=self.device, dtype=torch.float32)
future_time_enc = (
repeat(
future_time_enc,
@ -640,12 +516,7 @@ class Chronos2Model(PreTrainedModel):
)
patched_future = torch.cat(
[
future_time_enc,
patched_future_covariates,
patched_future_covariates_mask,
],
dim=-1,
[future_time_enc, patched_future_covariates, patched_future_covariates_mask], dim=-1
)
return patched_future, patched_future_covariates_mask
@ -661,10 +532,7 @@ class Chronos2Model(PreTrainedModel):
) -> torch.Tensor:
batch_size = future_target.shape[0]
output_patch_size = self.chronos_config.output_patch_size
assert (
quantile_preds.shape[0] == batch_size
and quantile_preds.shape[-1] >= future_target.shape[-1]
)
assert quantile_preds.shape[0] == batch_size and quantile_preds.shape[-1] >= future_target.shape[-1]
# normalize target and mask
future_target, _ = self.instance_norm(future_target, loc_scale)
@ -679,22 +547,15 @@ class Chronos2Model(PreTrainedModel):
# pad target and target_mask if they are shorter than model's prediction
if quantile_preds.shape[-1] > future_target.shape[-1]:
padding_shape = (
*future_target.shape[:-1],
quantile_preds.shape[-1] - future_target.shape[-1],
)
future_target = torch.cat(
[future_target, torch.zeros(padding_shape).to(future_target)], dim=-1
)
padding_shape = (*future_target.shape[:-1], quantile_preds.shape[-1] - future_target.shape[-1])
future_target = torch.cat([future_target, torch.zeros(padding_shape).to(future_target)], dim=-1)
future_target_mask = torch.cat(
[future_target_mask, torch.zeros(padding_shape).to(future_target_mask)],
dim=-1,
[future_target_mask, torch.zeros(padding_shape).to(future_target_mask)], dim=-1
)
quantiles = rearrange(self.quantiles, "num_quantiles -> 1 num_quantiles 1")
quantile_loss = 2 * torch.abs(
(future_target - quantile_preds)
* ((future_target <= quantile_preds).float() - quantiles)
(future_target - quantile_preds) * ((future_target <= quantile_preds).float() - quantiles)
)
inv_future_covariate_mask = 1 - rearrange(
patched_future_covariates_mask,
@ -744,17 +605,11 @@ class Chronos2Model(PreTrainedModel):
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
# append [REG] special token embedding, if needed
if self.chronos_config.use_reg_token:
reg_input_ids = torch.full(
(batch_size, 1), self.config.reg_token_id, device=input_embeds.device
)
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
reg_embeds = self.shared(reg_input_ids)
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
attention_mask = torch.cat(
[
attention_mask.to(self.dtype),
torch.ones_like(reg_input_ids).to(self.dtype),
],
dim=-1,
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
)
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
@ -764,9 +619,7 @@ class Chronos2Model(PreTrainedModel):
num_output_patches=num_output_patches,
batch_size=batch_size,
)
future_attention_mask = torch.ones(
batch_size, num_output_patches, dtype=self.dtype, device=self.device
)
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
# get future embeddings of shape (batch, num_output_patches, d_model)
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
@ -785,12 +638,7 @@ class Chronos2Model(PreTrainedModel):
group_ids=group_ids,
output_attentions=output_attentions,
)
return (
encoder_outputs,
loc_scale,
patched_future_covariates_mask,
num_context_patches,
)
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
def forward(
self,
@ -871,12 +719,7 @@ class Chronos2Model(PreTrainedModel):
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
"""
batch_size = context.shape[0]
(
encoder_outputs,
loc_scale,
patched_future_covariates_mask,
num_context_patches,
) = self.encode(
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
context=context,
context_mask=context_mask,
group_ids=group_ids,
@ -888,11 +731,7 @@ class Chronos2Model(PreTrainedModel):
output_attentions=output_attentions,
)
hidden_states: torch.Tensor = encoder_outputs[0]
assert hidden_states.shape == (
batch_size,
num_context_patches + 1 + num_output_patches,
self.model_dim,
)
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
# slice the last num_output_patches hidden states to be input into the output_patch_embedding
forecast_embeds = hidden_states[:, -num_output_patches:]