77import traceback
88import typing as t
99import uuid
10+ from collections import namedtuple
1011from functools import partial
1112from http import HTTPStatus
1213
2930 # optional dependencies
3031 ...
3132
33+ PendingInput = namedtuple ("PendingInput" , ["request_id" , "content" ])
34+
3235
3336class ExecutionStack :
3437 """Execution request stack.
@@ -39,7 +42,7 @@ class ExecutionStack:
3942 """
4043
4144 def __init__ (self ):
42- self .__pending_inputs : dict [str , dict ] = {}
45+ self .__pending_inputs : dict [str , PendingInput ] = {}
4346 self .__tasks : dict [str , asyncio .Task ] = {}
4447
4548 def __del__ (self ):
@@ -78,7 +81,12 @@ def get(self, kernel_id: str, uid: str) -> t.Any:
7881 raise ValueError (f"Request { uid } does not exists." )
7982
8083 if kernel_id in self .__pending_inputs :
81- return self .__pending_inputs .pop (kernel_id )
84+ get_logger ().info (f"Kernel '{ kernel_id } ' has a pending input." )
85+ # Check the request id is the one matching the appearance of the input
86+ # Otherwise another cell still looking for its results may capture the
87+ # pending input
88+ if uid == self .__pending_inputs [kernel_id ].request_id :
89+ return self .__pending_inputs .pop (kernel_id ).content
8290
8391 if self .__tasks [uid ].done ():
8492 task = self .__tasks .pop (uid )
@@ -102,11 +110,11 @@ def put(
102110 uid = str (uuid .uuid4 ())
103111
104112 self .__tasks [uid ] = asyncio .create_task (
105- _execute_task (uid , km , snippet , ycell , partial (self ._stdin_hook , km .kernel_id ))
113+ _execute_task (uid , km , snippet , ycell , partial (self ._stdin_hook , km .kernel_id , uid ))
106114 )
107115 return uid
108116
109- def _stdin_hook (self , kernel_id : str , msg : dict ) -> None :
117+ def _stdin_hook (self , kernel_id : str , request_id : str , msg : dict ) -> None :
110118 """Callback on stdin message.
111119
112120 It will register the pending input as temporary answer to the execution request.
@@ -119,10 +127,13 @@ def _stdin_hook(self, kernel_id: str, msg: dict) -> None:
119127
120128 header = msg ["header" ].copy ()
121129 header ["date" ] = header ["date" ].isoformat ()
122- self .__pending_inputs [kernel_id ] = {
123- "parent_header" : header ,
124- "input_request" : msg ["content" ],
125- }
130+ self .__pending_inputs [kernel_id ] = PendingInput (
131+ request_id ,
132+ {
133+ "parent_header" : header ,
134+ "input_request" : msg ["content" ],
135+ },
136+ )
126137
127138
128139async def _execute_task (
@@ -159,7 +170,6 @@ async def _execute_snippet(
159170 ycell : y .Map | None ,
160171 stdin_hook : t .Callable [[dict ], None ] | None ,
161172) -> dict [str , t .Any ]:
162-
163173 if ycell is not None :
164174 # Reset cell
165175 with ycell .doc .transaction ():
@@ -294,7 +304,7 @@ async def post(self, kernel_id: str) -> None:
294304 status_code = HTTPStatus .INTERNAL_SERVER_ERROR , reason = msg
295305 )
296306
297- notebook : YNotebook = await self ._ydoc .get_document (document_id = document_id , copy = False )
307+ notebook : YNotebook = await self ._ydoc .get_document (room_id = document_id , copy = False )
298308
299309 if notebook is None :
300310 msg = f"Document with ID { document_id } not found."
@@ -310,9 +320,7 @@ async def post(self, kernel_id: str) -> None:
310320 raise tornado .web .HTTPError (status_code = HTTPStatus .NOT_FOUND , reason = msg ) # noqa: B904
311321 else :
312322 # Check if there is more than one cell
313- try :
314- next (ycells )
315- except StopIteration :
323+ if next (ycells , None ) is not None :
316324 get_logger ().warning ("Multiple cells have the same ID '%s'." , cell_id )
317325
318326 if ycell ["cell_type" ] != "code" :
0 commit comments