Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions program_searcher/evolution_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@

from typing_extensions import Dict, override

from program_searcher.exceptions import InvalidMutationStrategiesError
from program_searcher.mutation_strategy import MutationStrategy
from program_searcher.program_model import Program


class EvolutionOperator(ABC):
@abstractmethod
def apply(
self,
population: deque[Program],
fitnesses: Dict[Program, float],
mutation_strategies: Dict[MutationStrategy, float],
):
def apply(self, population: deque[Program], fitnesses: Dict[Program, float]):
"""
Applies an evolution step to the population.

Expand All @@ -25,9 +21,6 @@ def apply(
Current population of programs.
fitnesses : Dict[Program, float]
Mapping from program to its fitness.
mutation_strategies : Dict[MutationStrategy, float]
Per-program mutation strategies with associated probabilities.

"""
raise NotImplementedError

Expand All @@ -47,30 +40,45 @@ class TournamentSelectionOperator(EvolutionOperator, ABC):
The number of programs to include in each tournament.
"""

def __init__(self, tournament_size: int):
def __init__(
self, tournament_size: int, mutation_strategies: Dict[MutationStrategy, float]
):
self.tournament_size = tournament_size
self.mutation_strategies = mutation_strategies

self._validate_mutation_strategies()

@override
def apply(
self,
population: deque[Program],
fitnesses: Dict[Program, float],
mutation_strategies: Dict[MutationStrategy, float],
):
def apply(self, population: deque[Program], fitnesses: Dict[Program, float]):
tournament_programs = random.choices(population, k=self.tournament_size)
best_program = max(tournament_programs, key=lambda prog: fitnesses[prog])
tournament_winner = best_program.copy()
tournament_winner = max(tournament_programs, key=lambda prog: fitnesses[prog])

program = population.popleft()
fitnesses.pop(program)

strategies = list(mutation_strategies.keys())
weights = list(mutation_strategies.values())
strategies = list(self.mutation_strategies.keys())
weights = list(self.mutation_strategies.values())

chosen_strategy = random.choices(strategies, weights=weights, k=1)[0]
chosen_strategy.mutate(tournament_winner)
mutated_program = chosen_strategy.mutate(tournament_winner)

population.append(mutated_program)

population.append(tournament_winner)
def _validate_mutation_strategies(self):
if abs(sum(self.mutation_strategies.values()) - 1.0) > 1e-6:
raise InvalidMutationStrategiesError(
f"sum of mutation_strategies values must be 1.0, but is {sum(self.mutation_strategies.values())}."
)

if any(value < 0 for value in self.mutation_strategies.values()):
raise InvalidMutationStrategiesError(
f"all mutation_strategies values must be >= 0. current values: {self.mutation_strategies}."
)

if any(value > 1 for value in self.mutation_strategies.values()):
raise InvalidMutationStrategiesError(
f"all mutation_strategies values must be <= 1. current values: {self.mutation_strategies}."
)


class FullPopulationMutationOperator(EvolutionOperator):
Expand All @@ -94,17 +102,17 @@ class FullPopulationMutationOperator(EvolutionOperator):
within this method.
"""

def __init__(self, mutation_strategies: Dict[MutationStrategy, float]):
self.mutation_strategies = mutation_strategies

@override
def apply(
self,
population: deque[Program],
fitnesses: Dict[Program, float],
mutation_strategies: Dict[MutationStrategy, float],
):
for program in population:
def apply(self, population: deque[Program], fitnesses: Dict[Program, float]):
for index, program in enumerate(population):
chosen_strategy = random.choices(
list(mutation_strategies.keys()),
weights=list(mutation_strategies.values()),
list(self.mutation_strategies.keys()),
weights=list(self.mutation_strategies.values()),
k=1,
)[0]
chosen_strategy.mutate(program)

mutated = chosen_strategy.mutate(program)
population[index] = mutated
6 changes: 5 additions & 1 deletion program_searcher/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ class ExecuteProgramError(Exception):
pass


class InvalidProgramSearchArgumentValue(Exception):
class InvalidProgramSearchArgumentValueError(Exception):
pass


class InvalidMutationStrategiesError(Exception):
pass
86 changes: 50 additions & 36 deletions program_searcher/mutation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,27 @@ class MutationStrategy(ABC):
"""
Abstract base class for mutation strategies.

Each implementation should modify the given `Program` object **in-place**
rather than returning a new object. This ensures that mutation is applied
directly without requiring extra copying or reassignment logic.
Each implementation should create a **new Program instance** that represents
the mutated version of the original program, rather than modifying the
input program in-place. This ensures that the original program remains
unchanged and allows safe caching of fitness values.
"""

@abstractmethod
def mutate(self, program: Program):
def mutate(self, program: Program) -> Program:
"""
Mutates the given program in-place.
Produces a mutated copy of the given program.

Parameters
----------
program : Program
The program object to be mutated.
The original program to base the mutation on. This program
**must not** be modified.

Returns
-------
None
The mutation is applied directly to `program`; no new object
should be returned.
Program
A new Program instance representing the mutated version of the input.
"""
raise NotImplementedError

Expand All @@ -50,7 +51,7 @@ class RemoveStatementMutationStrategy(MutationStrategy, ABC):
Defaults to 3.

Methods:
mutate(program: Program) -> None:
mutate(program: Program) -> Program:
Tries to remove one statement from the given program.
If no removable statements are available, or all retries fail,
the program remains unchanged.
Expand All @@ -60,19 +61,20 @@ def __init__(self, remove_retries: int = 3):
self.remove_retries = remove_retries

@override
def mutate(self, program: Program):
def mutate(self, program: Program) -> Program:
max_index = len(program) - (1 if program.has_return_statement() else 0)

if max_index <= 0:
return
return program

for _ in range(self.remove_retries):
try:
program_cp = program.copy()
statement_to_remove_idx = random.randrange(max_index)
program.remove_statement(statement_to_remove_idx)
return
program_cp.remove_statement(statement_to_remove_idx)
return program_cp
except RemoveStatementError:
pass
return program


class ReplaceStatementMutationStrategy(MutationStrategy, ABC):
Expand All @@ -92,7 +94,7 @@ class ReplaceStatementMutationStrategy(MutationStrategy, ABC):
that function expects.

Methods:
mutate(program: Program) -> None:
mutate(program: Program) -> Program:
Replaces one statement's function and arguments. If the program
has no eligible statements (only a `return` or is empty),
the program remains unchanged.
Expand All @@ -102,24 +104,28 @@ def __init__(self, available_functions: Dict[str, int]):
self.available_functions = available_functions

@override
def mutate(self, program: Program):
def mutate(self, program: Program) -> Program:
program_cp = program.copy()

func_name = random.choice(list(self.available_functions.keys()))
args_size = self.available_functions[func_name]

if not program.variables and args_size > 0:
return
if not program_cp.variables and args_size > 0:
return program

args = random.choices(program.variables, k=args_size)
args = random.choices(program_cp.variables, k=args_size)

max_index = len(program) - (1 if program.has_return_statement() else 0)
max_index = len(program_cp) - (1 if program_cp.has_return_statement() else 0)
if max_index <= 0:
return
return program

replace_index = random.randrange(max_index)
replace_statement = program.get_statement(replace_index)
replace_statement = program_cp.get_statement(replace_index)
replace_statement.func = func_name
replace_statement.args = args

return program_cp


class UpdateStatementArgsMutationStrategy(MutationStrategy, ABC):
"""
Expand All @@ -131,22 +137,29 @@ class UpdateStatementArgsMutationStrategy(MutationStrategy, ABC):
statement. The statement's function and target variable are not modified.

Methods:
mutate(program: Program) -> None:
mutate(program: Program) -> Program:
Updates the arguments of one statement in the program. If the program
is empty, no mutation is performed.
"""

@override
def mutate(self, program: Program):
def mutate(self, program: Program) -> Program:
if len(program) == 0:
return
return program

statement_idx = random.randrange(len(program))
statement = program.get_statement(statement_idx)
program_cp = program.copy()
statement_idx = random.randrange(len(program_cp))
statement = program_cp.get_statement(statement_idx)
statement_args_count = len(statement.args)

new_args = random.choices(program.variables, k=statement_args_count)
pr_vars = set(program_cp.variables)

if statement.func != Statement.RETURN_KEYWORD:
pr_vars.remove(statement.result_var_name)

new_args = random.choices(list(pr_vars), k=statement_args_count)
statement.args = new_args
return program_cp


class InsertStatementMutationStrategy(MutationStrategy, ABC):
Expand All @@ -167,18 +180,19 @@ def __init__(self, available_functions: Dict[str, int]):
self.available_functions = available_functions

@override
def mutate(self, program: Program):
if len(program) == 0:
return
def mutate(self, program: Program) -> Program:
program_cp = program.copy()

max_index = len(program) - (1 if program.has_return_statement() else 0)
max_index = len(program_cp) - (1 if program_cp.has_return_statement() else 0)
if max_index <= 0:
return
return program

insert_index = random.randrange(max_index)

statement = self._generate_random_statement(program.variables)
program.insert_statement(statement, insert_index)
statement = self._generate_random_statement(program_cp.variables)
program_cp.insert_statement(statement, insert_index)

return program_cp

def _generate_random_statement(self, program_vars: List[str]) -> Statement:
func_name = random.choice(list(self.available_functions.keys()))
Expand Down
Loading
Loading