Skip to content

Commit 3bc03cf

Browse files
authored
Merge pull request #35 from princeton-nlp/pypi
Pypi
2 parents 7382f24 + 43f1740 commit 3bc03cf

33 files changed

+1628
-1526
lines changed

MANIFEST.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
include src/tot/data/24/24.csv
2+
include src/tot/data/crosswords/mini0505_0_100_5.json
3+
include src/tot/data/crosswords/mini0505.json
4+
include src/tot/data/text/data_100_random_text.txt

fake.png renamed to pics/fake.png

File renamed without changes.
File renamed without changes.

pyproject.toml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
[build-system]
2+
requires = ["setuptools >= 61.0.0"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "tot"
7+
version = "0.1.0"
8+
description = 'Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"'
9+
readme = "README.md"
10+
requires-python = ">= 3.7"
11+
authors = [{ name = "Shunyu Yao", email = "shunyuyao.cs@gmail.com" }]
12+
license = { text = "MIT License" }
13+
keywords = ["tree-search", "large-language-models", "llm", "prompting", "tree-of-thoughts"]
14+
classifiers = [
15+
"License :: OSI Approved :: MIT License",
16+
"Programming Language :: Python :: 3",
17+
"Programming Language :: Python :: 3.7",
18+
"Programming Language :: Python :: 3.8",
19+
"Programming Language :: Python :: 3.9",
20+
"Programming Language :: Python :: 3.10",
21+
"Programming Language :: Python :: 3.11",
22+
'Intended Audience :: Science/Research',
23+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
24+
]
25+
dynamic=["dependencies"]
26+
27+
28+
[tool.setuptools.dynamic]
29+
dependencies = {file = ["requirements.txt"]}
30+
31+
[tool.setuptools.packages.find]
32+
where = ["src"] # list of folders that contain the packages (["."] by default)
33+
34+
[project.urls]
35+
Homepage = "https://github.com/princeton-nlp/tree-of-thought-llm"

readme.md

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,58 @@
22
[![DOI](https://zenodo.org/badge/642099326.svg)](https://zenodo.org/badge/latestdoi/642099326)
33

44
<details>
5-
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>
6-
7-
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
5+
<summary>Note: https://github.com/kyegomez/tree-of-thoughts CANNOT replicate paper results. </summary>
6+
7+
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](pics/fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
88
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
99
</details>
1010

1111

12+
![teaser](pics/teaser.png)
1213

14+
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
15+
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.
1316

1417

15-
![teaser](teaser.png)
1618

17-
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
18-
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.
1919

2020

21-
Please cite the paper and star this repo if you use ToT and find it interesting/useful. Thanks!
2221

23-
```bibtex
24-
@misc{yao2023tree,
25-
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models},
26-
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan},
27-
year={2023},
28-
eprint={2305.10601},
29-
archivePrefix={arXiv},
30-
primaryClass={cs.CL}
31-
}
22+
## Setup
23+
- Set up OpenAI API key and store in environment variable ``OPENAI_API_KEY`` (see [here](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)).
24+
25+
- Install dependencies and `tot` package (PyPI package coming soon):
26+
```bash
27+
git clone https://github.com/princeton-nlp/tree-of-thought-llm
28+
cd tree-of-thought-llm
29+
pip install -r requirements.txt
30+
pip install -e . # install `tot` package
3231
```
3332

3433

34+
## Quick Start
35+
The following minimal script will attempt to solve the game of 24 with `4 5 6 10` (might be a bit slow as it's using GPT-4):
36+
```python
37+
import argparse
38+
from tot.methods.bfs import solve
39+
from tot.tasks.game24 import Game24Task
3540

36-
## Setup
37-
You need to first have an OpenAI API key and store it in the environment variable ``OPENAI_API_KEY`` (see [here](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)). If you use custom base url, set it by environment variable ``OPENAI_API_BASE`` (e.g. https://api.openai.com/v1).
41+
args = argparse.Namespace(backend='gpt-4', temperature=0.7, task='game24', naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5)
3842

39-
Package requirement: ``pip install openai backoff sympy numpy``
43+
task = Game24Task()
44+
ys, infos = solve(args, task, 900)
45+
print(ys[0])
46+
```
4047

48+
And the output would be something like (note it's not deterministic, and sometimes the output can be wrong):
49+
```
50+
10 - 4 = 6 (left: 5 6 6)
51+
5 * 6 = 30 (left: 6 30)
52+
30 - 6 = 24 (left: 24)
53+
Answer: (5 * (10 - 4)) - 6 = 24
54+
```
4155

42-
## Experiments
56+
## Paper Experiments
4357

4458
Run experiments via ``sh scripts/{game24, text, crosswords}/{standard_sampling, cot_sampling, bfs}.sh``, except in crosswords we use a DFS algorithm for ToT, which can be run via ``scripts/crosswords/search_crosswords-dfs.ipynb``.
4559

@@ -55,13 +69,24 @@ The very simple ``run.py`` implements the ToT + BFS algorithm, as well as the na
5569

5670

5771

58-
## Trajectories
72+
## Paper Trajectories
5973
``logs/`` contains all the trajectories from the paper's experiments, except for ``logs/game24/gpt-4_0.7_propose1_value3_greedy5_start900_end1000.json`` which was reproduced after the paper (as the original experiment was done in a notebook) and achieved a 69\% score instead of the original 74\% score due to randomness in GPT decoding. We hope to aggregate multiple runs in the future to account for sampling randomness and update the paper, but this shouldn't affect the main conclusions of the paper.
6074

75+
## How to Add A New Task
76+
Setting up a new task is easy, and mainly involves two steps.
77+
* Set up a new task class in ``tot/tasks/`` and task files in ``tot/data/``. See ``tot/tasks/game24.py`` for an example. Add the task to ``tot/tasks/__init__.py``.
78+
* Set up task-specific prompts in ``tot/prompts/``. See ``tot/prompts/game24.py`` for an example. Depending on the nature of the task, choose ``--method_generate`` (choices=[``sample``, ``propose``]) and ``--method_evaluate`` (choices=[``value``, ``vote``]) and their corresponding prompts.
6179

80+
## Citations
81+
Please cite the paper and star this repo if you use ToT and find it interesting/useful, thanks! Feel free to contact shunyuyao.cs@gmail.com or open an issue if you have any questions.
6282

63-
## Questions
64-
Feel free to contact shunyuyao.cs@gmail.com or open an issue if you have any questions.
65-
66-
67-
83+
```bibtex
84+
@misc{yao2023tree,
85+
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models},
86+
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan},
87+
year={2023},
88+
eprint={2305.10601},
89+
archivePrefix={arXiv},
90+
primaryClass={cs.CL}
91+
}
92+
```

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ sympy==1.12
1616
tqdm==4.65.0
1717
urllib3==2.0.2
1818
yarl==1.9.2
19+
pandas==2.0.3

run.py

Lines changed: 7 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,18 @@
11
import os
22
import json
3-
import itertools
43
import argparse
5-
import numpy as np
6-
from functools import partial
7-
from models import gpt, gpt_usage
8-
from tasks import get_task
94

10-
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
11-
value_prompt = task.value_prompt_wrap(x, y)
12-
if cache_value and value_prompt in task.value_cache:
13-
return task.value_cache[value_prompt]
14-
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
15-
value = task.value_outputs_unwrap(x, y, value_outputs)
16-
if cache_value:
17-
task.value_cache[value_prompt] = value
18-
return value
19-
20-
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
21-
values = []
22-
local_value_cache = {}
23-
for y in ys: # each partial output
24-
if y in local_value_cache: # avoid duplicate candidates
25-
value = 0
26-
else:
27-
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
28-
local_value_cache[y] = value
29-
values.append(value)
30-
return values
31-
32-
def get_votes(task, x, ys, n_evaluate_sample):
33-
vote_prompt = task.vote_prompt_wrap(x, ys)
34-
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
35-
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
36-
return values
37-
38-
def get_proposals(task, x, y):
39-
propose_prompt = task.propose_prompt_wrap(x, y)
40-
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
41-
return [y + _ + '\n' for _ in proposals]
42-
43-
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
44-
if prompt_sample == 'standard':
45-
prompt = task.standard_prompt_wrap(x, y)
46-
elif prompt_sample == 'cot':
47-
prompt = task.cot_prompt_wrap(x, y)
48-
else:
49-
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
50-
samples = gpt(prompt, n=n_generate_sample, stop=stop)
51-
return [y + _ for _ in samples]
52-
53-
def solve(args, task, idx, to_print=True):
54-
print(gpt)
55-
x = task.get_input(idx) # input
56-
ys = [''] # current output candidates
57-
infos = []
58-
for step in range(task.steps):
59-
# generation
60-
if args.method_generate == 'sample':
61-
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
62-
elif args.method_generate == 'propose':
63-
new_ys = [get_proposals(task, x, y) for y in ys]
64-
new_ys = list(itertools.chain(*new_ys))
65-
ids = list(range(len(new_ys)))
66-
# evaluation
67-
if args.method_evaluate == 'vote':
68-
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
69-
elif args.method_evaluate == 'value':
70-
values = get_values(task, x, new_ys, args.n_evaluate_sample)
71-
72-
# selection
73-
if args.method_select == 'sample':
74-
ps = np.array(values) / sum(values)
75-
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
76-
elif args.method_select == 'greedy':
77-
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
78-
select_new_ys = [new_ys[select_id] for select_id in select_ids]
79-
80-
# log
81-
if to_print:
82-
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
83-
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
84-
85-
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
86-
ys = select_new_ys
87-
88-
if to_print:
89-
print(ys)
90-
return ys, {'steps': infos}
91-
92-
def naive_solve(args, task, idx, to_print=True):
93-
x = task.get_input(idx) # input
94-
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
95-
return ys, {}
5+
from tot.tasks import get_task
6+
from tot.methods.bfs import solve, naive_solve
7+
from tot.models import gpt_usage
968

979
def run(args):
98-
task = get_task(args.task, args.task_file_path)
10+
task = get_task(args.task)
9911
logs, cnt_avg, cnt_any = [], 0, 0
100-
global gpt
101-
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
10212
if args.naive_run:
103-
file = f'logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
13+
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
10414
else:
105-
file = f'logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
15+
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
10616
os.makedirs(os.path.dirname(file), exist_ok=True)
10717

10818
for i in range(args.task_start_index, args.task_end_index):
@@ -136,7 +46,6 @@ def parse_args():
13646
args.add_argument('--temperature', type=float, default=0.7)
13747

13848
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
139-
args.add_argument('--task_file_path', type=str, required=True)
14049
args.add_argument('--task_start_index', type=int, default=900)
14150
args.add_argument('--task_end_index', type=int, default=1000)
14251

@@ -145,7 +54,7 @@ def parse_args():
14554

14655
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
14756
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
148-
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'])
57+
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
14958
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
15059
args.add_argument('--n_evaluate_sample', type=int, default=1)
15160
args.add_argument('--n_select_sample', type=int, default=1)

scripts/crosswords/cot_sampling.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
python run.py \
22
--task crosswords \
3-
--task_file_path mini0505_0_100_5.json \
43
--task_start_index 0 \
54
--task_end_index 20 \
65
--naive_run \

scripts/crosswords/search_crosswords-dfs.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"metadata": {},
1515
"outputs": [],
1616
"source": [
17-
"cd ../.."
17+
"cd .."
1818
]
1919
},
2020
{
@@ -24,9 +24,9 @@
2424
"outputs": [],
2525
"source": [
2626
"import json\n",
27-
"from prompts.crosswords import propose_prompt, value_prompt\n",
28-
"from models import gpt\n",
29-
"from tasks.crosswords import MiniCrosswordsEnv\n",
27+
"from tot.prompts.crosswords import propose_prompt, value_prompt\n",
28+
"from tot.models import gpt\n",
29+
"from tot.tasks.crosswords import MiniCrosswordsEnv\n",
3030
"\n",
3131
"env = MiniCrosswordsEnv()"
3232
]
@@ -61,7 +61,7 @@
6161
"source": [
6262
"import re\n",
6363
"import copy\n",
64-
"from models import gpt\n",
64+
"from tot.models import gpt\n",
6565
"\n",
6666
"def parse_line(input_str):\n",
6767
" # regular expression pattern to match the input string format\n",

scripts/crosswords/standard_sampling.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
python run.py \
22
--task crosswords \
3-
--task_file_path mini0505_0_100_5.json \
43
--task_start_index 0 \
54
--task_end_index 20 \
65
--naive_run \

0 commit comments

Comments
 (0)