Skip to content

OptimAI-Lab/IRO

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

3 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Aligning Frozen LLMs by Reinforcement Learning: An Iterative Reweight-then-Optimize Approach

arXiv Python PyTorch License

A reinforcement learning framework for aligning frozen LLMs without parameter updates


πŸ“– Table of Contents


πŸ” Overview

This repository contains the official implementation of IRO (Iterative Reweight-then-Optimize), introduced in our paper: "Aligning Frozen LLMs by Reinforcement Learning: An Iterative Reweight-then-Optimize Approach".

🎯 What is IRO?

IRO is a novel reinforcement learning framework that performs RL-style alignment of frozen base models without modifying their parameters. The method works through an iterative process:

  1. Sample candidates from the base model
  2. Resample using current value functions
  3. Train a new lightweight value function for the next iteration

At test time, the trained value functions guide the base model generation via a search-based optimization process.

IRO Algorithm Illustration

Illustration of the IRO algorithm

πŸš€ Key Advantages

  • Parameter-efficient: No updates to the frozen base model
  • Sample-efficient: More efficient than Best-of-N methods
  • Iterative improvement: Progressive policy enhancement through value function training
Sample Efficiency Comparison

IRO demonstrates superior sample efficiency compared to Best-of-N methods


🎯 Key Features

  • βœ… Frozen Model Alignment: Align LLMs without parameter updates
  • βœ… Iterative Training: Progressive improvement through multiple iterations
  • βœ… Value Function Guidance: Lightweight value functions guide generation
  • βœ… Sample Efficiency: Better performance with fewer samples than baseline methods
  • βœ… Multiple Tasks: Support for summarization and instruction-following tasks

πŸ“Š Experimental Results

TL;DR Summarization Task

TL;DR Results

Performance comparison on TL;DR summarization task

Instruction-Following Task

Instruction Following Results

Results on instruction-following evaluation

Test-Time Scaling

Test-Time Scaling

Test-time scaling behavior with value function guidance


πŸš€ Quick Start

Prerequisites

  • Python: 3.10+
  • CUDA: Compatible GPU with CUDA 11.8+
  • Conda: For environment management

πŸ”§ Installation

Environment for Decoding

# Create and activate decoding environment
conda create -n decoding python=3.10
conda activate decoding

# Install PyTorch with CUDA support
pip install torch=2.1.0 --index-url https://download.pytorch.org/whl/cu118

# Install dependencies
pip install -r requirements.txt

# Optional: Install flash attention for better performance
pip install flash-attn==2.3.2 --no-build-isolation

Environment for Value Function Training

# Create environment from YAML file
conda env create -f env_value_train.yaml

πŸ’‘ Usage Examples

πŸ“ TL;DR Summarization Task

Model Configuration

Use 1B value function with 1B frozen policy. For 6.9B base model, update this line in all scripts:

base_model=vwxyzjn/EleutherAI_pythia-6.9b-deduped__sft__tldr

Training Pipeline

Step 1: Generate Initial Data

# Generate data from frozen policy Ο€_base
bash scripts/tldr/rollout_iter0.sh

Step 2: Train Value Function

# Train the first value function
bash scripts/tldr/value_train_iter1.sh

Step 3: Evaluate with Guidance

# Evaluate Ο€_1 generation on test data
bash scripts/tldr/run_guiding_iter1.sh

Step 4: Continue Iterations

# Iteration 2
bash scripts/tldr/rollout_iter1.sh
bash scripts/tldr/value_train_iter2.sh
bash scripts/tldr/run_guiding_iter2.sh

# Iteration 3
bash scripts/tldr/rollout_iter2.sh
bash scripts/tldr/value_train_iter3.sh
bash scripts/tldr/run_guiding_iter3.sh

Evaluation with GPT

Add the following to scripts/tldr/run_guiding_iter.sh for win-rate evaluation:

export OPENAI_API_KEY='your_openai_api_key'

python scripts/tldr/gpt_evaluate.py \
    --input_path ${save_dir}

Baseline Comparisons

# Run baseline methods
bash scripts/tldr/run_args.sh    # Argumentation
bash scripts/tldr/run_base.sh    # Base model
bash scripts/tldr/run_bon.sh     # Best-of-N
bash scripts/tldr/run_dpo.sh     # DPO

πŸŽ“ Instruction Following Task

Training with 7B Value Function + 8B Frozen Policy

Step 1: Generate Initial Data

# Generate data from frozen policy Ο€_base
bash scripts/instruction_following/7b_scripts/roll_out_8b_iter1.sh

Step 2: Train Value Function

# Train the first value function
bash scripts/instruction_following/7b_scripts/value_train_iter1.sh

Step 3: Evaluate with Guidance

# Evaluate Ο€_1 generation on test data
bash scripts/instruction_following/7b_scripts/run_valueguding_iter1.sh

Step 4: Continue Iterations

# Iteration 2
bash scripts/instruction_following/7b_scripts/roll_out_8b_iter2.sh
bash scripts/instruction_following/7b_scripts/value_train_iter2.sh
bash scripts/instruction_following/7b_scripts/run_valueguding_iter2.sh

# Iteration 3
bash scripts/instruction_following/7b_scripts/roll_out_8b_iter3.sh
bash scripts/instruction_following/7b_scripts/value_train_iter3.sh
bash scripts/instruction_following/7b_scripts/run_valueguding_iter3.sh

AlpacaEval Evaluation

# Activate evaluation environment
conda activate alpacaEval

# Run AlpacaEval comparison against GPT-4
alpaca_eval --model_outputs ${save_dir}/train_merge_reward.json \
    --output_path ${save_dir} \
    --reference_outputs ${reference_outputs}

Baseline Comparisons

# Run baseline methods
bash scripts/instruction_following/7b_scripts/run_args.sh      # Argumentation
bash scripts/instruction_following/7b_scripts/run_base_8b.sh   # Base model
bash scripts/instruction_following/7b_scripts/run_BoN_8b.sh    # Best-of-N
bash scripts/instruction_following/7b_scripts/run_cbs_8b.sh    # CBS

πŸ™ Acknowledgement

Our implementation builds upon excellent open-source projects:

We sincerely appreciate the contributions of these teams to the open-source research community.


πŸ“‘ Citation

If you find this work helpful, please consider citing our paper:

@article{zhang2025aligning,
  title={Aligning Frozen LLMs by Reinforcement Learning: An Iterative Reweight-then-Optimize Approach},
  author={Zhang, Xinnan and Li, Chenliang and Zeng, Siliang and Li, Jiaxiang and Wang, Zhongruo and Lin, Kaixiang and Lu, Songtao and Garcia, Alfredo and Hong, Mingyi},
  journal={arXiv preprint arXiv:2506.17828},
  year={2025}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published