From 88e555eb4316c755e180e80959613e0f4ad79f11 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 31 Jul 2025 15:37:15 +0100 Subject: [PATCH 01/18] add test for total ITI duration --- iblrig/test/test_choice_world.py | 39 +++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index ab797d7d4..dde5e81fe 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -6,20 +6,24 @@ import shutil import tempfile import unittest +from itertools import count from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import numpy as np import pandas as pd +import yaml import iblrig.choiceworld from iblrig import session_creator +from iblrig.base_choice_world import ChoiceWorldSession from iblrig.path_helper import iterate_previous_sessions from iblrig.raw_data_loaders import load_task_jsonable from iblrig.test.base import BaseTestCases from iblrig_tasks._iblrig_tasks_passiveChoiceWorld.task import Session as PassiveChoiceWorldSession from iblrig_tasks._iblrig_tasks_spontaneous.task import Session as SpontaneousSession from iblrig_tasks._iblrig_tasks_trainingChoiceWorld.task import Session as TrainingChoiceWorldSession +from iblutil.util import Bunch class TestGetPreviousSession(BaseTestCases.CommonTestTask): @@ -225,3 +229,36 @@ def test_training_phase_from_contrast_set(self): self.assertEqual(iblrig.choiceworld.training_phase_from_contrast_set(contrasts3), phase) with self.assertRaises(ValueError): iblrig.choiceworld.training_phase_from_contrast_set([0.666]) + + +class TestITI(unittest.TestCase): + def test_iti(self): + with ChoiceWorldSession.base_parameters_file.open() as f: + params = Bunch(yaml.safe_load(f)) + params['NTRIALS'] = 2 + + sma = MagicMock() + session = MagicMock().return_value + session.task_params = params + session._run = ChoiceWorldSession._run.__get__(session, ChoiceWorldSession) + session.get_state_machine_trial = ChoiceWorldSession.get_state_machine_trial.__get__(session, ChoiceWorldSession) + session._instantiate_state_machine.return_value = sma + session.paused = False + session.stopped = False + + with patch('iblrig.base_choice_world.time.sleep'): + session._run() + + last_state_timer = sma.add_state.call_args_list[-1].kwargs['state_timer'] + inter_sma_delay = 0.05 # the delay inbetween two state machines + + counter = count(0, last_state_timer + inter_sma_delay) + with ( + patch('iblrig.base_choice_world.time.time', side_effect=lambda: next(counter)) as mock_time, + patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, + ): + session._run() + self.assertEqual(mock_time.call_count, 4, 'expecting time.time() to have been called 4 times.') + self.assertEqual(mock_sleep.call_count, 1, 'expecting time.sleep() to have been called once.') + sleep_time = mock_sleep.call_args[0][0] + self.assertAlmostEqual(last_state_timer + inter_sma_delay + sleep_time, 1.0, msg='Total ITI should be 1 second') From 5ab238d6ee62678ce3b33e2de14020ddade84ae4 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 31 Jul 2025 16:44:32 +0100 Subject: [PATCH 02/18] Update test_choice_world.py --- iblrig/test/test_choice_world.py | 35 ++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index dde5e81fe..b3d47f906 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -232,11 +232,12 @@ def test_training_phase_from_contrast_set(self): class TestITI(unittest.TestCase): - def test_iti(self): + @staticmethod + def get_mock_session(n_trials: int) -> tuple[MagicMock, MagicMock]: + """Mock ChoiceWorldSession and StateMachine""" with ChoiceWorldSession.base_parameters_file.open() as f: params = Bunch(yaml.safe_load(f)) - params['NTRIALS'] = 2 - + params['NTRIALS'] = n_trials sma = MagicMock() session = MagicMock().return_value session.task_params = params @@ -245,20 +246,32 @@ def test_iti(self): session._instantiate_state_machine.return_value = sma session.paused = False session.stopped = False + return session, sma + def test_iti(self): + # the fraction of the ITI handled by the state machine's last state + session, sma = self.get_mock_session(1) with patch('iblrig.base_choice_world.time.sleep'): session._run() + iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] + + # the last state of the state machine needs to contain a BNC1 high of a certain duration - for extraction + self.assertGreater(iti_delay_sma, 0.2, 'Part of the ITI should be handled by the state machine.') + self.assertIn(('BNC1', 255), sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Expecting BNC1 high.') - last_state_timer = sma.add_state.call_args_list[-1].kwargs['state_timer'] - inter_sma_delay = 0.05 # the delay inbetween two state machines + # the assumed fraction of the ITI defined by processing delays + iti_delay_processing = 0.031231234234234 - counter = count(0, last_state_timer + inter_sma_delay) + # the fraction of the ITI handled by time.sleep() making up for processing delays + session, sma = self.get_mock_session(2) + counter = count(0, iti_delay_sma + iti_delay_processing) with ( - patch('iblrig.base_choice_world.time.time', side_effect=lambda: next(counter)) as mock_time, + patch('iblrig.base_choice_world.time.time', side_effect=lambda: next(counter)), patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, ): session._run() - self.assertEqual(mock_time.call_count, 4, 'expecting time.time() to have been called 4 times.') - self.assertEqual(mock_sleep.call_count, 1, 'expecting time.sleep() to have been called once.') - sleep_time = mock_sleep.call_args[0][0] - self.assertAlmostEqual(last_state_timer + inter_sma_delay + sleep_time, 1.0, msg='Total ITI should be 1 second') + self.assertEqual(session.bpod.run_state_machine.call_count, 2, 'expecting run_state_machine() to have been called twice.') + iti_delay_sleep = mock_sleep.call_args[0][0] if mock_sleep.call_args else 0.0 + + # the total ITI should be 1 second + self.assertAlmostEqual(iti_delay_sma + iti_delay_processing + iti_delay_sleep, 1.0, msg='Total ITI should be 1 second') From d01ab76bfd78a2fd9c2e6358d667dc93e8357e38 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 31 Jul 2025 16:47:42 +0100 Subject: [PATCH 03/18] Update test_choice_world.py --- iblrig/test/test_choice_world.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index b3d47f906..e8e3fa1fe 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -256,7 +256,7 @@ def test_iti(self): iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] # the last state of the state machine needs to contain a BNC1 high of a certain duration - for extraction - self.assertGreater(iti_delay_sma, 0.2, 'Part of the ITI should be handled by the state machine.') + self.assertGreaterEqual(iti_delay_sma, 0.5, 'Part of the ITI should be handled by the state machine.') self.assertIn(('BNC1', 255), sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Expecting BNC1 high.') # the assumed fraction of the ITI defined by processing delays From a936df6f7c9bdaab6d05c20034aa888db61b29fd Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 31 Jul 2025 18:48:29 +0100 Subject: [PATCH 04/18] Add warning about ITI < 0.5 --- iblrig/base_choice_world.py | 6 +++++- iblrig/test/test_choice_world.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index 5ad0dfdc3..bd15b6bd9 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -250,10 +250,15 @@ def _run(self) -> None: # obtain state machine definition self.next_trial() sma = self.get_state_machine_trial(trial_number) + last_state_duration = sma.state_timers[sma.total_states_added - 1] # Waiting for camera / initial delay will be handled just prior to the first trial # This is done here to allow for backward compatibility with unadapted tasks if trial_number == 0: + # warn if the duration of the last state is not sufficiently long + if last_state_duration < 0.5: + log.warning(f'The last state has a duration of only {last_state_duration} s. It should be 0.5 s or longer.') + # warn if state machine uses deprecated way of waiting for camera / initial delay if (5, SOFTCODE.TRIGGER_CAMERA) in sma.output_matrix[0] and sma.state_names[1] == 'delay_initiation': log.warning('') @@ -290,7 +295,6 @@ def _run(self) -> None: # run state machine log.info('-----------------------') log.info(f'Starting Trial #{trial_number}') - log.debug('running state machine') self.bpod.run_state_machine(sma) # Locks until state machine 'exit' is reached time_last_trial_end = time.time() diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index e8e3fa1fe..d77f6d4aa 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -8,7 +8,7 @@ import unittest from itertools import count from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import numpy as np import pandas as pd @@ -239,6 +239,10 @@ def get_mock_session(n_trials: int) -> tuple[MagicMock, MagicMock]: params = Bunch(yaml.safe_load(f)) params['NTRIALS'] = n_trials sma = MagicMock() + type(sma).total_states_added = PropertyMock(side_effect=lambda: sma.add_state.call_count) + type(sma).state_timers = PropertyMock( + side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] + ) session = MagicMock().return_value session.task_params = params session._run = ChoiceWorldSession._run.__get__(session, ChoiceWorldSession) From 6c52913e4f992282037892d4cf6ca8026b53544a Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Mon, 18 Aug 2025 16:41:51 +0100 Subject: [PATCH 05/18] add test for iti warning --- iblrig/test/test_choice_world.py | 7 +++++++ pyproject.toml | 1 + uv.lock | 24 ++++++++++++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index d77f6d4aa..c41a7a979 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -252,6 +252,13 @@ def get_mock_session(n_trials: int) -> tuple[MagicMock, MagicMock]: session.stopped = False return session, sma + def test_iti_warning(self): + # test that the ITI warning is raised when the last state does not handle the ITI + session, sma = self.get_mock_session(1) + type(sma).state_timers = [0.0] * 100 + with patch('iblrig.base_choice_world.time.sleep'), self.assertLogs('iblrig', level='WARNING'): + session._run() + def test_iti(self): # the fraction of the ITI handled by the state machine's last state session, sma = self.get_mock_session(1) diff --git a/pyproject.toml b/pyproject.toml index 23cc19462..27a1b6b44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "PyQt5-Qt5>=5.15.2", "PyQtWebEngine>=5.15.6", "PyQtWebEngine-Qt5>=5.15.2", + "pandas-stubs==2.3.0.250703", ] [project.optional-dependencies] project-extraction = [ diff --git a/uv.lock b/uv.lock index d51f90b21..a9890d864 100644 --- a/uv.lock +++ b/uv.lock @@ -786,6 +786,7 @@ dependencies = [ { name = "one-api" }, { name = "packaging" }, { name = "pandas" }, + { name = "pandas-stubs" }, { name = "pandera" }, { name = "psutil" }, { name = "pydantic" }, @@ -877,6 +878,7 @@ requires-dist = [ { name = "one-api", specifier = ">=3.3.0" }, { name = "packaging", specifier = ">=25.0" }, { name = "pandas", specifier = ">=2.3.0" }, + { name = "pandas-stubs", specifier = "==2.3.0.250703" }, { name = "pandera", specifier = ">=0.24.0" }, { name = "project-extraction", marker = "extra == 'project-extraction'", git = "https://github.com/int-brain-lab/project_extraction.git" }, { name = "psutil", specifier = ">=7.0.0" }, @@ -1607,6 +1609,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/d6/d7f5777162aa9b48ec3910bca5a58c9b5927cfd9cfde3aa64322f5ba4b9f/pandas-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:2eb789ae0274672acbd3c575b0598d213345660120a257b47b5dafdc618aec83", size = 11336561, upload-time = "2025-07-07T19:18:31.211Z" }, ] +[[package]] +name = "pandas-stubs" +version = "2.3.0.250703" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "types-pytz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ec/df/c1c51c5cec087b8f4d04669308b700e9648745a77cdd0c8c5e16520703ca/pandas_stubs-2.3.0.250703.tar.gz", hash = "sha256:fb6a8478327b16ed65c46b1541de74f5c5947f3601850caf3e885e0140584717", size = 103910, upload-time = "2025-07-02T17:49:11.667Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/cb/09d5f9bf7c8659af134ae0ffc1a349038a5d0ff93e45aedc225bde2872a3/pandas_stubs-2.3.0.250703-py3-none-any.whl", hash = "sha256:a9265fc69909f0f7a9cabc5f596d86c9d531499fed86b7838fd3278285d76b81", size = 154719, upload-time = "2025-07-02T17:49:10.697Z" }, +] + [[package]] name = "pandera" version = "0.25.0" @@ -3047,6 +3062,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/52/43e70a8e57fefb172c22a21000b03ebcc15e47e97f5cb8495b9c2832efb4/types_python_dateutil-2.9.0.20250708-py3-none-any.whl", hash = "sha256:4d6d0cc1cc4d24a2dc3816024e502564094497b713f7befda4d5bc7a8e3fd21f", size = 17724, upload-time = "2025-07-08T03:14:02.593Z" }, ] +[[package]] +name = "types-pytz" +version = "2025.2.0.20250516" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/72/b0e711fd90409f5a76c75349055d3eb19992c110f0d2d6aabbd6cfbc14bf/types_pytz-2025.2.0.20250516.tar.gz", hash = "sha256:e1216306f8c0d5da6dafd6492e72eb080c9a166171fa80dd7a1990fd8be7a7b3", size = 10940, upload-time = "2025-05-16T03:07:01.91Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl", hash = "sha256:e0e0c8a57e2791c19f718ed99ab2ba623856b11620cb6b637e5f62ce285a7451", size = 10136, upload-time = "2025-05-16T03:07:01.075Z" }, +] + [[package]] name = "types-pyyaml" version = "6.0.12.20250516" From a18189b01218536683c1f5d1f7388dcb5869ccd5 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 14:11:27 +0100 Subject: [PATCH 06/18] fixes for habituation CW (#795) * add extra state assuring go-cue will be played after onset of visual stimulus * add `hide_stim` state, untangle ITI * BNC1 high in first state * Update CHANGELOG.md --- CHANGELOG.md | 4 ++++ iblrig/base_choice_world.py | 47 +++++++++++++++++++++++++------------ 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e1bd0117..82060c068 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ Changelog ========= +8.30.0 +------ +* disentangle order of states in `HabituationChoiceWorldSession` + 8.29.0 ------ * added: GUI settings for changing MAIN_SYNC diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index bd15b6bd9..232b785c5 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -738,22 +738,22 @@ def draw_next_trial_info(self, *args, **kwargs): def get_state_machine_trial(self, i): sma = StateMachine(self.bpod) - # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on. - # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end; - # the offset of this state is trial start! + # Show the visual stimulus. + # Move to next state if Frame2TTL event is detected. + # Use the state-timer as a backup to prevent a stall. sma.add_state( - state_name='iti', - state_timer=1, # Stim off for 1 sec - state_change_conditions={'Tup': 'stim_on'}, - output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)], + state_name='stim_on', + state_timer=0.1, + state_change_conditions={'Tup': 'stim_center', 'BNC1High': 'play_tone', 'BNC1Low': 'play_tone'}, + output_actions=[self.bpod.actions.bonsai_show_stim, ('BNC1', 255)], ) - # This stim_on state is considered the actual trial start + # Play tone and wait for `delay_to_stim_center`. sma.add_state( - state_name='stim_on', + state_name='play_tone', state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'], + output_actions=[self.bpod.actions.play_tone], state_change_conditions={'Tup': 'stim_center'}, - output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone], ) sma.add_state( @@ -769,15 +769,32 @@ def get_state_machine_trial(self, i): state_change_conditions={'Tup': 'post_reward'}, output_actions=[('Valve1', 255), ('BNC1', 255)], ) - # This state defines the period after reward where Bpod TTL is LOW. - # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit. - # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds. + sma.add_state( state_name='post_reward', - state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time, - state_change_conditions={'Tup': 'exit'}, + state_timer=0.5 - self.reward_time, + state_change_conditions={'Tup': 'hide_stim'}, output_actions=[], ) + + # Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary + # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e., + # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall. + sma.add_state( + state_name='hide_stim', + state_timer=0.1, + output_actions=[self.bpod.actions.bonsai_hide_stim], + state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'}, + ) + + # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event. + sma.add_state( + state_name='exit_state', + state_timer=self.task_params.ITI_DELAY_SECS, + output_actions=[('BNC1', 255)], + state_change_conditions={'Tup': 'exit'}, + ) + return sma From 99859ac09441d994890b26619743e234b6087b08 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 15:10:40 +0100 Subject: [PATCH 07/18] fix dead_time in `ChoiceWorldSession._run()` --- iblrig/base_choice_world.py | 30 +++++++++++++++------------- iblrig/base_choice_world_params.yaml | 3 +-- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index 232b785c5..0c9e463f3 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -246,19 +246,15 @@ def _run(self) -> None: This method orchestrates the execution of the task by running a state machine for a specified number of trials. """ time_last_trial_end = time.time() + iti_last_trial = np.nan for trial_number in range(self.task_params.NTRIALS): # Main loop # obtain state machine definition self.next_trial() sma = self.get_state_machine_trial(trial_number) - last_state_duration = sma.state_timers[sma.total_states_added - 1] # Waiting for camera / initial delay will be handled just prior to the first trial # This is done here to allow for backward compatibility with unadapted tasks if trial_number == 0: - # warn if the duration of the last state is not sufficiently long - if last_state_duration < 0.5: - log.warning(f'The last state has a duration of only {last_state_duration} s. It should be 0.5 s or longer.') - # warn if state machine uses deprecated way of waiting for camera / initial delay if (5, SOFTCODE.TRIGGER_CAMERA) in sma.output_matrix[0] and sma.state_names[1] == 'delay_initiation': log.warning('') @@ -271,33 +267,39 @@ def _run(self) -> None: log.warning("'Deprecation Notes' in IBLRIG's documentation.") log.warning('**********************************************') log.warning('') - log.info('Waiting for 10s so you actually read this message ;-)') + log.warning('Waiting for 10s so you actually read this message ;-)') time.sleep(10) else: self._wait_for_camera_and_initial_delay() # send state machine description to Bpod device - log.debug('Sending state machine to bpod') self.bpod.send_state_machine(sma) # handle ITI durations if trial_number > 0: - # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the - # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next - dead_time = self.task_params.get('DEAD_TIME', 0.5) - dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end) + # ITI_DELAY_SECS defines the period between hiding the stimulus and start of the next trial's quiescent + # period. The state machine handles 0.5 seconds of this period (in order to deliver a BNC1High event + # required for extraction of the task data). The remaining time is handled here by `time.sleep` to make + # up for processing delays inbetween state-machine runs. + dt = self.task_params.ITI_DELAY_SECS - iti_last_trial - (time.time() - time_last_trial_end) # wait to achieve the desired ITI duration if dt > 0: - log.debug(f'Waiting {dt} s to achieve an ITI duration of {self.task_params.ITI_DELAY_SECS} s') + log.debug('Waiting %0.3f s to achieve an ITI duration of %0.1f s', dt, self.task_params.ITI_DELAY_SECS) time.sleep(dt) # run state machine log.info('-----------------------') - log.info(f'Starting Trial #{trial_number}') + log.info('Starting Trial #%d', trial_number) self.bpod.run_state_machine(sma) # Locks until state machine 'exit' is reached time_last_trial_end = time.time() + # The ITI duration is partially handled by Bpod within the last state of the state machine. + # This state should have a duration of 0.5 seconds (see explanation below). + iti_last_trial = sma.state_timers[sma.total_states_added - 1] + if iti_last_trial != 0.5: + log.warning('ATTENTION: The last state had a duration of %0.1f s. It should be exactly 0.5 s.', iti_last_trial) + # handle pause event if self.paused and trial_number < (self.task_params.NTRIALS - 1): log.info(f'Pausing session inbetween trials {trial_number} and {trial_number + 1}') @@ -541,7 +543,7 @@ def get_state_machine_trial(self, i): # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event. sma.add_state( state_name='exit_state', - state_timer=self.task_params.ITI_DELAY_SECS, + state_timer=min(0.5, self.task_params.ITI_DELAY_SECS), output_actions=[('BNC1', 255)], state_change_conditions={'Tup': 'exit'}, ) diff --git a/iblrig/base_choice_world_params.yaml b/iblrig/base_choice_world_params.yaml index 053716b27..f5c7346c9 100644 --- a/iblrig/base_choice_world_params.yaml +++ b/iblrig/base_choice_world_params.yaml @@ -12,8 +12,7 @@ 'FEEDBACK_ERROR_DELAY_SECS': 2 'FEEDBACK_NOGO_DELAY_SECS': 2 'INTERACTIVE_DELAY': 0.0 -'DEAD_TIME': 0.5 # the length of time before entering the next trial. This plus ITI_DELAY_SECS define period of closed-loop grey screen -'ITI_DELAY_SECS': 0.5 # this is the length of the ITI state at the end of the session. 0.5 seconds are added to it until the next trial start +'ITI_DELAY_SECS': 1.0 # 0.5 seconds will be handled by the state-machine, the rest by python inbetween state-machine runs 'NTRIALS': 2000 'PROBABILITY_LEFT': 0.5 'QUIESCENCE_THRESHOLDS': [-2, 2] From afaf5156032afcaabcdf38711a052cf4e567f833 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 15:59:09 +0100 Subject: [PATCH 08/18] fix test, expand logging --- iblrig/base_choice_world.py | 8 +++++++- iblrig/test/test_choice_world.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index 0c9e463f3..aa51b0d2f 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -281,12 +281,18 @@ def _run(self) -> None: # period. The state machine handles 0.5 seconds of this period (in order to deliver a BNC1High event # required for extraction of the task data). The remaining time is handled here by `time.sleep` to make # up for processing delays inbetween state-machine runs. - dt = self.task_params.ITI_DELAY_SECS - iti_last_trial - (time.time() - time_last_trial_end) + processing_delays = time.time() - time_last_trial_end + dt = self.task_params.ITI_DELAY_SECS - iti_last_trial - processing_delays # wait to achieve the desired ITI duration if dt > 0: log.debug('Waiting %0.3f s to achieve an ITI duration of %0.1f s', dt, self.task_params.ITI_DELAY_SECS) time.sleep(dt) + elif dt < 0: + iti_actual = self.task_params.ITI_DELAY_SECS - dt + log.warning('Inter-trial processing delays of ~%0.3f s could not be corrected for.', processing_delays) + log.warning('Targeted ITI: %0.1f s', self.task_params.ITI_DELAY_SECS) + log.warning('Actual ITI: %0.3f s', iti_actual) # run state machine log.info('-----------------------') diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index c41a7a979..8ac8e24f1 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -275,7 +275,7 @@ def test_iti(self): # the fraction of the ITI handled by time.sleep() making up for processing delays session, sma = self.get_mock_session(2) - counter = count(0, iti_delay_sma + iti_delay_processing) + counter = count(0, iti_delay_processing) with ( patch('iblrig.base_choice_world.time.time', side_effect=lambda: next(counter)), patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, From 05821db62baa686b2b1ff4bc4d95823a5ab43753 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 16:03:23 +0100 Subject: [PATCH 09/18] Update base_choice_world.py --- iblrig/base_choice_world.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index aa51b0d2f..ccf39dea2 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -289,10 +289,8 @@ def _run(self) -> None: log.debug('Waiting %0.3f s to achieve an ITI duration of %0.1f s', dt, self.task_params.ITI_DELAY_SECS) time.sleep(dt) elif dt < 0: - iti_actual = self.task_params.ITI_DELAY_SECS - dt - log.warning('Inter-trial processing delays of ~%0.3f s could not be corrected for.', processing_delays) log.warning('Targeted ITI: %0.1f s', self.task_params.ITI_DELAY_SECS) - log.warning('Actual ITI: %0.3f s', iti_actual) + log.warning('Actual ITI: %0.3f s', self.task_params.ITI_DELAY_SECS - dt) # run state machine log.info('-----------------------') From a76bfe7ab8118f041009ebc027ceb3c49b46de75 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 16:12:42 +0100 Subject: [PATCH 10/18] switch to perf_counter --- iblrig/base_choice_world.py | 8 ++++---- iblrig/test/test_choice_world.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index ccf39dea2..fc4687123 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -245,7 +245,7 @@ def _run(self) -> None: This method orchestrates the execution of the task by running a state machine for a specified number of trials. """ - time_last_trial_end = time.time() + time_last_trial_end = np.nan iti_last_trial = np.nan for trial_number in range(self.task_params.NTRIALS): # Main loop # obtain state machine definition @@ -281,12 +281,12 @@ def _run(self) -> None: # period. The state machine handles 0.5 seconds of this period (in order to deliver a BNC1High event # required for extraction of the task data). The remaining time is handled here by `time.sleep` to make # up for processing delays inbetween state-machine runs. - processing_delays = time.time() - time_last_trial_end + processing_delays = time.perf_counter() - time_last_trial_end dt = self.task_params.ITI_DELAY_SECS - iti_last_trial - processing_delays # wait to achieve the desired ITI duration if dt > 0: - log.debug('Waiting %0.3f s to achieve an ITI duration of %0.1f s', dt, self.task_params.ITI_DELAY_SECS) + log.debug('Sleeping %0.3f s to achieve an ITI duration of %0.1f s', dt, self.task_params.ITI_DELAY_SECS) time.sleep(dt) elif dt < 0: log.warning('Targeted ITI: %0.1f s', self.task_params.ITI_DELAY_SECS) @@ -296,7 +296,7 @@ def _run(self) -> None: log.info('-----------------------') log.info('Starting Trial #%d', trial_number) self.bpod.run_state_machine(sma) # Locks until state machine 'exit' is reached - time_last_trial_end = time.time() + time_last_trial_end = time.perf_counter() # The ITI duration is partially handled by Bpod within the last state of the state machine. # This state should have a duration of 0.5 seconds (see explanation below). diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index 8ac8e24f1..83548ea20 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -277,7 +277,7 @@ def test_iti(self): session, sma = self.get_mock_session(2) counter = count(0, iti_delay_processing) with ( - patch('iblrig.base_choice_world.time.time', side_effect=lambda: next(counter)), + patch('iblrig.base_choice_world.time.perf_counter', side_effect=lambda: next(counter)), patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, ): session._run() From 8d3000f2622677fb845c33ca9123388f307047ab Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 17:48:14 +0100 Subject: [PATCH 11/18] extend tests to `HabituationChoiceWorldSession` --- CHANGELOG.md | 1 + iblrig/test/test_choice_world.py | 82 ++++++++++++++++++-------------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82060c068..9a609db2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Changelog 8.30.0 ------ * disentangle order of states in `HabituationChoiceWorldSession` +* fix ITI durations in `ChoiceWorldSession` and `HabituationChoiceWorldSession` 8.29.0 ------ diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index 83548ea20..deef3a83f 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -13,10 +13,12 @@ import numpy as np import pandas as pd import yaml +from pybpodapi.state_machine import StateMachine +from pybpodapi.state_machine.state_machine_base import StateMachineBase import iblrig.choiceworld from iblrig import session_creator -from iblrig.base_choice_world import ChoiceWorldSession +from iblrig.base_choice_world import ChoiceWorldSession, HabituationChoiceWorldSession from iblrig.path_helper import iterate_previous_sessions from iblrig.raw_data_loaders import load_task_jsonable from iblrig.test.base import BaseTestCases @@ -233,20 +235,19 @@ def test_training_phase_from_contrast_set(self): class TestITI(unittest.TestCase): @staticmethod - def get_mock_session(n_trials: int) -> tuple[MagicMock, MagicMock]: + def get_mock_session(session_class: ChoiceWorldSession, n_trials: int) -> tuple[MagicMock, MagicMock]: """Mock ChoiceWorldSession and StateMachine""" - with ChoiceWorldSession.base_parameters_file.open() as f: - params = Bunch(yaml.safe_load(f)) + params = session_class.read_task_parameter_files() params['NTRIALS'] = n_trials sma = MagicMock() type(sma).total_states_added = PropertyMock(side_effect=lambda: sma.add_state.call_count) type(sma).state_timers = PropertyMock( side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] ) - session = MagicMock().return_value + session = MagicMock(spec=session_class).return_value session.task_params = params - session._run = ChoiceWorldSession._run.__get__(session, ChoiceWorldSession) - session.get_state_machine_trial = ChoiceWorldSession.get_state_machine_trial.__get__(session, ChoiceWorldSession) + session._run = session_class._run.__get__(session, session_class) + session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class) session._instantiate_state_machine.return_value = sma session.paused = False session.stopped = False @@ -254,35 +255,42 @@ def get_mock_session(n_trials: int) -> tuple[MagicMock, MagicMock]: def test_iti_warning(self): # test that the ITI warning is raised when the last state does not handle the ITI - session, sma = self.get_mock_session(1) - type(sma).state_timers = [0.0] * 100 - with patch('iblrig.base_choice_world.time.sleep'), self.assertLogs('iblrig', level='WARNING'): - session._run() + for session_class in [ChoiceWorldSession, HabituationChoiceWorldSession]: + session, sma = self.get_mock_session(session_class, 1) + type(sma).state_timers = [0.0] * 100 + with ( + patch('iblrig.base_choice_world.time.sleep'), + patch('iblrig.base_choice_world.StateMachine', return_value=sma), + self.assertLogs('iblrig', level='WARNING'), + ): + session._run() def test_iti(self): - # the fraction of the ITI handled by the state machine's last state - session, sma = self.get_mock_session(1) - with patch('iblrig.base_choice_world.time.sleep'): - session._run() - iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] - - # the last state of the state machine needs to contain a BNC1 high of a certain duration - for extraction - self.assertGreaterEqual(iti_delay_sma, 0.5, 'Part of the ITI should be handled by the state machine.') - self.assertIn(('BNC1', 255), sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Expecting BNC1 high.') - - # the assumed fraction of the ITI defined by processing delays - iti_delay_processing = 0.031231234234234 - - # the fraction of the ITI handled by time.sleep() making up for processing delays - session, sma = self.get_mock_session(2) - counter = count(0, iti_delay_processing) - with ( - patch('iblrig.base_choice_world.time.perf_counter', side_effect=lambda: next(counter)), - patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, - ): - session._run() - self.assertEqual(session.bpod.run_state_machine.call_count, 2, 'expecting run_state_machine() to have been called twice.') - iti_delay_sleep = mock_sleep.call_args[0][0] if mock_sleep.call_args else 0.0 - - # the total ITI should be 1 second - self.assertAlmostEqual(iti_delay_sma + iti_delay_processing + iti_delay_sleep, 1.0, msg='Total ITI should be 1 second') + for session_class in [ChoiceWorldSession, HabituationChoiceWorldSession]: + # the fraction of the ITI handled by the state machine's last state + session, sma = self.get_mock_session(session_class, 1) + with patch('iblrig.base_choice_world.StateMachine', return_value=sma): + session._run() + iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] + + # the last state of the state machine needs to contain a BNC1 high of a certain duration - for extraction + self.assertGreaterEqual(iti_delay_sma, 0.5, 'Part of the ITI should be handled by the state machine.') + self.assertIn(('BNC1', 255), sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Expecting BNC1 high.') + + # the assumed fraction of the ITI defined by processing delays + iti_delay_processing = 0.031231234234234 + + # the fraction of the ITI handled by time.sleep() making up for processing delays + session, sma = self.get_mock_session(session_class, 2) + counter = count(0, iti_delay_processing) + with ( + patch('iblrig.base_choice_world.time.perf_counter', side_effect=lambda c=counter: next(c)), + patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, + patch('iblrig.base_choice_world.StateMachine', return_value=sma), + ): + session._run() + self.assertEqual(session.bpod.run_state_machine.call_count, 2) + iti_delay_sleep = mock_sleep.call_args[0][0] if mock_sleep.call_args else 0.0 + + # the total ITI should be 1 second + self.assertAlmostEqual(iti_delay_sma + iti_delay_processing + iti_delay_sleep, 1.0, msg='Total ITI should be 1 s') From a8b63b5e17f6ca5b68f9f1f2dae80caaf512f417 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 17:49:08 +0100 Subject: [PATCH 12/18] Update test_choice_world.py --- iblrig/test/test_choice_world.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index deef3a83f..5551e9d9c 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -12,9 +12,6 @@ import numpy as np import pandas as pd -import yaml -from pybpodapi.state_machine import StateMachine -from pybpodapi.state_machine.state_machine_base import StateMachineBase import iblrig.choiceworld from iblrig import session_creator @@ -25,7 +22,6 @@ from iblrig_tasks._iblrig_tasks_passiveChoiceWorld.task import Session as PassiveChoiceWorldSession from iblrig_tasks._iblrig_tasks_spontaneous.task import Session as SpontaneousSession from iblrig_tasks._iblrig_tasks_trainingChoiceWorld.task import Session as TrainingChoiceWorldSession -from iblutil.util import Bunch class TestGetPreviousSession(BaseTestCases.CommonTestTask): From a6a92bbd33754451ee068d46e50dc09470bb35e4 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 18:31:48 +0100 Subject: [PATCH 13/18] work on tests, prepare move to pytest --- iblrig/test/test_choice_world.py | 62 ++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index 5551e9d9c..b9f858f6a 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -231,17 +231,16 @@ def test_training_phase_from_contrast_set(self): class TestITI(unittest.TestCase): @staticmethod - def get_mock_session(session_class: ChoiceWorldSession, n_trials: int) -> tuple[MagicMock, MagicMock]: + def get_mock_session(session_class: ChoiceWorldSession) -> tuple[MagicMock, MagicMock]: """Mock ChoiceWorldSession and StateMachine""" - params = session_class.read_task_parameter_files() - params['NTRIALS'] = n_trials sma = MagicMock() type(sma).total_states_added = PropertyMock(side_effect=lambda: sma.add_state.call_count) type(sma).state_timers = PropertyMock( side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] ) session = MagicMock(spec=session_class).return_value - session.task_params = params + session.task_params = session_class.read_task_parameter_files() + session.task_params['NTRIALS'] = 1 session._run = session_class._run.__get__(session, session_class) session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class) session._instantiate_state_machine.return_value = sma @@ -250,21 +249,31 @@ def get_mock_session(session_class: ChoiceWorldSession, n_trials: int) -> tuple[ return session, sma def test_iti_warning(self): - # test that the ITI warning is raised when the last state does not handle the ITI + # test that the ITI warning is raised when the last state does not correctly handle the ITI for session_class in [ChoiceWorldSession, HabituationChoiceWorldSession]: - session, sma = self.get_mock_session(session_class, 1) - type(sma).state_timers = [0.0] * 100 + session, sma = self.get_mock_session(session_class) + + type(sma).state_timers = [0.0] * 100 # all states have duration of 0.0 with ( patch('iblrig.base_choice_world.time.sleep'), patch('iblrig.base_choice_world.StateMachine', return_value=sma), - self.assertLogs('iblrig', level='WARNING'), + self.assertLogs('iblrig', level='WARNING') as cm, + ): + session._run() + self.assertTrue(any('It should be exactly 0.5 s.' in log for log in cm.output)) + + type(sma).state_timers = [0.5] * 100 # all states have duration of 0.5 + with ( + patch('iblrig.base_choice_world.time.sleep'), + patch('iblrig.base_choice_world.StateMachine', return_value=sma), + self.assertNoLogs('iblrig', level='WARNING'), ): session._run() def test_iti(self): for session_class in [ChoiceWorldSession, HabituationChoiceWorldSession]: # the fraction of the ITI handled by the state machine's last state - session, sma = self.get_mock_session(session_class, 1) + session, sma = self.get_mock_session(session_class) with patch('iblrig.base_choice_world.StateMachine', return_value=sma): session._run() iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] @@ -277,7 +286,8 @@ def test_iti(self): iti_delay_processing = 0.031231234234234 # the fraction of the ITI handled by time.sleep() making up for processing delays - session, sma = self.get_mock_session(session_class, 2) + session, sma = self.get_mock_session(session_class) + session.task_params['NTRIALS'] = 2 counter = count(0, iti_delay_processing) with ( patch('iblrig.base_choice_world.time.perf_counter', side_effect=lambda c=counter: next(c)), @@ -290,3 +300,35 @@ def test_iti(self): # the total ITI should be 1 second self.assertAlmostEqual(iti_delay_sma + iti_delay_processing + iti_delay_sleep, 1.0, msg='Total ITI should be 1 s') + + +# class TestITI2: +# @pytest.fixture(params=[ChoiceWorldSession, HabituationChoiceWorldSession]) +# def mock_session_and_sma(self, request): +# """Fixture that yields (session, sma) for each session_class.""" +# session_class = request.param +# sma = MagicMock() +# type(sma).total_states_added = PropertyMock(side_effect=lambda: sma.add_state.call_count) +# type(sma).state_timers = PropertyMock( +# side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] +# ) +# session = MagicMock(spec=session_class).return_value +# session._instantiate_state_machine.return_value = sma +# session._run = session_class._run.__get__(session, session_class) +# session.task_params = session_class.read_task_parameter_files() +# session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class) +# session.paused = False +# session.stopped = False +# return session, sma +# +# def test_iti_warning(self, mock_session_and_sma, caplog): +# session, sma = mock_session_and_sma +# session.task_params['NTRIALS'] = 1 +# type(sma).state_timers = [0.0] * 100 +# with ( +# patch('iblrig.base_choice_world.time.sleep'), +# patch('iblrig.base_choice_world.StateMachine', return_value=sma), +# caplog.at_level('WARNING'), +# ): +# session._run() +# assert any('ITI' in rec.message for rec in caplog.records) From 04973bdb2c9285a511f6df82b321855b77a99d2f Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 18:38:20 +0100 Subject: [PATCH 14/18] fix ITI in HabituationChoiceWorld --- iblrig/base_choice_world.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index fc4687123..305428df9 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -544,7 +544,7 @@ def get_state_machine_trial(self, i): state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'}, ) - # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event. + # Wait for 0.5 s before ending the trial. Raise BNC1 to mark this event. sma.add_state( state_name='exit_state', state_timer=min(0.5, self.task_params.ITI_DELAY_SECS), @@ -793,10 +793,10 @@ def get_state_machine_trial(self, i): state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'}, ) - # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event. + # Wait for 0.5 s before ending the trial. Raise BNC1 to mark this event. sma.add_state( state_name='exit_state', - state_timer=self.task_params.ITI_DELAY_SECS, + state_timer=min(0.5, self.task_params.ITI_DELAY_SECS), output_actions=[('BNC1', 255)], state_change_conditions={'Tup': 'exit'}, ) From e47f98f032e8bfb720dd45e8fa98079d0a1676f7 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 26 Sep 2025 19:56:33 +0100 Subject: [PATCH 15/18] Update test_choice_world.py --- iblrig/test/test_choice_world.py | 47 ++++++++++---------------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index b9f858f6a..4c07300b7 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -279,7 +279,7 @@ def test_iti(self): iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] # the last state of the state machine needs to contain a BNC1 high of a certain duration - for extraction - self.assertGreaterEqual(iti_delay_sma, 0.5, 'Part of the ITI should be handled by the state machine.') + self.assertEqual(iti_delay_sma, 0.5, 'The fraction of the ITI handled by the state machine should be 0.5 s.') self.assertIn(('BNC1', 255), sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Expecting BNC1 high.') # the assumed fraction of the ITI defined by processing delays @@ -293,42 +293,25 @@ def test_iti(self): patch('iblrig.base_choice_world.time.perf_counter', side_effect=lambda c=counter: next(c)), patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, patch('iblrig.base_choice_world.StateMachine', return_value=sma), + self.assertNoLogs('iblrig', level='WARNING'), ): session._run() self.assertEqual(session.bpod.run_state_machine.call_count, 2) + mock_sleep.assert_called_once() iti_delay_sleep = mock_sleep.call_args[0][0] if mock_sleep.call_args else 0.0 # the total ITI should be 1 second self.assertAlmostEqual(iti_delay_sma + iti_delay_processing + iti_delay_sleep, 1.0, msg='Total ITI should be 1 s') - -# class TestITI2: -# @pytest.fixture(params=[ChoiceWorldSession, HabituationChoiceWorldSession]) -# def mock_session_and_sma(self, request): -# """Fixture that yields (session, sma) for each session_class.""" -# session_class = request.param -# sma = MagicMock() -# type(sma).total_states_added = PropertyMock(side_effect=lambda: sma.add_state.call_count) -# type(sma).state_timers = PropertyMock( -# side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] -# ) -# session = MagicMock(spec=session_class).return_value -# session._instantiate_state_machine.return_value = sma -# session._run = session_class._run.__get__(session, session_class) -# session.task_params = session_class.read_task_parameter_files() -# session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class) -# session.paused = False -# session.stopped = False -# return session, sma -# -# def test_iti_warning(self, mock_session_and_sma, caplog): -# session, sma = mock_session_and_sma -# session.task_params['NTRIALS'] = 1 -# type(sma).state_timers = [0.0] * 100 -# with ( -# patch('iblrig.base_choice_world.time.sleep'), -# patch('iblrig.base_choice_world.StateMachine', return_value=sma), -# caplog.at_level('WARNING'), -# ): -# session._run() -# assert any('ITI' in rec.message for rec in caplog.records) + # if the processing delay is so high that the targeted ITI can't be met log a warning + iti_delay_processing = 0.6 + counter = count(0, iti_delay_processing) + with ( + patch('iblrig.base_choice_world.time.perf_counter', side_effect=lambda c=counter: next(c)), + patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, + patch('iblrig.base_choice_world.StateMachine', return_value=sma), + self.assertLogs('iblrig', level='WARNING') as cm, + ): + session._run() + self.assertTrue(any('Actual ITI: 1.1' in log for log in cm.output)) + mock_sleep.assert_not_called() From ee2a892e1354fd5cbda52b38a5d24c3b8750048e Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Sat, 27 Sep 2025 00:01:16 +0100 Subject: [PATCH 16/18] pytest rocks --- iblrig/test/test_choice_world.py | 154 ++++++++++++++----------------- pyproject.toml | 1 + uv.lock | 18 +++- 3 files changed, 86 insertions(+), 87 deletions(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index 4c07300b7..4e2edbffe 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd +import pytest import iblrig.choiceworld from iblrig import session_creator @@ -229,89 +230,70 @@ def test_training_phase_from_contrast_set(self): iblrig.choiceworld.training_phase_from_contrast_set([0.666]) -class TestITI(unittest.TestCase): - @staticmethod - def get_mock_session(session_class: ChoiceWorldSession) -> tuple[MagicMock, MagicMock]: - """Mock ChoiceWorldSession and StateMachine""" - sma = MagicMock() - type(sma).total_states_added = PropertyMock(side_effect=lambda: sma.add_state.call_count) - type(sma).state_timers = PropertyMock( - side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] - ) - session = MagicMock(spec=session_class).return_value - session.task_params = session_class.read_task_parameter_files() - session.task_params['NTRIALS'] = 1 - session._run = session_class._run.__get__(session, session_class) - session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class) - session._instantiate_state_machine.return_value = sma - session.paused = False - session.stopped = False - return session, sma - - def test_iti_warning(self): - # test that the ITI warning is raised when the last state does not correctly handle the ITI - for session_class in [ChoiceWorldSession, HabituationChoiceWorldSession]: - session, sma = self.get_mock_session(session_class) - - type(sma).state_timers = [0.0] * 100 # all states have duration of 0.0 - with ( - patch('iblrig.base_choice_world.time.sleep'), - patch('iblrig.base_choice_world.StateMachine', return_value=sma), - self.assertLogs('iblrig', level='WARNING') as cm, - ): - session._run() - self.assertTrue(any('It should be exactly 0.5 s.' in log for log in cm.output)) - - type(sma).state_timers = [0.5] * 100 # all states have duration of 0.5 - with ( - patch('iblrig.base_choice_world.time.sleep'), - patch('iblrig.base_choice_world.StateMachine', return_value=sma), - self.assertNoLogs('iblrig', level='WARNING'), - ): - session._run() - - def test_iti(self): - for session_class in [ChoiceWorldSession, HabituationChoiceWorldSession]: - # the fraction of the ITI handled by the state machine's last state - session, sma = self.get_mock_session(session_class) - with patch('iblrig.base_choice_world.StateMachine', return_value=sma): - session._run() - iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] - - # the last state of the state machine needs to contain a BNC1 high of a certain duration - for extraction - self.assertEqual(iti_delay_sma, 0.5, 'The fraction of the ITI handled by the state machine should be 0.5 s.') - self.assertIn(('BNC1', 255), sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Expecting BNC1 high.') - - # the assumed fraction of the ITI defined by processing delays - iti_delay_processing = 0.031231234234234 - - # the fraction of the ITI handled by time.sleep() making up for processing delays - session, sma = self.get_mock_session(session_class) - session.task_params['NTRIALS'] = 2 - counter = count(0, iti_delay_processing) - with ( - patch('iblrig.base_choice_world.time.perf_counter', side_effect=lambda c=counter: next(c)), - patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, - patch('iblrig.base_choice_world.StateMachine', return_value=sma), - self.assertNoLogs('iblrig', level='WARNING'), - ): - session._run() - self.assertEqual(session.bpod.run_state_machine.call_count, 2) - mock_sleep.assert_called_once() - iti_delay_sleep = mock_sleep.call_args[0][0] if mock_sleep.call_args else 0.0 - - # the total ITI should be 1 second - self.assertAlmostEqual(iti_delay_sma + iti_delay_processing + iti_delay_sleep, 1.0, msg='Total ITI should be 1 s') - - # if the processing delay is so high that the targeted ITI can't be met log a warning - iti_delay_processing = 0.6 - counter = count(0, iti_delay_processing) - with ( - patch('iblrig.base_choice_world.time.perf_counter', side_effect=lambda c=counter: next(c)), - patch('iblrig.base_choice_world.time.sleep', return_value=None) as mock_sleep, - patch('iblrig.base_choice_world.StateMachine', return_value=sma), - self.assertLogs('iblrig', level='WARNING') as cm, - ): - session._run() - self.assertTrue(any('Actual ITI: 1.1' in log for log in cm.output)) - mock_sleep.assert_not_called() +class TestITI: + @pytest.fixture(params=[ChoiceWorldSession, HabituationChoiceWorldSession]) + def session_and_sma(self, request, mocker): + def _factory(n_trials: int): + session_class = request.param + sma = MagicMock() + type(sma).total_states_added = PropertyMock(side_effect=lambda: sma.add_state.call_count) + type(sma).state_timers = PropertyMock( + side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] + ) + session = MagicMock(spec=session_class).return_value + session.task_params = session_class.read_task_parameter_files() + session.task_params['NTRIALS'] = n_trials + session._run = session_class._run.__get__(session, session_class) + session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class) + session._instantiate_state_machine.return_value = sma + session.paused = False + session.stopped = False + mocker.patch('iblrig.base_choice_world.StateMachine', return_value=sma) + return session, sma + + return _factory + + @pytest.fixture + def mock_sleep(self, mocker): + return mocker.patch('iblrig.base_choice_world.time.sleep') + + @pytest.fixture + def mock_perf_counter(self, mocker): + def _factory(period: float): + return mocker.patch('iblrig.base_choice_world.time.perf_counter', side_effect=count(0, period)) + + return _factory + + def test_last_state_duration(self, session_and_sma, mock_sleep, caplog): + """The last state should be 0.5 s in duration.""" + session, sma = session_and_sma(n_trials=1) + session._run() + last_state_duration = sma.add_state.call_args_list[-1].kwargs['state_timer'] + assert last_state_duration == 0.5, 'Last state should be 0.5 s in length' + + def test_last_state_duration_warning(self, session_and_sma, mock_sleep, caplog): + """If the last state is not 0.5 s in duration, a warning should be logged.""" + session, sma = session_and_sma(n_trials=1) + type(sma).state_timers = [0.0] * 100 + session._run() + assert 'It should be exactly 0.5 s.' in caplog.text + + def test_iti_components(self, session_and_sma, mock_sleep, mock_perf_counter, caplog, mocker): + """Test if ITI components are computed correctly.""" + session, sma = session_and_sma(n_trials=2) + iti_delay_processing = 0.4321 + mock_perf_counter(period=iti_delay_processing) + session._run() + iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] + iti_delay_sleep = mock_sleep.call_args[0][0] if mock_sleep.call_args else 0.0 + assert ('BNC1', 255) in sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Last state should raise BNC1' + assert mock_sleep.called, 'Sleep should be called' + assert pytest.approx(iti_delay_sma + iti_delay_processing + iti_delay_sleep, rel=1e-6) == 1.0, 'Total ITI should be 1.0 s' + + def test_warning_when_iti_too_high(self, session_and_sma, mock_sleep, mock_perf_counter, caplog, mocker): + """Too high processing delay should log warning.""" + session, sma = session_and_sma(n_trials=2) + mock_perf_counter(period=0.6) + session._run() + assert 'Actual ITI: 1.1' in caplog.text + assert not mock_sleep.called, 'Sleep should not be called' diff --git a/pyproject.toml b/pyproject.toml index 27a1b6b44..be4502872 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ dev = [ test = [ "pytest>=8.4.1", "pytest-cov>=6.2.1", + "pytest-mock>=3.15.1", "pytest-qt>=4.4.0", "pytest-xvfb>=3.1.1", "ruff==0.12.0", diff --git a/uv.lock b/uv.lock index a9890d864..4a56d79d1 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = "==3.10.*" resolution-markers = [ "sys_platform == 'darwin'", @@ -823,6 +823,7 @@ dev = [ { name = "pyqt5-stubs" }, { name = "pytest" }, { name = "pytest-cov" }, + { name = "pytest-mock" }, { name = "pytest-qt" }, { name = "pytest-xvfb" }, { name = "ruff" }, @@ -849,6 +850,7 @@ doc = [ test = [ { name = "pytest" }, { name = "pytest-cov" }, + { name = "pytest-mock" }, { name = "pytest-qt" }, { name = "pytest-xvfb" }, { name = "ruff" }, @@ -908,6 +910,7 @@ dev = [ { name = "pyqt5-stubs", specifier = ">=5.15.6.0" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-cov", specifier = ">=6.2.1" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, { name = "pytest-qt", specifier = ">=4.4.0" }, { name = "pytest-xvfb", specifier = ">=3.1.1" }, { name = "ruff", specifier = "==0.12.0" }, @@ -934,6 +937,7 @@ doc = [ test = [ { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-cov", specifier = ">=6.2.1" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, { name = "pytest-qt", specifier = ">=4.4.0" }, { name = "pytest-xvfb", specifier = ">=3.1.1" }, { name = "ruff", specifier = "==0.12.0" }, @@ -2165,6 +2169,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6d/73/7b0b15cb8605ee967b34aa1d949737ab664f94e6b0f1534e8339d9e64ab2/pytest_github_actions_annotate_failures-0.3.0-py3-none-any.whl", hash = "sha256:41ea558ba10c332c0bfc053daeee0c85187507b2034e990f21e4f7e5fef044cf", size = 6030, upload-time = "2025-01-17T22:39:31.701Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "pytest-qt" version = "4.5.0" From 6361216258ef5d45b20693abea4f71f2db64353f Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Sat, 27 Sep 2025 00:22:14 +0100 Subject: [PATCH 17/18] Update test_choice_world.py --- iblrig/test/test_choice_world.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index 4e2edbffe..b3e8be1a8 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -278,7 +278,7 @@ def test_last_state_duration_warning(self, session_and_sma, mock_sleep, caplog): session._run() assert 'It should be exactly 0.5 s.' in caplog.text - def test_iti_components(self, session_and_sma, mock_sleep, mock_perf_counter, caplog, mocker): + def test_iti_components(self, session_and_sma, mock_sleep, mock_perf_counter, caplog): """Test if ITI components are computed correctly.""" session, sma = session_and_sma(n_trials=2) iti_delay_processing = 0.4321 @@ -290,8 +290,8 @@ def test_iti_components(self, session_and_sma, mock_sleep, mock_perf_counter, ca assert mock_sleep.called, 'Sleep should be called' assert pytest.approx(iti_delay_sma + iti_delay_processing + iti_delay_sleep, rel=1e-6) == 1.0, 'Total ITI should be 1.0 s' - def test_warning_when_iti_too_high(self, session_and_sma, mock_sleep, mock_perf_counter, caplog, mocker): - """Too high processing delay should log warning.""" + def test_warning_when_iti_too_high(self, session_and_sma, mock_sleep, mock_perf_counter, caplog): + """Test if larger than intended ITI is logged with a warning.""" session, sma = session_and_sma(n_trials=2) mock_perf_counter(period=0.6) session._run() From 7705b438df3bbc28e80a0d9d5ae11004dd32b5e4 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Sat, 27 Sep 2025 18:56:29 +0100 Subject: [PATCH 18/18] Update test_choice_world.py --- iblrig/test/test_choice_world.py | 51 +++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index b3e8be1a8..57bd8eae1 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -8,7 +8,7 @@ import unittest from itertools import count from pathlib import Path -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import patch import numpy as np import pandas as pd @@ -16,7 +16,12 @@ import iblrig.choiceworld from iblrig import session_creator -from iblrig.base_choice_world import ChoiceWorldSession, HabituationChoiceWorldSession +from iblrig.base_choice_world import ( + ActiveChoiceWorldSession, + BiasedChoiceWorldSession, + ChoiceWorldSession, + HabituationChoiceWorldSession, +) from iblrig.path_helper import iterate_previous_sessions from iblrig.raw_data_loaders import load_task_jsonable from iblrig.test.base import BaseTestCases @@ -231,24 +236,50 @@ def test_training_phase_from_contrast_set(self): class TestITI: - @pytest.fixture(params=[ChoiceWorldSession, HabituationChoiceWorldSession]) + @pytest.fixture( + params=[ + ChoiceWorldSession, + HabituationChoiceWorldSession, + ActiveChoiceWorldSession, + BiasedChoiceWorldSession, + TrainingChoiceWorldSession, + ] + ) def session_and_sma(self, request, mocker): def _factory(n_trials: int): session_class = request.param - sma = MagicMock() - type(sma).total_states_added = PropertyMock(side_effect=lambda: sma.add_state.call_count) - type(sma).state_timers = PropertyMock( + + # Mocked StateMachine + sma = mocker.MagicMock() + type(sma).total_states_added = mocker.PropertyMock(side_effect=lambda: sma.add_state.call_count) + type(sma).state_timers = mocker.PropertyMock( side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] ) - session = MagicMock(spec=session_class).return_value - session.task_params = session_class.read_task_parameter_files() - session.task_params['NTRIALS'] = n_trials + + # Create autospec instance + session = mocker.create_autospec(session_class, instance=True) + session.bpod = mocker.MagicMock() + session.trials_table = mocker.MagicMock() + session.trial_num = mocker.MagicMock() + session.movement_left = mocker.MagicMock() + session.movement_right = mocker.MagicMock() + session.interactive = mocker.MagicMock() + session.paths = mocker.MagicMock() + + # Restore real methods session._run = session_class._run.__get__(session, session_class) session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class) + + # Patch returned StateMachine session._instantiate_state_machine.return_value = sma + mocker.patch('iblrig.base_choice_world.StateMachine', return_value=sma) + + # Minimal task parameters + session.task_params = session_class.read_task_parameter_files() + session.task_params['NTRIALS'] = n_trials session.paused = False session.stopped = False - mocker.patch('iblrig.base_choice_world.StateMachine', return_value=sma) + return session, sma return _factory