Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
=========

8.30.0
------
* disentangle order of states in `HabituationChoiceWorldSession`
* fix ITI durations in `ChoiceWorldSession` and `HabituationChoiceWorldSession`

8.29.0
------
* added: GUI settings for changing MAIN_SYNC
Expand Down
83 changes: 55 additions & 28 deletions iblrig/base_choice_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def _run(self) -> None:

This method orchestrates the execution of the task by running a state machine for a specified number of trials.
"""
time_last_trial_end = time.time()
time_last_trial_end = np.nan
iti_last_trial = np.nan
for trial_number in range(self.task_params.NTRIALS): # Main loop
# obtain state machine definition
self.next_trial()
Expand All @@ -266,33 +267,42 @@ def _run(self) -> None:
log.warning("'Deprecation Notes' in IBLRIG's documentation.")
log.warning('**********************************************')
log.warning('')
log.info('Waiting for 10s so you actually read this message ;-)')
log.warning('Waiting for 10s so you actually read this message ;-)')
time.sleep(10)
else:
self._wait_for_camera_and_initial_delay()

# send state machine description to Bpod device
log.debug('Sending state machine to bpod')
self.bpod.send_state_machine(sma)

# handle ITI durations
if trial_number > 0:
# The ITI_DELAY_SECS defines the grey screen period within the state machine, where the
# Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next
dead_time = self.task_params.get('DEAD_TIME', 0.5)
dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end)
# ITI_DELAY_SECS defines the period between hiding the stimulus and start of the next trial's quiescent
# period. The state machine handles 0.5 seconds of this period (in order to deliver a BNC1High event
# required for extraction of the task data). The remaining time is handled here by `time.sleep` to make
# up for processing delays inbetween state-machine runs.
processing_delays = time.perf_counter() - time_last_trial_end
dt = self.task_params.ITI_DELAY_SECS - iti_last_trial - processing_delays

# wait to achieve the desired ITI duration
if dt > 0:
log.debug(f'Waiting {dt} s to achieve an ITI duration of {self.task_params.ITI_DELAY_SECS} s')
log.debug('Sleeping %0.3f s to achieve an ITI duration of %0.1f s', dt, self.task_params.ITI_DELAY_SECS)
time.sleep(dt)
elif dt < 0:
log.warning('Targeted ITI: %0.1f s', self.task_params.ITI_DELAY_SECS)
log.warning('Actual ITI: %0.3f s', self.task_params.ITI_DELAY_SECS - dt)

# run state machine
log.info('-----------------------')
log.info(f'Starting Trial #{trial_number}')
log.debug('running state machine')
log.info('Starting Trial #%d', trial_number)
self.bpod.run_state_machine(sma) # Locks until state machine 'exit' is reached
time_last_trial_end = time.time()
time_last_trial_end = time.perf_counter()

# The ITI duration is partially handled by Bpod within the last state of the state machine.
# This state should have a duration of 0.5 seconds (see explanation below).
iti_last_trial = sma.state_timers[sma.total_states_added - 1]
if iti_last_trial != 0.5:
log.warning('ATTENTION: The last state had a duration of %0.1f s. It should be exactly 0.5 s.', iti_last_trial)

# handle pause event
if self.paused and trial_number < (self.task_params.NTRIALS - 1):
Expand Down Expand Up @@ -534,10 +544,10 @@ def get_state_machine_trial(self, i):
state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
)

# Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event.
# Wait for 0.5 s before ending the trial. Raise BNC1 to mark this event.
sma.add_state(
state_name='exit_state',
state_timer=self.task_params.ITI_DELAY_SECS,
state_timer=min(0.5, self.task_params.ITI_DELAY_SECS),
output_actions=[('BNC1', 255)],
state_change_conditions={'Tup': 'exit'},
)
Expand Down Expand Up @@ -734,22 +744,22 @@ def draw_next_trial_info(self, *args, **kwargs):
def get_state_machine_trial(self, i):
sma = StateMachine(self.bpod)

# NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on.
# During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end;
# the offset of this state is trial start!
# Show the visual stimulus.
# Move to next state if Frame2TTL event is detected.
# Use the state-timer as a backup to prevent a stall.
sma.add_state(
state_name='iti',
state_timer=1, # Stim off for 1 sec
state_change_conditions={'Tup': 'stim_on'},
output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)],
state_name='stim_on',
state_timer=0.1,
state_change_conditions={'Tup': 'stim_center', 'BNC1High': 'play_tone', 'BNC1Low': 'play_tone'},
output_actions=[self.bpod.actions.bonsai_show_stim, ('BNC1', 255)],
)

# This stim_on state is considered the actual trial start
# Play tone and wait for `delay_to_stim_center`.
sma.add_state(
state_name='stim_on',
state_name='play_tone',
state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'],
output_actions=[self.bpod.actions.play_tone],
state_change_conditions={'Tup': 'stim_center'},
output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone],
)

sma.add_state(
Expand All @@ -765,15 +775,32 @@ def get_state_machine_trial(self, i):
state_change_conditions={'Tup': 'post_reward'},
output_actions=[('Valve1', 255), ('BNC1', 255)],
)
# This state defines the period after reward where Bpod TTL is LOW.
# NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit.
# The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds.

sma.add_state(
state_name='post_reward',
state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time,
state_change_conditions={'Tup': 'exit'},
state_timer=0.5 - self.reward_time,
state_change_conditions={'Tup': 'hide_stim'},
output_actions=[],
)

# Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
# Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
# when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
sma.add_state(
state_name='hide_stim',
state_timer=0.1,
output_actions=[self.bpod.actions.bonsai_hide_stim],
state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
)

# Wait for 0.5 s before ending the trial. Raise BNC1 to mark this event.
sma.add_state(
state_name='exit_state',
state_timer=min(0.5, self.task_params.ITI_DELAY_SECS),
output_actions=[('BNC1', 255)],
state_change_conditions={'Tup': 'exit'},
)

return sma


Expand Down
3 changes: 1 addition & 2 deletions iblrig/base_choice_world_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
'FEEDBACK_ERROR_DELAY_SECS': 2
'FEEDBACK_NOGO_DELAY_SECS': 2
'INTERACTIVE_DELAY': 0.0
'DEAD_TIME': 0.5 # the length of time before entering the next trial. This plus ITI_DELAY_SECS define period of closed-loop grey screen
'ITI_DELAY_SECS': 0.5 # this is the length of the ITI state at the end of the session. 0.5 seconds are added to it until the next trial start
'ITI_DELAY_SECS': 1.0 # 0.5 seconds will be handled by the state-machine, the rest by python inbetween state-machine runs
'NTRIALS': 2000
'PROBABILITY_LEFT': 0.5
'QUIESCENCE_THRESHOLDS': [-2, 2]
Expand Down
103 changes: 103 additions & 0 deletions iblrig/test/test_choice_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,22 @@
import shutil
import tempfile
import unittest
from itertools import count
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pandas as pd
import pytest

import iblrig.choiceworld
from iblrig import session_creator
from iblrig.base_choice_world import (
ActiveChoiceWorldSession,
BiasedChoiceWorldSession,
ChoiceWorldSession,
HabituationChoiceWorldSession,
)
from iblrig.path_helper import iterate_previous_sessions
from iblrig.raw_data_loaders import load_task_jsonable
from iblrig.test.base import BaseTestCases
Expand Down Expand Up @@ -225,3 +233,98 @@ def test_training_phase_from_contrast_set(self):
self.assertEqual(iblrig.choiceworld.training_phase_from_contrast_set(contrasts3), phase)
with self.assertRaises(ValueError):
iblrig.choiceworld.training_phase_from_contrast_set([0.666])


class TestITI:
@pytest.fixture(
params=[
ChoiceWorldSession,
HabituationChoiceWorldSession,
ActiveChoiceWorldSession,
BiasedChoiceWorldSession,
TrainingChoiceWorldSession,
]
)
def session_and_sma(self, request, mocker):
def _factory(n_trials: int):
session_class = request.param

# Mocked StateMachine
sma = mocker.MagicMock()
type(sma).total_states_added = mocker.PropertyMock(side_effect=lambda: sma.add_state.call_count)
type(sma).state_timers = mocker.PropertyMock(
side_effect=lambda: [float(x.kwargs['state_timer']) for x in sma.add_state.call_args_list]
)

# Create autospec instance
session = mocker.create_autospec(session_class, instance=True)
session.bpod = mocker.MagicMock()
session.trials_table = mocker.MagicMock()
session.trial_num = mocker.MagicMock()
session.movement_left = mocker.MagicMock()
session.movement_right = mocker.MagicMock()
session.interactive = mocker.MagicMock()
session.paths = mocker.MagicMock()

# Restore real methods
session._run = session_class._run.__get__(session, session_class)
session.get_state_machine_trial = session_class.get_state_machine_trial.__get__(session, session_class)

# Patch returned StateMachine
session._instantiate_state_machine.return_value = sma
mocker.patch('iblrig.base_choice_world.StateMachine', return_value=sma)

# Minimal task parameters
session.task_params = session_class.read_task_parameter_files()
session.task_params['NTRIALS'] = n_trials
session.paused = False
session.stopped = False

return session, sma

return _factory

@pytest.fixture
def mock_sleep(self, mocker):
return mocker.patch('iblrig.base_choice_world.time.sleep')

@pytest.fixture
def mock_perf_counter(self, mocker):
def _factory(period: float):
return mocker.patch('iblrig.base_choice_world.time.perf_counter', side_effect=count(0, period))

return _factory

def test_last_state_duration(self, session_and_sma, mock_sleep, caplog):
"""The last state should be 0.5 s in duration."""
session, sma = session_and_sma(n_trials=1)
session._run()
last_state_duration = sma.add_state.call_args_list[-1].kwargs['state_timer']
assert last_state_duration == 0.5, 'Last state should be 0.5 s in length'

def test_last_state_duration_warning(self, session_and_sma, mock_sleep, caplog):
"""If the last state is not 0.5 s in duration, a warning should be logged."""
session, sma = session_and_sma(n_trials=1)
type(sma).state_timers = [0.0] * 100
session._run()
assert 'It should be exactly 0.5 s.' in caplog.text

def test_iti_components(self, session_and_sma, mock_sleep, mock_perf_counter, caplog):
"""Test if ITI components are computed correctly."""
session, sma = session_and_sma(n_trials=2)
iti_delay_processing = 0.4321
mock_perf_counter(period=iti_delay_processing)
session._run()
iti_delay_sma = sma.add_state.call_args_list[-1].kwargs['state_timer']
iti_delay_sleep = mock_sleep.call_args[0][0] if mock_sleep.call_args else 0.0
assert ('BNC1', 255) in sma.add_state.call_args_list[-1].kwargs['output_actions'], 'Last state should raise BNC1'
assert mock_sleep.called, 'Sleep should be called'
assert pytest.approx(iti_delay_sma + iti_delay_processing + iti_delay_sleep, rel=1e-6) == 1.0, 'Total ITI should be 1.0 s'

def test_warning_when_iti_too_high(self, session_and_sma, mock_sleep, mock_perf_counter, caplog):
"""Test if larger than intended ITI is logged with a warning."""
session, sma = session_and_sma(n_trials=2)
mock_perf_counter(period=0.6)
session._run()
assert 'Actual ITI: 1.1' in caplog.text
assert not mock_sleep.called, 'Sleep should not be called'
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies = [
"PyQt5-Qt5>=5.15.2",
"PyQtWebEngine>=5.15.6",
"PyQtWebEngine-Qt5>=5.15.2",
"pandas-stubs==2.3.0.250703",
]
[project.optional-dependencies]
project-extraction = [
Expand Down Expand Up @@ -92,6 +93,7 @@ dev = [
test = [
"pytest>=8.4.1",
"pytest-cov>=6.2.1",
"pytest-mock>=3.15.1",
"pytest-qt>=4.4.0",
"pytest-xvfb>=3.1.1",
"ruff==0.12.0",
Expand Down
Loading
Loading