Skip to content

Commit c44f152

Browse files
authored
Merge pull request #23 from junuMoon/refact/eval_script
eval 스크립트 수정
2 parents a0ea1f3 + 96d896b commit c44f152

File tree

1 file changed

+130
-56
lines changed

1 file changed

+130
-56
lines changed

v2/eval/evaluate_reranker.py

Lines changed: 130 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import os
22
import logging
3-
from multiprocessing import Process, current_process
3+
from multiprocessing import Process, current_process, Queue
44
import torch
55
import json
6+
import queue
7+
from pathlib import Path
8+
import argparse
9+
from typing import List
610

711
import mteb
812
from mteb import MTEB
@@ -168,35 +172,7 @@ def patched_load_results_file(self):
168172
logger = logging.getLogger("main")
169173

170174

171-
# GPU별 task 매핑 - 필요에 따라 GPU 번호를 조정하세요
172-
TASK_LIST_RERANKER_GPU_MAPPING = {
173-
7: [
174-
"Ko-StrategyQA",
175-
"AutoRAGRetrieval",
176-
"PublicHealthQA",
177-
"BelebeleRetrieval",
178-
"XPQARetrieval",
179-
"MultiLongDocRetrieval",
180-
"MIRACLRetrieval",
181-
"MrTidyRetrieval"
182-
]
183-
}
184-
185-
model_names = [
186-
# "BAAI/bge-reranker-v2-m3",
187-
# "dragonkue/bge-reranker-v2-m3-ko",
188-
# "sigridjineth/ko-reranker-v1.1",
189-
# "sigridjineth/ko-reranker-v1.2-preview",
190-
"Alibaba-NLP/gte-multilingual-reranker-base",
191-
"upskyy/ko-reranker-8k",
192-
"Dongjin-kr/ko-reranker",
193-
# "jinaai/jina-reranker-v2-base-multilingual",
194-
# 여기에 다른 모델들 추가
195-
]
196-
197-
previous_results_dir = "./results/stage1/top_1k_qrels"
198-
199-
def evaluate_reranker_model(model_name, gpu_id, tasks):
175+
def evaluate_reranker_model(model_name: str, gpu_id: int, tasks: List[str], previous_results_dir: Path, output_base_dir: Path, top_k: int, verbosity: int):
200176
try:
201177
device = torch.device(f"cuda:{str(gpu_id)}")
202178
torch.cuda.set_device(device)
@@ -205,16 +181,17 @@ def evaluate_reranker_model(model_name, gpu_id, tasks):
205181
setproctitle(f"{model_name}-reranker-{gpu_id}")
206182
print(f"Running tasks: {tasks} / {model_name} on GPU {gpu_id} in process {current_process().name}")
207183

184+
model_path = Path(model_name)
185+
output_dir = output_base_dir / model_path.parent.name / model_path.name
186+
output_dir.mkdir(parents=True, exist_ok=True)
187+
208188
cross_encoder = CrossEncoder(
209189
model_name,
210190
trust_remote_code=True,
211191
model_kwargs={"torch_dtype": torch.bfloat16},
212192
device=device
213193
)
214194

215-
output_dir = os.path.join("./results/stage2", model_name)
216-
217-
# TODO 모델별 batch size 조정
218195
batch_size = 2048
219196

220197
for task in tasks:
@@ -227,45 +204,142 @@ def evaluate_reranker_model(model_name, gpu_id, tasks):
227204
)
228205
evaluation = MTEB(tasks=tasks_mteb)
229206

230-
if os.path.exists(os.path.join(previous_results_dir, task + "_id.jsonl")):
207+
previous_results_path = previous_results_dir / (task + "_id.jsonl")
208+
if previous_results_path.exists():
231209
print(f"Previous results found: {task}")
232-
previous_results = os.path.join(previous_results_dir, task + "_id.jsonl")
210+
previous_results = str(previous_results_path)
233211

234212
evaluation.run(
235213
cross_encoder,
236-
top_k=50,
214+
top_k=top_k,
237215
save_predictions=True,
238-
output_folder=output_dir,
216+
output_folder=str(output_dir),
239217
previous_results=previous_results,
240-
batch_size=batch_size
218+
batch_size=batch_size,
219+
verbosity=verbosity,
241220
)
242221
else:
243222
print(f"Previous results not found: {task}")
244223
evaluation.run(
245224
cross_encoder,
246-
top_k=50,
225+
top_k=top_k,
247226
save_predictions=True,
248-
output_folder=output_dir,
249-
batch_size=batch_size
227+
output_folder=str(output_dir),
228+
batch_size=batch_size,
229+
verbosity=verbosity,
250230
)
251231

252232
except Exception as ex:
253233
print(f"Error in GPU {gpu_id} with model {model_name}: {ex}")
254234
traceback.print_exc()
255235

256-
if __name__ == "__main__":
257-
torch.multiprocessing.set_start_method('spawn')
258-
259-
for model_name in model_names:
260-
print(f"Starting evaluation for model: {model_name}")
261-
processes = []
236+
def worker(job_queue: Queue, gpu_queue: Queue, previous_results_dir: Path, output_base_dir: Path, top_k: int, verbosity: int):
237+
"""작업 큐와 GPU 큐에서 작업을 가져와 실행하는 워커 함수"""
238+
while True:
239+
try:
240+
model_name, task = job_queue.get(timeout=1)
241+
except queue.Empty:
242+
break
262243

263-
for gpu_id, tasks in TASK_LIST_RERANKER_GPU_MAPPING.items():
264-
p = Process(target=evaluate_reranker_model, args=(model_name, gpu_id, tasks))
265-
p.start()
266-
processes.append(p)
267-
268-
for p in processes:
269-
p.join()
270-
271-
print(f"Completed evaluation for model: {model_name}")
244+
gpu_id = None
245+
try:
246+
gpu_id = gpu_queue.get()
247+
print(f"Process {current_process().name}: Starting task: {task} / {model_name} on GPU {gpu_id}")
248+
evaluate_reranker_model(model_name, gpu_id, [task], previous_results_dir, output_base_dir, top_k, verbosity)
249+
print(f"Process {current_process().name}: Finished task: {task} / {model_name} on GPU {gpu_id}")
250+
except Exception:
251+
print(f"!!!!!!!!!! Process {current_process().name}: Error during task: {task} / {model_name} on GPU {gpu_id} !!!!!!!!!!!")
252+
traceback.print_exc()
253+
finally:
254+
if gpu_id is not None:
255+
gpu_queue.put(gpu_id)
256+
257+
258+
# --- 기본 설정값 (커맨드라인 인자로 덮어쓸 수 있음) ---
259+
DEFAULT_MODEL_NAMES = [
260+
"BAAI/bge-reranker-v2-m3",
261+
"dragonkue/bge-reranker-v2-m3-ko",
262+
"sigridjineth/ko-reranker-v1.1",
263+
"sigridjineth/ko-reranker-v1.2-preview",
264+
"Alibaba-NLP/gte-multilingual-reranker-base",
265+
"upskyy/ko-reranker-8k",
266+
"Dongjin-kr/ko-reranker",
267+
"jinaai/jina-reranker-v2-base-multilingual",
268+
]
269+
DEFAULT_TASKS = [
270+
"Ko-StrategyQA", "AutoRAGRetrieval", "PublicHealthQA", "BelebeleRetrieval",
271+
"XPQARetrieval", "MultiLongDocRetrieval", "MIRACLRetrieval", "MrTidyRetrieval"
272+
]
273+
DEFAULT_GPU_IDS = [0, 1, 2, 3, 4, 6, 7]
274+
V2_ROOT = Path(__file__).resolve().parents[1]
275+
DEFAULT_PREVIOUS_RESULTS_DIR = V2_ROOT / "eval/results/stage1/top_1k_qrels"
276+
DEFAULT_OUTPUT_DIR = V2_ROOT / "eval/results/stage2"
277+
278+
assert V2_ROOT.exists(), f"V2_ROOT does not exist: {V2_ROOT}"
279+
assert DEFAULT_PREVIOUS_RESULTS_DIR.exists(), f"DEFAULT_PREVIOUS_RESULTS_DIR does not exist: {DEFAULT_PREVIOUS_RESULTS_DIR}"
280+
assert DEFAULT_OUTPUT_DIR.exists(), f"DEFAULT_OUTPUT_DIR does not exist: {DEFAULT_OUTPUT_DIR}"
281+
# -----------------------------------------------------
282+
283+
284+
def main():
285+
parser = argparse.ArgumentParser(description="MTEB Reranker 벤치마크를 병렬로 실행합니다.")
286+
parser.add_argument(
287+
"--model_names", nargs="+", default=DEFAULT_MODEL_NAMES, help="평가할 리랭커 모델 이름 또는 경로 리스트"
288+
)
289+
parser.add_argument(
290+
"--tasks", nargs="+", default=DEFAULT_TASKS, help="평가할 MTEB 태스크 리스트"
291+
)
292+
parser.add_argument(
293+
"--gpu_ids", nargs="+", type=int, default=DEFAULT_GPU_IDS, help="사용할 GPU ID 리스트"
294+
)
295+
parser.add_argument(
296+
"--previous_results_dir", type=str, default=str(DEFAULT_PREVIOUS_RESULTS_DIR), help="1단계(BM25) 결과가 저장된 디렉토리"
297+
)
298+
parser.add_argument(
299+
"--output_dir", type=str, default=str(DEFAULT_OUTPUT_DIR), help="2단계(리랭킹) 최종 결과를 저장할 디렉토리"
300+
)
301+
parser.add_argument(
302+
"--model_dir", type=str, default=None, help="평가할 로컬 모델들이 저장된 디렉토리. 각 하위 디렉토리가 모델로 간주됩니다."
303+
)
304+
parser.add_argument(
305+
"--top_k", type=int, default=50, help="리랭킹에 사용할 상위 K개 문서 수"
306+
)
307+
parser.add_argument(
308+
"--verbosity", type=int, default=0, help="MTEB 로그 상세 수준 (0: 진행률 표시줄만, 1: 점수 표시, 2: 상세 정보, 3: 디버그용)"
309+
)
310+
args = parser.parse_args()
311+
312+
torch.multiprocessing.set_start_method('spawn', force=True)
313+
314+
previous_results_dir = Path(args.previous_results_dir)
315+
output_dir = Path(args.output_dir)
316+
317+
job_queue = Queue()
318+
gpu_queue = Queue()
319+
320+
total_jobs = 0
321+
for model_name in args.model_names:
322+
for task in args.tasks:
323+
job_queue.put((model_name, task))
324+
total_jobs += 1
325+
326+
for gpu_id in args.gpu_ids:
327+
gpu_queue.put(gpu_id)
328+
329+
processes = []
330+
num_workers = len(args.gpu_ids)
331+
print(f"Starting {num_workers} workers on GPUs: {args.gpu_ids}")
332+
print(f"Total jobs to process: {total_jobs}")
333+
334+
for _ in range(num_workers):
335+
p = Process(target=worker, args=(job_queue, gpu_queue, previous_results_dir, output_dir, args.top_k, args.verbosity))
336+
p.start()
337+
processes.append(p)
338+
339+
for p in processes:
340+
p.join()
341+
342+
print("All evaluation tasks completed.")
343+
344+
if __name__ == "__main__":
345+
main()

0 commit comments

Comments
 (0)