diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e1bd0117..9a609db2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ Changelog ========= +8.30.0 +------ +* disentangle order of states in `HabituationChoiceWorldSession` +* fix ITI durations in `ChoiceWorldSession` and `HabituationChoiceWorldSession` + 8.29.0 ------ * added: GUI settings for changing MAIN_SYNC diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index 5ad0dfdc3..5af298f4c 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -245,7 +245,8 @@ def _run(self) -> None: This method orchestrates the execution of the task by running a state machine for a specified number of trials. """ - time_last_trial_end = time.time() + time_last_trial_end = np.nan + iti_last_trial = np.nan for trial_number in range(self.task_params.NTRIALS): # Main loop # obtain state machine definition self.next_trial() @@ -266,33 +267,42 @@ def _run(self) -> None: log.warning("'Deprecation Notes' in IBLRIG's documentation.") log.warning('**********************************************') log.warning('') - log.info('Waiting for 10s so you actually read this message ;-)') + log.warning('Waiting for 10s so you actually read this message ;-)') time.sleep(10) else: self._wait_for_camera_and_initial_delay() # send state machine description to Bpod device - log.debug('Sending state machine to bpod') self.bpod.send_state_machine(sma) # handle ITI durations if trial_number > 0: - # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the - # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next - dead_time = self.task_params.get('DEAD_TIME', 0.5) - dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end) + # ITI_DELAY_SECS defines the period between hiding the stimulus and start of the next trial's quiescent + # period. The state machine handles 0.5 seconds of this period (in order to deliver a BNC1High event + # required for extraction of the task data). The remaining time is handled here by `time.sleep` to make + # up for processing delays inbetween state-machine runs. + processing_delays = time.perf_counter() - time_last_trial_end + dt = self.task_params.ITI_DELAY_SECS - iti_last_trial - processing_delays # wait to achieve the desired ITI duration if dt > 0: - log.debug(f'Waiting {dt} s to achieve an ITI duration of {self.task_params.ITI_DELAY_SECS} s') + log.debug('Sleeping %0.3f s to achieve an ITI duration of %0.1f s', dt, self.task_params.ITI_DELAY_SECS) time.sleep(dt) + elif dt < 0: + log.warning('Targeted ITI: %0.1f s', self.task_params.ITI_DELAY_SECS) + log.warning('Actual ITI: %0.3f s', self.task_params.ITI_DELAY_SECS - dt) # run state machine log.info('-----------------------') - log.info(f'Starting Trial #{trial_number}') - log.debug('running state machine') + log.info('Starting Trial #%d', trial_number) self.bpod.run_state_machine(sma) # Locks until state machine 'exit' is reached - time_last_trial_end = time.time() + time_last_trial_end = time.perf_counter() + + # The ITI duration is partially handled by Bpod within the last state of the state machine. + # This state should have a duration of 0.5 seconds (see explanation below). + iti_last_trial = sma.state_timers[sma.total_states_added - 1] + if iti_last_trial != 0.5: + log.warning('ATTENTION: The last state had a duration of %0.1f s. It should be exactly 0.5 s.', iti_last_trial) # handle pause event if self.paused and trial_number < (self.task_params.NTRIALS - 1): @@ -534,10 +544,10 @@ def get_state_machine_trial(self, i): state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'}, ) - # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event. + # Wait for 0.5 s before ending the trial. Raise BNC1 to mark this event. sma.add_state( state_name='exit_state', - state_timer=self.task_params.ITI_DELAY_SECS, + state_timer=min(0.5, self.task_params.ITI_DELAY_SECS), output_actions=[('BNC1', 255)], state_change_conditions={'Tup': 'exit'}, ) @@ -734,22 +744,22 @@ def draw_next_trial_info(self, *args, **kwargs): def get_state_machine_trial(self, i): sma = StateMachine(self.bpod) - # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on. - # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end; - # the offset of this state is trial start! + # Show the visual stimulus. + # Move to next state if Frame2TTL event is detected. + # Use the state-timer as a backup to prevent a stall. sma.add_state( - state_name='iti', - state_timer=1, # Stim off for 1 sec - state_change_conditions={'Tup': 'stim_on'}, - output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)], + state_name='stim_on', + state_timer=0.1, + state_change_conditions={'Tup': 'play_tone', 'BNC1High': 'play_tone', 'BNC1Low': 'play_tone'}, + output_actions=[self.bpod.actions.bonsai_show_stim, ('BNC1', 255)], ) - # This stim_on state is considered the actual trial start + # Play tone and wait for `delay_to_stim_center`. sma.add_state( - state_name='stim_on', + state_name='play_tone', state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'], + output_actions=[self.bpod.actions.play_tone], state_change_conditions={'Tup': 'stim_center'}, - output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone], ) sma.add_state( @@ -765,15 +775,32 @@ def get_state_machine_trial(self, i): state_change_conditions={'Tup': 'post_reward'}, output_actions=[('Valve1', 255), ('BNC1', 255)], ) - # This state defines the period after reward where Bpod TTL is LOW. - # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit. - # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds. + sma.add_state( state_name='post_reward', - state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time, - state_change_conditions={'Tup': 'exit'}, + state_timer=0.5 - self.reward_time, + state_change_conditions={'Tup': 'hide_stim'}, output_actions=[], ) + + # Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary + # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e., + # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall. + sma.add_state( + state_name='hide_stim', + state_timer=0.1, + output_actions=[self.bpod.actions.bonsai_hide_stim], + state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'}, + ) + + # Wait for 0.5 s before ending the trial. Raise BNC1 to mark this event. + sma.add_state( + state_name='exit_state', + state_timer=min(0.5, self.task_params.ITI_DELAY_SECS), + output_actions=[('BNC1', 255)], + state_change_conditions={'Tup': 'exit'}, + ) + return sma diff --git a/iblrig/base_choice_world_params.yaml b/iblrig/base_choice_world_params.yaml index 053716b27..f5c7346c9 100644 --- a/iblrig/base_choice_world_params.yaml +++ b/iblrig/base_choice_world_params.yaml @@ -12,8 +12,7 @@ 'FEEDBACK_ERROR_DELAY_SECS': 2 'FEEDBACK_NOGO_DELAY_SECS': 2 'INTERACTIVE_DELAY': 0.0 -'DEAD_TIME': 0.5 # the length of time before entering the next trial. This plus ITI_DELAY_SECS define period of closed-loop grey screen -'ITI_DELAY_SECS': 0.5 # this is the length of the ITI state at the end of the session. 0.5 seconds are added to it until the next trial start +'ITI_DELAY_SECS': 1.0 # 0.5 seconds will be handled by the state-machine, the rest by python inbetween state-machine runs 'NTRIALS': 2000 'PROBABILITY_LEFT': 0.5 'QUIESCENCE_THRESHOLDS': [-2, 2] diff --git a/iblrig/test/test_choice_world.py b/iblrig/test/test_choice_world.py index ab797d7d4..57bd8eae1 100644 --- a/iblrig/test/test_choice_world.py +++ b/iblrig/test/test_choice_world.py @@ -6,14 +6,22 @@ import shutil import tempfile import unittest +from itertools import count from pathlib import Path from unittest.mock import patch import numpy as np import pandas as pd +import pytest import iblrig.choiceworld from iblrig import session_creator +from iblrig.base_choice_world import ( + ActiveChoiceWorldSession, + BiasedChoiceWorldSession, + ChoiceWorldSession, + HabituationChoiceWorldSession, +) from iblrig.path_helper import iterate_previous_sessions from iblrig.raw_data_loaders import load_task_jsonable from iblrig.test.base import BaseTestCases @@ -225,3 +233,98 @@ def test_training_phase_from_contrast_set(self): self.assertEqual(iblrig.choiceworld.training_phase_from_contrast_set(contrasts3), phase) with self.assertRaises(ValueError): iblrig.choiceworld.training_phase_from_contrast_set([0.666]) + + +class TestITI: + @pytest.fixture( + params=[ + ChoiceWorldSession, + HabituationChoiceWorldSession, + ActiveChoiceWorldSession, + BiasedChoiceWorldSession, + TrainingChoiceWorldSession, + ] + ) + def session_and_sma(self, request, mocker): + def _factory(n_trials: int): + session_class = request.param + + # Mocked StateMachine + sma = mocker.MagicMock() + type(sma).total_states_added = mocker.PropertyMock(side_effect=lambda: sma.add_state.call_count) + type(sma).state_timers = mocker.PropertyMock( + side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list] + ) + + # Create autospec instance + session = mocker.create_autospec(session_class, instance=True) + session.bpod = mocker.MagicMock() + session.trials_table = mocker.MagicMock() + session.trial_num = mocker.MagicMock() + session.movement_left = mocker.MagicMock() + session.movement_right = mocker.MagicMock() + session.interactive = mocker.MagicMock() + session.paths = mocker.MagicMock() + + # Restore real methods + session._run = session_class._run.__get__(session, session_class) + session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class) + + # Patch returned StateMachine + session._instantiate_state_machine.return_value = sma + mocker.patch('iblrig.base_choice_world.StateMachine', return_value=sma) + + # Minimal task parameters + session.task_params = session_class.read_task_parameter_files() + session.task_params['NTRIALS'] = n_trials + session.paused = False + session.stopped = False + + return session, sma + + return _factory + + @pytest.fixture + def mock_sleep(self, mocker): + return mocker.patch('iblrig.base_choice_world.time.sleep') + + @pytest.fixture + def mock_perf_counter(self, mocker): + def _factory(period: float): + return mocker.patch('iblrig.base_choice_world.time.perf_counter', side_effect=count(0, period)) + + return _factory + + def test_last_state_duration(self, session_and_sma, mock_sleep, caplog): + """The last state should be 0.5 s in duration.""" + session, sma = session_and_sma(n_trials=1) + session._run() + last_state_duration = sma.add_state.call_args_list[-1].kwargs['state_timer'] + assert last_state_duration == 0.5, 'Last state should be 0.5 s in length' + + def test_last_state_duration_warning(self, session_and_sma, mock_sleep, caplog): + """If the last state is not 0.5 s in duration, a warning should be logged.""" + session, sma = session_and_sma(n_trials=1) + type(sma).state_timers = [0.0] * 100 + session._run() + assert 'It should be exactly 0.5 s.' in caplog.text + + def test_iti_components(self, session_and_sma, mock_sleep, mock_perf_counter, caplog): + """Test if ITI components are computed correctly.""" + session, sma = session_and_sma(n_trials=2) + iti_delay_processing = 0.4321 + mock_perf_counter(period=iti_delay_processing) + session._run() + iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer'] + iti_delay_sleep = mock_sleep.call_args[0][0] if mock_sleep.call_args else 0.0 + assert ('BNC1', 255) in sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Last state should raise BNC1' + assert mock_sleep.called, 'Sleep should be called' + assert pytest.approx(iti_delay_sma + iti_delay_processing + iti_delay_sleep, rel=1e-6) == 1.0, 'Total ITI should be 1.0 s' + + def test_warning_when_iti_too_high(self, session_and_sma, mock_sleep, mock_perf_counter, caplog): + """Test if larger than intended ITI is logged with a warning.""" + session, sma = session_and_sma(n_trials=2) + mock_perf_counter(period=0.6) + session._run() + assert 'Actual ITI: 1.1' in caplog.text + assert not mock_sleep.called, 'Sleep should not be called' diff --git a/pyproject.toml b/pyproject.toml index 23cc19462..be4502872 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "PyQt5-Qt5>=5.15.2", "PyQtWebEngine>=5.15.6", "PyQtWebEngine-Qt5>=5.15.2", + "pandas-stubs==2.3.0.250703", ] [project.optional-dependencies] project-extraction = [ @@ -92,6 +93,7 @@ dev = [ test = [ "pytest>=8.4.1", "pytest-cov>=6.2.1", + "pytest-mock>=3.15.1", "pytest-qt>=4.4.0", "pytest-xvfb>=3.1.1", "ruff==0.12.0", diff --git a/uv.lock b/uv.lock index d51f90b21..4a56d79d1 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = "==3.10.*" resolution-markers = [ "sys_platform == 'darwin'", @@ -786,6 +786,7 @@ dependencies = [ { name = "one-api" }, { name = "packaging" }, { name = "pandas" }, + { name = "pandas-stubs" }, { name = "pandera" }, { name = "psutil" }, { name = "pydantic" }, @@ -822,6 +823,7 @@ dev = [ { name = "pyqt5-stubs" }, { name = "pytest" }, { name = "pytest-cov" }, + { name = "pytest-mock" }, { name = "pytest-qt" }, { name = "pytest-xvfb" }, { name = "ruff" }, @@ -848,6 +850,7 @@ doc = [ test = [ { name = "pytest" }, { name = "pytest-cov" }, + { name = "pytest-mock" }, { name = "pytest-qt" }, { name = "pytest-xvfb" }, { name = "ruff" }, @@ -877,6 +880,7 @@ requires-dist = [ { name = "one-api", specifier = ">=3.3.0" }, { name = "packaging", specifier = ">=25.0" }, { name = "pandas", specifier = ">=2.3.0" }, + { name = "pandas-stubs", specifier = "==2.3.0.250703" }, { name = "pandera", specifier = ">=0.24.0" }, { name = "project-extraction", marker = "extra == 'project-extraction'", git = "https://github.com/int-brain-lab/project_extraction.git" }, { name = "psutil", specifier = ">=7.0.0" }, @@ -906,6 +910,7 @@ dev = [ { name = "pyqt5-stubs", specifier = ">=5.15.6.0" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-cov", specifier = ">=6.2.1" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, { name = "pytest-qt", specifier = ">=4.4.0" }, { name = "pytest-xvfb", specifier = ">=3.1.1" }, { name = "ruff", specifier = "==0.12.0" }, @@ -932,6 +937,7 @@ doc = [ test = [ { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-cov", specifier = ">=6.2.1" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, { name = "pytest-qt", specifier = ">=4.4.0" }, { name = "pytest-xvfb", specifier = ">=3.1.1" }, { name = "ruff", specifier = "==0.12.0" }, @@ -1607,6 +1613,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/d6/d7f5777162aa9b48ec3910bca5a58c9b5927cfd9cfde3aa64322f5ba4b9f/pandas-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:2eb789ae0274672acbd3c575b0598d213345660120a257b47b5dafdc618aec83", size = 11336561, upload-time = "2025-07-07T19:18:31.211Z" }, ] +[[package]] +name = "pandas-stubs" +version = "2.3.0.250703" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "types-pytz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ec/df/c1c51c5cec087b8f4d04669308b700e9648745a77cdd0c8c5e16520703ca/pandas_stubs-2.3.0.250703.tar.gz", hash = "sha256:fb6a8478327b16ed65c46b1541de74f5c5947f3601850caf3e885e0140584717", size = 103910, upload-time = "2025-07-02T17:49:11.667Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/cb/09d5f9bf7c8659af134ae0ffc1a349038a5d0ff93e45aedc225bde2872a3/pandas_stubs-2.3.0.250703-py3-none-any.whl", hash = "sha256:a9265fc69909f0f7a9cabc5f596d86c9d531499fed86b7838fd3278285d76b81", size = 154719, upload-time = "2025-07-02T17:49:10.697Z" }, +] + [[package]] name = "pandera" version = "0.25.0" @@ -2150,6 +2169,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6d/73/7b0b15cb8605ee967b34aa1d949737ab664f94e6b0f1534e8339d9e64ab2/pytest_github_actions_annotate_failures-0.3.0-py3-none-any.whl", hash = "sha256:41ea558ba10c332c0bfc053daeee0c85187507b2034e990f21e4f7e5fef044cf", size = 6030, upload-time = "2025-01-17T22:39:31.701Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "pytest-qt" version = "4.5.0" @@ -3047,6 +3078,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/52/43e70a8e57fefb172c22a21000b03ebcc15e47e97f5cb8495b9c2832efb4/types_python_dateutil-2.9.0.20250708-py3-none-any.whl", hash = "sha256:4d6d0cc1cc4d24a2dc3816024e502564094497b713f7befda4d5bc7a8e3fd21f", size = 17724, upload-time = "2025-07-08T03:14:02.593Z" }, ] +[[package]] +name = "types-pytz" +version = "2025.2.0.20250516" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/72/b0e711fd90409f5a76c75349055d3eb19992c110f0d2d6aabbd6cfbc14bf/types_pytz-2025.2.0.20250516.tar.gz", hash = "sha256:e1216306f8c0d5da6dafd6492e72eb080c9a166171fa80dd7a1990fd8be7a7b3", size = 10940, upload-time = "2025-05-16T03:07:01.91Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl", hash = "sha256:e0e0c8a57e2791c19f718ed99ab2ba623856b11620cb6b637e5f62ce285a7451", size = 10136, upload-time = "2025-05-16T03:07:01.075Z" }, +] + [[package]] name = "types-pyyaml" version = "6.0.12.20250516"