@@ -145,67 +145,95 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
145145
146146
147147def naive_branch_general_stat (
148- ts , w , f , windows = None , polarised = False , span_normalise = True
148+ ts , w , f , windows = None , time_windows = None , polarised = False , span_normalise = True
149149):
150150 # NOTE: does not behave correctly for unpolarised stats
151151 # with non-ancestral material.
152152 if windows is None :
153153 windows = [0.0 , ts .sequence_length ]
154+ drop_time_windows = time_windows is None
155+ if time_windows is None :
156+ time_windows = [0.0 , np .inf ]
157+ else :
158+ if time_windows [0 ] != 0 :
159+ time_windows = [0 ] + time_windows
154160 n , k = w .shape
161+ tw = len (time_windows ) - 1
155162 # hack to determine m
156163 m = len (f (w [0 ]))
157164 total = np .sum (w , axis = 0 )
158165
159- sigma = np .zeros ((ts .num_trees , m ))
160- for tree in ts .trees ():
161- x = np .zeros ((ts .num_nodes , k ))
162- x [ts .samples ()] = w
163- for u in tree .nodes (order = "postorder" ):
164- for v in tree .children (u ):
165- x [u ] += x [v ]
166- if polarised :
167- s = sum (tree .branch_length (u ) * f (x [u ]) for u in tree .nodes ())
166+ sigma = np .zeros ((ts .num_trees , tw , m ))
167+ for j , upper_time in enumerate (time_windows [1 :]):
168+ if np .isfinite (upper_time ):
169+ decap_ts = ts .decapitate (upper_time )
168170 else :
169- s = sum (
170- tree .branch_length (u ) * (f (x [u ]) + f (total - x [u ]))
171- for u in tree .nodes ()
172- )
173- sigma [tree .index ] = s * tree .span
171+ decap_ts = ts
172+ assert np .all (list (ts .samples ()) == list (decap_ts .samples ()))
173+ for tree in decap_ts .trees ():
174+ x = np .zeros ((decap_ts .num_nodes , k ))
175+ x [decap_ts .samples ()] = w
176+ for u in tree .nodes (order = "postorder" ):
177+ for v in tree .children (u ):
178+ x [u ] += x [v ]
179+ if polarised :
180+ s = sum (tree .branch_length (u ) * f (x [u ]) for u in tree .nodes ())
181+ else :
182+ s = sum (
183+ tree .branch_length (u ) * (f (x [u ]) + f (total - x [u ]))
184+ for u in tree .nodes ()
185+ )
186+ sigma [tree .index , j , :] = s * tree .span
187+ for j in range (1 , tw ):
188+ sigma [:, j , :] = sigma [:, j , :] - sigma [:, j - 1 , :]
174189 if isinstance (windows , str ) and windows == "trees" :
175190 # need to average across the windows
176191 if span_normalise :
177192 for j , tree in enumerate (ts .trees ()):
178193 sigma [j ] /= tree .span
179- return sigma
194+ out = sigma
180195 else :
181- return windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
196+ out = windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
197+ if drop_time_windows :
198+ assert out .ndim == 3
199+ out = out [:, 0 ]
200+ return out
182201
183202
184203def branch_general_stat (
185- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
204+ ts ,
205+ sample_weights ,
206+ summary_func ,
207+ windows = None ,
208+ time_windows = None ,
209+ polarised = False ,
210+ span_normalise = True ,
186211):
187212 """
188213 Efficient implementation of the algorithm used as the basis for the
189214 underlying C version.
190215 """
191216 n , state_dim = sample_weights .shape
192217 windows = ts .parse_windows (windows )
218+ drop_time_windows = time_windows is None
219+ time_windows = ts .parse_time_windows (time_windows )
193220 num_windows = windows .shape [0 ] - 1
221+ num_time_windows = time_windows .shape [0 ] - 1
194222
195223 # Determine result_dim
196224 result_dim = len (summary_func (sample_weights [0 ]))
197- result = np .zeros ((num_windows , result_dim ))
225+ result = np .zeros ((num_windows , num_time_windows , result_dim ))
198226 state = np .zeros ((ts .num_nodes , state_dim ))
199227 state [ts .samples ()] = sample_weights
200228 total_weight = np .sum (sample_weights , axis = 0 )
201229
202230 time = ts .tables .nodes .time
203231 parent = np .zeros (ts .num_nodes , dtype = np .int32 ) - 1
204- branch_length = np .zeros (ts .num_nodes )
232+ branch_length = np .zeros (( num_time_windows , ts .num_nodes ) )
205233 # The value of summary_func(u) for every node.
206234 summary = np .zeros ((ts .num_nodes , result_dim ))
207235 # The result for the current tree *not* weighted by span.
208- running_sum = np .zeros (result_dim )
236+ running_sum = np .zeros (( num_time_windows , result_dim ) )
209237
210238 def polarised_summary (u ):
211239 s = summary_func (state [u ])
@@ -217,31 +245,48 @@ def polarised_summary(u):
217245 summary [u ] = polarised_summary (u )
218246
219247 window_index = 0
248+
249+ def update_sum (u , sign ):
250+ time_window_index = 0
251+ if parent [u ] != - 1 :
252+ while (
253+ time_window_index < num_time_windows
254+ and time_windows [time_window_index ] < time [parent [u ]]
255+ ):
256+ running_sum [time_window_index ] += sign * (
257+ branch_length [time_window_index , u ] * summary [u ]
258+ )
259+ time_window_index += 1
260+
220261 for (t_left , t_right ), edges_out , edges_in in ts .edge_diffs ():
221262 for edge in edges_out :
222263 u = edge .child
223- running_sum -= branch_length [ u ] * summary [ u ]
264+ update_sum ( u , sign = - 1 )
224265 u = edge .parent
225266 while u != - 1 :
226- running_sum -= branch_length [ u ] * summary [ u ]
267+ update_sum ( u , sign = - 1 )
227268 state [u ] -= state [edge .child ]
228269 summary [u ] = polarised_summary (u )
229- running_sum += branch_length [ u ] * summary [ u ]
270+ update_sum ( u , sign = + 1 )
230271 u = parent [u ]
231272 parent [edge .child ] = - 1
232- branch_length [edge .child ] = 0
273+ for tw in range (num_time_windows ):
274+ branch_length [tw , edge .child ] = 0
233275
234276 for edge in edges_in :
235277 parent [edge .child ] = edge .parent
236- branch_length [edge .child ] = time [edge .parent ] - time [edge .child ]
278+ for tw in range (num_time_windows ):
279+ branch_length [tw , edge .child ] = min (
280+ time [edge .parent ], time_windows [tw + 1 ]
281+ ) - max (time [edge .child ], time_windows [tw ])
237282 u = edge .child
238- running_sum += branch_length [ u ] * summary [ u ]
283+ update_sum ( u , sign = + 1 )
239284 u = edge .parent
240285 while u != - 1 :
241- running_sum -= branch_length [ u ] * summary [ u ]
286+ update_sum ( u , sign = - 1 )
242287 state [u ] += state [edge .child ]
243288 summary [u ] = polarised_summary (u )
244- running_sum += branch_length [ u ] * summary [ u ]
289+ update_sum ( u , sign = + 1 )
245290 u = parent [u ]
246291
247292 # Update the windows
@@ -253,7 +298,12 @@ def polarised_summary(u):
253298 right = min (t_right , w_right )
254299 span = right - left
255300 assert span > 0
256- result [window_index ] += running_sum * span
301+ time_window_index = 0
302+ while time_window_index < num_time_windows :
303+ result [window_index , time_window_index ] += (
304+ running_sum [time_window_index ] * span
305+ )
306+ time_window_index += 1
257307 if w_right <= t_right :
258308 window_index += 1
259309 else :
@@ -263,6 +313,9 @@ def polarised_summary(u):
263313
264314 # print("window_index:", window_index, windows.shape)
265315 assert window_index == windows .shape [0 ] - 1
316+ if drop_time_windows :
317+ assert result .ndim == 3
318+ result = result [:, 0 ]
266319 if span_normalise :
267320 for j in range (num_windows ):
268321 result [j ] /= windows [j + 1 ] - windows [j ]
@@ -322,13 +375,20 @@ def naive_site_general_stat(
322375
323376
324377def site_general_stat (
325- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
378+ ts ,
379+ sample_weights ,
380+ summary_func ,
381+ windows = None ,
382+ time_windows = None ,
383+ polarised = False ,
384+ span_normalise = True ,
326385):
327386 """
328387 Problem: 'sites' is different that the other windowing options
329388 because if we output by site we don't want to normalize by length of the window.
330389 Solution: we pass an argument "normalize", to the windowing function.
331390 """
391+ assert time_windows is None
332392 windows = ts .parse_windows (windows )
333393 num_windows = windows .shape [0 ] - 1
334394 n , state_dim = sample_weights .shape
@@ -425,12 +485,19 @@ def naive_node_general_stat(
425485
426486
427487def node_general_stat (
428- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
488+ ts ,
489+ sample_weights ,
490+ summary_func ,
491+ windows = None ,
492+ time_windows = None ,
493+ polarised = False ,
494+ span_normalise = True ,
429495):
430496 """
431497 Efficient implementation of the algorithm used as the basis for the
432498 underlying C version.
433499 """
500+ assert time_windows is None
434501 n , state_dim = sample_weights .shape
435502 windows = ts .parse_windows (windows )
436503 num_windows = windows .shape [0 ] - 1
@@ -500,6 +567,7 @@ def general_stat(
500567 sample_weights ,
501568 summary_func ,
502569 windows = None ,
570+ time_windows = None ,
503571 polarised = False ,
504572 mode = "site" ,
505573 span_normalise = True ,
@@ -518,6 +586,7 @@ def general_stat(
518586 sample_weights ,
519587 summary_func ,
520588 windows = windows ,
589+ time_windows = time_windows ,
521590 polarised = polarised ,
522591 span_normalise = span_normalise ,
523592 )
@@ -3534,7 +3603,9 @@ class TestSitef3(Testf3, MutatedTopologyExamplesMixin):
35343603############################################
35353604
35363605
3537- def branch_f4 (ts , sample_sets , indexes , windows = None , span_normalise = True ):
3606+ def branch_f4 (
3607+ ts , sample_sets , indexes , windows = None , time_windows = None , span_normalise = True
3608+ ):
35383609 windows = ts .parse_windows (windows )
35393610 out = np .zeros ((len (windows ) - 1 , len (indexes )))
35403611 for j in range (len (windows ) - 1 ):
@@ -3674,7 +3745,15 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
36743745 return out
36753746
36763747
3677- def f4 (ts , sample_sets , indexes = None , windows = None , mode = "site" , span_normalise = True ):
3748+ def f4 (
3749+ ts ,
3750+ sample_sets ,
3751+ indexes = None ,
3752+ windows = None ,
3753+ time_windows = None ,
3754+ mode = "site" ,
3755+ span_normalise = True ,
3756+ ):
36783757 """
36793758 Patterson's f4 statistic definitions.
36803759 """
@@ -6994,3 +7073,53 @@ def f_too_long(_):
69947073 output_dim = 1 ,
69957074 strict = False ,
69967075 )
7076+
7077+
7078+ class TestTimeWindows :
7079+
7080+ def test_general_stat (self , four_taxa_test_case ):
7081+ # 1.00┊ 7 ┊ ┊ ┊
7082+ # ┊ ┏━┻━┓ ┊ ┊ ┊
7083+ # 0.70┊ ┃ ┃ ┊ ┊ 6 ┊
7084+ # ┊ ┃ ┃ ┊ ┊ ┏━┻━┓ ┊
7085+ # 0.50┊ ┃ 5 ┊ 5 ┊ ┃ 5 ┊
7086+ # ┊ ┃ ┏┻━┓ ┊ ┏━┻━┓ ┊ ┃ ┏┻━┓ ┊
7087+ # 0.40┊ ┃ 8 ┃ ┊ 4 8 ┊ ┃ 8 ┃ ┊
7088+ # ┊ ┃ ┏┻┓ ┃ ┊ ┏┻┓ ┏┻┓ ┊ ┃ ┏┻┓ ┃ ┊
7089+ # 0.00┊ 0 1 3 2 ┊ 0 2 1 3 ┊ 0 1 3 2 ┊
7090+ # 0.00 0.20 0.80 2.50
7091+ ts = four_taxa_test_case
7092+ true_x = np .array (
7093+ [
7094+ [
7095+ [
7096+ 0.2 * (1 + 0.5 + 0.4 )
7097+ + (0.8 - 0.2 ) * (1 + 0.8 )
7098+ + (2.5 - 0.8 ) * (1.0 + 0.5 + 0.4 )
7099+ ],
7100+ [0.2 * 1.0 + 0 + (2.5 - 0.8 ) * 0.4 ],
7101+ ]
7102+ ]
7103+ )
7104+
7105+ n = ts .num_samples
7106+
7107+ def f (x ):
7108+ return (x > 0 ) * (1 - x / n )
7109+
7110+ W = np .ones ((ts .num_samples , 1 ))
7111+ x = naive_branch_general_stat (
7112+ ts , W , f , time_windows = [0 , 0.5 , 2.0 ], span_normalise = False
7113+ )
7114+ np .testing .assert_allclose (x , true_x )
7115+
7116+ x0 = branch_general_stat (ts , W , f , time_windows = None , span_normalise = False )
7117+ x1 = naive_branch_general_stat (
7118+ ts , W , f , time_windows = None , span_normalise = False
7119+ )
7120+ np .testing .assert_allclose (x0 , x1 )
7121+ x_tw = branch_general_stat (
7122+ ts , W , f , time_windows = [0 , 0.5 , 2.0 ], span_normalise = False
7123+ )
7124+
7125+ np .testing .assert_allclose (x , x_tw )
0 commit comments