@@ -547,6 +547,74 @@ def _compute_loss(
547547
548548 return loss
549549
550+ def encode (
551+ self ,
552+ context : torch .Tensor ,
553+ context_mask : torch .Tensor | None = None ,
554+ group_ids : torch .Tensor | None = None ,
555+ future_covariates : torch .Tensor | None = None ,
556+ future_covariates_mask : torch .Tensor | None = None ,
557+ num_output_patches : int = 1 ,
558+ future_target : torch .Tensor | None = None ,
559+ future_target_mask : torch .Tensor | None = None ,
560+ output_attentions : bool = False ,
561+ ):
562+ self ._validate_input (
563+ context = context ,
564+ context_mask = context_mask ,
565+ future_covariates = future_covariates ,
566+ future_covariates_mask = future_covariates_mask ,
567+ group_ids = group_ids ,
568+ num_output_patches = num_output_patches ,
569+ future_target = future_target ,
570+ future_target_mask = future_target_mask ,
571+ )
572+
573+ batch_size = context .shape [0 ]
574+ patched_context , attention_mask , loc_scale = self ._prepare_patched_context (
575+ context = context , context_mask = context_mask
576+ )
577+ num_context_patches = attention_mask .shape [- 1 ]
578+
579+ # get input embeddings of shape (batch, num_context_patches, d_model)
580+ input_embeds : torch .Tensor = self .input_patch_embedding (patched_context )
581+ # append [REG] special token embedding, if needed
582+ if self .chronos_config .use_reg_token :
583+ reg_input_ids = torch .full ((batch_size , 1 ), self .config .reg_token_id , device = input_embeds .device )
584+ reg_embeds = self .shared (reg_input_ids )
585+ input_embeds = torch .cat ([input_embeds , reg_embeds ], dim = - 2 )
586+ attention_mask = torch .cat (
587+ [attention_mask .to (self .dtype ), torch .ones_like (reg_input_ids ).to (self .dtype )], dim = - 1
588+ )
589+
590+ patched_future , patched_future_covariates_mask = self ._prepare_patched_future (
591+ future_covariates = future_covariates ,
592+ future_covariates_mask = future_covariates_mask ,
593+ loc_scale = loc_scale ,
594+ num_output_patches = num_output_patches ,
595+ batch_size = batch_size ,
596+ )
597+ future_attention_mask = torch .ones (batch_size , num_output_patches , dtype = self .dtype , device = self .device )
598+
599+ # get future embeddings of shape (batch, num_output_patches, d_model)
600+ future_embeds : torch .Tensor = self .input_patch_embedding (patched_future )
601+
602+ # concatenate context and future embeddings and masks
603+ input_embeds = torch .cat ([input_embeds , future_embeds ], dim = - 2 )
604+ attention_mask = torch .cat ([attention_mask , future_attention_mask ], dim = - 1 )
605+
606+ if group_ids is None :
607+ # by default, each time series is treated independently, i.e., no mixing across the batch
608+ group_ids = torch .arange (batch_size , dtype = torch .long , device = self .device )
609+
610+ encoder_outputs : Chronos2EncoderOutput = self .encoder (
611+ attention_mask = attention_mask ,
612+ inputs_embeds = input_embeds ,
613+ group_ids = group_ids ,
614+ output_attentions = output_attentions ,
615+ )
616+ return encoder_outputs , loc_scale , patched_future_covariates_mask , num_context_patches
617+
550618 def forward (
551619 self ,
552620 context : torch .Tensor ,
@@ -625,63 +693,19 @@ def forward(
625693 - enc_time_self_attn_weights: Time self attention weights, if output_attentions=True
626694 - enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
627695 """
628-
629- self ._validate_input (
696+ batch_size = context . shape [ 0 ]
697+ encoder_outputs , loc_scale , patched_future_covariates_mask , num_context_patches = self .encode (
630698 context = context ,
631699 context_mask = context_mask ,
700+ group_ids = group_ids ,
632701 future_covariates = future_covariates ,
633702 future_covariates_mask = future_covariates_mask ,
634- group_ids = group_ids ,
635703 num_output_patches = num_output_patches ,
636704 future_target = future_target ,
637705 future_target_mask = future_target_mask ,
638- )
639-
640- batch_size = context .shape [0 ]
641- patched_context , attention_mask , loc_scale = self ._prepare_patched_context (
642- context = context , context_mask = context_mask
643- )
644- num_context_patches = attention_mask .shape [- 1 ]
645-
646- # get input embeddings of shape (batch, num_context_patches, d_model)
647- input_embeds : torch .Tensor = self .input_patch_embedding (patched_context )
648- # append [REG] special token embedding, if needed
649- if self .chronos_config .use_reg_token :
650- reg_input_ids = torch .full ((batch_size , 1 ), self .config .reg_token_id , device = input_embeds .device )
651- reg_embeds = self .shared (reg_input_ids )
652- input_embeds = torch .cat ([input_embeds , reg_embeds ], dim = - 2 )
653- attention_mask = torch .cat (
654- [attention_mask .to (self .dtype ), torch .ones_like (reg_input_ids ).to (self .dtype )], dim = - 1
655- )
656-
657- patched_future , patched_future_covariates_mask = self ._prepare_patched_future (
658- future_covariates = future_covariates ,
659- future_covariates_mask = future_covariates_mask ,
660- loc_scale = loc_scale ,
661- num_output_patches = num_output_patches ,
662- batch_size = batch_size ,
663- )
664- future_attention_mask = torch .ones (batch_size , num_output_patches , dtype = self .dtype , device = self .device )
665-
666- # get future embeddings of shape (batch, num_output_patches, d_model)
667- future_embeds : torch .Tensor = self .input_patch_embedding (patched_future )
668-
669- # concatenate context and future embeddings and masks
670- input_embeds = torch .cat ([input_embeds , future_embeds ], dim = - 2 )
671- attention_mask = torch .cat ([attention_mask , future_attention_mask ], dim = - 1 )
672-
673- if group_ids is None :
674- # by default, each time series is treated independently, i.e., no mixing across the batch
675- group_ids = torch .arange (batch_size , dtype = torch .long , device = self .device )
676-
677- encoder_outputs : Chronos2EncoderOutput = self .encoder (
678- attention_mask = attention_mask ,
679- inputs_embeds = input_embeds ,
680- group_ids = group_ids ,
681706 output_attentions = output_attentions ,
682707 )
683708 hidden_states : torch .Tensor = encoder_outputs [0 ]
684-
685709 assert hidden_states .shape == (batch_size , num_context_patches + 1 + num_output_patches , self .model_dim )
686710
687711 # slice the last num_output_patches hidden states to be input into the output_patch_embedding
0 commit comments