2222"""
2323Python implementation of the Li and Stephens forwards and backwards algorithms.
2424"""
25+ import io
2526import warnings
2627
2728import lshmm as ls
3738MISSING = - 1
3839
3940
40- # np.set_printoptions(linewidth=1000, precision=3)
41+ # For debugging
42+ np .set_printoptions (linewidth = 1000 , precision = 3 )
4143
4244
4345def check_alleles (alleles , m ):
@@ -151,7 +153,7 @@ def node_values(self):
151153 def print_state (self ):
152154 print ("LsHMM state" )
153155 print ("match_all_nodes =" , self .match_all_nodes )
154- print ("Tree =" )
156+ print ("Tree = " , self . tree . index , self . tree . interval )
155157 node_labels = {}
156158 for u , value in self .node_values ().items ():
157159 label = f"{ u } "
@@ -434,11 +436,13 @@ def update_probabilities(self, site, haplotype_state):
434436 def process_site (self , site , haplotype_state ):
435437 self .update_probabilities (site , haplotype_state )
436438 # d1 = self.node_values()
439+ print ("PRE" )
440+ self .print_state ()
437441 self .compress ()
438442 # d2 = self.node_values()
439443 # assert d1 == d2
440- # print("AFTER COMPRESS")
441- # self.print_state()
444+ print ("AFTER COMPRESS" )
445+ self .print_state ()
442446 s = self .compute_normalisation_factor ()
443447 for st in self .T :
444448 assert st .tree_node != tskit .NULL
@@ -489,8 +493,13 @@ def run(self, h):
489493 self .initialise (1 / n )
490494 while self .tree .next ():
491495 self .update_tree ()
496+ if self .tree .index != 0 :
497+ print ("AFTER UPDATE TREE" )
498+ self .print_state ()
492499 for site in self .tree .sites ():
493500 self .process_site (site , h [site .id ])
501+ print ("BEFORE UPDATE TREE" )
502+ self .print_state ()
494503 return self .output
495504
496505 def compute_normalisation_factor (self ):
@@ -1182,7 +1191,6 @@ def verify(self, ts):
11821191 self .assertAllClose (ll , ll_check )
11831192
11841193
1185- # TODO add params to run the various checks
11861194def check_viterbi (
11871195 ts ,
11881196 h ,
@@ -1212,10 +1220,10 @@ def check_viterbi(
12121220 cm = ls_viterbi_tree (
12131221 h , ts , rho = recombination , mu = mutation , match_all_nodes = match_all_nodes
12141222 )
1223+ cm .print_state ()
12151224 path_tree = cm .traceback (match_all_nodes = match_all_nodes )
12161225 ll_tree = np .sum (np .log10 (cm .normalisation_factor ))
12171226 assert np .isscalar (ll_tree )
1218- # print(cm)
12191227 # print("path tree = ", path_tree)
12201228
12211229 if compare_lshmm :
@@ -1437,8 +1445,8 @@ def test_match_sample(self, j):
14371445 ts = self .ts ()
14381446 h = np .zeros (4 )
14391447 h [j ] = 1
1440- # path = check_viterbi(ts, h)
1441- # nt.assert_array_equal([j, j, j, j], path)
1448+ path = check_viterbi (ts , h )
1449+ nt .assert_array_equal ([j , j , j , j ], path )
14421450 cm = check_forward_matrix (ts , h )
14431451 check_backward_matrix (ts , h , cm )
14441452
@@ -1525,6 +1533,19 @@ def test_match_sample(self, u, h):
15251533 )
15261534
15271535
1536+ def validate_match_all_nodes (ts , h , expected_path ):
1537+ path = check_viterbi (
1538+ ts , h , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1539+ )
1540+ nt .assert_array_equal (expected_path , path )
1541+ cm = check_forward_matrix (
1542+ ts , h , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1543+ )
1544+ bm = check_backward_matrix (
1545+ ts , h , cm , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1546+ )
1547+
1548+
15281549class TestSingleBalancedTreeAllNodesExample :
15291550 # 3.00┊ 6 ┊
15301551 # ┊ ┏━┻━┓ ┊
@@ -1540,7 +1561,6 @@ def ts():
15401561 tables .tree_sequence (), start = 1 , nodes = np .arange (len (tables .nodes ) - 1 )
15411562 )
15421563
1543- # def test_match_sample(self, u, h):
15441564 @pytest .mark .parametrize (
15451565 ("h" , "expected_path" ),
15461566 [
@@ -1558,20 +1578,99 @@ def ts():
15581578 ([0 , 0 , 0 , 0 , 0 , 0 ], [6 ] * 6 ),
15591579 ],
15601580 )
1561- def test_match_sample (self , h , expected_path ):
1562- ts = self .ts ()
1563- path = check_viterbi (
1564- ts , h , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1581+ def test_exact_match (self , h , expected_path ):
1582+ validate_match_all_nodes (self .ts (), h , expected_path )
1583+
1584+
1585+ class TestMultiTreeExample :
1586+ # 0.84┊ 7 ┊ 7 ┊
1587+ # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊
1588+ # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊
1589+ # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊
1590+ # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊
1591+ # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊
1592+ # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊
1593+ # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊
1594+ # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊
1595+ # 0 6 7
1596+ @staticmethod
1597+ def ts ():
1598+ nodes = """\
1599+ is_sample time
1600+ 1 0.000000
1601+ 1 0.000000
1602+ 1 0.000000
1603+ 1 0.000000
1604+ 0 0.041304
1605+ 0 0.045967
1606+ 0 0.416719
1607+ 0 0.838075
1608+ """
1609+ edges = """\
1610+ left right parent child
1611+ 0.000000 7.000000 4 1
1612+ 0.000000 7.000000 4 2
1613+ 0.000000 6.000000 5 0
1614+ 0.000000 6.000000 5 4
1615+ 6.000000 7.000000 6 0
1616+ 6.000000 7.000000 6 3
1617+ 0.000000 6.000000 7 3
1618+ 6.000000 7.000000 7 4
1619+ 0.000000 6.000000 7 5
1620+ 6.000000 7.000000 7 6
1621+ """
1622+ ts = tskit .load_text (
1623+ nodes = io .StringIO (nodes ), edges = io .StringIO (edges ), strict = False
15651624 )
1625+ return add_unique_node_mutations (ts , nodes = range (7 ))
1626+
1627+ # 0.84┊ 7 ┊ 7 ┊
1628+ # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊
1629+ # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊
1630+ # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊
1631+ # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊
1632+ # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊
1633+ # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊
1634+ # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊
1635+ # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊
1636+ # 0 6 7
1637+
1638+ @pytest .mark .parametrize (
1639+ ("h" , "expected_path" ),
1640+ [
1641+ # Just samples
1642+ ([1 , 0 , 0 , 0 , 0 , 1 , 1 ], [0 ] * 7 ),
1643+ ([0 , 1 , 0 , 0 , 1 , 1 , 0 ], [1 ] * 7 ),
1644+ ([0 , 0 , 1 , 0 , 1 , 1 , 0 ], [2 ] * 7 ),
1645+ ([0 , 0 , 0 , 1 , 0 , 0 , 1 ], [3 ] * 7 ),
1646+ # Match root
1647+ ([0 , 0 , 0 , 0 , 0 , 0 , 0 ], [7 ] * 7 ),
1648+ ],
1649+ )
1650+ def test_match_all_nodes (self , h , expected_path ):
1651+ # print()
1652+ # print(self.ts().draw_text())
1653+ # with open("tmp.svg", "w") as f:
1654+ # f.write(self.ts().draw_svg())
1655+ validate_match_all_nodes (self .ts (), h , expected_path )
1656+
1657+ @pytest .mark .parametrize (
1658+ ("h" , "expected_path" ),
1659+ [
1660+ ([1 , 0 , 0 , 0 , 0 , 1 , 1 ], [0 ] * 7 ),
1661+ ([0 , 1 , 0 , 0 , 1 , 1 , 0 ], [1 ] * 7 ),
1662+ ([0 , 0 , 1 , 0 , 1 , 1 , 0 ], [2 ] * 7 ),
1663+ ([0 , 0 , 0 , 1 , 0 , 0 , 1 ], [3 ] * 7 ),
1664+ # Switch between each of the samples
1665+ ([1 , 1 , 1 , 1 , 0 , 0 , 1 ], [0 , 1 , 2 , 3 , 3 , 3 , 3 ]),
1666+ ],
1667+ )
1668+ def test_match_samples (self , h , expected_path ):
1669+ ts = self .ts ()
1670+ path = check_viterbi (ts , h )
15661671 nt .assert_array_equal (expected_path , path )
1567- cm = check_forward_matrix (
1568- ts , h , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1569- )
1570- print (cm .decode ())
1571- bm = check_backward_matrix (
1572- ts , h , cm , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1573- )
1574- print (bm .decode ())
1672+ cm = check_forward_matrix (ts , h )
1673+ check_backward_matrix (ts , h , cm )
15751674
15761675
15771676class TestSimulationExamples :
0 commit comments