Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from mellea.backends import ModelOption
from mellea.stdlib.requirement import check, req, simple_validate
from mellea.stdlib.sampling.prefix_cached import RejectionSamplingStrategyWithPrefix

requirements = [
req("The email should have a salutation"), # == r1
req(
"Use only lower-case letters",
validation_fn=simple_validate(lambda x: x.lower() == x),
), # == r2
check("Do not mention purple elephants."), # == r3
req("The email should be funny."),
]

import mellea # noqa: E402


def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str:
email_candidate = m.instruct(
"Write an email to {{name}} using the notes following: {{notes}}.",
requirements=requirements,
strategy=RejectionSamplingStrategyWithPrefix(loop_budget=5),
user_variables={"name": name, "notes": notes},
return_sampling_results=True,
)

if email_candidate.success:
return str(email_candidate.result)
else:
return email_candidate.sample_generations[0].value
2 changes: 1 addition & 1 deletion docs/examples/mcp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ uv pip install "mcp[cli]"

and run the example in MCP debug UI:
```bash
uv run mcp dev docs/examples/tutorial/mcp_example.py
uv run mcp dev docs/examples/mcp/mcp_example.py
```


Expand Down
9 changes: 7 additions & 2 deletions mellea/stdlib/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,9 @@ def chat(

def validate(
reqs: Requirement | list[Requirement],
context: Context,
backend: Backend,
*,
context: Context | None = None,
output: CBlock | None = None,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
Expand Down Expand Up @@ -701,9 +701,9 @@ async def achat(

async def avalidate(
reqs: Requirement | list[Requirement],
context: Context,
backend: Backend,
*,
context: Context | None = None,
output: CBlock | None = None,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
Expand All @@ -715,7 +715,12 @@ async def avalidate(
reqs = [reqs] if not isinstance(reqs, list) else reqs
reqs = [Requirement(req) if type(req) is str else req for req in reqs]

assert (context is not None) != (output is not None), (
"Either context or output must be provided. Not both."
)

if output is None:
assert context is not None
validation_target_ctx = context
else:
validation_target_ctx = SimpleContext()
Expand Down
15 changes: 7 additions & 8 deletions mellea/stdlib/reqlib/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import mistletoe

from mellea.stdlib.base import Context
from mellea.stdlib.requirement import Requirement
from mellea.stdlib.requirement import Requirement, ValidationResult, simple_validate

# region lists

Expand All @@ -25,8 +25,9 @@ def as_markdown_list(ctx: Context) -> list[str] | None:
return None


def _md_list(ctx: Context):
return as_markdown_list(ctx) is not None
async def _md_list(ctx: Context) -> ValidationResult:
re = as_markdown_list(ctx) is not None
return ValidationResult(re)


is_markdown_list = Requirement(
Expand All @@ -40,11 +41,10 @@ def _md_list(ctx: Context):
# region tables


def _md_table(ctx: Context):
raw_output = ctx.last_output()
def _md_table(raw_output: str) -> bool:
assert raw_output is not None
try:
parsed = mistletoe.Document(raw_output.value)
parsed = mistletoe.Document(raw_output)
if len(parsed.children) != 1:
return False
return type(parsed.children[0]) is mistletoe.block_token.Table
Expand All @@ -54,8 +54,7 @@ def _md_table(ctx: Context):

is_markdown_table = Requirement(
description="The output should be formatted as a Markdown table.",
validation_fn=_md_table,
validation_fn=simple_validate(_md_table),
)


# endregion
18 changes: 9 additions & 9 deletions mellea/stdlib/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
import re
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from copy import copy
from typing import Any, overload

Expand Down Expand Up @@ -93,7 +93,7 @@ class Requirement(Component):
def __init__(
self,
description: str | None = None,
validation_fn: Callable[[Context], ValidationResult] | None = None,
validation_fn: Callable[[Context], Awaitable[ValidationResult]] | None = None,
*,
output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool,
check_only: bool = False,
Expand Down Expand Up @@ -127,7 +127,7 @@ async def validate(
"""Chooses the appropriate validation strategy and applies that strategy."""
if self.validation_fn is not None:
# Python validation strategy
return self.validation_fn(ctx)
return await self.validation_fn(ctx)
else:
# LLMaJ validation strategy. This includes ALora because the backend generate call will appropriately dispatch.
assert self.output_to_bool is not None
Expand Down Expand Up @@ -197,7 +197,7 @@ class ScorerRequirement(Requirement):
def __init__(
self,
description: str | None = None,
validation_fn: Callable[[Context], ValidationResult] | None = None,
validation_fn: Callable[[Context], Awaitable[ValidationResult]] | None = None,
preference_ordering: str = "max",
*,
output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool,
Expand Down Expand Up @@ -234,7 +234,7 @@ async def validate(
"""Chooses the appropriate validation strategy and applies that strategy. Asserts that the returned ValidationResult has a valid score."""
if self.validation_fn is not None:
# Python validation strategy
validation_result = self.validation_fn(ctx)
validation_result = await self.validation_fn(ctx)
assert validation_result._score is not None, (
"ScorerRequirement must have a score that is not None"
)
Expand Down Expand Up @@ -292,18 +292,18 @@ def check(*args, **kwargs) -> Requirement:
@overload
def simple_validate(
fn: Callable[[str], tuple[bool, str]],
) -> Callable[[Context], ValidationResult]: ...
) -> Callable[[Context], Awaitable[ValidationResult]]: ...


@overload
def simple_validate(
fn: Callable[[str], bool], *, reason: str | None = None
) -> Callable[[Context], ValidationResult]: ...
) -> Callable[[Context], Awaitable[ValidationResult]]: ...


def simple_validate(
fn: Callable[[str], Any], *, reason: str | None = None
) -> Callable[[Context], ValidationResult]:
) -> Callable[[Context], Awaitable[ValidationResult]]:
"""Syntactic sugar for writing validation functions that only operate over the last output from the model (interpreted as a string).

This is useful when your validation logic only depends upon the most recent model output. For example:
Expand All @@ -321,7 +321,7 @@ def simple_validate(
reason: only used if the provided function returns a bool; if the validation function fails, a static reason for that failure to give to the llm when repairing
"""

def validate(ctx: Context) -> ValidationResult:
async def validate(ctx: Context) -> ValidationResult:
o = ctx.last_output()
if o is None or o.value is None:
FancyLogger.get_logger().warn(
Expand Down
1 change: 0 additions & 1 deletion mellea/stdlib/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ async def sample(
reqs=reqs,
context=result_ctx,
backend=backend,
output=result,
format=format,
model_options=model_options,
# tool_calls=tool_calls # Don't support using tool calls in validation strategies.
Expand Down
1 change: 0 additions & 1 deletion mellea/stdlib/sampling/best_of_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ async def sample(
reqs=reqs,
context=result_ctx,
backend=backend,
output=result,
format=format,
model_options=model_options,
input=next_action._description, # type: ignore
Expand Down
85 changes: 85 additions & 0 deletions mellea/stdlib/sampling/prefix_cached.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Sampling Strategy that uses prefix caching idea based on two turn chats."""

from collections.abc import Awaitable, Callable

from mellea.backends import Backend, BaseModelSubclass, ModelOption
from mellea.stdlib.base import ChatContext, Component, Context, ContextTurn
from mellea.stdlib.chat import Message
from mellea.stdlib.requirement import Requirement, ValidationResult
from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult


class RejectionSamplingStrategyWithPrefix(RejectionSamplingStrategy):
"""Rejection Sampling class that uses the last turn as prefix cache for requirement checking."""

async def sample(
self,
action: Component,
context: Context,
backend: Backend,
requirements: list[Requirement] | None,
*,
validation_ctx: Context | None = None,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
show_progress: bool = True,
) -> SamplingResult:
"""Sample method inherited from RejectionSamplingStrategy."""
reqs: list[Requirement] = []
if self.requirements is not None:
reqs += self.requirements
elif requirements is not None:
reqs += requirements
reqs = list(set(reqs))

def make_val(
req_string: str,
) -> Callable[[Context], Awaitable[ValidationResult]]:
async def validate_agentic(ctx: Context) -> ValidationResult:
lt = ctx.last_turn()
assert isinstance(lt, ContextTurn)
assert lt.model_input is not None
assert lt.output is not None

chat_ctx = ChatContext()
chat_ctx = chat_ctx.add(lt.model_input)
chat_ctx = chat_ctx.add(lt.output)

action = Message(
role="user",
content=f"Does the output fulfill the requirement? Answer only with yes or no. Requirement: '{req_string}'",
)

llm_as_a_judge_result, _ = backend.generate_from_context(
action,
chat_ctx,
format=format,
model_options={ModelOption.MAX_NEW_TOKENS: 10},
)
value = await llm_as_a_judge_result.avalue()

return ValidationResult(
result=value.lower().startswith("yes"),
reason=value,
thunk=llm_as_a_judge_result,
)

return validate_agentic

for req in reqs:
if req.validation_fn is None:
req.validation_fn = make_val(str(req.description))

res = await super().sample(
action=action,
context=context,
backend=backend,
requirements=reqs,
validation_ctx=validation_ctx,
format=format,
model_options=model_options,
tool_calls=tool_calls,
show_progress=show_progress,
)
return res
4 changes: 2 additions & 2 deletions mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ def validate(
"""Validates a set of requirements over the output (if provided) or the current context (if the output is not provided)."""
return mfuncs.validate(
reqs=reqs,
context=self.ctx,
backend=self.backend,
context=self.ctx if output is None else None,
output=output,
format=format,
model_options=model_options,
Expand Down Expand Up @@ -730,7 +730,7 @@ async def avalidate(
"""Validates a set of requirements over the output (if provided) or the current context (if the output is not provided)."""
return await mfuncs.avalidate(
reqs=reqs,
context=self.ctx,
context=self.ctx if output is None else None,
backend=self.backend,
output=output,
format=format,
Expand Down
4 changes: 1 addition & 3 deletions test/stdlib_basics/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ async def test_ainstruct(m_session):
assert ctx._data is out

async def test_avalidate(m_session):
initial_ctx = m_session.ctx
backend = m_session.backend

val_result = await avalidate(
reqs=[req("Be formal."), req("Avoid telling jokes.")],
context=initial_ctx,
backend=backend,
output=ModelOutputThunk("Here is an output.")
)
Expand All @@ -77,4 +75,4 @@ async def test_avalidate(m_session):


if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__])
Loading