@@ -128,6 +128,7 @@ def __init__(
128128 if graph_assigner_cls is None :
129129 graph_assigner_cls = GraphAssigner
130130 self ._graph_assigner_cls = graph_assigner_cls
131+ self ._chunk_to_copied = dict ()
131132 self ._logic_key_generator = LogicKeyGenerator ()
132133
133134 @classmethod
@@ -226,6 +227,7 @@ def _gen_subtask_info(
226227 result_chunks_set = set ()
227228 chunk_graph = ChunkGraph (result_chunks )
228229 out_of_scope_chunks = []
230+ chunk_to_copied = self ._chunk_to_copied
229231 update_meta_chunks = []
230232 # subtask properties
231233 band = None
@@ -271,11 +273,13 @@ def _gen_subtask_info(
271273 chunk_priority = chunk .op .priority
272274 # process input chunks
273275 inp_chunks = []
276+ input_changed = False
274277 build_fetch_index_to_chunks = dict ()
275278 for i , inp_chunk in enumerate (chunk .inputs ):
276279 if inp_chunk in chunks_set :
277- inp_chunks .append (inp_chunk )
280+ inp_chunks .append (chunk_to_copied [ inp_chunk ] )
278281 else :
282+ input_changed = True
279283 build_fetch_index_to_chunks [i ] = inp_chunk
280284 inp_chunks .append (None )
281285 if not isinstance (inp_chunk .op , Fetch ):
@@ -285,14 +289,31 @@ def _gen_subtask_info(
285289 )
286290 for i , fetch_chunk in zip (build_fetch_index_to_chunks , fetch_chunks ):
287291 inp_chunks [i ] = fetch_chunk
288- for out_chunk in chunk .op .outputs :
292+
293+ if input_changed :
294+ copied_op = chunk .op .copy ()
295+ copied_op ._key = chunk .op .key
296+ out_chunks = [
297+ c .data
298+ for c in copied_op .new_chunks (
299+ inp_chunks , kws = [c .params .copy () for c in chunk .op .outputs ]
300+ )
301+ ]
302+ else :
303+ out_chunks = chunk .op .outputs
289304 # Note: `dtypes`, `index_value`, and `columns_value` are lazily
290305 # initialized, so we should call property `params` to initialize
291306 # these fields.
292- out_chunk .params
293- processed .add (out_chunk )
307+ [o .params for o in out_chunks ]
308+
309+ for src_chunk , out_chunk in zip (chunk .op .outputs , out_chunks ):
310+ processed .add (src_chunk )
311+ out_chunk ._key = src_chunk .key
294312 chunk_graph .add_node (out_chunk )
295- if out_chunk in self ._final_result_chunks_set :
313+ # cannot be copied twice
314+ assert src_chunk not in chunk_to_copied
315+ chunk_to_copied [src_chunk ] = out_chunk
316+ if src_chunk in self ._final_result_chunks_set :
296317 if out_chunk not in result_chunks_set :
297318 # add to result chunks
298319 result_chunks .append (out_chunk )
@@ -320,12 +341,18 @@ def _gen_subtask_info(
320341 if out_of_scope_chunks :
321342 inp_subtasks = []
322343 for out_of_scope_chunk in out_of_scope_chunks :
344+ copied_out_of_scope_chunk = chunk_to_copied [out_of_scope_chunk ]
323345 inp_subtask = chunk_to_subtask [out_of_scope_chunk ]
324- if out_of_scope_chunk not in inp_subtask .chunk_graph .result_chunks :
346+ if (
347+ copied_out_of_scope_chunk
348+ not in inp_subtask .chunk_graph .result_chunks
349+ ):
325350 # make sure the chunk that out of scope
326351 # is in the input subtask's results,
327352 # or the meta may be lost
328- inp_subtask .chunk_graph .result_chunks .append (out_of_scope_chunk )
353+ inp_subtask .chunk_graph .result_chunks .append (
354+ copied_out_of_scope_chunk
355+ )
329356 inp_subtasks .append (inp_subtask )
330357 depth = max (st .priority [0 ] for st in inp_subtasks ) + 1
331358 else :
@@ -383,9 +410,10 @@ def _gen_map_reduce_info(
383410 # record analyzer map reduce id for mapper op
384411 # copied chunk exists because map chunk must have
385412 # been processed before shuffle proxy
386- if not hasattr (map_chunk , "extra_params" ): # pragma: no cover
387- map_chunk .extra_params = dict ()
388- map_chunk .extra_params ["analyzer_map_reduce_id" ] = map_reduce_id
413+ copied_map_chunk = self ._chunk_to_copied [map_chunk ]
414+ if not hasattr (copied_map_chunk , "extra_params" ): # pragma: no cover
415+ copied_map_chunk .extra_params = dict ()
416+ copied_map_chunk .extra_params ["analyzer_map_reduce_id" ] = map_reduce_id
389417 reducer_bands = [assign_results [r .outputs [0 ]] for r in reducer_ops ]
390418 map_reduce_info = MapReduceInfo (
391419 map_reduce_id = map_reduce_id ,
0 commit comments