Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ numpy = ">=2.2.3,<3.0.0"
scikit-learn = ">=1.6.1,<2.0.0"
scipy = ">=1.15.2,<2.0.0"
seaborn = ">=0.13.2,<0.14.0"
joblib = "^1.5.2"

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.4"
Expand Down

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bug with joblib on Windows where running something with joblib outside of the if __name__ == "__main__" block causes the program to crash.

To solve this problem, I suggest explicitly stating in the documentation that a special launch method is required on Windows, and providing sample code.

Additionally, within run, you could check for the operating system, the number of processes, and whether it's inside the required if block. I'm not sure this is possible, but if possible, it would be a good idea.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget add your name in __author__

Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
Expectation-step.
"""

__author__ = "Danil Totmyanin"
__author__ = "Danil Totmyanin, Aleksandra Ri"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"


import os
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Callable, ClassVar
from typing import Callable, ClassVar, Optional

import numpy as np
from joblib import Parallel, delayed, parallel_backend

from ....distributions import ContinuousDistribution
from ....optimizers import Optimizer
Expand Down Expand Up @@ -46,6 +48,11 @@ class MaximizationStep(PipelineStep):
optimizer : Optimizer
A numerical optimizer instance used to find the optimal parameters
when an analytical solution is not available for a given strategy.
n_jobs : Optional[int], default=None
The number of jobs to run in parallel for the optimization tasks.
- ``None`` (default): The number of jobs is determined automatically. It will
be the minimum of the number of optimization blocks and the number of
available CPUs.

Attributes
----------
Expand All @@ -66,10 +73,18 @@ class MaximizationStep(PipelineStep):
{MaximizationStrategy.QFUNCTION: q_function_strategy}
)

def __init__(self, blocks: Sequence[OptimizationBlock], optimizer: Optimizer):
def __init__(self, blocks: Sequence[OptimizationBlock], optimizer: Optimizer, n_jobs: Optional[int] = None):
self.blocks = list(blocks)
self.optimizer = optimizer

if n_jobs is not None:
self._n_jobs = n_jobs
else:
cpu_count = os.cpu_count() or 1
default_jobs = min(len(self.blocks), cpu_count)

self._n_jobs = default_jobs if default_jobs > 0 else 1

@property
def available_next_steps(self) -> list[type[PipelineStep]]:
"""list[type[PipelineStep]]: Defines the valid subsequent steps.
Expand Down Expand Up @@ -98,6 +113,33 @@ def _update_components_params(self, component: ContinuousDistribution, params: d
param_values = list(params.values())
component.set_params_from_vector(param_names, param_values)

def _optimization_worker(
self,
state: PipelineState,
block: OptimizationBlock,
optimizer: Optimizer,
) -> tuple[int, dict[str, float]]:
"""Helper method to execute the optimization strategy for a single block.

Parameters
----------
state : PipelineState
The current state of the estimation pipeline.
block : OptimizationBlock
The configuration block defining which component and parameters to optimize.
optimizer : Optimizer
The optimizer instance passed to the strategy function.

Returns
-------
tuple[int, dict[str, float]]
A tuple containing the component ID and a dictionary of its newly optimized parameters.
"""
component = state.curr_mixture[block.component_id]
component_id, new_params = self._strategies[block.maximization_strategy](component, state, block, optimizer)

return component_id, new_params

def run(self, state: PipelineState) -> PipelineState:
"""Executes the M-step.

Expand Down Expand Up @@ -125,13 +167,13 @@ def run(self, state: PipelineState) -> PipelineState:
state.error = error
return state

results = []
curr_mixture = state.curr_mixture

for block in self.blocks:
strategy = self._strategies[block.maximization_strategy]
component_id, new_params = strategy(curr_mixture[block.component_id], state, block, self.optimizer)
results.append((component_id, new_params))
# Use threading backend: NumPy/SciPy release the GIL, enabling true parallelism without data copying overhead.
with parallel_backend("threading", n_jobs=self._n_jobs):
results = Parallel()(
delayed(self._optimization_worker)(state, block, self.optimizer) for block in self.blocks
)

for result in results:
component_id, params = result
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for MaximizationStep"""

__author__ = "Danil Totmyanin"
__author__ = "Danil Totmyanin, Aleksandra Ri"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

Expand All @@ -21,6 +21,11 @@
from rework_pysatl_mpest.optimizers import Optimizer


def serial_executor(generator):
""" "A mock executor for joblib that runs tasks sequentially in a single process."""
return [func(*args, **kwargs) for func, args, kwargs in generator]


@pytest.fixture
def mock_optimizer(mocker: MockerFixture) -> Optimizer:
"""Fixture to create a mock Optimizer."""
Expand Down Expand Up @@ -138,6 +143,12 @@ def test_run_calls_correct_strategy_and_updates_params(

target_component = mock_components[0]

# Mock parallel execution to run tasks sequentially for testing.
mocker.patch(
"rework_pysatl_mpest.estimators.iterative.steps.maximization_step.Parallel",
return_value=serial_executor,
)

step.run(pipeline_state)

# mock_strategy was called once
Expand Down Expand Up @@ -219,6 +230,12 @@ def test_run_processes_blocks_sequentially(
new_strategies = {MaximizationStrategy.QFUNCTION: mock_strategy}
mocker.patch.object(MaximizationStep, "_strategies", new_strategies)

# Mock parallel execution to run tasks sequentially for testing.
mocker.patch(
"rework_pysatl_mpest.estimators.iterative.steps.maximization_step.Parallel",
return_value=serial_executor,
)

# Act
step.run(pipeline_state)

Expand Down