This repository contains the official implementation of IRO (Iterative Reweight-then-Optimize), introduced in our paper: "Aligning Frozen LLMs by Reinforcement Learning: An Iterative Reweight-then-Optimize Approach".
IRO is a novel reinforcement learning framework that performs RL-style alignment of frozen base models without modifying their parameters. The method works through an iterative process:
- Sample candidates from the base model
- Resample using current value functions
- Train a new lightweight value function for the next iteration
At test time, the trained value functions guide the base model generation via a search-based optimization process.
- Parameter-efficient: No updates to the frozen base model
- Sample-efficient: More efficient than Best-of-N methods
- Iterative improvement: Progressive policy enhancement through value function training
- β Frozen Model Alignment: Align LLMs without parameter updates
- β Iterative Training: Progressive improvement through multiple iterations
- β Value Function Guidance: Lightweight value functions guide generation
- β Sample Efficiency: Better performance with fewer samples than baseline methods
- β Multiple Tasks: Support for summarization and instruction-following tasks
- Python: 3.10+
- CUDA: Compatible GPU with CUDA 11.8+
- Conda: For environment management
# Create and activate decoding environment
conda create -n decoding python=3.10
conda activate decoding
# Install PyTorch with CUDA support
pip install torch=2.1.0 --index-url https://download.pytorch.org/whl/cu118
# Install dependencies
pip install -r requirements.txt
# Optional: Install flash attention for better performance
pip install flash-attn==2.3.2 --no-build-isolation
# Create environment from YAML file
conda env create -f env_value_train.yaml
Use 1B value function with 1B frozen policy. For 6.9B base model, update this line in all scripts:
base_model=vwxyzjn/EleutherAI_pythia-6.9b-deduped__sft__tldr
Step 1: Generate Initial Data
# Generate data from frozen policy Ο_base
bash scripts/tldr/rollout_iter0.sh
Step 2: Train Value Function
# Train the first value function
bash scripts/tldr/value_train_iter1.sh
Step 3: Evaluate with Guidance
# Evaluate Ο_1 generation on test data
bash scripts/tldr/run_guiding_iter1.sh
Step 4: Continue Iterations
# Iteration 2
bash scripts/tldr/rollout_iter1.sh
bash scripts/tldr/value_train_iter2.sh
bash scripts/tldr/run_guiding_iter2.sh
# Iteration 3
bash scripts/tldr/rollout_iter2.sh
bash scripts/tldr/value_train_iter3.sh
bash scripts/tldr/run_guiding_iter3.sh
Add the following to scripts/tldr/run_guiding_iter.sh
for win-rate evaluation:
export OPENAI_API_KEY='your_openai_api_key'
python scripts/tldr/gpt_evaluate.py \
--input_path ${save_dir}
# Run baseline methods
bash scripts/tldr/run_args.sh # Argumentation
bash scripts/tldr/run_base.sh # Base model
bash scripts/tldr/run_bon.sh # Best-of-N
bash scripts/tldr/run_dpo.sh # DPO
Step 1: Generate Initial Data
# Generate data from frozen policy Ο_base
bash scripts/instruction_following/7b_scripts/roll_out_8b_iter1.sh
Step 2: Train Value Function
# Train the first value function
bash scripts/instruction_following/7b_scripts/value_train_iter1.sh
Step 3: Evaluate with Guidance
# Evaluate Ο_1 generation on test data
bash scripts/instruction_following/7b_scripts/run_valueguding_iter1.sh
Step 4: Continue Iterations
# Iteration 2
bash scripts/instruction_following/7b_scripts/roll_out_8b_iter2.sh
bash scripts/instruction_following/7b_scripts/value_train_iter2.sh
bash scripts/instruction_following/7b_scripts/run_valueguding_iter2.sh
# Iteration 3
bash scripts/instruction_following/7b_scripts/roll_out_8b_iter3.sh
bash scripts/instruction_following/7b_scripts/value_train_iter3.sh
bash scripts/instruction_following/7b_scripts/run_valueguding_iter3.sh
# Activate evaluation environment
conda activate alpacaEval
# Run AlpacaEval comparison against GPT-4
alpaca_eval --model_outputs ${save_dir}/train_merge_reward.json \
--output_path ${save_dir} \
--reference_outputs ${reference_outputs}
# Run baseline methods
bash scripts/instruction_following/7b_scripts/run_args.sh # Argumentation
bash scripts/instruction_following/7b_scripts/run_base_8b.sh # Base model
bash scripts/instruction_following/7b_scripts/run_BoN_8b.sh # Best-of-N
bash scripts/instruction_following/7b_scripts/run_cbs_8b.sh # CBS
Our implementation builds upon excellent open-source projects:
- TL;DR Summarization - Foundational summarization framework
- Weak-to-Strong Search - Search-based optimization techniques
We sincerely appreciate the contributions of these teams to the open-source research community.
If you find this work helpful, please consider citing our paper:
@article{zhang2025aligning,
title={Aligning Frozen LLMs by Reinforcement Learning: An Iterative Reweight-then-Optimize Approach},
author={Zhang, Xinnan and Li, Chenliang and Zeng, Siliang and Li, Jiaxiang and Wang, Zhongruo and Lin, Kaixiang and Lu, Songtao and Garcia, Alfredo and Hong, Mingyi},
journal={arXiv preprint arXiv:2506.17828},
year={2025}
}