diff --git a/program_searcher/evolution_operator.py b/program_searcher/evolution_operator.py index 9784a51..9e046a7 100644 --- a/program_searcher/evolution_operator.py +++ b/program_searcher/evolution_operator.py @@ -4,18 +4,14 @@ from typing_extensions import Dict, override +from program_searcher.exceptions import InvalidMutationStrategiesError from program_searcher.mutation_strategy import MutationStrategy from program_searcher.program_model import Program class EvolutionOperator(ABC): @abstractmethod - def apply( - self, - population: deque[Program], - fitnesses: Dict[Program, float], - mutation_strategies: Dict[MutationStrategy, float], - ): + def apply(self, population: deque[Program], fitnesses: Dict[Program, float]): """ Applies an evolution step to the population. @@ -25,9 +21,6 @@ def apply( Current population of programs. fitnesses : Dict[Program, float] Mapping from program to its fitness. - mutation_strategies : Dict[MutationStrategy, float] - Per-program mutation strategies with associated probabilities. - """ raise NotImplementedError @@ -47,30 +40,45 @@ class TournamentSelectionOperator(EvolutionOperator, ABC): The number of programs to include in each tournament. """ - def __init__(self, tournament_size: int): + def __init__( + self, tournament_size: int, mutation_strategies: Dict[MutationStrategy, float] + ): self.tournament_size = tournament_size + self.mutation_strategies = mutation_strategies + + self._validate_mutation_strategies() @override - def apply( - self, - population: deque[Program], - fitnesses: Dict[Program, float], - mutation_strategies: Dict[MutationStrategy, float], - ): + def apply(self, population: deque[Program], fitnesses: Dict[Program, float]): tournament_programs = random.choices(population, k=self.tournament_size) - best_program = max(tournament_programs, key=lambda prog: fitnesses[prog]) - tournament_winner = best_program.copy() + tournament_winner = max(tournament_programs, key=lambda prog: fitnesses[prog]) program = population.popleft() fitnesses.pop(program) - strategies = list(mutation_strategies.keys()) - weights = list(mutation_strategies.values()) + strategies = list(self.mutation_strategies.keys()) + weights = list(self.mutation_strategies.values()) chosen_strategy = random.choices(strategies, weights=weights, k=1)[0] - chosen_strategy.mutate(tournament_winner) + mutated_program = chosen_strategy.mutate(tournament_winner) + + population.append(mutated_program) - population.append(tournament_winner) + def _validate_mutation_strategies(self): + if abs(sum(self.mutation_strategies.values()) - 1.0) > 1e-6: + raise InvalidMutationStrategiesError( + f"sum of mutation_strategies values must be 1.0, but is {sum(self.mutation_strategies.values())}." + ) + + if any(value < 0 for value in self.mutation_strategies.values()): + raise InvalidMutationStrategiesError( + f"all mutation_strategies values must be >= 0. current values: {self.mutation_strategies}." + ) + + if any(value > 1 for value in self.mutation_strategies.values()): + raise InvalidMutationStrategiesError( + f"all mutation_strategies values must be <= 1. current values: {self.mutation_strategies}." + ) class FullPopulationMutationOperator(EvolutionOperator): @@ -94,17 +102,17 @@ class FullPopulationMutationOperator(EvolutionOperator): within this method. """ + def __init__(self, mutation_strategies: Dict[MutationStrategy, float]): + self.mutation_strategies = mutation_strategies + @override - def apply( - self, - population: deque[Program], - fitnesses: Dict[Program, float], - mutation_strategies: Dict[MutationStrategy, float], - ): - for program in population: + def apply(self, population: deque[Program], fitnesses: Dict[Program, float]): + for index, program in enumerate(population): chosen_strategy = random.choices( - list(mutation_strategies.keys()), - weights=list(mutation_strategies.values()), + list(self.mutation_strategies.keys()), + weights=list(self.mutation_strategies.values()), k=1, )[0] - chosen_strategy.mutate(program) + + mutated = chosen_strategy.mutate(program) + population[index] = mutated diff --git a/program_searcher/exceptions.py b/program_searcher/exceptions.py index ba189b9..f75ce57 100644 --- a/program_searcher/exceptions.py +++ b/program_searcher/exceptions.py @@ -14,5 +14,9 @@ class ExecuteProgramError(Exception): pass -class InvalidProgramSearchArgumentValue(Exception): +class InvalidProgramSearchArgumentValueError(Exception): + pass + + +class InvalidMutationStrategiesError(Exception): pass diff --git a/program_searcher/mutation_strategy.py b/program_searcher/mutation_strategy.py index 0ec5837..a542c2c 100644 --- a/program_searcher/mutation_strategy.py +++ b/program_searcher/mutation_strategy.py @@ -11,26 +11,27 @@ class MutationStrategy(ABC): """ Abstract base class for mutation strategies. - Each implementation should modify the given `Program` object **in-place** - rather than returning a new object. This ensures that mutation is applied - directly without requiring extra copying or reassignment logic. + Each implementation should create a **new Program instance** that represents + the mutated version of the original program, rather than modifying the + input program in-place. This ensures that the original program remains + unchanged and allows safe caching of fitness values. """ @abstractmethod - def mutate(self, program: Program): + def mutate(self, program: Program) -> Program: """ - Mutates the given program in-place. + Produces a mutated copy of the given program. Parameters ---------- program : Program - The program object to be mutated. + The original program to base the mutation on. This program + **must not** be modified. Returns ------- - None - The mutation is applied directly to `program`; no new object - should be returned. + Program + A new Program instance representing the mutated version of the input. """ raise NotImplementedError @@ -50,7 +51,7 @@ class RemoveStatementMutationStrategy(MutationStrategy, ABC): Defaults to 3. Methods: - mutate(program: Program) -> None: + mutate(program: Program) -> Program: Tries to remove one statement from the given program. If no removable statements are available, or all retries fail, the program remains unchanged. @@ -60,19 +61,20 @@ def __init__(self, remove_retries: int = 3): self.remove_retries = remove_retries @override - def mutate(self, program: Program): + def mutate(self, program: Program) -> Program: max_index = len(program) - (1 if program.has_return_statement() else 0) if max_index <= 0: - return + return program for _ in range(self.remove_retries): try: + program_cp = program.copy() statement_to_remove_idx = random.randrange(max_index) - program.remove_statement(statement_to_remove_idx) - return + program_cp.remove_statement(statement_to_remove_idx) + return program_cp except RemoveStatementError: - pass + return program class ReplaceStatementMutationStrategy(MutationStrategy, ABC): @@ -92,7 +94,7 @@ class ReplaceStatementMutationStrategy(MutationStrategy, ABC): that function expects. Methods: - mutate(program: Program) -> None: + mutate(program: Program) -> Program: Replaces one statement's function and arguments. If the program has no eligible statements (only a `return` or is empty), the program remains unchanged. @@ -102,24 +104,28 @@ def __init__(self, available_functions: Dict[str, int]): self.available_functions = available_functions @override - def mutate(self, program: Program): + def mutate(self, program: Program) -> Program: + program_cp = program.copy() + func_name = random.choice(list(self.available_functions.keys())) args_size = self.available_functions[func_name] - if not program.variables and args_size > 0: - return + if not program_cp.variables and args_size > 0: + return program - args = random.choices(program.variables, k=args_size) + args = random.choices(program_cp.variables, k=args_size) - max_index = len(program) - (1 if program.has_return_statement() else 0) + max_index = len(program_cp) - (1 if program_cp.has_return_statement() else 0) if max_index <= 0: - return + return program replace_index = random.randrange(max_index) - replace_statement = program.get_statement(replace_index) + replace_statement = program_cp.get_statement(replace_index) replace_statement.func = func_name replace_statement.args = args + return program_cp + class UpdateStatementArgsMutationStrategy(MutationStrategy, ABC): """ @@ -131,22 +137,29 @@ class UpdateStatementArgsMutationStrategy(MutationStrategy, ABC): statement. The statement's function and target variable are not modified. Methods: - mutate(program: Program) -> None: + mutate(program: Program) -> Program: Updates the arguments of one statement in the program. If the program is empty, no mutation is performed. """ @override - def mutate(self, program: Program): + def mutate(self, program: Program) -> Program: if len(program) == 0: - return + return program - statement_idx = random.randrange(len(program)) - statement = program.get_statement(statement_idx) + program_cp = program.copy() + statement_idx = random.randrange(len(program_cp)) + statement = program_cp.get_statement(statement_idx) statement_args_count = len(statement.args) - new_args = random.choices(program.variables, k=statement_args_count) + pr_vars = set(program_cp.variables) + + if statement.func != Statement.RETURN_KEYWORD: + pr_vars.remove(statement.result_var_name) + + new_args = random.choices(list(pr_vars), k=statement_args_count) statement.args = new_args + return program_cp class InsertStatementMutationStrategy(MutationStrategy, ABC): @@ -167,18 +180,19 @@ def __init__(self, available_functions: Dict[str, int]): self.available_functions = available_functions @override - def mutate(self, program: Program): - if len(program) == 0: - return + def mutate(self, program: Program) -> Program: + program_cp = program.copy() - max_index = len(program) - (1 if program.has_return_statement() else 0) + max_index = len(program_cp) - (1 if program_cp.has_return_statement() else 0) if max_index <= 0: - return + return program insert_index = random.randrange(max_index) - statement = self._generate_random_statement(program.variables) - program.insert_statement(statement, insert_index) + statement = self._generate_random_statement(program_cp.variables) + program_cp.insert_statement(statement, insert_index) + + return program_cp def _generate_random_statement(self, program_vars: List[str]) -> Statement: func_name = random.choice(list(self.available_functions.keys())) diff --git a/program_searcher/program_model.py b/program_searcher/program_model.py index 44ea044..6b60b15 100644 --- a/program_searcher/program_model.py +++ b/program_searcher/program_model.py @@ -264,7 +264,11 @@ def create_func_node(stmt: Statement): return node_id def create_input_node(arg): - if arg not in self.program_arg_names: + if ( + arg not in self.program_arg_names + and not isinstance(arg, float) + and not isinstance(arg, int) + ): return node_id = f"{arg}_{arg_counts.get(arg, 0)}" G.add_node(node_id, label=arg, type="input") @@ -277,7 +281,11 @@ def create_input_node(arg): for stmt, node_id in stmt_nodes.items(): for idx, arg in enumerate(stmt.args): - if arg in self.program_arg_names: + if ( + arg in self.program_arg_names + or isinstance(arg, float) + or isinstance(arg, int) + ): arg_node_id = create_input_node(arg) G.add_edge(arg_node_id, node_id, arg_pos=idx) else: @@ -288,6 +296,34 @@ def create_input_node(arg): self.graph = G def to_hash(self): + self.generate_graph() + + if not nx.is_directed_acyclic_graph(self.graph): + return self._hash_linearly() + + return self._hash_by_dag() + + def copy(self): + new_program = Program(self.program_name, self.program_arg_names.copy()) + new_program._statements = [copy.deepcopy(stmt) for stmt in self._statements] + new_program.variables = self.variables.copy() + new_program.last_variable_index = self.last_variable_index + new_program.execution_error = self.execution_error + new_program.graph = self.graph.copy() if self.graph else None + new_program.program_str = self.program_str + return new_program + + def to_python_func(self, global_args: Dict[str, object] = {}) -> Callable: + local_ns = {} + exec(self.program_str, global_args, local_ns) + func = local_ns[self.program_name] + + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def _hash_linearly(self): var_mapping = {} canonical_counter = 0 @@ -314,22 +350,32 @@ def to_hash(self): repr_str = str(canonical_repr).encode("utf-8") return hashlib.sha256(repr_str).hexdigest() - def copy(self): - new_program = Program(self.program_name, self.program_arg_names.copy()) - new_program._statements = [copy.deepcopy(stmt) for stmt in self._statements] - new_program.variables = self.variables.copy() - new_program.last_variable_index = self.last_variable_index - return new_program - - def to_python_func(self, global_args: Dict[str, object] = {}) -> Callable: - local_ns = {} - exec(self.program_str, global_args, local_ns) - func = local_ns[self.program_name] - - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return wrapper + def _hash_by_dag(self): + topo_nodes = list(nx.topological_sort(self.graph)) + + node_labels = {} + func_counter = {} + input_counter = {} + for node in topo_nodes: + data = self.graph.nodes[node] + if data["type"] == "func": + count = func_counter.get(data["label"], 0) + node_labels[node] = f"{data['label']}_{count}" + func_counter[data["label"]] = count + 1 + else: + count = input_counter.get(data["label"], 0) + node_labels[node] = f"in_{count}" + input_counter[data["label"]] = count + 1 + + edges = sorted( + [ + (node_labels[u], node_labels[v], d.get("arg_pos")) + for u, v, d in self.graph.edges(data=True) + ] + ) + + repr_str = str((sorted(node_labels.values()), edges)).encode("utf-8") + return hashlib.sha256(repr_str).hexdigest() def _add_return_statement_if_not_contained(self): if self.has_return_statement(): @@ -337,6 +383,7 @@ def _add_return_statement_if_not_contained(self): return_vars = self.variables[-self.return_vars_count :] return_stmt = Statement(func="return", args=return_vars) + self._statements.append(return_stmt) def _ensure_proper_stmt_index(self, index: int): diff --git a/program_searcher/program_search.py b/program_searcher/program_search.py index 7ba97cc..748e128 100644 --- a/program_searcher/program_search.py +++ b/program_searcher/program_search.py @@ -8,23 +8,17 @@ EvolutionOperator, TournamentSelectionOperator, ) -from program_searcher.exceptions import InvalidProgramSearchArgumentValue +from program_searcher.exceptions import InvalidProgramSearchArgumentValueError from program_searcher.history_tracker import Step, StepsTracker from program_searcher.mutation_strategy import ( - MutationStrategy, + InsertStatementMutationStrategy, RemoveStatementMutationStrategy, + ReplaceStatementMutationStrategy, UpdateStatementArgsMutationStrategy, ) from program_searcher.program_model import Program, Statement, WarmStartProgram from program_searcher.stop_condition import StopCondition -_DEFAULT_MUTATION_STRATEGIES = { - UpdateStatementArgsMutationStrategy(): 1 / 2, - RemoveStatementMutationStrategy(): 1 / 2, -} - -_DEFAULT_EVOLUTION_OPERATOR = TournamentSelectionOperator(tournament_size=2) - class ProgramSearch: def __init__( @@ -54,7 +48,6 @@ def __init__( config (dict, optional): Dictionary of optional parameters. Possible keys and their defaults: - pop_size (int, default=1000): Population size. - evolution_operator (EvolutionOperator, default=TournamentSelectionOperator): operator that performs operations and update population. - - mutation_strategies (Dict[MutationStrategy,float], default=_DEFAULT_MUTATION_STRATEGIES): Mutation strategies with probabilities. - restart_steps (int, default=None): Number of steps after which to restart search. - warm_start_program (WarmStartProgram, default=None): Program to initialize population with. - logger (logging.Logger, default=logging.getLogger(__name__)): Logger for informational and error messages. @@ -73,11 +66,9 @@ def __init__( config = config or {} self.pop_size: int = config.get("pop_size", 1000) self.evolution_operator: EvolutionOperator = config.get( - "evolution_operator", _DEFAULT_EVOLUTION_OPERATOR - ) - self.mutation_strategies: Dict[MutationStrategy, float] = config.get( - "mutation_strategies", _DEFAULT_MUTATION_STRATEGIES + "evolution_operator", self._create_default_evolution_operator() ) + self.restart_steps: int = config.get("restart_steps") self.warm_start_program: WarmStartProgram = config.get("warm_start_program") self.logger: logging.Logger = config.get("logger") or logging.getLogger( @@ -136,7 +127,6 @@ def search(self) -> Tuple[Program, float]: self.evolution_operator.apply( population=self.population, fitnesses=self.fitnesses, - mutation_strategies=self.mutation_strategies, ) self._replace_error_programs() self._replace_equivalent_programs() @@ -177,6 +167,9 @@ def _evaluate_population(self): ) for program in self.population: + if program in self.fitnesses and self.fitnesses[program] is not None: + continue + if warm_hash is not None and program.to_hash() == warm_hash: self.fitnesses[program] = self.warm_start_program.fitness else: @@ -189,7 +182,6 @@ def _replace_error_programs(self): f"Replacing program at index {index} failed execution: {program.execution_error}" ) self.population[index] = self._get_program_replacement() - continue def _replace_equivalent_programs(self): seen_program_hashes = set() @@ -285,27 +277,25 @@ def _init_seeds(self): def _validate_arguments(self): if self.min_program_statements > self.max_program_statements: - raise InvalidProgramSearchArgumentValue( + raise InvalidProgramSearchArgumentValueError( f"min_program_statements ({self.min_program_statements}) cannot be greater than " f"max_program_statements ({self.max_program_statements})." ) if self.pop_size < 0: - raise InvalidProgramSearchArgumentValue( + raise InvalidProgramSearchArgumentValueError( f"pop_size must be non-negative, got {self.pop_size}." ) - if abs(sum(self.mutation_strategies.values()) - 1.0) > 1e-6: - raise InvalidProgramSearchArgumentValue( - f"sum of mutation_strategies values must be 1.0, but is {sum(self.mutation_strategies.values())}." - ) - - if any(value < 0 for value in self.mutation_strategies.values()): - raise InvalidProgramSearchArgumentValue( - f"all mutation_strategies values must be >= 0. current values: {self.mutation_strategies}." - ) + def _create_default_evolution_operator(self): + mutation_strategies = { + UpdateStatementArgsMutationStrategy(): 1 / 4, + RemoveStatementMutationStrategy(): 1 / 4, + ReplaceStatementMutationStrategy(self.available_functions): 1 / 4, + InsertStatementMutationStrategy(self.available_functions): 1 / 4, + } - if any(value > 1 for value in self.mutation_strategies.values()): - raise InvalidProgramSearchArgumentValue( - f"all mutation_strategies values must be <= 1. current values: {self.mutation_strategies}." - ) + evolution_operator = TournamentSelectionOperator( + tournament_size=1, mutation_strategies=mutation_strategies + ) + return evolution_operator diff --git a/tests/test_evolution_operator.py b/tests/test_evolution_operator.py new file mode 100644 index 0000000..b8286bc --- /dev/null +++ b/tests/test_evolution_operator.py @@ -0,0 +1,44 @@ +import unittest + +from program_searcher.evolution_operator import TournamentSelectionOperator +from program_searcher.exceptions import InvalidMutationStrategiesError +from program_searcher.mutation_strategy import ( + RemoveStatementMutationStrategy, + ReplaceStatementMutationStrategy, + UpdateStatementArgsMutationStrategy, +) + + +class TestTorunamentSelectionEvolutionStrategy(unittest.TestCase): + def test_invalid_mutation_strategies_sum(self): + mutation_strategies = { + ReplaceStatementMutationStrategy: 0.5, + RemoveStatementMutationStrategy: 0.5, + UpdateStatementArgsMutationStrategy: 0.5, + } + with self.assertRaises(InvalidMutationStrategiesError): + TournamentSelectionOperator( + tournament_size=1, mutation_strategies=mutation_strategies + ) + + def test_mutation_strategies_negative_value(self): + mutation_strategies = { + RemoveStatementMutationStrategy: -0.1, + ReplaceStatementMutationStrategy: 0.6, + UpdateStatementArgsMutationStrategy: 0.5, + } + with self.assertRaises(InvalidMutationStrategiesError): + TournamentSelectionOperator( + tournament_size=1, mutation_strategies=mutation_strategies + ) + + def test_mutation_strategies_value_greater_than_one(self): + mutation_strategies = { + RemoveStatementMutationStrategy: 1.1, + ReplaceStatementMutationStrategy: -0.05, + UpdateStatementArgsMutationStrategy: -0.05, + } + with self.assertRaises(InvalidMutationStrategiesError): # noqa: F821 + TournamentSelectionOperator( + tournament_size=1, mutation_strategies=mutation_strategies + ) diff --git a/tests/test_mutation_strategy.py b/tests/test_mutation_strategy.py index 6a8d32b..75d05aa 100644 --- a/tests/test_mutation_strategy.py +++ b/tests/test_mutation_strategy.py @@ -8,6 +8,8 @@ ) from program_searcher.program_model import Program, Statement +random.seed(42) + def make_program_with_return(): prog = Program(program_name="dummy", program_arg_names=["X", "y"]) @@ -29,24 +31,22 @@ def make_program_no_return(): class TestRemoveStatementMutationStrategy(unittest.TestCase): def setUp(self): random.seed(42) - self.sut = RemoveStatementMutationStrategy(remove_retries=3) + self.strategy = RemoveStatementMutationStrategy(remove_retries=3) def test_remove_simple_statement(self): prog = make_program_no_return() len_pr_before = len(prog) - strat = RemoveStatementMutationStrategy(remove_retries=3) - strat.mutate(prog) + mutated = self.strategy.mutate(prog) - self.assertEqual(len(prog), len_pr_before - 1) + self.assertEqual(len(mutated), len_pr_before - 1) def test_does_not_remove_return_statement(self): prog = make_program_with_return() - strat = RemoveStatementMutationStrategy(remove_retries=3) - strat.mutate(prog) + mutated = self.strategy.mutate(prog) - self.assertTrue(any(stmt.func == "return" for stmt in prog._statements)) + self.assertTrue(any(stmt.func == "return" for stmt in mutated._statements)) def test_dependency_blocks_removal(self): prog = Program(program_name="dummy", program_arg_names=["X"]) @@ -54,27 +54,23 @@ def test_dependency_blocks_removal(self): prog.insert_statement(Statement(args=["x1"], func="square")) # x2 prog.insert_statement(Statement(args=["x2"], func="return")) # return - before = len(prog) - self.sut.mutate(prog) - after = len(prog) + mutated = self.strategy.mutate(prog) - self.assertEqual(before, after) + self.assertEqual(len(mutated), len(prog)) def test_empty_program(self): prog = Program(program_name="dummy", program_arg_names=["X", "y"]) - self.sut.mutate(prog) - self.assertEqual(len(prog), 0) + mutated = self.strategy.mutate(prog) + self.assertEqual(len(mutated), 0) def test_program_with_only_return(self): prog = Program(program_name="dummy", program_arg_names=["X"]) prog.insert_statement(Statement(args=["X"], func="return")) - before = len(prog) - self.sut.mutate(prog) - after = len(prog) + mutated = self.strategy.mutate(prog) - self.assertEqual(before, after) - self.assertTrue(all(stmt.func == "return" for stmt in prog._statements)) + self.assertEqual(len(prog), len(mutated)) + self.assertTrue(all(stmt.func == "return" for stmt in mutated._statements)) class TestReplaceStatementMutationStrategy(unittest.TestCase): @@ -91,40 +87,38 @@ def setUp(self): def test_replaces_function_and_args(self): prog = make_program_no_return() max_index = len(prog) - 1 - self.strategy.mutate(prog) + mutated = self.strategy.mutate(prog) - for stmt in prog._statements[:max_index]: + for stmt in mutated._statements[:max_index]: self.assertIn(stmt.func, self.available_functions) self.assertEqual(len(stmt.args), self.available_functions[stmt.func]) def test_replacement_does_not_remove_statements(self): prog = make_program_no_return() - before = len(prog) - self.strategy.mutate(prog) - after = len(prog) + mutated = self.strategy.mutate(prog) - self.assertEqual(before, after) + self.assertEqual(len(prog), len(mutated)) def test_return_statement_not_replaced(self): prog = make_program_with_return() - self.strategy.mutate(prog) + mutated = self.strategy.mutate(prog) - self.assertTrue(any(stmt.func == "return" for stmt in prog._statements)) + self.assertTrue(any(stmt.func == "return" for stmt in mutated._statements)) def test_only_return_program_remains_unchanged(self): prog = Program(program_name="dummy", program_arg_names=["X"]) prog.insert_statement(Statement(args=["X"], func="return")) before_funcs = [stmt.func for stmt in prog._statements] - self.strategy.mutate(prog) - after_funcs = [stmt.func for stmt in prog._statements] + mutated = self.strategy.mutate(prog) + after_funcs = [stmt.func for stmt in mutated._statements] self.assertEqual(before_funcs, after_funcs) def test_empty_program_remains_unchanged(self): prog = Program(program_name="dummy", program_arg_names=["X"]) - self.strategy.mutate(prog) - self.assertEqual(len(prog), 0) + mutated = self.strategy.mutate(prog) + self.assertEqual(len(mutated), 0) class TestUpdateStatementArgsMutationStrategy(unittest.TestCase): @@ -136,12 +130,12 @@ def test_arguments_are_replaced(self): prog = make_program_no_return() original_args = [stmt.args[:] for stmt in prog._statements] - self.strategy.mutate(prog) + mutated = self.strategy.mutate(prog) - for orig, stmt in zip(original_args, prog._statements): + for orig, stmt in zip(original_args, mutated._statements): self.assertEqual(len(orig), len(stmt.args)) - for stmt in prog._statements: + for stmt in mutated._statements: for arg in stmt.args: self.assertIn(arg, prog.variables) @@ -149,11 +143,12 @@ def test_single_statement_updated(self): prog = make_program_no_return() original_args = [stmt.args[:] for stmt in prog._statements] - self.strategy.mutate(prog) + mutated = self.strategy.mutate(prog) self.assertTrue( any( - orig != stmt.args for orig, stmt in zip(original_args, prog._statements) + orig != stmt.args + for orig, stmt in zip(original_args, mutated._statements) ) ) @@ -162,28 +157,30 @@ def test_return_statement_can_be_updated(self): return_stmt_index = len(prog._statements) - 1 original_args = prog.get_statement(return_stmt_index).args[:] - self.strategy.mutate(prog) + mutated = self.strategy.mutate(prog) self.assertEqual( - len(original_args), len(prog.get_statement(return_stmt_index).args) + len(original_args), len(mutated.get_statement(return_stmt_index).args) ) def test_empty_program_no_crash(self): prog = Program(program_name="dummy", program_arg_names=["X"]) - self.strategy.mutate(prog) + mutated = self.strategy.mutate(prog) + + self.assertEqual(len(mutated), 0) self.assertEqual(len(prog), 0) def test_at_least_one_statement_args_changed(self): prog = make_program_no_return() original_args = [stmt.args[:] for stmt in prog._statements] - self.strategy.mutate(prog) + mutated = self.strategy.mutate(prog) self.assertTrue( any( - orig != stmt.args for orig, stmt in zip(original_args, prog._statements) - ), - "Żaden statement nie zmienił argumentów", + orig != stmt.args + for orig, stmt in zip(original_args, mutated._statements) + ) ) diff --git a/tests/test_program_model.py b/tests/test_program_model.py index cc2230d..6040599 100644 --- a/tests/test_program_model.py +++ b/tests/test_program_model.py @@ -171,7 +171,13 @@ def test_to_hash_equivalence_for_isomorphic_programs(self): stmt4 = Statement([stmt3.result_var_name], "return") prog2.insert_statement(stmt4) + prog3 = Program("prog3", ["c", "b"]) + prog3.insert_statement(Statement(args=[1.0, 2.0], func="add")) + prog3.insert_statement(Statement(args=["x1"], func="return")) + self.assertEqual(prog1.to_hash(), prog2.to_hash()) + self.assertNotEqual(prog3.to_hash(), prog1.to_hash()) + self.assertNotEqual(prog3.to_hash(), prog2.to_hash()) def test_copy_creates_independent_program(self): stmt = Statement(["a", "b"], "add") @@ -279,6 +285,30 @@ def test_generate_program_graph_unused_parts(self): self.assertTrue(nx.is_isomorphic(graph, expected_graph)) + def test_generate_program_graph_with_consts(self): + program = Program("test_prog", program_arg_names=["a"]) + program.insert_statement(Statement(args=[0.001], func="const")) # x1 + program.insert_statement(Statement(args=["x1", "a"], func="add")) # x2 + program.insert_statement(Statement(args=["x2", 5], func="mult")) # x3 + program.insert_statement(Statement(args=["x3"], func="return")) + + program.generate_graph() + graph = program.graph + + expected_edges = [ + ("0.001_0", "const_0"), + ("const_0", "add_0"), + ("a_0", "add_0"), + ("0.5_0", "mult_0"), + ("add_0", "mult_0"), + ("mult_0", "return_0"), + ] + + expected_graph = nx.DiGraph() + expected_graph.add_edges_from(expected_edges) + + self.assertTrue(nx.is_isomorphic(graph, expected_graph)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_program_search.py b/tests/test_program_search.py index f0e795f..f0253af 100644 --- a/tests/test_program_search.py +++ b/tests/test_program_search.py @@ -6,7 +6,7 @@ import unittest from program_searcher.evolution_operator import FullPopulationMutationOperator -from program_searcher.exceptions import InvalidProgramSearchArgumentValue +from program_searcher.exceptions import InvalidProgramSearchArgumentValueError from program_searcher.history_tracker import CsvStepsTracker from program_searcher.mutation_strategy import ( InsertStatementMutationStrategy, @@ -19,7 +19,7 @@ from program_searcher.stop_condition import MaxStepsStopCondition -class TestProgramSearchValidation(unittest.TestCase): +class TestProgramSearch(unittest.TestCase): def setUp(self): self.correct_args = { "program_name": "test_program", @@ -32,11 +32,6 @@ def setUp(self): "evaluate_program_func": lambda p: 0.0, "config": { "pop_size": 10, - "mutation_strategies": { - RemoveStatementMutationStrategy: 0.3, - ReplaceStatementMutationStrategy: 0.3, - UpdateStatementArgsMutationStrategy: 0.4, - }, "logger": logging.getLogger("test_logger"), }, } @@ -45,43 +40,13 @@ def test_min_greater_than_max(self): args = self.correct_args.copy() args["min_program_statements"] = 6 args["max_program_statements"] = 5 - with self.assertRaises(InvalidProgramSearchArgumentValue): + with self.assertRaises(InvalidProgramSearchArgumentValueError): ProgramSearch(**args) def test_negative_pop_size(self): args = self.correct_args.copy() args["config"]["pop_size"] = -1 - with self.assertRaises(InvalidProgramSearchArgumentValue): - ProgramSearch(**args) - - def test_invalid_mutation_strategies_sum(self): - args = self.correct_args.copy() - args["config"]["mutation_strategies"] = { - ReplaceStatementMutationStrategy: 0.5, - RemoveStatementMutationStrategy: 0.5, - UpdateStatementArgsMutationStrategy: 0.5, - } - with self.assertRaises(InvalidProgramSearchArgumentValue): - ProgramSearch(**args) - - def test_mutation_strategies_negative_value(self): - args = self.correct_args.copy() - args["config"]["mutation_strategies"] = { - RemoveStatementMutationStrategy: -0.1, - ReplaceStatementMutationStrategy: 0.6, - UpdateStatementArgsMutationStrategy: 0.5, - } - with self.assertRaises(InvalidProgramSearchArgumentValue): - ProgramSearch(**args) - - def test_mutation_strategies_value_greater_than_one(self): - args = self.correct_args.copy() - args["config"]["mutation_strategies"] = { - RemoveStatementMutationStrategy: 1.1, - ReplaceStatementMutationStrategy: -0.05, - UpdateStatementArgsMutationStrategy: -0.05, - } - with self.assertRaises(InvalidProgramSearchArgumentValue): + with self.assertRaises(InvalidProgramSearchArgumentValueError): ProgramSearch(**args) def test_search_with_defaults_should_not_raise_any(self): @@ -121,6 +86,9 @@ def test_search_should_not_raise_any(self): logger = logging.getLogger("program_searcher") logger.setLevel(logging.DEBUG) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -128,6 +96,7 @@ def test_search_should_not_raise_any(self): ) file_handler.setFormatter(formatter) logger.addHandler(file_handler) + logger.addHandler(console_handler) warm_start_program = Program( program_name="test", program_arg_names=["a", "b"] @@ -138,6 +107,7 @@ def test_search_should_not_raise_any(self): warm_start_program.insert_statement( Statement(["a", "b"], func="op.multiply") ) + warm_start_program.insert_statement(Statement(["a"], func="return")) warm_start = WarmStartProgram(warm_start_program) available_functions_local = { "op.add": 2, @@ -153,7 +123,7 @@ def test_search_should_not_raise_any(self): ReplaceStatementMutationStrategy(available_functions_local): 2 / 5, } - evolution_operator = FullPopulationMutationOperator() + evolution_operator = FullPopulationMutationOperator(mutation_strategies) program_search = ProgramSearch( program_name="test", @@ -168,7 +138,6 @@ def test_search_should_not_raise_any(self): "pop_size": 50, "restart_steps": 15, "logger": logger, - "mutation_strategies": mutation_strategies, "warm_start_program": warm_start, "step_trackers": [ CsvStepsTracker(file_dir=csv_dir, save_batch_size=5)