Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions iblrig/ephys.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
import argparse
import asyncio
import datetime
import logging
import string
from pathlib import Path

import numpy as np

from iblatlas import atlas
from iblrig.base_tasks import EmptySession
from iblrig.net import get_server_communicator, read_stdin, update_alyx_token
from iblrig.path_helper import load_pydantic_yaml
from iblrig.pydantic_definitions import RigSettings
from iblrig.transfer_experiments import EphysCopier
from iblutil.io import net
from iblutil.util import setup_logger
from one.alf.io import next_num_folder
from one.api import OneAlyx


def prepare_ephys_session_cmd():
parser = argparse.ArgumentParser(prog='start_video_session', description='Prepare video PC for video recording session.')
parser.add_argument('subject_name', help='name of subject')
parser.add_argument('nprobes', help='number of probes', type=int, default=2)
parser.add_argument('--debug', action='store_true', help='enable debugging mode')
parser.add_argument(
'--service-uri',
required=False,
nargs='?',
default=False,
type=str,
help='the service URI to listen to messages on. pass ":<port>" to specify port only.',
)
args = parser.parse_args()
setup_logger(name='iblrig', level='DEBUG' if args.debug else 'INFO')
prepare_ephys_session(args.subject_name, args.nprobes)
setup_logger(name=__name__, level='DEBUG' if args.debug else 'INFO')
if args.service_uri:
asyncio.run(main_v8_networked(args.subject_name, args.debug, args.nprobes, args.service_uri))
else:
prepare_ephys_session(args.subject_name, args.nprobes)


def prepare_ephys_session(subject_name: str, nprobes: int = 2):
Expand Down Expand Up @@ -74,3 +95,120 @@ def neuropixel24_micromanipulator_coordinates(ref_shank, pname, ba=None, shank_s
shank['depth'] = ref_shank['depth'] + (xyz_entry[2] - xyz_ref[2]) * 1e6
trajectories[f'{pname}{string.ascii_lowercase[i]}'] = shank
return trajectories


async def main_v8_networked(mouse, debug=False, n_probes=2, service_uri=None):
# from iblrig.base_tasks import EmptySession

log = logging.getLogger(__name__)

# if PARAMS.get('PROBE_TYPE_00', '3B') != '3B' or PARAMS.get('PROBE_TYPE_01', '3B') != '3B':
# raise NotImplementedError('Only 3B probes supported.')
# if n_probes is None:
# n_probes = sum(k.lower().startswith('probe_type_') for k in PARAMS)

# FIXME this isn't working!
# session = EmptySession(subject=mouse, interactive=False, iblrig_settings=iblrig_settings)
# session_path = session.paths.SESSION_FOLDER
# FIXME The following should be done by the EmptySession class
iblrig_settings = load_pydantic_yaml(RigSettings)
date = datetime.datetime.now().date().isoformat()
num = next_num_folder(iblrig_settings.iblrig_local_data_path / mouse / date)
session_path = iblrig_settings.iblrig_local_data_path / mouse / date / num
raw_data_folder = session_path.joinpath('raw_ephys_data')
raw_data_folder.mkdir(parents=True, exist_ok=True)

log.info('Created %s', raw_data_folder)
remote_subject_folder = iblrig_settings.iblrig_remote_subjects_path

for n in range(n_probes):
probe_folder = raw_data_folder / f'probe{n:02}'
probe_folder.mkdir(exist_ok=True)
log.info('Created %s', probe_folder)

# Save the stub files locally and in the remote repo for future copy script to use
copier = EphysCopier(session_path=session_path, remote_subjects_folder=remote_subject_folder)
communicator, _ = await get_server_communicator(service_uri, 'neuropixel')
copier.initialize_experiment(nprobes=n_probes)

one = OneAlyx(silent=True)
exp_ref = one.path2ref(session_path)
tasks = set()

log.info('Type "abort" to cancel or just press return to finalize')
while True:
# Ensure we are awaiting a message from the remote rig.
# This task must be re-added each time a message is received.
if not any(t.get_name() == 'remote' for t in tasks) and communicator and communicator.is_connected:
task = asyncio.create_task(communicator.on_event(net.base.ExpMessage.any()), name='remote')
tasks.add(task)
if not any(t.get_name() == 'keyboard' for t in tasks):
tasks.add(asyncio.create_task(anext(read_stdin()), name='keyboard'))
# Await the next task outcome
done, _ = await asyncio.wait(tasks, timeout=None, return_when=asyncio.FIRST_COMPLETED)
for task in done:
match task.get_name():
case 'keyboard':
if net.base.is_success(task):
line = task.result().strip().lower()
if line == 'abort' and not any(filter(Path.is_file, raw_data_folder.rglob('*'))):
log.warning('Removing %s', raw_data_folder)
for d in raw_data_folder.iterdir(): # Remove probe folders
d.rmdir()
raw_data_folder.rmdir() # remove collection
# Remove remote exp description file
log.debug('Removing %s', copier.file_remote_experiment_description)
copier.file_remote_experiment_description.unlink()
copier.file_remote_experiment_description.with_suffix('.status_pending').unlink()
# Delete whole session folder?
session_files = list(session_path.rglob('*'))
if len(session_files) == 1 and session_files[0].name.startswith('_ibl_experiment.description'):
ans = input(f'Remove empty session {"/".join(session_path.parts[-3:])}? [y/N]\n')
if (ans.strip().lower() or 'n')[0] == 'y':
log.warning('Removing %s', session_path)
log.debug('Removing %s', session_files[0])
session_files[0].unlink()
session_path.rmdir()
else:
session_path.joinpath('transfer_me.flag').touch()
communicator.close()
for t in tasks:
t.cancel()
tasks.clear()
return
case 'remote':
if task.cancelled():
log.debug('Remote com await cancelled')
log.error('Remote communicator closed')
else:
data, addr, event = task.result()
S = net.base.ExpMessage # noqa
match event:
case S.EXPINFO:
reponse_data = {'exp_ref': one.dict2ref(exp_ref), 'main_sync': True}
await communicator.info(net.base.ExpStatus.RUNNING, reponse_data, addr=addr)
case S.EXPSTATUS:
await communicator.status(net.base.ExpStatus.RUNNING, addr=addr)
case S.EXPINIT:
expected = one.dict2ref(exp_ref)
remote_ref = (data[0] or {}).get('exp_ref') if any(data) else None
if remote_ref and remote_ref != expected:
log.critical('Experiment reference mismatch! Expected %s, got %s', expected, remote_ref)
data = {'exp_ref': one.dict2ref(exp_ref), 'status': net.base.ExpStatus.RUNNING}
await communicator.init(data, addr=addr)
case S.EXPSTART:
await communicator.start(exp_ref, addr=addr)
case S.ALYX:
base_url, token = data
if base_url and token and next(iter(token)):
# Install alyx token
update_alyx_token(data, addr, one.alyx)
elif one.alyx.is_logged_in and (base_url or one.alyx.base_url) == one.alyx.base_url:
# Return alyx token
await communicator.alyx(one.alyx, addr=addr)
case _:
# Do nothing for the others # TODO Change iblrig mixin to not await on stop and cleanups
await communicator.confirmed_send((event, {'status': net.base.ExpStatus.RUNNING}), addr=addr)
case _:
raise NotImplementedError(f'Unexpected task "{task.get_name()}"')
tasks.remove(task)
186 changes: 186 additions & 0 deletions iblrig/test/test_ephys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import tempfile
import unittest
from datetime import date
from pathlib import Path
from unittest.mock import AsyncMock, patch

import iblrig.ephys
from iblrig.test.base import TEST_DB
from iblutil.io import net
from iblutil.util import Bunch


class TestFinalizeEphysSession(unittest.TestCase):
Expand Down Expand Up @@ -38,3 +45,182 @@ def test_neuropixel24_micromanipulator(self):
},
}
assert trajectories == a


class TestPrepareEphysSessionNetworked(unittest.IsolatedAsyncioTestCase):
"""Test the main_v8_networked function."""

def setUp(self):
"""Set up keyboard input and settings mocks."""
# Set up keyboad input mock
# When we set self.keyboard to a non-empty string, the test function should interpret this as
# keyboard input and stop the session (i.e. close the communicator and return)
self.keyboard = ''
read_stdin = patch('iblrig.ephys.read_stdin')
self.addCleanup(read_stdin.stop)
read_stdin_mock = read_stdin.start()

async def _stdin():
if self.keyboard:
yield self.keyboard

read_stdin_mock.side_effect = _stdin

# Set up settings mock
tmp = tempfile.TemporaryDirectory()
self.tmpdir = Path(tmp.name)
(local := self.tmpdir.joinpath('local')).mkdir()
(remote := self.tmpdir.joinpath('remote')).mkdir()
(remote_subjects := remote.joinpath('subjects')).mkdir()
self.settings = Bunch(
iblrig_local_data_path=local, iblrig_remote_data_path=remote, iblrig_remote_subjects_path=remote_subjects
)
m = patch('iblrig.ephys.load_pydantic_yaml', return_value=self.settings)
m.start()
self.addCleanup(m.stop)
self.addr = '192.168.0.5:99998' # Fake address of the behaviour rig

async def asyncSetUp(self):
"""Set up communicator mock.

To side-step UDP communication, we mock the communicator and simulate the messages that
would be sent by the behaviour rig, then assert that the response methods are called with
the expected arguments.
"""
self.communicator = AsyncMock(spec=iblrig.ephys.net.app.EchoProtocol)
self.communicator.is_connected = True
m = patch('iblrig.ephys.get_server_communicator', return_value=(self.communicator, None))
m.start()
self.addCleanup(m.stop)

async def test_standard_message_sequence(self):
"""Test the main_v8_networked function with the usual sequence of behaviour rig messages."""
# Create some mock behaviour rig messages
ref = f'{date.today()}_1_foo'
info_msg = ((net.base.ExpStatus.CONNECTED, {'subject_name': 'foo'}), self.addr, net.base.ExpMessage.EXPINFO)
init_msg = ([{'exp_ref': ref}], self.addr, net.base.ExpMessage.EXPINIT)
start_msg = ((ref, {}), self.addr, net.base.ExpMessage.EXPSTART)
status_msg = (net.base.ExpStatus.RUNNING, self.addr, net.base.ExpMessage.EXPSTATUS)
# This is the order in which the messages are expected to be sent (excluding status)
self.messages = (info_msg, init_msg, start_msg, status_msg)

messages = self._iterate_messages()
self.communicator.on_event.side_effect = lambda evt: next(messages)
await iblrig.ephys.main_v8_networked('foo', debug=True)

# The on_event method is awaited at first then each time a message is received
self.communicator.on_event.assert_awaited_with(net.base.ExpMessage.any())
self.assertEqual(1 + len(self.messages), self.communicator.on_event.await_count)

# Check that the expected methods were called with the expected arguments
kwargs = dict(addr=self.addr)
expected_responses = [
('info', (net.base.ExpStatus.RUNNING, {'exp_ref': ref, 'main_sync': True}), kwargs),
('init', ({'exp_ref': ref, 'status': net.base.ExpStatus.RUNNING},), kwargs),
('start', ({'subject': 'foo', 'date': date.today(), 'sequence': 1},), kwargs),
('status', (net.base.ExpStatus.RUNNING,), kwargs),
('close', (), {}), # should be called after the last message (when keyboad input is simulated)
]
# Check odd method calls as even ones are the on_event calls
actual_reponses = map(tuple, self.communicator.method_calls[1::2])
for expected, actual in zip(expected_responses, actual_reponses, strict=False):
self.assertEqual(expected, actual)

# Check that the local and remote sessions were created
expected = [
f'local/foo/{date.today()}/001/transfer_me.flag',
f'local/foo/{date.today()}/001/_ibl_experiment.description_ephys.yaml',
f'remote/subjects/foo/{date.today()}/001/_devices/{date.today()}_1_foo@ephys.status_pending',
f'remote/subjects/foo/{date.today()}/001/_devices/{date.today()}_1_foo@ephys.yaml',
]
self.assertCountEqual(map(self.tmpdir.joinpath, expected), self.tmpdir.rglob('*.*'))
# Should have created the raw ephys folders
self.assertEqual(2, len(list(self.tmpdir.glob(f'local/foo/{date.today()}/001/raw_ephys_data/probe??'))))

async def test_abort_session(self):
"""Test the main_v8_networked function with misc events and user 'abort' input."""
# Create some mock behaviour rig messages where the behaviour rig runs subject 'bar' instead of 'foo'
# No exception should be raised (this happens at the behaviour rig) but this should be logged
ref = f'{date.today()}_1_bar'
info_msg = ((net.base.ExpStatus.CONNECTED, {'subject_name': 'bar'}), self.addr, net.base.ExpMessage.EXPINFO)
init_msg = ([{'exp_ref': ref}], self.addr, net.base.ExpMessage.EXPINIT)
start_msg = ((ref, {}), self.addr, net.base.ExpMessage.EXPSTART)
interrupt_msg = ((), self.addr, net.base.ExpMessage.EXPINTERRUPT)
cleanup_msg = ((), self.addr, net.base.ExpMessage.EXPCLEANUP)
self.messages = (info_msg, init_msg, start_msg, interrupt_msg, cleanup_msg)

messages = self._iterate_messages(keyboard_input='ABORT\n')
self.communicator.on_event.side_effect = lambda evt: next(messages)
# Should log exp ref mismatch
with self.assertLogs('iblrig.ephys', level='CRITICAL'), patch('builtins.input', return_value='y') as mock_input:
await iblrig.ephys.main_v8_networked('foo', debug=True)
mock_input.assert_called_once()

# The on_event method is awaited at first then each time a message is received
self.communicator.on_event.assert_awaited_with(net.base.ExpMessage.any())
self.assertEqual(1 + len(self.messages), self.communicator.on_event.await_count)

# Check that the expected methods were called with the expected arguments
kwargs = dict(addr=self.addr)
ref = f'{date.today()}_1_foo'
expected_responses = [
('info', (net.base.ExpStatus.RUNNING, {'exp_ref': ref, 'main_sync': True}), kwargs),
('init', ({'exp_ref': ref, 'status': net.base.ExpStatus.RUNNING},), kwargs),
('start', ({'subject': 'foo', 'date': date.today(), 'sequence': 1},), kwargs),
('confirmed_send', ((net.base.ExpMessage.EXPINTERRUPT, {'status': net.base.ExpStatus.RUNNING}),), kwargs),
('confirmed_send', ((net.base.ExpMessage.EXPCLEANUP, {'status': net.base.ExpStatus.RUNNING}),), kwargs),
('close', (), {}), # should be called after the last message (when keyboad input is simulated)
]
# Check odd method calls as even ones are the on_event calls
for expected, actual in zip(expected_responses, map(tuple, self.communicator.method_calls[1::2]), strict=False):
self.assertEqual(expected, actual)

# Check that the local and remote sessions were removed
self.assertFalse(any(self.tmpdir.rglob('*.*')))
self.assertFalse(self.tmpdir.joinpath(f'local/foo/{date.today()}/001').exists())

# Check behaviour when user does not confirm cleanup
self.communicator.reset_mock() # _iterate_messages asserts no methods were called yet
messages = self._iterate_messages(keyboard_input='ABORT\n')
self.communicator.on_event.side_effect = lambda evt: next(messages)
with patch('builtins.input', return_value='') as mock_input:
await iblrig.ephys.main_v8_networked('foo', debug=True)
mock_input.assert_called_once()
self.assertTrue(any(self.tmpdir.rglob('*.*')))
self.assertTrue(self.tmpdir.joinpath(f'local/foo/{date.today()}/001').exists())

async def test_alyx_request(self):
"""Test the main_v8_networked function alyx request message."""
# Create some mock behaviour rig messages that request and provide Alyx credentials
alyx_req = ((None, {}), self.addr, net.base.ExpMessage.ALYX)
alyx_mes = ((TEST_DB['base_url'], {'test_user': {'token': 't0k3n'}}), self.addr, net.base.ExpMessage.ALYX)
# Behaviour should be thus:
# 1. Request not processed as Alyx offline by default
# 2. Alyx object updated with remote token
# 3. Request processed with updated Alyx object (now logged in)
self.messages = (alyx_req, alyx_mes, alyx_req)

messages = self._iterate_messages()
self.communicator.on_event.side_effect = lambda evt: next(messages)
with patch('iblrig.ephys.update_alyx_token', wraps=iblrig.ephys.update_alyx_token) as m:
await iblrig.ephys.main_v8_networked('foo', debug=True)
m.assert_called_once()

# Check that the expected methods were called with the expected arguments
self.communicator.alyx.assert_awaited_once()
(alyx,), addr = self.communicator.alyx.call_args
self.assertTrue(alyx.is_logged_in)
self.assertEqual(TEST_DB['base_url'], alyx.base_url)
self.assertEqual({'token': 't0k3n'}, alyx._token)

def _iterate_messages(self, keyboard_input='\n'):
"""Yield behaviour rig UDP messages with added side effect simulating keyboard input after."""
# When first called we shouold not have awaited any methods on the communicator yet
for method in ('info', 'init', 'start', 'status', 'alyx', 'confirmed_send'):
getattr(self.communicator, method).assert_not_awaited()
# Yeild the messages in order
for msg in self.messages: # noqa: UP028
yield msg # ruff complains here but the suggested `yield from` does not work
# After the last message is processed, terminate by simulating keyboard input
self.keyboard = keyboard_input
yield self.messages[-1]
Loading