Skip to content

Commit 52a1bcc

Browse files
petrelharphyanwong
authored andcommitted
c stuff for disjoint options to union
Add some python tests And fix concatenate() all_mutations implies really all_sites
1 parent e4ca469 commit 52a1bcc

File tree

11 files changed

+418
-30
lines changed

11 files changed

+418
-30
lines changed

c/tests/test_tables.c

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11240,6 +11240,109 @@ test_table_collection_union(void)
1124011240
tsk_table_collection_free(&tables);
1124111241
}
1124211242

11243+
static void
11244+
test_table_collection_disjoint_union(void)
11245+
{
11246+
int ret;
11247+
tsk_id_t ret_id;
11248+
tsk_table_collection_t tables;
11249+
tsk_table_collection_t tables1;
11250+
tsk_table_collection_t tables2;
11251+
tsk_table_collection_t tables12;
11252+
tsk_id_t node_mapping[4];
11253+
11254+
tsk_memset(node_mapping, 0xff, sizeof(node_mapping));
11255+
11256+
ret = tsk_table_collection_init(&tables1, 0);
11257+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11258+
tables1.sequence_length = 2;
11259+
11260+
// set up nodes, which will be shared
11261+
// flags, time, pop, ind, metadata, metadata_length
11262+
ret_id = tsk_node_table_add_row(
11263+
&tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0);
11264+
CU_ASSERT_FATAL(ret_id >= 0);
11265+
ret_id = tsk_node_table_add_row(
11266+
&tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0);
11267+
CU_ASSERT_FATAL(ret_id >= 0);
11268+
ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 0.5, TSK_NULL, TSK_NULL, NULL, 0);
11269+
CU_ASSERT_FATAL(ret_id >= 0);
11270+
ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 1.5, TSK_NULL, TSK_NULL, NULL, 0);
11271+
CU_ASSERT_FATAL(ret_id >= 0);
11272+
ret = tsk_table_collection_copy(&tables1, &tables2, 0);
11273+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11274+
11275+
// for tables1:
11276+
// on [0, 1] we have 0, 1 inherit from 2
11277+
// left, right, parent, child, metadata, metadata_length
11278+
ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 0, NULL, 0);
11279+
CU_ASSERT_FATAL(ret_id >= 0);
11280+
ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 1, NULL, 0);
11281+
CU_ASSERT_FATAL(ret_id >= 0);
11282+
ret_id = tsk_site_table_add_row(&tables1.sites, 0.4, "T", 1, NULL, 0);
11283+
CU_ASSERT_FATAL(ret_id >= 0);
11284+
ret_id = tsk_mutation_table_add_row(
11285+
&tables1.mutations, ret_id, 0, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0);
11286+
CU_ASSERT_FATAL(ret_id >= 0);
11287+
ret = tsk_table_collection_build_index(&tables1, 0);
11288+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11289+
ret = tsk_table_collection_sort(&tables1, NULL, 0);
11290+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11291+
11292+
// all this goes in tables12 so far
11293+
ret = tsk_table_collection_copy(&tables1, &tables12, 0);
11294+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11295+
11296+
// for tables2; and need to add to tables12 also:
11297+
// on [1, 2] we have 0, 1 inherit from 3
11298+
// left, right, parent, child, metadata, metadata_length
11299+
ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 0, NULL, 0);
11300+
CU_ASSERT_FATAL(ret_id >= 0);
11301+
ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 1, NULL, 0);
11302+
CU_ASSERT_FATAL(ret_id >= 0);
11303+
ret_id = tsk_site_table_add_row(&tables2.sites, 1.4, "A", 1, NULL, 0);
11304+
CU_ASSERT_FATAL(ret_id >= 0);
11305+
ret_id = tsk_mutation_table_add_row(
11306+
&tables2.mutations, ret_id, 1, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0);
11307+
CU_ASSERT_FATAL(ret_id >= 0);
11308+
ret = tsk_table_collection_build_index(&tables2, 0);
11309+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11310+
ret = tsk_table_collection_sort(&tables2, NULL, 0);
11311+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11312+
// also tables12
11313+
ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 0, NULL, 0);
11314+
CU_ASSERT_FATAL(ret_id >= 0);
11315+
ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 1, NULL, 0);
11316+
CU_ASSERT_FATAL(ret_id >= 0);
11317+
ret_id = tsk_site_table_add_row(&tables12.sites, 1.4, "A", 1, NULL, 0);
11318+
CU_ASSERT_FATAL(ret_id >= 0);
11319+
ret_id = tsk_mutation_table_add_row(
11320+
&tables12.mutations, ret_id, 1, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0);
11321+
CU_ASSERT_FATAL(ret_id >= 0);
11322+
ret = tsk_table_collection_build_index(&tables12, 0);
11323+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11324+
ret = tsk_table_collection_sort(&tables12, NULL, 0);
11325+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11326+
11327+
// now disjoint union-ing tables1 and tables2 should get tables12
11328+
ret = tsk_table_collection_copy(&tables1, &tables, 0);
11329+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11330+
node_mapping[0] = 0;
11331+
node_mapping[1] = 1;
11332+
node_mapping[2] = 2;
11333+
node_mapping[3] = 3;
11334+
ret = tsk_table_collection_union(&tables, &tables2, node_mapping,
11335+
TSK_UNION_NO_CHECK_SHARED | TSK_UNION_ALL_EDGES | TSK_UNION_ALL_MUTATIONS);
11336+
CU_ASSERT_EQUAL_FATAL(ret, 0);
11337+
CU_ASSERT_FATAL(
11338+
tsk_table_collection_equals(&tables, &tables12, TSK_CMP_IGNORE_PROVENANCE));
11339+
11340+
tsk_table_collection_free(&tables12);
11341+
tsk_table_collection_free(&tables2);
11342+
tsk_table_collection_free(&tables1);
11343+
tsk_table_collection_free(&tables);
11344+
}
11345+
1124311346
static void
1124411347
test_table_collection_union_middle_merge(void)
1124511348
{
@@ -11836,6 +11939,7 @@ main(int argc, char **argv)
1183611939
test_table_collection_subset_unsorted },
1183711940
{ "test_table_collection_subset_errors", test_table_collection_subset_errors },
1183811941
{ "test_table_collection_union", test_table_collection_union },
11942+
{ "test_table_collection_disjoint_union", test_table_collection_disjoint_union },
1183911943
{ "test_table_collection_union_middle_merge",
1184011944
test_table_collection_union_middle_merge },
1184111945
{ "test_table_collection_union_errors", test_table_collection_union_errors },

c/tskit/tables.c

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13202,6 +13202,8 @@ tsk_table_collection_union(tsk_table_collection_t *self,
1320213202
tsk_id_t *site_map = NULL;
1320313203
bool add_populations = !(options & TSK_UNION_NO_ADD_POP);
1320413204
bool check_shared_portion = !(options & TSK_UNION_NO_CHECK_SHARED);
13205+
bool all_edges = !!(options & TSK_UNION_ALL_EDGES);
13206+
bool all_mutations = !!(options & TSK_UNION_ALL_MUTATIONS);
1320513207

1320613208
/* Not calling TSK_CHECK_TREES so casting to int is safe */
1320713209
ret = (int) tsk_table_collection_check_integrity(self, 0);
@@ -13285,7 +13287,7 @@ tsk_table_collection_union(tsk_table_collection_t *self,
1328513287
// edges
1328613288
for (k = 0; k < (tsk_id_t) other->edges.num_rows; k++) {
1328713289
tsk_edge_table_get_row_unsafe(&other->edges, k, &edge);
13288-
if ((other_node_mapping[edge.parent] == TSK_NULL)
13290+
if (all_edges || (other_node_mapping[edge.parent] == TSK_NULL)
1328913291
|| (other_node_mapping[edge.child] == TSK_NULL)) {
1329013292
new_parent = node_map[edge.parent];
1329113293
new_child = node_map[edge.child];
@@ -13298,14 +13300,31 @@ tsk_table_collection_union(tsk_table_collection_t *self,
1329813300
}
1329913301
}
1330013302

13301-
// mutations and sites
13303+
// sites
13304+
// first do the "disjoint" (all_mutations) case, where we just add all sites;
13305+
// otherwise we want to just add sites for new mutations
13306+
if (all_mutations) {
13307+
for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) {
13308+
tsk_site_table_get_row_unsafe(&other->sites, k, &site);
13309+
ret_id = tsk_site_table_add_row(&self->sites, site.position,
13310+
site.ancestral_state, site.ancestral_state_length, site.metadata,
13311+
site.metadata_length);
13312+
if (ret_id < 0) {
13313+
ret = (int) ret_id;
13314+
goto out;
13315+
}
13316+
site_map[site.id] = ret_id;
13317+
}
13318+
}
13319+
13320+
// mutations (and maybe sites)
1330213321
i = 0;
1330313322
for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) {
1330413323
tsk_site_table_get_row_unsafe(&other->sites, k, &site);
1330513324
while ((i < (tsk_id_t) other->mutations.num_rows)
1330613325
&& (other->mutations.site[i] == site.id)) {
1330713326
tsk_mutation_table_get_row_unsafe(&other->mutations, i, &mut);
13308-
if (other_node_mapping[mut.node] == TSK_NULL) {
13327+
if (all_mutations || (other_node_mapping[mut.node] == TSK_NULL)) {
1330913328
if (site_map[site.id] == TSK_NULL) {
1331013329
ret_id = tsk_site_table_add_row(&self->sites, site.position,
1331113330
site.ancestral_state, site.ancestral_state_length, site.metadata,

c/tskit/tables.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,11 +858,21 @@ equality of the subsets.
858858
*/
859859
#define TSK_UNION_NO_CHECK_SHARED (1 << 0)
860860
/**
861-
By default, all nodes new to ``self`` are assigned new populations. If this
861+
By default, all nodes new to ``self`` are assigned new populations. If this
862862
option is specified, nodes that are added to ``self`` will retain the
863863
population IDs they have in ``other``.
864864
*/
865865
#define TSK_UNION_NO_ADD_POP (1 << 1)
866+
/**
867+
By default, union only adds edges adjacent to a newly added node;
868+
this option adds all edges.
869+
*/
870+
#define TSK_UNION_ALL_EDGES (1 << 2)
871+
/**
872+
By default, union only adds only mutations on newly added edges, and
873+
sites for those mutations; this option adds all mutations and all sites.
874+
*/
875+
#define TSK_UNION_ALL_MUTATIONS (1 << 3)
866876
/** @} */
867877

868878
/**
@@ -4414,6 +4424,10 @@ that are exclusive ``other`` are added to ``self``, along with:
44144424
By default, populations of newly added nodes are assumed to be new populations,
44154425
and added to the population table as well.
44164426
4427+
The behavior can be changed by the flags ``TSK_UNION_ALL_EDGES`` and
4428+
``TSK_UNION_ALL_MUTATIONS``, which will (respectively) add *all* edges
4429+
or *all* sites and mutations instead.
4430+
44174431
This operation will also sort the resulting tables, so the tables may change
44184432
even if nothing new is added, if the original tables were not sorted.
44194433

python/CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@
5050
- Add ``Mutation.inherited_state`` property which returns the inherited state
5151
for a single mutation. (:user:`benjeffery`, :pr:`3277`, :issue:`2631`)
5252

53+
- Add ``all_mutations`` and ``all_edges`` options to ``TreeSequence.union``,
54+
allowing greater flexibility in "disjoint union" situations.
55+
(:user:`hyanwong`, :user:`petrelharp`, :issue:`3181`)
56+
5357
**Bugfixes**
5458

5559
- In some tables with mutations out-of-order ``TableCollection.sort`` did not re-order
@@ -84,6 +88,9 @@
8488
- Prevent iterating over a ``TopologyCounter``
8589
(:user:`benjeffery` , :pr:`3202`, :issue:`1462`)
8690

91+
- Fix ``TreeSequence.concatenate()`` to work with internal samples by using the
92+
``all_mutations`` and ``all_edges`` parameters in ``union()``
93+
(:user:`hyanwong`, :pr:`3283`, :issue:`3181`)
8794

8895
**Breaking changes**
8996

python/_tskitmodule.c

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4347,15 +4347,18 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds)
43474347
npy_intp *shape;
43484348
tsk_flags_t options = 0;
43494349
int check_shared = true;
4350+
int all_edges = false;
4351+
int all_mutations = false;
43504352
int add_populations = true;
43514353
static char *kwlist[] = { "other", "other_node_mapping", "check_shared_equality",
4352-
"add_populations", NULL };
4354+
"add_populations", "all_edges", "all_mutations", NULL };
43534355

43544356
if (TableCollection_check_state(self) != 0) {
43554357
goto out;
43564358
}
4357-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|ii", kwlist, &TableCollectionType,
4358-
&other, &other_node_mapping, &check_shared, &add_populations)) {
4359+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|iiii", kwlist,
4360+
&TableCollectionType, &other, &other_node_mapping, &check_shared,
4361+
&add_populations, &all_edges, &all_mutations)) {
43594362
goto out;
43604363
}
43614364
nmap_array = (PyArrayObject *) PyArray_FROMANY(
@@ -4370,6 +4373,12 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds)
43704373
" number of nodes in the other tree sequence.");
43714374
goto out;
43724375
}
4376+
if (all_edges) {
4377+
options |= TSK_UNION_ALL_EDGES;
4378+
}
4379+
if (all_mutations) {
4380+
options |= TSK_UNION_ALL_MUTATIONS;
4381+
}
43734382
if (!check_shared) {
43744383
options |= TSK_UNION_NO_CHECK_SHARED;
43754384
}

python/tests/test_highlevel.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2536,6 +2536,28 @@ def test_mutation_parent_errors(self, mutations, error):
25362536
else:
25372537
tables.tree_sequence()
25382538

2539+
def test_union(self, ts_fixture):
2540+
# most of the union tests are in test_tables.py, here we just sanity check
2541+
tables = ts_fixture.dump_tables()
2542+
tables.migrations.clear() # migrations not supported in union()
2543+
ts = tables.tree_sequence()
2544+
tables = tskit.TableCollection(ts.sequence_length)
2545+
tables.time_units = ts.time_units
2546+
empty = tables.tree_sequence()
2547+
union_ts = empty.union(
2548+
ts,
2549+
node_mapping=np.full(ts.num_nodes, tskit.NULL, dtype=int),
2550+
all_edges=True,
2551+
all_mutations=True,
2552+
check_shared_equality=False,
2553+
)
2554+
union_ts.tables.assert_equals(
2555+
ts.tables,
2556+
ignore_metadata=True,
2557+
ignore_reference_sequence=True,
2558+
ignore_provenance=True,
2559+
)
2560+
25392561

25402562
class TestSimplify:
25412563
# This class was factored out of the old TestHighlevel class 2022-12-13,

python/tests/test_lowlevel.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,25 @@ def test_union_bad_args(self):
433433
with pytest.raises(ValueError):
434434
tc.union(tc2, np.array([[1], [2]], dtype="int32"))
435435

436+
@pytest.mark.parametrize("value", [True, False])
437+
@pytest.mark.parametrize(
438+
"flag",
439+
[
440+
"all_edges",
441+
"all_mutations",
442+
"check_shared_equality",
443+
"add_populations",
444+
],
445+
)
446+
def test_union_options(self, flag, value):
447+
ts = msprime.simulate(10, random_seed=1)
448+
tc = ts.dump_tables()._ll_tables
449+
empty_tables = ts.dump_tables()
450+
for table in empty_tables.table_name_map.keys():
451+
getattr(empty_tables, table).clear()
452+
tc2 = empty_tables._ll_tables
453+
tc.union(tc2, np.arange(0, dtype="int32"), **{flag: value})
454+
436455
def test_equals_bad_args(self):
437456
ts = msprime.simulate(10, random_seed=1242)
438457
tc = ts.dump_tables()._ll_tables

python/tests/test_tables.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5271,6 +5271,82 @@ def test_examples(self):
52715271
ts = tables.tree_sequence()
52725272
self.verify_union(*self.split_example(ts, T))
52735273

5274+
def test_split_and_rejoin(self):
5275+
ts = self.get_msprime_example(5, T=2, seed=928)
5276+
cutpoints = np.array([0, 0.25, 0.5, 0.75, 1]) * ts.sequence_length
5277+
tables1 = ts.dump_tables()
5278+
tables1.delete_intervals([cutpoints[0:2], cutpoints[2:4]], simplify=False)
5279+
tables2 = ts.dump_tables()
5280+
tables2.delete_intervals([cutpoints[1:3], cutpoints[3:]], simplify=False)
5281+
tables1.union(
5282+
tables2,
5283+
all_edges=True,
5284+
all_mutations=True,
5285+
node_mapping=np.arange(ts.num_nodes),
5286+
check_shared_equality=False,
5287+
)
5288+
tables1.edges.squash()
5289+
tables1.sort()
5290+
tables1.assert_equals(ts.tables, ignore_provenance=True)
5291+
5292+
def test_both_empty(self):
5293+
tables = tskit.TableCollection(sequence_length=1)
5294+
t1 = tables.copy()
5295+
t2 = tables.copy()
5296+
t1.union(t2, node_mapping=np.arange(0), all_edges=True, all_mutations=True)
5297+
t1.assert_equals(tables, ignore_provenance=True)
5298+
5299+
def test_one_empty(self):
5300+
ts = self.get_msprime_example(5, T=2, seed=928)
5301+
ts = ts.simplify() # the example has a load of unreferenced individuals
5302+
tables = ts.dump_tables()
5303+
empty = tskit.TableCollection(sequence_length=tables.sequence_length)
5304+
empty.time_units = tables.time_units
5305+
5306+
# union with empty should be no-op
5307+
tables.union(
5308+
empty, node_mapping=np.arange(0), all_edges=True, all_mutations=True
5309+
)
5310+
tables.assert_equals(ts.dump_tables(), ignore_provenance=True)
5311+
5312+
# empty union with tables should be tables
5313+
empty.union(
5314+
tables,
5315+
node_mapping=np.full(tables.nodes.num_rows, tskit.NULL),
5316+
all_edges=True,
5317+
all_mutations=True,
5318+
check_shared_equality=False,
5319+
)
5320+
empty.assert_equals(tables, ignore_provenance=True)
5321+
5322+
def test_reciprocal_empty(self):
5323+
# reciprocally add mutations from one table and edges from the other
5324+
edges_table = tskit.Tree.generate_comb(6, span=6).tree_sequence.dump_tables()
5325+
muts_table = tskit.TableCollection(sequence_length=6)
5326+
muts_table.nodes.replace_with(edges_table.nodes) # same nodes, no edges
5327+
for j in range(0, 6):
5328+
site_id = muts_table.sites.add_row(position=j, ancestral_state="0")
5329+
if j % 2 == 0:
5330+
# Some sites empty
5331+
muts_table.mutations.add_row(site=site_id, node=j, derived_state="1")
5332+
identity_map = np.arange(len(muts_table.nodes), dtype="int32")
5333+
params = {"node_mapping": identity_map, "check_shared_equality": False}
5334+
5335+
test_table = edges_table.copy()
5336+
test_table.union(muts_table, **params, all_edges=True) # null op
5337+
assert len(test_table.sites) == 0
5338+
assert len(test_table.mutations) == 0
5339+
test_table.union(muts_table, **params, all_mutations=True)
5340+
assert test_table.sites == muts_table.sites
5341+
assert test_table.mutations == muts_table.mutations
5342+
5343+
muts_table.union(edges_table, **params, all_mutations=True) # null op
5344+
assert len(muts_table.edges) == 0
5345+
muts_table.union(edges_table, **params, all_edges=True)
5346+
assert muts_table.edges == edges_table.edges
5347+
5348+
muts_table.assert_equals(test_table, ignore_provenance=True)
5349+
52745350

52755351
class TestTableSetitemMetadata:
52765352
@pytest.mark.parametrize("table_name", tskit.TABLE_NAMES)

0 commit comments

Comments
 (0)