Skip to content

Commit 85d85d8

Browse files
authored
feat(budget_checker):refactor budget checker (#276)
* feat(budget_checker):refactor budget checker * feat(budget_checker): fix pre commit * feat(budget_checker): refactor budget checker as main class
1 parent 62c765a commit 85d85d8

File tree

10 files changed

+45
-39
lines changed

10 files changed

+45
-39
lines changed

tiny_scientist/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
from .budget_checker import BudgetChecker
12
from .coder import Coder
23
from .reviewer import Reviewer
34
from .safety_checker import SafetyChecker
45
from .scientist import TinyScientist
56
from .thinker import Thinker
6-
from .utils.checker import Checker
77
from .writer import Writer
88

99
__all__ = [
@@ -13,5 +13,5 @@
1313
"Thinker",
1414
"Writer",
1515
"TinyScientist",
16-
"Checker",
16+
"BudgetChecker",
1717
]

tiny_scientist/utils/checker.py renamed to tiny_scientist/budget_checker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class BudgetExceededError(Exception):
99
"""Raised when a call would exceed the configured budget."""
1010

1111

12-
class Checker:
12+
class BudgetChecker:
1313
"""Track API usage cost and enforce a spending budget."""
1414

1515
def __init__(self, budget: Optional[float] = None) -> None:
@@ -48,7 +48,7 @@ def add_cost(
4848
if task_name not in self.per_task_cost:
4949
self.per_task_cost[task_name] = 0.0
5050
self.per_task_cost[task_name] += cost
51-
return cost
51+
return float(cost)
5252

5353
def report(self) -> None:
5454
print(f"Total cost: ${self.total_cost:.4f}")

tiny_scientist/coder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from aider.models import Model
1313
from rich import print
1414

15+
from .budget_checker import BudgetChecker
1516
from .configs import Config
16-
from .utils.checker import Checker
1717
from .utils.llm import create_client, get_response_from_llm
1818

1919

@@ -28,7 +28,7 @@ def __init__(
2828
prompt_template_dir: Optional[str] = None,
2929
chat_history: Optional[str] = None,
3030
auto_install: bool = True,
31-
cost_tracker: Optional[Checker] = None,
31+
cost_tracker: Optional[BudgetChecker] = None,
3232
):
3333
"""Initialize the ExperimentCoder with configuration and Aider setup."""
3434
self.client, self.model = create_client(model)
@@ -38,7 +38,7 @@ def __init__(
3838
self.max_stderr_output = max_stderr_output
3939
self.auto_install = auto_install
4040
self.config = Config()
41-
self.cost_tracker = cost_tracker or Checker()
41+
self.cost_tracker = cost_tracker or BudgetChecker()
4242

4343
# Load prompts
4444
self.prompts = self.config.prompt_template.coder_prompt

tiny_scientist/reviewer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from rich import print
55

6+
from .budget_checker import BudgetChecker
67
from .configs import Config
78
from .tool import BaseTool, PaperSearchTool
8-
from .utils.checker import Checker
99
from .utils.error_handler import api_calling_error_exponential_backoff
1010
from .utils.input_formatter import InputFormatter
1111
from .utils.llm import (
@@ -24,7 +24,7 @@ def __init__(
2424
num_reflections: int = 2,
2525
temperature: float = 0.75,
2626
prompt_template_dir: Optional[str] = None,
27-
cost_tracker: Optional[Checker] = None,
27+
cost_tracker: Optional[BudgetChecker] = None,
2828
pre_reflection_threshold: float = 0.5,
2929
post_reflection_threshold: float = 0.8,
3030
s2_api_key: Optional[str] = None,
@@ -38,7 +38,7 @@ def __init__(
3838
self.searcher: BaseTool = PaperSearchTool(s2_api_key=s2_api_key)
3939
self._query_cache: Dict[str, List[Dict[str, Any]]] = {}
4040
self.last_related_works_string = ""
41-
self.cost_tracker = cost_tracker or Checker()
41+
self.cost_tracker = cost_tracker or BudgetChecker()
4242
self.pre_reflection_threshold = pre_reflection_threshold
4343
self.post_reflection_threshold = post_reflection_threshold
4444

tiny_scientist/safety_checker.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import yaml
77

8+
from .budget_checker import BudgetChecker
89
from .configs import Config
9-
from .utils.cost_tracker import CostTracker
1010
from .utils.error_handler import api_calling_error_exponential_backoff
1111
from .utils.llm import create_client, get_response_from_llm
1212

@@ -44,10 +44,10 @@ def is_rejection_response(response: str) -> bool:
4444

4545
class PromptAttackDetector:
4646
def __init__(
47-
self, model: str = "gpt-4o", cost_tracker: Optional[CostTracker] = None
47+
self, model: str = "gpt-4o", cost_tracker: Optional[BudgetChecker] = None
4848
) -> None:
4949
self.client, self.model = create_client(model)
50-
self.cost_tracker = cost_tracker or CostTracker()
50+
self.cost_tracker = cost_tracker or BudgetChecker()
5151
self.config = Config()
5252
self.prompts = self.config.prompt_template.safety_prompt
5353

@@ -148,10 +148,10 @@ class SafetyChecker:
148148
"""
149149

150150
def __init__(
151-
self, model: str = "gpt-4o", cost_tracker: Optional[CostTracker] = None
151+
self, model: str = "gpt-4o", cost_tracker: Optional[BudgetChecker] = None
152152
) -> None:
153153
self.model = model
154-
self.cost_tracker = cost_tracker or CostTracker()
154+
self.cost_tracker = cost_tracker or BudgetChecker()
155155
self.detector = PromptAttackDetector(
156156
model=model, cost_tracker=self.cost_tracker
157157
)
@@ -199,14 +199,16 @@ def check_safety(self, intent: str) -> Tuple[bool, Dict[str, Any]]:
199199
self.cost_tracker.report()
200200
return is_safe, safety_report
201201

202-
def _load_ethics_prompts(self) -> Dict[str, str]:
202+
def _load_ethics_prompts(self) -> Dict[str, Any]:
203203
"""Load ethics prompts from the YAML file."""
204204
prompt_path = os.path.join(
205205
os.path.dirname(__file__), "prompts", "safetychecker_prompt.yaml"
206206
)
207207
try:
208208
with open(prompt_path, "r", encoding="utf-8") as file:
209-
return yaml.safe_load(file)
209+
import typing
210+
211+
return typing.cast(Dict[str, Any], yaml.safe_load(file))
210212
except FileNotFoundError:
211213
print(f"Warning: Ethics prompts file not found at {prompt_path}")
212214
return {}
@@ -399,7 +401,9 @@ def comprehensive_safety_check(
399401
# Overall safety determination
400402
overall_safe = is_intent_safe
401403
if idea is not None:
402-
overall_safe = overall_safe and result["idea_ethics"]["is_ethically_sound"]
404+
overall_safe = overall_safe and bool(
405+
result["idea_ethics"]["is_ethically_sound"]
406+
)
403407

404408
result["overall_safety"] = {
405409
"is_safe": overall_safe,

tiny_scientist/scientist.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import toml
55
from rich import print
66

7+
from .budget_checker import BudgetChecker
78
from .coder import Coder
89
from .reviewer import Reviewer
910
from .safety_checker import SafetyChecker
1011
from .thinker import Thinker
11-
from .utils.checker import Checker
12-
from .utils.cost_tracker import CostTracker
1312
from .utils.input_formatter import InputFormatter
1413
from .writer import Writer
1514

@@ -80,7 +79,7 @@ def __init__(
8079

8180
self.safety_checker = (
8281
SafetyChecker(
83-
model=model, cost_tracker=CostTracker(budget=per_module_budget)
82+
model=model, cost_tracker=BudgetChecker(budget=per_module_budget)
8483
)
8584
if enable_safety_check
8685
else None
@@ -94,8 +93,9 @@ def __init__(
9493
iter_num=3,
9594
search_papers=True,
9695
generate_exp_plan=True,
96+
enable_ethical_defense=False,
9797
enable_safety_check=enable_safety_check,
98-
cost_tracker=Checker(budget=allocation.get("thinker")),
98+
cost_tracker=BudgetChecker(budget=allocation.get("thinker")),
9999
)
100100

101101
self.coder = Coder(
@@ -104,22 +104,22 @@ def __init__(
104104
prompt_template_dir=prompt_template_dir,
105105
max_iters=4,
106106
max_runs=3,
107-
cost_tracker=Checker(budget=allocation.get("coder")),
107+
cost_tracker=BudgetChecker(budget=allocation.get("coder")),
108108
)
109109

110110
self.writer = Writer(
111111
model=model,
112112
output_dir=output_dir,
113113
prompt_template_dir=prompt_template_dir,
114114
template=template,
115-
cost_tracker=Checker(budget=allocation.get("writer")),
115+
cost_tracker=BudgetChecker(budget=allocation.get("writer")),
116116
)
117117

118118
self.reviewer = Reviewer(
119119
model=model,
120120
prompt_template_dir=prompt_template_dir,
121121
tools=[],
122-
cost_tracker=Checker(budget=allocation.get("reviewer")),
122+
cost_tracker=BudgetChecker(budget=allocation.get("reviewer")),
123123
)
124124

125125
def think(

tiny_scientist/thinker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from rich import print
66

7+
from .budget_checker import BudgetChecker
78
from .configs import Config
89
from .safety_checker import SafetyChecker
910
from .tool import PaperSearchTool
10-
from .utils.checker import Checker
1111
from .utils.error_handler import api_calling_error_exponential_backoff
1212
from .utils.llm import (
1313
create_client,
@@ -27,10 +27,11 @@ def __init__(
2727
output_dir: str = "",
2828
temperature: float = 0.75,
2929
prompt_template_dir: Optional[str] = None,
30-
cost_tracker: Optional[Checker] = None,
30+
cost_tracker: Optional[BudgetChecker] = None,
3131
enable_safety_check: bool = False,
3232
pre_reflection_threshold: float = 0.5,
3333
post_reflection_threshold: float = 0.8,
34+
enable_ethical_defense: bool = False,
3435
):
3536
self.tools = tools
3637
self.iter_num = iter_num
@@ -70,12 +71,13 @@ def __init__(
7071
3. Novelty: How original is the idea compared to existing work?
7172
4. Feasibility: How practical is implementation within reasonable resource constraints?
7273
5. Impact: What is the potential impact of this research on the field and broader applications?"""
73-
self.cost_tracker = cost_tracker or Checker()
74+
self.cost_tracker = cost_tracker or BudgetChecker()
7475
self.pre_reflection_threshold = pre_reflection_threshold
7576
self.post_reflection_threshold = post_reflection_threshold
7677

7778
self.enable_safety_check = enable_safety_check
78-
# Initialize SafetyChecker for comprehensive safety checks
79+
self.enable_ethical_defense = enable_ethical_defense
80+
self.safety_checker: Optional[SafetyChecker]
7981
if self.enable_safety_check:
8082
self.safety_checker = SafetyChecker(
8183
model=self.model, cost_tracker=self.cost_tracker

tiny_scientist/tool.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import toml
1212
from rich import print
1313

14+
from .budget_checker import BudgetChecker
1415
from .configs import Config
15-
from .utils.checker import Checker
1616
from .utils.error_handler import api_calling_error_exponential_backoff
1717
from .utils.llm import create_client, get_response_from_llm
1818

@@ -22,8 +22,8 @@
2222

2323

2424
class BaseTool(abc.ABC):
25-
def __init__(self, cost_tracker: Optional[Checker] = None) -> None:
26-
self.cost_tracker = cost_tracker or Checker()
25+
def __init__(self, cost_tracker: Optional[BudgetChecker] = None) -> None:
26+
self.cost_tracker = cost_tracker or BudgetChecker()
2727
self.github_token = config["core"].get("github_token", None)
2828

2929
@abc.abstractmethod

tiny_scientist/utils/llm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import toml
1010
from google.generativeai.types import GenerationConfig
1111

12-
from tiny_scientist.utils.checker import Checker
12+
from tiny_scientist.budget_checker import BudgetChecker
1313

1414
# Load config
1515
config_path = os.path.join(
@@ -84,7 +84,7 @@ def get_batch_responses_from_llm(
8484
msg_history: Any = None,
8585
temperature: float = 0.75,
8686
n_responses: int = 1,
87-
cost_tracker: Optional[Checker] = None,
87+
cost_tracker: Optional[BudgetChecker] = None,
8888
task_name: Optional[str] = None,
8989
) -> Tuple[List[str], List[List[Dict[str, str]]]]:
9090
if msg_history is None:
@@ -214,7 +214,7 @@ def get_response_from_llm(
214214
print_debug: bool = False,
215215
msg_history: Any = None,
216216
temperature: float = 0.75,
217-
cost_tracker: Optional[Checker] = None,
217+
cost_tracker: Optional[BudgetChecker] = None,
218218
task_name: Optional[str] = None,
219219
) -> Tuple[str, List[Dict[str, Any]]]:
220220
if msg_history is None:
@@ -442,7 +442,7 @@ def get_batch_responses_from_llm_with_tools(
442442
msg_history: Optional[List[Dict[str, str]]] = None,
443443
temperature: float = 0.75,
444444
n_responses: int = 1,
445-
cost_tracker: Optional[Checker] = None,
445+
cost_tracker: Optional[BudgetChecker] = None,
446446
task_name: Optional[str] = None,
447447
) -> Tuple[List[Union[str, Dict[str, Any]]], List[List[Dict[str, str]]]]:
448448
"""

tiny_scientist/writer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import cairosvg
1111
from rich import print
1212

13+
from .budget_checker import BudgetChecker
1314
from .configs import Config
1415
from .tool import BaseTool, DrawerTool, PaperSearchTool
15-
from .utils.checker import Checker
1616
from .utils.llm import (
1717
create_client,
1818
extract_json_between_markers,
@@ -33,7 +33,7 @@ def __init__(
3333
template: str,
3434
temperature: float = 0.75,
3535
prompt_template_dir: Optional[str] = None,
36-
cost_tracker: Optional[Checker] = None,
36+
cost_tracker: Optional[BudgetChecker] = None,
3737
s2_api_key: Optional[str] = None,
3838
) -> None:
3939
self.client, self.model = create_client(model)
@@ -50,7 +50,7 @@ def __init__(
5050
self.formatter = ICLROutputFormatter(model=self.model, client=self.client)
5151

5252
self.prompts = self.config.prompt_template.writer_prompt
53-
self.cost_tracker = cost_tracker or Checker()
53+
self.cost_tracker = cost_tracker or BudgetChecker()
5454

5555
with resources.files("tiny_scientist.fewshot_sample").joinpath(
5656
"automated_relational.txt"

0 commit comments

Comments
 (0)