Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
728456c
Fix: Auto-increment seed across batch_run iterations
EwoutH Oct 7, 2025
c03c6fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2025
37a0839
Merge remote-tracking branch 'upstream/main' into batch_seed
quaquel Nov 12, 2025
7f456af
add rng as kwarg and deprecate iterations
quaquel Nov 12, 2025
10136f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
904e796
fix for typo in value error message
quaquel Nov 12, 2025
7972638
Update batchrunner.py
quaquel Nov 12, 2025
529f3ac
Update migration_guide.md
quaquel Nov 16, 2025
7b6eaef
Update migration_guide.md
quaquel Nov 16, 2025
e53e16b
Update migration_guide.md
quaquel Nov 16, 2025
4113a11
Update 9_batch_run.ipynb
quaquel Nov 16, 2025
94c44e3
add support for both seed and rng
quaquel Nov 16, 2025
0778ff4
Update migration_guide.md
quaquel Nov 16, 2025
5db0058
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 16, 2025
27a777d
Update test_batch_run.py
quaquel Nov 17, 2025
0a4ea99
Update test_batch_run.py
quaquel Nov 17, 2025
4726920
Update docs/migration_guide.md
quaquel Nov 17, 2025
7ba67df
Update mesa/batchrunner.py
quaquel Nov 17, 2025
088425a
Update mesa/batchrunner.py
quaquel Nov 17, 2025
435399c
Update batchrunner.py
quaquel Nov 17, 2025
f5b017a
Merge branch 'main' into batch_seed
quaquel Nov 25, 2025
f0167e6
remove forced newlines from migration guide
EwoutH Nov 26, 2025
f2007e3
Merge remote-tracking branch 'upstream/main' into batch_seed
quaquel Dec 5, 2025
9341cce
Update batchrunner.py
quaquel Dec 5, 2025
4ebe7d8
Update test_batch_run.py
quaquel Dec 5, 2025
8b338a9
Update test_batch_run.py
quaquel Dec 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions mesa/batchrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,32 @@

import itertools
import multiprocessing
from collections.abc import Iterable, Mapping
import warnings
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from multiprocessing import Pool
from typing import Any

import numpy as np
from tqdm.auto import tqdm

from mesa.model import Model

multiprocessing.set_start_method("spawn", force=True)

SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence


def batch_run(
model_cls: type[Model],
parameters: Mapping[str, Any | Iterable[Any]],
# We still retain the Optional[int] because users may set it to None (i.e. use all CPUs)
number_processes: int | None = 1,
iterations: int = 1,
iterations: int | None = None,
data_collection_period: int = -1,
max_steps: int = 1000,
display_progress: bool = True,
rng: SeedLike | Iterable[SeedLike] | None = None,
) -> list[dict[str, Any]]:
"""Batch run a mesa model with a set of parameter values.

Expand All @@ -62,6 +67,7 @@ def batch_run(
data_collection_period (int, optional): Number of steps after which data gets collected, by default -1 (end of episode)
max_steps (int, optional): Maximum number of model steps after which the model halts, by default 1000
display_progress (bool, optional): Display batch run process, by default True
rng : a valid value or iterable of values for seeding the random number generator in the model

Returns:
List[Dict[str, Any]]
Expand All @@ -70,11 +76,28 @@ def batch_run(
batch_run assumes the model has a `datacollector` attribute that has a DataCollector object initialized.

"""
if iterations is not None and rng is not None:
raise ValueError(
"you cannot use both iterations and rng at the same time. Please only use rng."
)
if iterations is not None:
warnings.warn(
"iterations is deprecated, please use rng instead",
DeprecationWarning,
stacklevel=2,
)
rng = [
None,
] * iterations
if not isinstance(rng, Iterable):
rng = [rng]

runs_list = []
run_id = 0
for iteration in range(iterations):
for i, rng_i in enumerate(rng):
for kwargs in _make_model_kwargs(parameters):
runs_list.append((run_id, iteration, kwargs))
kwargs["rng"] = rng_i
runs_list.append((run_id, i, kwargs))
run_id += 1

process_func = partial(
Expand Down Expand Up @@ -170,6 +193,7 @@ def _model_run_func(
Return model_data, agent_data from the reporters
"""
run_id, iteration, kwargs = run

model = model_cls(**kwargs)
while model.running and model.steps <= max_steps:
model.step()
Expand Down
110 changes: 109 additions & 1 deletion tests/test_batch_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test Batchrunner."""

import pytest

import mesa
from mesa.agent import Agent
from mesa.batchrunner import _make_model_kwargs
Expand Down Expand Up @@ -130,7 +132,7 @@ def step(self): # noqa: D102


def test_batch_run(): # noqa: D103
result = mesa.batch_run(MockModel, {}, number_processes=2)
result = mesa.batch_run(MockModel, {}, number_processes=2, rng=42)
assert result == [
{
"RunId": 0,
Expand All @@ -140,6 +142,7 @@ def test_batch_run(): # noqa: D103
"AgentID": 1,
"agent_id": 1,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 0,
Expand All @@ -149,6 +152,7 @@ def test_batch_run(): # noqa: D103
"AgentID": 2,
"agent_id": 2,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 0,
Expand All @@ -158,9 +162,111 @@ def test_batch_run(): # noqa: D103
"AgentID": 3,
"agent_id": 3,
"agent_local": 250.0,
"rng": 42,
},
]

result = mesa.batch_run(MockModel, {}, number_processes=2, iterations=1)
assert result == [
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 1,
"agent_id": 1,
"agent_local": 250.0,
"rng": None,
},
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 2,
"agent_id": 2,
"agent_local": 250.0,
"rng": None,
},
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 3,
"agent_id": 3,
"agent_local": 250.0,
"rng": None,
},
]

result = mesa.batch_run(MockModel, {}, number_processes=2, rng=[42, 31415])
assert result == [
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 1,
"agent_id": 1,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 2,
"agent_id": 2,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 3,
"agent_id": 3,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 1,
"iteration": 1,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 1,
"agent_id": 1,
"agent_local": 250.0,
"rng": 31415,
},
{
"RunId": 1,
"iteration": 1,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 2,
"agent_id": 2,
"agent_local": 250.0,
"rng": 31415,
},
{
"RunId": 1,
"iteration": 1,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 3,
"agent_id": 3,
"agent_local": 250.0,
"rng": 31415,
},
]

with pytest.raises(ValueError):
mesa.batch_run(MockModel, {}, number_processes=2, rng=42, iterations=1)


def test_batch_run_with_params(): # noqa: D103
mesa.batch_run(
Expand All @@ -185,6 +291,7 @@ def test_batch_run_no_agent_reporters(): # noqa: D103
"Step": 1000,
"enable_agent_reporters": False,
"reported_model_param": 42,
"rng": None,
}
]

Expand All @@ -208,6 +315,7 @@ def test_batch_run_unhashable_param(): # noqa: D103
"agent_local": 250.0,
"n_agents": 2,
"variable_model_params": {"key": "value"},
"rng": None,
}

assert result == [
Expand Down
Loading