Skip to content

Commit a79b647

Browse files
committed
fix: forward method
1 parent ca4b04b commit a79b647

File tree

2 files changed

+42
-82
lines changed

2 files changed

+42
-82
lines changed

src/tevatron/reranker/modeling.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import os
1+
import logging
22
from dataclasses import dataclass
3-
from typing import Dict, Optional
3+
from typing import Optional, Dict, Any
44

55
import torch
66
from torch import nn, Tensor
@@ -11,8 +11,6 @@
1111

1212
from tevatron.reranker.arguments import ModelArguments
1313

14-
import logging
15-
1614
logger = logging.getLogger(__name__)
1715

1816

@@ -27,6 +25,7 @@ class RerankerModel(nn.Module):
2725

2826
def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
2927
super().__init__()
28+
logger.info(f"Initializing RerankerModel with train_batch_size: {train_batch_size}")
3029
self.config = hf_model.config
3130
self.hf_model = hf_model
3231
self.train_batch_size = train_batch_size
@@ -36,31 +35,26 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
3635
'target_label',
3736
torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device)
3837
)
39-
for name, param in self.hf_model.named_parameters():
40-
# for some reason, ds zero 3 left some weights empty
41-
if 'modules_to_save' in name and param.numel() == 0:
42-
logger.warning(f'parameter {name}, shape {param.shape} is empty')
43-
param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data
44-
logger.warning('{} data: {}'.format(name, param.data.cpu().numpy()))
45-
46-
def forward(self, pair: Dict[str, Tensor] = None):
47-
ranker_logits = self.hf_model(**pair, return_dict=True).logits
48-
if self.train_batch_size:
49-
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
50-
loss = self.cross_entropy(grouped_logits, self.target_label)
51-
return RerankerOutput(
52-
loss=loss,
53-
scores=ranker_logits
54-
)
38+
logger.info(f"RerankerModel initialized with config: {self.config}")
39+
40+
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, **kwargs):
41+
logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}")
42+
outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
43+
44+
if labels is not None:
45+
loss = self.cross_entropy(outputs.logits.view(self.train_batch_size, -1), labels)
46+
logger.debug(f"Computed loss: {loss.item()}")
47+
else:
48+
loss = None
49+
logger.debug("No labels provided, skipping loss computation")
5550

5651
return RerankerOutput(
57-
loss=None,
58-
scores=ranker_logits
52+
loss=loss,
53+
scores=outputs.logits
5954
)
6055

61-
def gradient_checkpointing_enable(self, **kwargs):
56+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None):
6257
return False
63-
# self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
6458

6559
@classmethod
6660
def build(
@@ -69,21 +63,27 @@ def build(
6963
train_args: TrainingArguments,
7064
**hf_kwargs,
7165
):
66+
logger.info(f"Building RerankerModel with args: {model_args}")
7267
base_model = cls.TRANSFORMER_CLS.from_pretrained(
7368
model_args.model_name_or_path,
7469
**hf_kwargs,
7570
)
7671
if base_model.config.pad_token_id is None:
7772
base_model.config.pad_token_id = 0
73+
logger.info("Set pad_token_id to 0")
74+
7875
if model_args.lora or model_args.lora_name_or_path:
76+
logger.info("Applying LoRA")
7977
if train_args.gradient_checkpointing:
8078
base_model.enable_input_require_grads()
8179
if model_args.lora_name_or_path:
80+
logger.info(f"Loading LoRA from {model_args.lora_name_or_path}")
8281
lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
8382
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path,
8483
torch_dtype=torch.bfloat16,
8584
attn_implementation="flash_attention_2")
8685
else:
86+
logger.info("Initializing new LoRA")
8787
lora_config = LoraConfig(
8888
base_model_name_or_path=model_args.model_name_or_path,
8989
task_type=TaskType.SEQ_CLS,
@@ -99,6 +99,7 @@ def build(
9999
train_batch_size=train_args.per_device_train_batch_size,
100100
)
101101
else:
102+
logger.info("Building model without LoRA")
102103
model = cls(
103104
hf_model=base_model,
104105
train_batch_size=train_args.per_device_train_batch_size,
@@ -110,23 +111,28 @@ def load(cls,
110111
model_name_or_path: str,
111112
lora_name_or_path: str = None,
112113
**hf_kwargs):
114+
logger.info(f"Loading RerankerModel from {model_name_or_path}")
113115
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs,
114116
torch_dtype=torch.bfloat16,
115117
attn_implementation="flash_attention_2")
116118
if base_model.config.pad_token_id is None:
117119
base_model.config.pad_token_id = 0
120+
logger.info("Set pad_token_id to 0")
118121
if lora_name_or_path:
122+
logger.info(f"Loading LoRA from {lora_name_or_path}")
119123
lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
120124
lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
121125
lora_model = lora_model.merge_and_unload()
122126
model = cls(
123127
hf_model=lora_model,
124128
)
125129
else:
130+
logger.info("Loading model without LoRA")
126131
model = cls(
127132
hf_model=base_model,
128133
)
129134
return model
130135

131136
def save(self, output_dir: str):
132-
self.hf_model.save_pretrained(output_dir)
137+
logger.info(f"Saving model to {output_dir}")
138+
self.hf_model.save_pretrained(output_dir)

src/tevatron/reranker/trainer.py

Lines changed: 11 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,35 @@
1-
import os
2-
from typing import Optional
1+
from tevatron.reranker.modeling import RerankerOutput
2+
from tevatron.retriever.trainer import TevatronTrainer
3+
from grad_cache import GradCache
34

4-
import torch
5-
from torch import Tensor
6-
from torch.nn import functional as F
7-
8-
from transformers.trainer import Trainer
9-
from transformers.deepspeed import is_deepspeed_zero3_enabled
10-
from peft import get_peft_model_state_dict
11-
12-
import logging
13-
14-
logger = logging.getLogger(__name__)
15-
16-
try:
17-
from grad_cache import GradCache
18-
19-
_grad_cache_available = True
20-
except ModuleNotFoundError:
21-
_grad_cache_available = False
22-
23-
24-
def split_inputs(model_input: dict, chunk_size: int):
5+
def split_inputs(model_input, chunk_size):
256
keys = list(model_input.keys())
267
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
278
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]
289

10+
def get_rep(x: RerankerOutput):
11+
return x.scores
2912

30-
def get_rep(x):
31-
return x.logits
32-
33-
34-
class RerankerTrainer(Trainer):
13+
class RerankerTrainer(TevatronTrainer):
3514
def __init__(self, *args, **kwargs):
36-
super(RerankerTrainer, self).__init__(*args, **kwargs)
37-
38-
if not _grad_cache_available:
39-
raise ValueError(
40-
'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.')
41-
15+
super().__init__(*args, **kwargs)
16+
loss_fn = lambda x, y: self.compute_loss(self.model, {'input_ids': x, 'labels': y})
4217
self.gc = GradCache(
4318
models=[self.model],
4419
chunk_sizes=[self.args.gc_chunk_size],
45-
loss_fn=self.compute_loss,
20+
loss_fn=loss_fn,
4621
split_input_fn=split_inputs,
4722
get_rep_fn=get_rep,
4823
fp16=self.args.fp16,
4924
scaler=self.scaler if self.args.fp16 else None
5025
)
5126

52-
def _save(self, output_dir: Optional[str] = None, state_dict=None):
53-
output_dir = output_dir if output_dir is not None else self.args.output_dir
54-
os.makedirs(output_dir, exist_ok=True)
55-
logger.info("Saving model checkpoint to %s", output_dir)
56-
self.model.save(output_dir)
57-
58-
if is_deepspeed_zero3_enabled():
59-
if state_dict is None:
60-
state_dict = self.model.state_dict()
61-
prefix = 'hf_model.'
62-
assert all(
63-
k.startswith(prefix) or k == "target_label"
64-
for k in state_dict.keys()
65-
), list(state_dict.keys())
66-
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
67-
lora_state_dict = get_peft_model_state_dict(self.model.hf_model, state_dict)
68-
if self.args.process_index <= 0:
69-
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
70-
print(f"Save adapter model at {output_dir}")
71-
7227
def compute_loss(self, model, inputs, return_outputs=False):
73-
outputs = model(inputs)
28+
outputs = model(**inputs)
7429
loss = outputs.loss
7530
return (loss, outputs) if return_outputs else loss
7631

7732
def training_step(self, model, inputs):
78-
model.train()
7933
_distributed = self.args.local_rank > -1
8034
self.gc.models = [model]
8135
loss = self.gc(inputs, no_sync_except_last=_distributed)

0 commit comments

Comments
 (0)