88from .expression import QueryExpression , AndList
99from .errors import DataJointError , LostConnectionError
1010import signal
11+ import multiprocessing as mp
1112
1213# noinspection PyExceptionInherit,PyCallingNonCallable
1314
1415logger = logging .getLogger (__name__ )
1516
1617
18+ # --- helper functions for multiprocessing --
19+
20+ def _initialize_populate (table , jobs , populate_kwargs ):
21+ """
22+ Initialize the process for mulitprocessing.
23+ Saves the unpickled copy of the table to the current process and reconnects.
24+ """
25+ process = mp .current_process ()
26+ process .table = table
27+ process .jobs = jobs
28+ process .populate_kwargs = populate_kwargs
29+ table .connection .connect () # reconnect
30+
31+
32+ def _call_populate1 (key ):
33+ """
34+ Call current process' table._populate1()
35+ :key - a dict specifying job to compute
36+ :return: key, error if error, otherwise None
37+ """
38+ process = mp .current_process ()
39+ return process .table ._populate1 (key , process .jobs , ** process .populate_kwargs )
40+
41+
1742class AutoPopulate :
1843 """
1944 AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -28,8 +53,9 @@ def key_source(self):
2853 """
2954 :return: the query expression that yields primary key values to be passed,
3055 sequentially, to the ``make`` method when populate() is called.
31- The default value is the join of the parent relations.
32- Users may override to change the granularity or the scope of populate() calls.
56+ The default value is the join of the parent tables references from the primary key.
57+ Subclasses may override they key_source to change the scope or the granularity
58+ of the make calls.
3359 """
3460 def _rename_attributes (table , props ):
3561 return (table .proj (
@@ -96,29 +122,30 @@ def _jobs_to_do(self, restrictions):
96122
97123 def populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
98124 reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
99- display_progress = False ):
125+ display_progress = False , processes = 1 ):
100126 """
101- rel .populate() calls rel .make(key) for every primary key in self.key_source
102- for which there is not already a tuple in rel .
103- :param restrictions: a list of restrictions each restrict (rel .key_source - target.proj())
127+ table .populate() calls table .make(key) for every primary key in self.key_source
128+ for which there is not already a tuple in table .
129+ :param restrictions: a list of restrictions each restrict (table .key_source - target.proj())
104130 :param suppress_errors: if True, do not terminate execution.
105131 :param return_exception_objects: return error objects instead of just error messages
106- :param reserve_jobs: if true, reserves job to populate in asynchronous fashion
132+ :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
107133 :param order: "original"|"reverse"|"random" - the order of execution
134+ :param limit: if not None, check at most this many keys
135+ :param max_calls: if not None, populate at most this many keys
108136 :param display_progress: if True, report progress_bar
109- :param limit: if not None, checks at most that many keys
110- :param max_calls: if not None, populates at max that many keys
137+ :param processes: number of processes to use. When set to a large number, then
138+ uses as many as CPU cores
111139 """
112140 if self .connection .in_transaction :
113141 raise DataJointError ('Populate cannot be called during a transaction.' )
114142
115143 valid_order = ['original' , 'reverse' , 'random' ]
116144 if order not in valid_order :
117145 raise DataJointError ('The order argument must be one of %s' % str (valid_order ))
118- error_list = [] if suppress_errors else None
119146 jobs = self .connection .schemas [self .target .database ].jobs if reserve_jobs else None
120147
121- # define and setup signal handler for SIGTERM
148+ # define and set up signal handler for SIGTERM:
122149 if reserve_jobs :
123150 def handler (signum , frame ):
124151 logger .info ('Populate terminated by SIGTERM' )
@@ -131,60 +158,99 @@ def handler(signum, frame):
131158 elif order == "random" :
132159 random .shuffle (keys )
133160
134- call_count = 0
135161 logger .info ('Found %d keys to populate' % len (keys ))
136162
137- make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
163+ keys = keys [:max_calls ]
164+ nkeys = len (keys )
138165
139- for key in (tqdm (keys , desc = self .__class__ .__name__ ) if display_progress else keys ):
140- if max_calls is not None and call_count >= max_calls :
141- break
142- if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
143- self .connection .start_transaction ()
144- if key in self .target : # already populated
145- self .connection .cancel_transaction ()
146- if reserve_jobs :
147- jobs .complete (self .target .table_name , self ._job_key (key ))
166+ if processes > 1 :
167+ processes = min (processes , nkeys , mp .cpu_count ())
168+
169+ error_list = []
170+ populate_kwargs = dict (
171+ suppress_errors = suppress_errors ,
172+ return_exception_objects = return_exception_objects )
173+
174+ if processes == 1 :
175+ for key in tqdm (keys , desc = self .__class__ .__name__ ) if display_progress else keys :
176+ error = self ._populate1 (key , jobs , ** populate_kwargs )
177+ if error is not None :
178+ error_list .append (error )
179+ else :
180+ # spawn multiple processes
181+ self .connection .close () # disconnect parent process from MySQL server
182+ del self .connection ._conn .ctx # SSLContext is not pickleable
183+ with mp .Pool (processes , _initialize_populate , (self , populate_kwargs )) as pool :
184+ if display_progress :
185+ with tqdm (desc = "Processes: " , total = nkeys ) as pbar :
186+ for error in pool .imap (_call_populate1 , keys , chunksize = 1 ):
187+ if error is not None :
188+ error_list .append (error )
189+ pbar .update ()
148190 else :
149- logger .info ('Populating: ' + str (key ))
150- call_count += 1
151- self .__class__ ._allow_insert = True
152- try :
153- make (dict (key ))
154- except (KeyboardInterrupt , SystemExit , Exception ) as error :
155- try :
156- self .connection .cancel_transaction ()
157- except LostConnectionError :
158- pass
159- error_message = '{exception}{msg}' .format (
160- exception = error .__class__ .__name__ ,
161- msg = ': ' + str (error ) if str (error ) else '' )
162- if reserve_jobs :
163- # show error name and error message (if any)
164- jobs .error (
165- self .target .table_name , self ._job_key (key ),
166- error_message = error_message , error_stack = traceback .format_exc ())
167- if not suppress_errors or isinstance (error , SystemExit ):
168- raise
169- else :
170- logger .error (error )
171- error_list .append ((key , error if return_exception_objects else error_message ))
172- else :
173- self .connection .commit_transaction ()
174- if reserve_jobs :
175- jobs .complete (self .target .table_name , self ._job_key (key ))
176- finally :
177- self .__class__ ._allow_insert = False
191+ for error in pool .imap (_call_populate1 , keys ):
192+ if error is not None :
193+ error_list .append (error )
194+ self .connection .connect () # reconnect parent process to MySQL server
178195
179- # place back the original signal handler
196+ # restore original signal handler:
180197 if reserve_jobs :
181198 signal .signal (signal .SIGTERM , old_handler )
182- return error_list
199+
200+ if suppress_errors :
201+ return error_list
202+
203+ def _populate1 (self , key , jobs , suppress_errors , return_exception_objects ):
204+ """
205+ populates table for one source key, calling self.make inside a transaction.
206+ :param jobs: the jobs table or None if not reserve_jobs
207+ :param key: dict specifying job to populate
208+ :param suppress_errors: bool if errors should be suppressed and returned
209+ :param return_exception_objects: if True, errors must be returned as objects
210+ :return: (key, error) when suppress_errors=True, otherwise None
211+ """
212+ make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
213+
214+ if jobs is None or jobs .reserve (self .target .table_name , self ._job_key (key )):
215+ self .connection .start_transaction ()
216+ if key in self .target : # already populated
217+ self .connection .cancel_transaction ()
218+ if jobs is not None :
219+ jobs .complete (self .target .table_name , self ._job_key (key ))
220+ else :
221+ logger .info ('Populating: ' + str (key ))
222+ self .__class__ ._allow_insert = True
223+ try :
224+ make (dict (key ))
225+ except (KeyboardInterrupt , SystemExit , Exception ) as error :
226+ try :
227+ self .connection .cancel_transaction ()
228+ except LostConnectionError :
229+ pass
230+ error_message = '{exception}{msg}' .format (
231+ exception = error .__class__ .__name__ ,
232+ msg = ': ' + str (error ) if str (error ) else '' )
233+ if jobs is not None :
234+ # show error name and error message (if any)
235+ jobs .error (
236+ self .target .table_name , self ._job_key (key ),
237+ error_message = error_message , error_stack = traceback .format_exc ())
238+ if not suppress_errors or isinstance (error , SystemExit ):
239+ raise
240+ else :
241+ logger .error (error )
242+ return key , error if return_exception_objects else error_message
243+ else :
244+ self .connection .commit_transaction ()
245+ if jobs is not None :
246+ jobs .complete (self .target .table_name , self ._job_key (key ))
247+ finally :
248+ self .__class__ ._allow_insert = False
183249
184250 def progress (self , * restrictions , display = True ):
185251 """
186- report progress of populating the table
187- :return: remaining, total -- tuples to be populated
252+ Report the progress of populating the table.
253+ :return: ( remaining, total) -- numbers of tuples to be populated
188254 """
189255 todo = self ._jobs_to_do (restrictions )
190256 total = len (todo )
0 commit comments