Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -5344,6 +5344,69 @@ TreeSequence_dump_tables(TreeSequence *self, PyObject *args, PyObject *kwds)
return ret;
}

static PyObject *
TreeSequence_link_ancestors(TreeSequence *self, PyObject *args, PyObject *kwds)
{
int err;
PyObject *ret = NULL;
PyObject *samples = NULL;
PyObject *ancestors = NULL;
PyArrayObject *samples_array = NULL;
PyArrayObject *ancestors_array = NULL;
npy_intp *shape;
tsk_size_t num_samples, num_ancestors;
EdgeTable *result = NULL;
PyObject *result_args = NULL;
static char *kwlist[] = { "samples", "ancestors", NULL };

if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &samples, &ancestors)) {
goto out;
}

samples_array = (PyArrayObject *) PyArray_FROMANY(
samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
if (samples_array == NULL) {
goto out;
}
shape = PyArray_DIMS(samples_array);
num_samples = (tsk_size_t) shape[0];

ancestors_array = (PyArrayObject *) PyArray_FROMANY(
ancestors, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
if (ancestors_array == NULL) {
goto out;
}
shape = PyArray_DIMS(ancestors_array);
num_ancestors = (tsk_size_t) shape[0];

result_args = PyTuple_New(0);
if (result_args == NULL) {
goto out;
}
result = (EdgeTable *) PyObject_CallObject((PyObject *) &EdgeTableType, result_args);
if (result == NULL) {
goto out;
}
err = tsk_table_collection_link_ancestors(self->tree_sequence->tables,
PyArray_DATA(samples_array), num_samples, PyArray_DATA(ancestors_array),
num_ancestors, 0, result->table);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = (PyObject *) result;
result = NULL;
out:
Py_XDECREF(samples_array);
Py_XDECREF(ancestors_array);
Py_XDECREF(result);
Py_XDECREF(result_args);
return ret;
}

static PyObject *
TreeSequence_load(TreeSequence *self, PyObject *args, PyObject *kwds)
{
Expand Down Expand Up @@ -8528,6 +8591,10 @@ static PyMethodDef TreeSequence_methods[] = {
.ml_meth = (PyCFunction) TreeSequence_dump_tables,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Dumps the tree sequence to the specified set of tables" },
{ .ml_name = "link_ancestors",
.ml_meth = (PyCFunction) TreeSequence_link_ancestors,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Returns an EdgeTable linking the specified samples and ancestors." },
{ .ml_name = "get_node",
.ml_meth = (PyCFunction) TreeSequence_get_node,
.ml_flags = METH_VARARGS,
Expand Down
10 changes: 10 additions & 0 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -4901,6 +4901,11 @@ def do_map(self, ts, ancestors, samples=None, compare_lib=True):
if compare_lib:
lib_result = ts.dump_tables().link_ancestors(samples, ancestors)
assert ancestor_table == lib_result
ts_result = ts.link_ancestors(samples, ancestors)
assert ancestor_table == ts_result
if _tskit.HAS_NUMPY_2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the HAS_NUMPY_2 thing is needed, should work either way right?

tables_result = ts.tables.link_ancestors(samples, ancestors)
assert ancestor_table == tables_result
return ancestor_table

def test_deprecated_name(self):
Expand All @@ -4914,6 +4919,11 @@ def test_deprecated_name(self):
tss = s.link_ancestors()
lib_result = ts.dump_tables().map_ancestors(samples, ancestors)
assert tss == lib_result
ts_result = ts.link_ancestors(samples, ancestors)
assert tss == ts_result
if _tskit.HAS_NUMPY_2:
immutable_result = ts.tables.map_ancestors(samples, ancestors)
assert tss == immutable_result
assert list(tss.parent) == [8, 8, 8, 8, 8]
assert list(tss.child) == [0, 1, 2, 3, 4]
assert all(tss.left) == 0
Expand Down
17 changes: 15 additions & 2 deletions python/tskit/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4780,6 +4780,21 @@ def __str__(self):
]
)

def link_ancestors(self, samples, ancestors):
"""
See :meth:`TableCollection.link_ancestors`.
"""
samples = util.safe_np_int_cast(samples, np.int32)
ancestors = util.safe_np_int_cast(ancestors, np.int32)
ll_edge_table = self._llts.link_ancestors(samples, ancestors)
return EdgeTable(ll_table=ll_edge_table)

def map_ancestors(self, *args, **kwargs):
"""
Deprecated alias for :meth:`link_ancestors`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this is our opportunity to remove map_ancestors?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although that would doubtless break someone's code somewhere...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not bother - it's not doing any harm and no point in breaking code if we don't have to

"""
return self.link_ancestors(*args, **kwargs)

_MUTATOR_METHODS = {
"clear",
"sort",
Expand All @@ -4803,8 +4818,6 @@ def __str__(self):
"ibd_segments",
"fromdict",
"simplify",
"link_ancestors",
"map_ancestors",
}

def copy(self):
Expand Down
16 changes: 16 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -4372,6 +4372,22 @@ def dump_tables(self):
self._ll_tree_sequence.dump_tables(ll_tables)
return tables.TableCollection(ll_tables=ll_tables)

def link_ancestors(self, samples, ancestors):
"""
Equivalent to :meth:`TableCollection.link_ancestors`; see that method for full
documentation and parameter semantics.

:param list[int] samples: Node IDs to retain as samples.
:param list[int] ancestors: Node IDs to treat as ancestors.
:return: An :class:`tables.EdgeTable` containing the genealogical links between
the supplied ``samples`` and ``ancestors``.
:rtype: tables.EdgeTable
"""
samples = util.safe_np_int_cast(samples, np.int32)
ancestors = util.safe_np_int_cast(ancestors, np.int32)
ll_edge_table = self._ll_tree_sequence.link_ancestors(samples, ancestors)
return tables.EdgeTable(ll_table=ll_edge_table)

def dump_text(
self,
nodes=None,
Expand Down