1
- import os
1
+ import logging
2
2
from dataclasses import dataclass
3
- from typing import Dict , Optional
3
+ from typing import Optional , Dict , Any
4
4
5
5
import torch
6
6
from torch import nn , Tensor
11
11
12
12
from tevatron .reranker .arguments import ModelArguments
13
13
14
- import logging
15
-
16
14
logger = logging .getLogger (__name__ )
17
15
18
16
@@ -27,6 +25,7 @@ class RerankerModel(nn.Module):
27
25
28
26
def __init__ (self , hf_model : PreTrainedModel , train_batch_size : int = None ):
29
27
super ().__init__ ()
28
+ logger .info (f"Initializing RerankerModel with train_batch_size: { train_batch_size } " )
30
29
self .config = hf_model .config
31
30
self .hf_model = hf_model
32
31
self .train_batch_size = train_batch_size
@@ -36,31 +35,26 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
36
35
'target_label' ,
37
36
torch .zeros (self .train_batch_size , dtype = torch .long , device = self .hf_model .device )
38
37
)
39
- for name , param in self .hf_model .named_parameters ():
40
- # for some reason, ds zero 3 left some weights empty
41
- if 'modules_to_save' in name and param .numel () == 0 :
42
- logger .warning (f'parameter { name } , shape { param .shape } is empty' )
43
- param .data = nn .Linear (self .hf_model .config .hidden_size , 1 ).weight .data
44
- logger .warning ('{} data: {}' .format (name , param .data .cpu ().numpy ()))
45
-
46
- def forward (self , pair : Dict [str , Tensor ] = None ):
47
- ranker_logits = self .hf_model (** pair , return_dict = True ).logits
48
- if self .train_batch_size :
49
- grouped_logits = ranker_logits .view (self .train_batch_size , - 1 )
50
- loss = self .cross_entropy (grouped_logits , self .target_label )
51
- return RerankerOutput (
52
- loss = loss ,
53
- scores = ranker_logits
54
- )
38
+ logger .info (f"RerankerModel initialized with config: { self .config } " )
39
+
40
+ def forward (self , input_ids : Tensor = None , attention_mask : Tensor = None , labels : Tensor = None , ** kwargs ):
41
+ logger .debug (f"Forward pass with input shape: { input_ids .shape if input_ids is not None else 'None' } " )
42
+ outputs = self .hf_model (input_ids = input_ids , attention_mask = attention_mask , ** kwargs )
43
+
44
+ if labels is not None :
45
+ loss = self .cross_entropy (outputs .logits .view (self .train_batch_size , - 1 ), labels )
46
+ logger .debug (f"Computed loss: { loss .item ()} " )
47
+ else :
48
+ loss = None
49
+ logger .debug ("No labels provided, skipping loss computation" )
55
50
56
51
return RerankerOutput (
57
- loss = None ,
58
- scores = ranker_logits
52
+ loss = loss ,
53
+ scores = outputs . logits
59
54
)
60
55
61
- def gradient_checkpointing_enable (self , ** kwargs ):
56
+ def gradient_checkpointing_enable (self , gradient_checkpointing_kwargs : Optional [ Dict [ str , Any ]] = None ):
62
57
return False
63
- # self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
64
58
65
59
@classmethod
66
60
def build (
@@ -69,21 +63,27 @@ def build(
69
63
train_args : TrainingArguments ,
70
64
** hf_kwargs ,
71
65
):
66
+ logger .info (f"Building RerankerModel with args: { model_args } " )
72
67
base_model = cls .TRANSFORMER_CLS .from_pretrained (
73
68
model_args .model_name_or_path ,
74
69
** hf_kwargs ,
75
70
)
76
71
if base_model .config .pad_token_id is None :
77
72
base_model .config .pad_token_id = 0
73
+ logger .info ("Set pad_token_id to 0" )
74
+
78
75
if model_args .lora or model_args .lora_name_or_path :
76
+ logger .info ("Applying LoRA" )
79
77
if train_args .gradient_checkpointing :
80
78
base_model .enable_input_require_grads ()
81
79
if model_args .lora_name_or_path :
80
+ logger .info (f"Loading LoRA from { model_args .lora_name_or_path } " )
82
81
lora_config = LoraConfig .from_pretrained (model_args .lora_name_or_path , ** hf_kwargs )
83
82
lora_model = PeftModel .from_pretrained (base_model , model_args .lora_name_or_path ,
84
83
torch_dtype = torch .bfloat16 ,
85
84
attn_implementation = "flash_attention_2" )
86
85
else :
86
+ logger .info ("Initializing new LoRA" )
87
87
lora_config = LoraConfig (
88
88
base_model_name_or_path = model_args .model_name_or_path ,
89
89
task_type = TaskType .SEQ_CLS ,
@@ -99,6 +99,7 @@ def build(
99
99
train_batch_size = train_args .per_device_train_batch_size ,
100
100
)
101
101
else :
102
+ logger .info ("Building model without LoRA" )
102
103
model = cls (
103
104
hf_model = base_model ,
104
105
train_batch_size = train_args .per_device_train_batch_size ,
@@ -110,23 +111,28 @@ def load(cls,
110
111
model_name_or_path : str ,
111
112
lora_name_or_path : str = None ,
112
113
** hf_kwargs ):
114
+ logger .info (f"Loading RerankerModel from { model_name_or_path } " )
113
115
base_model = cls .TRANSFORMER_CLS .from_pretrained (model_name_or_path , num_labels = 1 , ** hf_kwargs ,
114
116
torch_dtype = torch .bfloat16 ,
115
117
attn_implementation = "flash_attention_2" )
116
118
if base_model .config .pad_token_id is None :
117
119
base_model .config .pad_token_id = 0
120
+ logger .info ("Set pad_token_id to 0" )
118
121
if lora_name_or_path :
122
+ logger .info (f"Loading LoRA from { lora_name_or_path } " )
119
123
lora_config = LoraConfig .from_pretrained (lora_name_or_path , ** hf_kwargs )
120
124
lora_model = PeftModel .from_pretrained (base_model , lora_name_or_path , config = lora_config )
121
125
lora_model = lora_model .merge_and_unload ()
122
126
model = cls (
123
127
hf_model = lora_model ,
124
128
)
125
129
else :
130
+ logger .info ("Loading model without LoRA" )
126
131
model = cls (
127
132
hf_model = base_model ,
128
133
)
129
134
return model
130
135
131
136
def save (self , output_dir : str ):
132
- self .hf_model .save_pretrained (output_dir )
137
+ logger .info (f"Saving model to { output_dir } " )
138
+ self .hf_model .save_pretrained (output_dir )
0 commit comments