Skip to content

Commit 12892c3

Browse files
committed
fix: prediction step
1 parent a79b647 commit 12892c3

File tree

2 files changed

+56
-40
lines changed

2 files changed

+56
-40
lines changed

src/tevatron/reranker/modeling.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from dataclasses import dataclass
3-
from typing import Optional, Dict, Any
3+
from typing import Optional
44

55
import torch
66
from torch import nn, Tensor
@@ -23,33 +23,18 @@ class RerankerOutput(ModelOutput):
2323
class RerankerModel(nn.Module):
2424
TRANSFORMER_CLS = AutoModelForSequenceClassification
2525

26-
def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
26+
def __init__(self, hf_model: PreTrainedModel):
2727
super().__init__()
28-
logger.info(f"Initializing RerankerModel with train_batch_size: {train_batch_size}")
28+
logger.info("Initializing RerankerModel")
2929
self.config = hf_model.config
3030
self.hf_model = hf_model
31-
self.train_batch_size = train_batch_size
32-
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
33-
if train_batch_size:
34-
self.register_buffer(
35-
'target_label',
36-
torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device)
37-
)
3831
logger.info(f"RerankerModel initialized with config: {self.config}")
3932

40-
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, **kwargs):
33+
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwargs):
4134
logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}")
4235
outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
4336

44-
if labels is not None:
45-
loss = self.cross_entropy(outputs.logits.view(self.train_batch_size, -1), labels)
46-
logger.debug(f"Computed loss: {loss.item()}")
47-
else:
48-
loss = None
49-
logger.debug("No labels provided, skipping loss computation")
50-
5137
return RerankerOutput(
52-
loss=loss,
5338
scores=outputs.logits
5439
)
5540

@@ -94,16 +79,10 @@ def build(
9479
inference_mode=False,
9580
)
9681
lora_model = get_peft_model(base_model, lora_config)
97-
model = cls(
98-
hf_model=lora_model,
99-
train_batch_size=train_args.per_device_train_batch_size,
100-
)
82+
model = cls(hf_model=lora_model)
10183
else:
10284
logger.info("Building model without LoRA")
103-
model = cls(
104-
hf_model=base_model,
105-
train_batch_size=train_args.per_device_train_batch_size,
106-
)
85+
model = cls(hf_model=base_model)
10786
return model
10887

10988
@classmethod
@@ -123,14 +102,10 @@ def load(cls,
123102
lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
124103
lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
125104
lora_model = lora_model.merge_and_unload()
126-
model = cls(
127-
hf_model=lora_model,
128-
)
105+
model = cls(hf_model=lora_model)
129106
else:
130107
logger.info("Loading model without LoRA")
131-
model = cls(
132-
hf_model=base_model,
133-
)
108+
model = cls(hf_model=base_model)
134109
return model
135110

136111
def save(self, output_dir: str):

src/tevatron/reranker/trainer.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,37 @@
1-
from tevatron.reranker.modeling import RerankerOutput
2-
from tevatron.retriever.trainer import TevatronTrainer
1+
import logging
2+
from typing import Dict, Union, Any
3+
4+
import torch
5+
from torch import nn
6+
from transformers import Trainer
7+
from transformers.trainer_utils import PredictionOutput
8+
9+
from tevatron.arguments import TevatronTrainingArguments
310
from grad_cache import GradCache
411

12+
logger = logging.getLogger(__name__)
13+
514
def split_inputs(model_input, chunk_size):
15+
logger.debug(f"Splitting inputs with chunk size: {chunk_size}")
616
keys = list(model_input.keys())
717
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
818
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]
919

10-
def get_rep(x: RerankerOutput):
11-
return x.scores
20+
def get_rep(model_output):
21+
logger.debug(f"Getting representation from model output: {type(model_output)}")
22+
return model_output.scores
1223

13-
class RerankerTrainer(TevatronTrainer):
24+
class RerankerTrainer(Trainer):
1425
def __init__(self, *args, **kwargs):
1526
super().__init__(*args, **kwargs)
16-
loss_fn = lambda x, y: self.compute_loss(self.model, {'input_ids': x, 'labels': y})
27+
logger.info("Initializing RerankerTrainer")
28+
self.args: TevatronTrainingArguments
29+
30+
def loss_fn(scores, labels):
31+
grouped_scores = scores.view(self.args.train_group_size, -1)
32+
labels = torch.zeros(self.args.train_group_size, dtype=torch.long, device=scores.device)
33+
return nn.CrossEntropyLoss()(grouped_scores, labels)
34+
1735
self.gc = GradCache(
1836
models=[self.model],
1937
chunk_sizes=[self.args.gc_chunk_size],
@@ -23,14 +41,37 @@ def __init__(self, *args, **kwargs):
2341
fp16=self.args.fp16,
2442
scaler=self.scaler if self.args.fp16 else None
2543
)
44+
logger.info(f"GradCache initialized with chunk size: {self.args.gc_chunk_size}")
2645

2746
def compute_loss(self, model, inputs, return_outputs=False):
47+
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
2848
outputs = model(**inputs)
2949
loss = outputs.loss
50+
logger.debug(f"Computed loss: {loss.item()}")
3051
return (loss, outputs) if return_outputs else loss
3152

32-
def training_step(self, model, inputs):
53+
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
54+
logger.debug("Entering training step")
55+
model.train()
56+
inputs = self._prepare_inputs(inputs)
3357
_distributed = self.args.local_rank > -1
3458
self.gc.models = [model]
3559
loss = self.gc(inputs, no_sync_except_last=_distributed)
60+
logger.debug(f"Training step loss: {loss.item()}")
3661
return loss
62+
63+
def prediction_step(
64+
self,
65+
model: nn.Module,
66+
inputs: Dict[str, Union[torch.Tensor, Any]],
67+
prediction_loss_only: bool,
68+
ignore_keys: bool = None,
69+
) -> PredictionOutput:
70+
logger.debug("Entering prediction step")
71+
inputs = self._prepare_inputs(inputs)
72+
with torch.no_grad():
73+
outputs = model(**inputs)
74+
loss = outputs.loss
75+
logits = outputs.scores
76+
logger.debug(f"Prediction step loss: {loss.item() if loss is not None else 'N/A'}")
77+
return PredictionOutput(predictions=logits, label_ids=inputs.get("labels"), metrics=None)

0 commit comments

Comments
 (0)