diff --git a/c/examples/Makefile b/c/examples/Makefile index b289c2c253..03639ad165 100644 --- a/c/examples/Makefile +++ b/c/examples/Makefile @@ -23,7 +23,7 @@ TSKIT_SOURCE=../tskit/*.c ../subprojects/kastore/kastore.c targets = api_structure error_handling \ haploid_wright_fisher streaming \ tree_iteration tree_traversal \ - take_ownership + take_ownership haplotype_benchmark all: $(targets) @@ -32,4 +32,3 @@ $(targets): %: %.c clean: rm -f $(targets) - diff --git a/c/examples/haplotype_benchmark.c b/c/examples/haplotype_benchmark.c new file mode 100644 index 0000000000..6f78456480 --- /dev/null +++ b/c/examples/haplotype_benchmark.c @@ -0,0 +1,90 @@ +#include +#include +#include +#include + +#include +#include +#include + +#define CHECK_TSK(err) \ + do { \ + if ((err) < 0) { \ + fprintf(stderr, "Error: line %d: %s\n", __LINE__, tsk_strerror(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define NUM_ITERATIONS 1 +#define MAX_BENCHMARK_NODES 500 + +int +main(int argc, char **argv) +{ + int ret; + tsk_table_collection_t tables; + tsk_treeseq_t treeseq; + tsk_haplotype_t haplotype_decoder; + int8_t *haplotype = NULL; + double elapsed_seconds; + clock_t start_clock, end_clock; + uint64_t checksum = 0; + + const char *filename = "../../simulated_chrom_21_100k.ts"; + if (argc > 1) { + filename = argv[1]; + } + + ret = tsk_table_collection_init(&tables, 0); + CHECK_TSK(ret); + + ret = tsk_table_collection_load(&tables, filename, 0); + CHECK_TSK(ret); + + ret = tsk_treeseq_init(&treeseq, &tables, 0); + CHECK_TSK(ret); + + tsk_size_t num_nodes = tsk_treeseq_get_num_nodes(&treeseq); + tsk_size_t num_sites = tsk_treeseq_get_num_sites(&treeseq); + if (num_sites == 0) { + fprintf(stderr, "Tree sequence has no sites\n"); + exit(EXIT_FAILURE); + } + + tsk_id_t node_limit + = (tsk_id_t) (num_nodes < MAX_BENCHMARK_NODES ? num_nodes : MAX_BENCHMARK_NODES); + + ret = tsk_haplotype_init(&haplotype_decoder, &treeseq, 0, (tsk_id_t) num_sites); + CHECK_TSK(ret); + + haplotype = malloc(num_sites * sizeof(*haplotype)); + if (haplotype == NULL) { + fprintf(stderr, "Failed to allocate haplotype buffer\n"); + exit(EXIT_FAILURE); + } + + start_clock = clock(); + for (int iter = 0; iter < NUM_ITERATIONS; iter++) { + for (tsk_id_t node = 0; node < node_limit; node++) { + ret = tsk_haplotype_decode(&haplotype_decoder, node, haplotype); + CHECK_TSK(ret); + for (tsk_id_t site = 0; site < (tsk_id_t) num_sites; site++) { + checksum += (uint64_t) haplotype[site]; + } + } + } + end_clock = clock(); + + elapsed_seconds = (double) (end_clock - start_clock) / CLOCKS_PER_SEC; + + printf("Loaded tree sequence from %s\n", filename); + printf("Decoded %d iterations over %lld nodes × %lld sites in %.3f seconds\n", + NUM_ITERATIONS, (long long) node_limit, (long long) num_sites, elapsed_seconds); + printf("Checksummed haplotypes: %llu\n", (unsigned long long) checksum); + + free(haplotype); + tsk_haplotype_free(&haplotype_decoder); + tsk_treeseq_free(&treeseq); + tsk_table_collection_free(&tables); + return EXIT_SUCCESS; +} diff --git a/c/meson.build b/c/meson.build index f5c1a0f585..6d0e8c66b5 100644 --- a/c/meson.build +++ b/c/meson.build @@ -113,6 +113,9 @@ if not meson.is_subproject() executable('tree_traversal', sources: ['examples/tree_traversal.c'], link_with: [tskit_lib], dependencies: lib_deps) + executable('haplotype_benchmark', + sources: ['examples/haplotype_benchmark.c'], + link_with: [tskit_lib], dependencies: lib_deps) executable('streaming', sources: ['examples/streaming.c'], link_with: [tskit_lib], dependencies: lib_deps) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index c2385281bd..dd85e17bf2 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -23,13 +23,804 @@ * SOFTWARE. */ +#include #include #include #include #include +#include +#include + +#if defined(_MSC_VER) +#include +#endif #include +// FIXME Tskit already has a bitset implementation that maybe we could use + +static inline uint32_t +tsk_haplotype_ctz64(uint64_t x) +{ +#if defined(_MSC_VER) + unsigned long index; + _BitScanForward64(&index, x); + return (uint32_t) index; +#else + return (uint32_t) __builtin_ctzll(x); +#endif +} + +static inline uint32_t +tsk_haplotype_popcount64(uint64_t value) +{ +#if defined(_MSC_VER) + return (uint32_t) __popcnt64(value); +#else + return (uint32_t) __builtin_popcountll(value); +#endif +} + +static inline void +tsk_haplotype_bitset_clear(tsk_haplotype_t *self, tsk_size_t idx) +{ + tsk_size_t word = idx >> 6; + uint64_t mask = UINT64_C(1) << (idx & 63); + if ((self->unresolved_bits[word] & mask) == 0) { + return; + } + self->unresolved_bits[word] &= ~mask; + if (self->unresolved_counts[word] > 0) { + self->unresolved_counts[word]--; + } +} + +static inline void +tsk_haplotype_clear_word_bit(tsk_haplotype_t *self, tsk_size_t word, uint64_t mask) +{ + if ((self->unresolved_bits[word] & mask) != 0) { + self->unresolved_bits[word] &= ~mask; + if (self->unresolved_counts[word] > 0) { + self->unresolved_counts[word]--; + } + } +} + +static inline bool +tsk_haplotype_bitset_test(const uint64_t *bits, tsk_size_t idx) +{ + tsk_size_t word = idx >> 6; + uint64_t mask = UINT64_C(1) << (idx & 63); + return (bits[word] & mask) != 0; +} + +static inline tsk_size_t +tsk_haplotype_bitset_next( + const tsk_haplotype_t *self, tsk_size_t start, tsk_size_t limit) +{ + tsk_size_t word = start >> 6; + tsk_size_t word_limit = (limit + 63) >> 6; + uint64_t mask, value; + + if (start >= limit || word >= self->num_bit_words) { + return limit; + } + mask = UINT64_MAX << (start & 63); + value = self->unresolved_bits[word] & mask; + while (value == 0) { + word++; + if (word >= word_limit || word >= self->num_bit_words) { + return limit; + } + while (word < word_limit && word < self->num_bit_words + && self->unresolved_counts[word] == 0) { + word++; + } + if (word >= word_limit || word >= self->num_bit_words) { + return limit; + } + value = self->unresolved_bits[word]; + } + start = (word << 6) + tsk_haplotype_ctz64(value); + return start < limit ? start : limit; +} + +static inline uint64_t +tsk_haplotype_mask_from_offsets(uint32_t start_offset, uint32_t end_offset) +{ + if (start_offset >= end_offset) { + return 0; + } + if (start_offset == 0 && end_offset >= 64) { + return UINT64_MAX; + } + uint64_t high_mask + = end_offset >= 64 ? UINT64_MAX : ((UINT64_C(1) << end_offset) - 1); + uint64_t low_mask = start_offset == 0 ? 0 : ((UINT64_C(1) << start_offset) - 1); + return high_mask & ~low_mask; +} + +static bool +tsk_haplotype_find_next_uncovered(tsk_haplotype_t *self, tsk_size_t start, + tsk_size_t end, const int32_t *interval_start, const int32_t *interval_end, + tsk_size_t interval_count, tsk_size_t *out_index) +{ + if (start >= end) { + return false; + } + tsk_size_t word = start >> 6; + tsk_size_t last_word = (end - 1) >> 6; + if (word >= self->num_bit_words) { + return false; + } + // FIXME Horrendous logic here, needs jeromeifying. + uint64_t start_mask = UINT64_MAX << (start & 63); + for (; word <= last_word && word < self->num_bit_words; word++) { + if (self->unresolved_counts[word] == 0) { + start_mask = UINT64_MAX; + continue; + } + uint64_t word_bits = self->unresolved_bits[word]; + if (word == (start >> 6)) { + word_bits &= start_mask; + } + if (word == last_word) { + uint64_t end_mask = UINT64_MAX >> (63 - ((end - 1) & 63)); + word_bits &= end_mask; + } + if (word_bits == 0) { + start_mask = UINT64_MAX; + continue; + } + if (interval_count > 0) { + int32_t word_left = (int32_t)(word << 6); + int32_t word_right = word_left + 64; + uint64_t coverage_mask = 0; + for (tsk_size_t p = 0; p < interval_count; p++) { + int32_t interval_left = interval_start[p]; + int32_t interval_right = interval_end[p]; + if (interval_left >= interval_right) { + continue; + } + if (interval_right <= word_left || interval_left >= word_right) { + continue; + } + int32_t clipped_left + = interval_left > word_left ? interval_left : word_left; + int32_t clipped_right + = interval_right < word_right ? interval_right : word_right; + if ((int32_t) start > clipped_left) { + clipped_left = (int32_t) start; + } + if ((int32_t) end < clipped_right) { + clipped_right = (int32_t) end; + } + if (clipped_left >= clipped_right) { + continue; + } + uint32_t start_offset = (uint32_t)(clipped_left - word_left); + uint32_t end_offset = (uint32_t)(clipped_right - word_left); + coverage_mask + |= tsk_haplotype_mask_from_offsets(start_offset, end_offset); + if (coverage_mask == UINT64_MAX) { + break; + } + } + word_bits &= ~coverage_mask; + } + while (word_bits != 0) { + tsk_size_t bit = tsk_haplotype_ctz64(word_bits); + word_bits &= word_bits - 1; + tsk_size_t bit_index = (word << 6) + bit; + if (bit_index >= end) { + break; + } + *out_index = bit_index; + return true; + } + start_mask = UINT64_MAX; + } + return false; +} + +static void +tsk_haplotype_reset_bitset(tsk_haplotype_t *self) +{ + if (self->num_bit_words > 0) { + tsk_memcpy(self->unresolved_bits, self->initial_bits, + self->num_bit_words * sizeof(*self->unresolved_bits)); + tsk_memcpy(self->unresolved_counts, self->initial_counts, + self->num_bit_words * sizeof(*self->unresolved_counts)); + } +} + +// FIXME We're building the whole index here, which is a bit sad when we're clipping a +// region. +static int +tsk_haplotype_build_parent_index(tsk_haplotype_t *self) +{ + int ret = 0; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_edge_table_t *edges = &tables->edges; + const tsk_id_t *edges_child = edges->child; + tsk_size_t num_edges = edges->num_rows; + int32_t *child_counts = NULL; + + if (num_edges == 0) { + self->parent_edge_index = NULL; + if (self->num_nodes > 0) { + self->parent_index_range + = tsk_calloc(self->num_nodes * 2, sizeof(*self->parent_index_range)); + if (self->parent_index_range == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + } else { + self->parent_index_range = NULL; + } + goto out; + } + + self->parent_edge_index = tsk_malloc(num_edges * sizeof(*self->parent_edge_index)); + self->parent_index_range + = tsk_malloc(self->num_nodes * 2 * sizeof(*self->parent_index_range)); + child_counts = tsk_calloc(self->num_nodes, sizeof(*child_counts)); + if (self->parent_edge_index == NULL + || (self->num_nodes > 0 && self->parent_index_range == NULL) + || child_counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + for (tsk_size_t j = 0; j < num_edges; j++) { + tsk_id_t child = edges_child[j]; + if (child >= 0 && child < (tsk_id_t) self->num_nodes) { + if (child_counts[child] == INT32_MAX) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + child_counts[child]++; + } + } + + int32_t current_start = 0; + for (tsk_size_t u = 0; u < (tsk_size_t) self->num_nodes; u++) { + int32_t offset = (int32_t)(u * 2); + self->parent_index_range[offset] = current_start; + self->parent_index_range[offset + 1] = current_start; + current_start += child_counts[u]; + } + + for (tsk_size_t j = 0; j < num_edges; j++) { + tsk_id_t child = edges_child[j]; + if (child >= 0 && child < (tsk_id_t) self->num_nodes) { + int32_t end_offset = (int32_t)(child * 2 + 1); + int32_t pos = self->parent_index_range[end_offset]; + self->parent_edge_index[pos] = (tsk_id_t) j; + self->parent_index_range[end_offset] = pos + 1; + } + } + + for (tsk_size_t u = 0; u < (tsk_size_t) self->num_nodes; u++) { + int32_t offset = (int32_t)(u * 2); + int32_t end = self->parent_index_range[offset + 1]; + self->parent_index_range[offset] = end - child_counts[u]; + } + +out: + if (ret != 0) { + tsk_safe_free(self->parent_edge_index); + self->parent_edge_index = NULL; + tsk_safe_free(self->parent_index_range); + self->parent_index_range = NULL; + } + tsk_safe_free(child_counts); + return ret; +} + +// FIXME No point adding mutations who are above nodes we have no interest in. +static int +tsk_haplotype_build_mutation_index(tsk_haplotype_t *self) +{ + int ret = 0; + tsk_size_t j; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_mutation_table_t *mutations = &tables->mutations; + int32_t *counts = NULL; + tsk_size_t total_mutations = 0; + tsk_id_t site_start = self->site_start; + tsk_id_t site_stop = self->site_stop; + + counts = tsk_calloc(self->num_nodes, sizeof(*counts)); + if (self->num_nodes > 0 && counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + for (j = 0; j < mutations->num_rows; j++) { + tsk_id_t node = mutations->node[j]; + tsk_id_t site = mutations->site[j]; + if (site < site_start || site >= site_stop) { + continue; + } + if (node >= 0 && node < (tsk_id_t) self->num_nodes) { + if (counts[node] == INT32_MAX) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + counts[node]++; + } + } + + self->node_mutation_offsets + = tsk_malloc((self->num_nodes + 1) * sizeof(*self->node_mutation_offsets)); + if (self->node_mutation_offsets == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + self->node_mutation_offsets[0] = 0; + for (j = 0; j < self->num_nodes; j++) { + total_mutations += (tsk_size_t) counts[j]; + if (total_mutations > INT32_MAX) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + self->node_mutation_offsets[j + 1] = (int32_t) total_mutations; + } + + self->node_mutation_sites + = tsk_malloc(total_mutations * sizeof(*self->node_mutation_sites)); + self->node_mutation_states + = tsk_malloc(total_mutations * sizeof(*self->node_mutation_states)); + if ((total_mutations > 0) + && (self->node_mutation_sites == NULL || self->node_mutation_states == NULL)) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + for (j = 0; j < self->num_nodes; j++) { + counts[j] = self->node_mutation_offsets[j]; + } + for (j = mutations->num_rows; j > 0; j--) { + tsk_size_t mut_index = j - 1; + tsk_id_t node = mutations->node[mut_index]; + tsk_id_t site = mutations->site[mut_index]; + if (site < site_start || site >= site_stop) { + continue; + } + if (node >= 0 && node < (tsk_id_t) self->num_nodes) { + tsk_size_t start = mutations->derived_state_offset[mut_index]; + tsk_size_t stop = mutations->derived_state_offset[mut_index + 1]; + tsk_size_t length = stop - start; + uint8_t allele; + + if (length != 1) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + allele = (uint8_t) mutations->derived_state[start]; + if (allele > 0x7F) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + self->node_mutation_sites[counts[node]] = (int32_t)(site - site_start); + self->node_mutation_states[counts[node]] = allele; + counts[node]++; + } + } + +out: + tsk_safe_free(counts); + return ret; +} + +// FIXME Not sure this is even needed +static int +tsk_haplotype_build_ancestral_states(tsk_haplotype_t *self) +{ + int ret = 0; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_site_table_t *sites = &tables->sites; + tsk_id_t site_start = self->site_start; + tsk_size_t j; + + if (self->num_sites == 0) { + self->ancestral_states = NULL; + return 0; + } + + self->ancestral_states + = tsk_malloc((tsk_size_t) self->num_sites * sizeof(*self->ancestral_states)); + if (self->ancestral_states == NULL) { + return tsk_trace_error(TSK_ERR_NO_MEMORY); + } + + for (j = 0; j < (tsk_size_t) self->num_sites; j++) { + tsk_id_t site = site_start + (tsk_id_t) j; + tsk_size_t start = sites->ancestral_state_offset[site]; + tsk_size_t stop = sites->ancestral_state_offset[site + 1]; + tsk_size_t length = stop - start; + uint8_t allele; + if (length != 1) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + allele = (uint8_t) sites->ancestral_state[start]; + if (allele > 0x7F) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + self->ancestral_states[j] = allele; + } + +out: + if (ret != 0) { + tsk_safe_free(self->ancestral_states); + self->ancestral_states = NULL; + } + return ret; +} + +static int +tsk_haplotype_build_edge_intervals(tsk_haplotype_t *self) +{ + int ret = 0; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_edge_table_t *edges = &tables->edges; + const double *positions = tables->sites.position + self->site_start; + tsk_size_t num_edges = edges->num_rows; + tsk_size_t j; + + if (num_edges == 0) { + self->edge_start_index = NULL; + self->edge_end_index = NULL; + return 0; + } + + self->edge_start_index = tsk_malloc(num_edges * sizeof(*self->edge_start_index)); + self->edge_end_index = tsk_malloc(num_edges * sizeof(*self->edge_end_index)); + if (self->edge_start_index == NULL || self->edge_end_index == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + if (self->num_sites == 0) { + for (j = 0; j < num_edges; j++) { + self->edge_start_index[j] = 0; + self->edge_end_index[j] = 0; + } + goto out; + } + + for (j = 0; j < num_edges; j++) { + double left = edges->left[j]; + double right = edges->right[j]; + tsk_size_t start + = tsk_search_sorted(positions, (tsk_size_t) self->num_sites, left); + tsk_size_t end + = tsk_search_sorted(positions, (tsk_size_t) self->num_sites, right); + if (start > (tsk_size_t) self->num_sites) { + start = (tsk_size_t) self->num_sites; + } + if (end > (tsk_size_t) self->num_sites) { + end = (tsk_size_t) self->num_sites; + } + self->edge_start_index[j] = (int32_t) start; + self->edge_end_index[j] = (int32_t) end; + } + +out: + if (ret != 0) { + tsk_safe_free(self->edge_start_index); + tsk_safe_free(self->edge_end_index); + self->edge_start_index = NULL; + self->edge_end_index = NULL; + } + return ret; +} + +static int +tsk_haplotype_alloc_bitset(tsk_haplotype_t *self) +{ + tsk_size_t j; + + self->num_bit_words = ((tsk_size_t) self->num_sites + 63) >> 6; + if (self->num_bit_words == 0) { + self->unresolved_bits = NULL; + self->initial_bits = NULL; + self->unresolved_counts = NULL; + self->initial_counts = NULL; + return 0; + } + self->unresolved_bits + = tsk_malloc(self->num_bit_words * sizeof(*self->unresolved_bits)); + self->initial_bits = tsk_malloc(self->num_bit_words * sizeof(*self->initial_bits)); + self->unresolved_counts + = tsk_malloc(self->num_bit_words * sizeof(*self->unresolved_counts)); + self->initial_counts + = tsk_malloc(self->num_bit_words * sizeof(*self->initial_counts)); + if (self->unresolved_bits == NULL || self->initial_bits == NULL + || self->unresolved_counts == NULL || self->initial_counts == NULL) { + return tsk_trace_error(TSK_ERR_NO_MEMORY); + } + for (j = 0; j < self->num_bit_words; j++) { + uint64_t word = UINT64_MAX; + if (j == self->num_bit_words - 1 && (tsk_size_t) self->num_sites % 64 != 0) { + uint32_t bits = (uint32_t)((tsk_size_t) self->num_sites & 63); + word = (UINT64_C(1) << bits) - 1; + } + self->initial_bits[j] = word; + self->initial_counts[j] = tsk_haplotype_popcount64(word); + } + tsk_haplotype_reset_bitset(self); + return 0; +} + +int +tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, + tsk_id_t site_start, tsk_id_t site_stop) +{ + int ret = 0; + const tsk_table_collection_t *tables; + const tsk_site_table_t *sites; + tsk_size_t total_sites; + + if (tree_sequence == NULL) { + return tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + } + + tsk_memset(self, 0, sizeof(*self)); + self->tree_sequence = tree_sequence; + + tables = tree_sequence->tables; + sites = &tables->sites; + total_sites = sites->num_rows; + + if (site_start < 0 || site_stop < site_start || site_stop > (tsk_id_t) total_sites) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + + self->site_start = (int32_t) site_start; + self->site_stop = (int32_t) site_stop; + self->num_sites = (int32_t)(site_stop - site_start); + self->num_nodes = tables->nodes.num_rows; + self->num_edges = tables->edges.num_rows; + self->site_positions = sites->position + site_start; + + ret = tsk_haplotype_build_parent_index(self); + if (ret != 0) { + goto out; + } + ret = tsk_haplotype_build_mutation_index(self); + if (ret != 0) { + goto out; + } + ret = tsk_haplotype_build_ancestral_states(self); + if (ret != 0) { + goto out; + } + ret = tsk_haplotype_build_edge_intervals(self); + if (ret != 0) { + goto out; + } + ret = tsk_haplotype_alloc_bitset(self); + if (ret != 0) { + goto out; + } + if (self->num_edges > 0) { + self->edge_stack = tsk_malloc(self->num_edges * sizeof(*self->edge_stack)); + self->stack_interval_start + = tsk_malloc(self->num_edges * sizeof(*self->stack_interval_start)); + self->stack_interval_end + = tsk_malloc(self->num_edges * sizeof(*self->stack_interval_end)); + self->parent_interval_start + = tsk_malloc(self->num_edges * sizeof(*self->parent_interval_start)); + self->parent_interval_end + = tsk_malloc(self->num_edges * sizeof(*self->parent_interval_end)); + if (self->edge_stack == NULL || self->stack_interval_start == NULL + || self->stack_interval_end == NULL || self->parent_interval_start == NULL + || self->parent_interval_end == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + } + + self->initialised = true; + +out: + if (ret != 0) { + tsk_haplotype_free(self); + } + return ret; +} + +int +tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) +{ + tsk_size_t stack_top = 0; + const tsk_table_collection_t *tables; + const tsk_edge_table_t *edges; + const tsk_id_t *edge_parent; + int32_t interval_start, interval_end; + int32_t mut_start, mut_end; + tsk_size_t idx; + tsk_size_t parent_count; + uint64_t *bits; + + if (self == NULL || haplotype == NULL) { + return tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + } + if (!self->initialised) { + return tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + } + if (node < 0 || node >= (tsk_id_t) self->num_nodes) { + return tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); + } + if (self->num_sites == 0) { + return 0; + } + + tables = self->tree_sequence->tables; + edges = &tables->edges; + edge_parent = edges->parent; + bits = self->unresolved_bits; + + // Create a bitset that tracks which sites are still unresolved + for (idx = 0; idx < (tsk_size_t) self->num_sites; idx++) { + haplotype[idx] = (int8_t) self->ancestral_states[idx]; + } + tsk_haplotype_reset_bitset(self); + + mut_start = self->node_mutation_offsets[node]; + mut_end = self->node_mutation_offsets[node + 1]; + // Apply mutations above this node + for (int32_t m = mut_start; m < mut_end; m++) { + int32_t site = self->node_mutation_sites[m]; + if (site >= 0 && site < self->num_sites + && tsk_haplotype_bitset_test(bits, (tsk_size_t) site)) { + haplotype[site] = (int8_t) self->node_mutation_states[m]; + tsk_haplotype_bitset_clear(self, (tsk_size_t) site); + } + } + + int32_t child_start = 0; + int32_t child_stop = 0; + if (self->parent_index_range != NULL) { + int32_t range_offset = node * 2; + child_start = self->parent_index_range[range_offset]; + child_stop = self->parent_index_range[range_offset + 1]; + } + // Push all edges from this node (that are still relavent to resolving sites) onto + // the stack + for (int32_t i = child_start; i < child_stop; i++) { + tsk_id_t edge = self->parent_edge_index[i]; + int32_t start = self->edge_start_index[edge]; + int32_t end = self->edge_end_index[edge]; + if (start >= end) { + continue; + } + tsk_size_t uncovered_idx; + if (tsk_haplotype_find_next_uncovered(self, (tsk_size_t) start, (tsk_size_t) end, + self->parent_interval_start, self->parent_interval_end, 0, + &uncovered_idx)) { + self->edge_stack[stack_top] = edge; + self->stack_interval_start[stack_top] = start; + self->stack_interval_end[stack_top] = end; + stack_top++; + } + } + + // Now process the stack until we run out of edges or have resolved all sites + while (stack_top > 0) { + stack_top--; + tsk_id_t edge = self->edge_stack[stack_top]; + tsk_id_t ancestor = edge_parent[edge]; + interval_start = self->stack_interval_start[stack_top]; + interval_end = self->stack_interval_end[stack_top]; + + // Apply mutations above this ancestor + if (ancestor >= 0) { + mut_start = self->node_mutation_offsets[ancestor]; + mut_end = self->node_mutation_offsets[ancestor + 1]; + for (int32_t m = mut_start; m < mut_end; m++) { + int32_t site = self->node_mutation_sites[m]; + if (site >= interval_start && site < interval_end + && tsk_haplotype_bitset_test(bits, (tsk_size_t) site)) { + haplotype[site] = (int8_t) self->node_mutation_states[m]; + tsk_haplotype_bitset_clear(self, (tsk_size_t) site); + } + } + } + + // Going up the tree push all edges from this ancestor (that are still relavent + // to resolving sites) + parent_count = 0; + if (ancestor >= 0 && self->parent_index_range != NULL) { + int32_t range_offset = ancestor * 2; + child_start = self->parent_index_range[range_offset]; + child_stop = self->parent_index_range[range_offset + 1]; + for (int32_t i = child_start; i < child_stop; i++) { + tsk_id_t parent_edge = self->parent_edge_index[i]; + int32_t parent_start = self->edge_start_index[parent_edge]; + int32_t parent_end = self->edge_end_index[parent_edge]; + if (parent_start < interval_start) { + parent_start = interval_start; + } + if (parent_end > interval_end) { + parent_end = interval_end; + } + if (parent_start >= parent_end) { + continue; + } + tsk_size_t uncovered_idx; + if (tsk_haplotype_find_next_uncovered(self, (tsk_size_t) parent_start, + (tsk_size_t) parent_end, self->parent_interval_start, + self->parent_interval_end, parent_count, &uncovered_idx)) { + self->edge_stack[stack_top] = parent_edge; + self->stack_interval_start[stack_top] = parent_start; + self->stack_interval_end[stack_top] = parent_end; + stack_top++; + self->parent_interval_start[parent_count] = parent_start; + self->parent_interval_end[parent_count] = parent_end; + parent_count++; + } + } + } else { + child_start = 0; + child_stop = 0; + } + + // Clear out any sites that are still unresolved in this interval but not covered + // by any parent edges + tsk_size_t uncovered_idx; + while (tsk_haplotype_find_next_uncovered(self, (tsk_size_t) interval_start, + (tsk_size_t) interval_end, self->parent_interval_start, + self->parent_interval_end, parent_count, &uncovered_idx)) { + tsk_size_t word_index = uncovered_idx >> 6; + uint64_t mask = UINT64_C(1) << (uncovered_idx & 63); + tsk_haplotype_clear_word_bit(self, word_index, mask); + } + } + + // Reset the bitset for next time + for (tsk_size_t w = 0; w < self->num_bit_words; w++) { + self->unresolved_bits[w] = 0; + self->unresolved_counts[w] = 0; + } + + return 0; +} + +int +tsk_haplotype_free(tsk_haplotype_t *self) +{ + if (self == NULL) { + return 0; + } + tsk_safe_free(self->ancestral_states); + tsk_safe_free(self->node_mutation_offsets); + tsk_safe_free(self->node_mutation_sites); + tsk_safe_free(self->node_mutation_states); + tsk_safe_free(self->parent_edge_index); + tsk_safe_free(self->parent_index_range); + tsk_safe_free(self->edge_start_index); + tsk_safe_free(self->edge_end_index); + tsk_safe_free(self->edge_stack); + tsk_safe_free(self->stack_interval_start); + tsk_safe_free(self->stack_interval_end); + tsk_safe_free(self->parent_interval_start); + tsk_safe_free(self->parent_interval_end); + tsk_safe_free(self->unresolved_bits); + tsk_safe_free(self->initial_bits); + tsk_safe_free(self->unresolved_counts); + tsk_safe_free(self->initial_counts); + self->tree_sequence = NULL; + self->site_positions = NULL; + self->initialised = false; + return 0; +} + /* ======================================================== * * Variant generator * ======================================================== */ diff --git a/c/tskit/genotypes.h b/c/tskit/genotypes.h index 8c3b769e5a..993dbe2f08 100644 --- a/c/tskit/genotypes.h +++ b/c/tskit/genotypes.h @@ -86,6 +86,35 @@ typedef struct { tsk_variant_t variant; } tsk_vargen_t; +typedef struct { + const tsk_treeseq_t *tree_sequence; + tsk_size_t num_nodes; + tsk_size_t num_edges; + int32_t site_start; + int32_t site_stop; + int32_t num_sites; + const double *site_positions; + uint8_t *ancestral_states; + int32_t *node_mutation_offsets; + int32_t *node_mutation_sites; + uint8_t *node_mutation_states; + tsk_id_t *parent_edge_index; + int32_t *parent_index_range; + int32_t *edge_start_index; + int32_t *edge_end_index; + tsk_id_t *edge_stack; + int32_t *stack_interval_start; + int32_t *stack_interval_end; + int32_t *parent_interval_start; + int32_t *parent_interval_end; + uint64_t *unresolved_bits; + uint64_t *initial_bits; + tsk_size_t num_bit_words; + uint32_t *unresolved_counts; + uint32_t *initial_counts; + bool initialised; +} tsk_haplotype_t; + /** @defgroup VARIANT_API_GROUP Variant API for obtaining genotypes. @{ @@ -179,6 +208,10 @@ void tsk_variant_print_state(const tsk_variant_t *self, FILE *out); /** @} */ /* Deprecated vargen methods (since C API v1.0) */ +int tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, + tsk_id_t site_start, tsk_id_t site_stop); +int tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype); +int tsk_haplotype_free(tsk_haplotype_t *self); int tsk_vargen_init(tsk_vargen_t *self, const tsk_treeseq_t *tree_sequence, const tsk_id_t *samples, tsk_size_t num_samples, const char **alleles, tsk_flags_t options); diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 23ab663538..10983613fb 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -154,6 +154,12 @@ typedef struct { tsk_tree_t *tree; } Tree; +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_haplotype_t *haplotype; +} Haplotype; + typedef struct { PyObject_HEAD TreeSequence *tree_sequence; @@ -10594,6 +10600,132 @@ static PyTypeObject TreeType = { // clang-format on }; +/*=================================================================== + * Haplotype + *=================================================================== + */ + +/* Forward declaration */ +static PyTypeObject HaplotypeType; + +static int +Haplotype_check_state(Haplotype *self) +{ + int ret = 0; + if (self->haplotype == NULL) { + PyErr_SetString(PyExc_SystemError, "haplotype not initialised"); + ret = -1; + } + return ret; +} + +static void +Haplotype_dealloc(Haplotype *self) +{ + if (self->haplotype != NULL) { + tsk_haplotype_free(self->haplotype); + PyMem_Free(self->haplotype); + self->haplotype = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject *) self); +} + +static int +Haplotype_init(Haplotype *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = { "tree_sequence", "site_start", "site_stop", NULL }; + TreeSequence *tree_sequence = NULL; + Py_ssize_t site_start; + Py_ssize_t site_stop; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!nn", kwlist, &TreeSequenceType, + &tree_sequence, &site_start, &site_stop)) { + goto out; + } + + self->haplotype = PyMem_Malloc(sizeof(*self->haplotype)); + if (self->haplotype == NULL) { + PyErr_NoMemory(); + goto out; + } + + self->tree_sequence = tree_sequence; + Py_INCREF(tree_sequence); + + err = tsk_haplotype_init(self->haplotype, tree_sequence->tree_sequence, + (tsk_id_t) site_start, (tsk_id_t) site_stop); + if (err != 0) { + handle_library_error(err); + goto out; + } + + ret = 0; +out: + if (ret != 0) { + if (self->haplotype != NULL) { + tsk_haplotype_free(self->haplotype); + PyMem_Free(self->haplotype); + self->haplotype = NULL; + } + Py_XDECREF(self->tree_sequence); + self->tree_sequence = NULL; + } + return ret; +} + +static PyObject * +Haplotype_decode(Haplotype *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + tsk_id_t node; + + if (Haplotype_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O&", &tsk_id_converter, &node)) { + goto out; + } + + ret = PyBytes_FromStringAndSize(NULL, (Py_ssize_t) self->haplotype->num_sites); + if (ret == NULL) { + goto out; + } + err = tsk_haplotype_decode(self->haplotype, node, (int8_t *) PyBytes_AS_STRING(ret)); + if (err != 0) { + handle_library_error(err); + Py_CLEAR(ret); + goto out; + } +out: + return ret; +} + +static PyMethodDef Haplotype_methods[] = { + { .ml_name = "decode", + .ml_meth = (PyCFunction) Haplotype_decode, + .ml_flags = METH_VARARGS, + .ml_doc = "Decode the haplotype for the specified node." }, + { NULL }, +}; + +static PyTypeObject HaplotypeType = { + // clang-format off + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "_tskit.Haplotype", + .tp_basicsize = sizeof(Haplotype), + .tp_dealloc = (destructor) Haplotype_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Low-level haplotype decoder", + .tp_methods = Haplotype_methods, + .tp_init = (initproc) Haplotype_init, + .tp_new = PyType_GenericNew, + // clang-format on +}; + /*=================================================================== * Variant *=================================================================== @@ -11924,6 +12056,13 @@ PyInit__tskit(void) Py_INCREF(&TreeType); PyModule_AddObject(module, "Tree", (PyObject *) &TreeType); + /* Haplotype type */ + if (PyType_Ready(&HaplotypeType) < 0) { + return NULL; + } + Py_INCREF(&HaplotypeType); + PyModule_AddObject(module, "Haplotype", (PyObject *) &HaplotypeType); + /* Variant type */ if (PyType_Ready(&VariantType) < 0) { return NULL; diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py index 38aede1894..4e629afa7f 100644 --- a/python/tests/test_jit.py +++ b/python/tests/test_jit.py @@ -671,3 +671,163 @@ def ancestral_edges_tskit(ts, start_node): a1 = ancestral_edges(numba_ts, u) a2 = ancestral_edges_tskit(ts, u) nt.assert_array_equal(a1, a2) + + +def build_alignment_example(): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # 2 + tables.edges.add_row(0, 3, parent=2, child=0) + tables.edges.add_row(0, 3, parent=2, child=1) + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=2, derived_state="G") + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=1, node=0, derived_state="C") + tables.sites.add_row(2, "A") + tables.mutations.add_row(site=2, node=1, derived_state="T") + tables.sort() + return tables.tree_sequence() + + +def build_missing_alignment_example(): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 isolated + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # ancestor for sample 1 + tables.edges.add_row(0, 3, parent=2, child=1) + tables.sites.add_row(0, "A") + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=1, node=2, derived_state="T") + tables.sites.add_row(2, "A") + tables.sort() + return tables.tree_sequence() + + +def build_internal_sample_example(): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=1) # 1 internal sample + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 2 + tables.nodes.add_row(flags=0, time=2) # 3 root + tables.edges.add_row(0, 3, parent=1, child=0) + tables.edges.add_row(0, 3, parent=3, child=1) + tables.edges.add_row(0, 3, parent=3, child=2) + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=3, derived_state="G") + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=1, node=1, derived_state="C") + tables.sites.add_row(2, "A") + tables.mutations.add_row(site=2, node=0, derived_state="T") + tables.sort() + return tables.tree_sequence() + + +def build_overlapping_edges_example(): + tables = tskit.TableCollection(sequence_length=4) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # 2 + tables.nodes.add_row(flags=0, time=1) # 3 + tables.nodes.add_row(flags=0, time=2) # 4 root + tables.edges.add_row(0, 2, parent=2, child=0) + tables.edges.add_row(2, 4, parent=3, child=0) + tables.edges.add_row(0, 4, parent=3, child=1) + tables.edges.add_row(0, 4, parent=4, child=2) + tables.edges.add_row(0, 4, parent=4, child=3) + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=0, node=2, derived_state="G") + tables.sites.add_row(3, "A") + tables.mutations.add_row(site=1, node=3, derived_state="T") + tables.sort() + return tables.tree_sequence() + + +def build_deep_mutation_example(): + tables = tskit.TableCollection(sequence_length=2) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # 2 + tables.nodes.add_row(flags=0, time=2) # 3 + tables.nodes.add_row(flags=0, time=3) # 4 root + tables.edges.add_row(0, 2, parent=2, child=0) + tables.edges.add_row(0, 2, parent=4, child=1) + tables.edges.add_row(0, 2, parent=3, child=2) + tables.edges.add_row(0, 2, parent=4, child=3) + tables.sites.add_row(0, "A") + m0 = tables.mutations.add_row(site=0, node=4, derived_state="C") + m1 = tables.mutations.add_row(site=0, node=3, derived_state="G", parent=m0) + tables.mutations.add_row(site=0, node=2, derived_state="T", parent=m1) + tables.sort() + return tables.tree_sequence() + + +def build_multiple_roots_example(): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # 2 root A + tables.nodes.add_row(flags=0, time=1) # 3 root B + tables.edges.add_row(0, 3, parent=2, child=0) + tables.edges.add_row(0, 3, parent=3, child=1) + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=2, derived_state="G") + tables.sites.add_row(2, "A") + tables.mutations.add_row(site=1, node=3, derived_state="T") + tables.sort() + return tables.tree_sequence() + + +def _check_alignments(ts): + expected = list(ts.haplotypes()) + sites = list(ts.sites()) + adjusted = [] + for hap in expected: + chars = list(hap) + for j, c in enumerate(chars): + if c == "N": + chars[j] = sites[j].ancestral_state + adjusted.append("".join(chars)) + numba_ts = jit_numba.jitwrap(ts) + observed = list(jit_numba.alignments(numba_ts)) + samples = [node for node, _ in observed] + haplotypes = [hap for _, hap in observed] + assert samples == list(ts.samples()) + assert haplotypes == adjusted + + +def test_jit_alignments_basic(): + ts = build_alignment_example() + _check_alignments(ts) + + +def test_jit_alignments_missing_data(): + ts = build_missing_alignment_example() + _check_alignments(ts) + + +def test_jit_alignments_internal_sample(): + ts = build_internal_sample_example() + _check_alignments(ts) + + +def test_jit_alignments_overlapping_edges(): + ts = build_overlapping_edges_example() + _check_alignments(ts) + + +def test_jit_alignments_deep_mutations(): + ts = build_deep_mutation_example() + _check_alignments(ts) + + +def test_jit_alignments_multiple_roots(): + ts = build_multiple_roots_example() + _check_alignments(ts) + + +def test_jit_alignments_msprime_example(): + ts = msprime.sim_ancestry(5, sequence_length=8, ploidy=1, random_seed=5) + ts = msprime.sim_mutations(ts, rate=0.5, random_seed=13) + assert ts.discrete_genome + _check_alignments(ts) diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py index 40534431d1..c5c0059183 100644 --- a/python/tskit/jit/numba.py +++ b/python/tskit/jit/numba.py @@ -672,3 +672,298 @@ def jitwrap(ts): ) return numba_ts + + +@numba.njit +def _bitset_init(bitset, num_sites): + # Initialise all bits to 1 (meaning "unresolved") and mask any unused bits + # in the final word when the number of sites is not a multiple of 64. + n_words = bitset.shape[0] + all_bits = np.uint64((1 << 64) - 1) + for w in range(n_words): + bitset[w] = all_bits + if n_words > 0: + excess = n_words * 64 - num_sites + if excess > 0: + mask = all_bits >> excess + bitset[n_words - 1] = mask + + +@numba.njit +def _bitset_clear(bitset, idx): + word = idx >> 6 + bit = np.uint64(1) << (idx & 63) + bitset[word] &= ~bit + + +@numba.njit +def _bitset_is_set(bitset, idx): + word = idx >> 6 + bit = np.uint64(1) << (idx & 63) + return (bitset[word] & bit) != 0 + + +@numba.njit +def _ctz64(x): + count = 0 + while (x & 1) == 0: + x >>= 1 + count += 1 + return count + + +@numba.njit +def _bitset_next(bitset, start, num_sites): + # Return the index of the first set bit >= start, or num_sites if none. + if start >= num_sites: + return num_sites + n_words = bitset.shape[0] + word = start >> 6 + offset = start & 63 + if word >= n_words: + return num_sites + mask = np.uint64(-1) << offset + value = bitset[word] & mask + while value == 0: + word += 1 + if word >= n_words: + return num_sites + value = bitset[word] + return (word << 6) + _ctz64(value) + + +def _build_node_mutation_index(numba_ts): + num_nodes = numba_ts.num_nodes + num_mutations = numba_ts.num_mutations + + counts = np.zeros(num_nodes, dtype=np.int32) + mutations_node = numba_ts.mutations_node + mutations_site = numba_ts.mutations_site + mutations_derived_state = numba_ts.mutations_derived_state + + for mut_id in range(num_mutations - 1, -1, -1): + node = mutations_node[mut_id] + if 0 <= node < num_nodes: + counts[node] += 1 + + offsets = np.zeros(num_nodes + 1, dtype=np.int32) + total = 0 + for u in range(num_nodes): + offsets[u] = total + total += counts[u] + offsets[num_nodes] = total + + node_sites = np.empty(total, dtype=np.int32) + node_alleles = np.empty(total, dtype=np.uint8) + insert_pos = offsets.copy() + + for mut_id in range(num_mutations - 1, -1, -1): + node = mutations_node[mut_id] + if 0 <= node < num_nodes: + site = mutations_site[mut_id] + allele = mutations_derived_state[mut_id] + if len(allele) != 1: + raise ValueError("Expected single-character derived alleles") + pos = insert_pos[node] + node_sites[pos] = site + node_alleles[pos] = ord(allele[0]) + insert_pos[node] += 1 + + return node_sites, node_alleles, offsets + + +def _compute_next_site_index(numba_ts, sequence_length): + num_sites = numba_ts.num_sites + next_site = np.empty(sequence_length + 1, dtype=np.int32) + site_positions = numba_ts.sites_position.astype(np.int64) + j = 0 + for pos in range(sequence_length + 1): + while j < num_sites and site_positions[j] < pos: + j += 1 + next_site[pos] = j + return next_site + + +@numba.njit +def _node_haplotype( + numba_ts, + parent_index, + edge_start_index, + edge_end_index, + next_site_index, + node_mut_sites, + node_mut_alleles, + node_mut_offsets, + ancestral_codes, + node, + hap, + unresolved_bits, + stack_edges, + stack_start, + stack_end, + parent_interval_start, + parent_interval_end, +): + num_sites = ancestral_codes.shape[0] + if num_sites == 0: + return + for j in range(num_sites): + hap[j] = ancestral_codes[j] + _bitset_init(unresolved_bits, num_sites) + + edges_parent = numba_ts.edges_parent + edge_index = parent_index.edge_index + index_range = parent_index.index_range + + stack_top = 0 + + mut_start = node_mut_offsets[node] + mut_stop = node_mut_offsets[node + 1] + for m in range(mut_start, mut_stop): + site_idx = node_mut_sites[m] + if site_idx >= num_sites: + continue + if _bitset_is_set(unresolved_bits, site_idx): + hap[site_idx] = node_mut_alleles[m] + _bitset_clear(unresolved_bits, site_idx) + + start_edge, stop_edge = index_range[node, 0], index_range[node, 1] + for i in range(start_edge, stop_edge): + edge = edge_index[i] + start_idx = edge_start_index[edge] + end_idx = edge_end_index[edge] + if start_idx >= end_idx: + continue + unresolved = _bitset_next(unresolved_bits, start_idx, num_sites) + if unresolved < end_idx: + stack_edges[stack_top] = edge + stack_start[stack_top] = start_idx + stack_end[stack_top] = end_idx + stack_top += 1 + + while stack_top > 0: + stack_top -= 1 + edge = stack_edges[stack_top] + interval_start = stack_start[stack_top] + interval_end = stack_end[stack_top] + ancestor = edges_parent[edge] + + mut_start = node_mut_offsets[ancestor] + mut_stop = node_mut_offsets[ancestor + 1] + for m in range(mut_start, mut_stop): + site_idx = node_mut_sites[m] + if interval_start <= site_idx < interval_end: + if _bitset_is_set(unresolved_bits, site_idx): + hap[site_idx] = node_mut_alleles[m] + _bitset_clear(unresolved_bits, site_idx) + + parent_count = 0 + start_edge, stop_edge = index_range[ancestor, 0], index_range[ancestor, 1] + for idx in range(start_edge, stop_edge): + parent_edge = edge_index[idx] + parent_start = edge_start_index[parent_edge] + parent_end = edge_end_index[parent_edge] + if parent_start < interval_start: + parent_start = interval_start + if parent_end > interval_end: + parent_end = interval_end + if parent_start >= parent_end: + continue + unresolved = _bitset_next(unresolved_bits, parent_start, num_sites) + if unresolved < parent_end: + # Push this parent edge because it still covers unresolved sites. + stack_edges[stack_top] = parent_edge + stack_start[stack_top] = parent_start + stack_end[stack_top] = parent_end + stack_top += 1 + parent_interval_start[parent_count] = parent_start + parent_interval_end[parent_count] = parent_end + parent_count += 1 + + idx = _bitset_next(unresolved_bits, interval_start, num_sites) + while idx < interval_end: + needs_parent = False + for j in range(parent_count): + if parent_interval_start[j] <= idx < parent_interval_end[j]: + needs_parent = True + break + if needs_parent: + # This site is still covered by a parent edge that hasn't been + # processed yet, so leave it for that ancestor. + idx = _bitset_next(unresolved_bits, idx + 1, num_sites) + else: + # No higher ancestor will supply a mutation, so the ancestral + # allele stands and we can mark the site resolved. + _bitset_clear(unresolved_bits, idx) + idx = _bitset_next(unresolved_bits, idx, num_sites) + + +def alignments(numba_ts): + num_sites = numba_ts.num_sites + sequence_length = numba_ts.sequence_length + if not float(sequence_length).is_integer(): + raise ValueError("This prototype requires discrete genomic coordinates") + sequence_length = int(sequence_length) + + ancestral_codes = np.empty(num_sites, dtype=np.uint8) + for site_id in range(num_sites): + allele = numba_ts.sites_ancestral_state[site_id] + if len(allele) != 1: + raise ValueError("Expected single-character ancestral alleles") + ancestral_codes[site_id] = ord(allele[0]) + + node_mut_sites, node_mut_alleles, node_mut_offsets = _build_node_mutation_index( + numba_ts + ) + next_site_index = _compute_next_site_index(numba_ts, sequence_length) + parent_index = numba_ts.parent_index() + edge_start_index = np.empty(numba_ts.num_edges, dtype=np.int32) + edge_end_index = np.empty(numba_ts.num_edges, dtype=np.int32) + seq_len = int(sequence_length) + for e in range(numba_ts.num_edges): + left = int(numba_ts.edges_left[e]) + if left < 0: + left = 0 + if left > seq_len: + left = seq_len + right = int(numba_ts.edges_right[e]) + if right < 0: + right = 0 + if right > seq_len: + right = seq_len + edge_start_index[e] = next_site_index[left] + edge_end_index[e] = next_site_index[right] + + hap = np.empty(num_sites, dtype=np.uint8) + bitset_size = (num_sites + 63) // 64 + unresolved_bits = np.empty(bitset_size, dtype=np.uint64) + stack_edges = np.empty(numba_ts.num_edges, dtype=np.int32) + stack_start = np.empty(numba_ts.num_edges, dtype=np.int32) + stack_end = np.empty(numba_ts.num_edges, dtype=np.int32) + parent_interval_start = np.empty(numba_ts.num_edges, dtype=np.int32) + parent_interval_end = np.empty(numba_ts.num_edges, dtype=np.int32) + + nodes_flags = numba_ts.nodes_flags + for node in range(numba_ts.num_nodes): + if nodes_flags[node] & NODE_IS_SAMPLE: + # Decode the haplotype for this sample node using the ancestor walk. + _node_haplotype( + numba_ts, + parent_index, + edge_start_index, + edge_end_index, + next_site_index, + node_mut_sites, + node_mut_alleles, + node_mut_offsets, + ancestral_codes, + node, + hap, + unresolved_bits, + stack_edges, + stack_start, + stack_end, + parent_interval_start, + parent_interval_end, + ) + yield int(node), hap.tobytes().decode("ascii") diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 177d9187aa..89d3436cb1 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -46,6 +46,7 @@ import tskit import tskit.combinatorics as combinatorics import tskit.drawing as drawing +import tskit.exceptions as exceptions import tskit.metadata as metadata_module import tskit.provenance as provenance import tskit.tables as tables @@ -5262,49 +5263,109 @@ def _haplotypes_array( # return an array of haplotypes and the first and last site positions if missing_data_character is None: missing_data_character = "N" + if len(missing_data_character) != 1: + raise ValueError("missing_data_character must be a single character") + try: + missing_data_character.encode("ascii") + except UnicodeEncodeError: + raise TypeError("missing_data_character must be ASCII") start_site, stop_site = np.searchsorted(self.sites_position, interval) - H = np.empty( - ( - self.num_samples if samples is None else len(samples), - stop_site - start_site, - ), - dtype=np.int8, + num_sites = stop_site - start_site + missing_int8 = ord(missing_data_character) + + # FIXME! The low-level code doesn't support isolated_as_missing + # yet so we do this ugly check here + want_missing = ( + True if isolated_as_missing is None else bool(isolated_as_missing) ) - missing_int8 = ord(missing_data_character.encode("ascii")) - for var in self.variants( - samples=samples, - isolated_as_missing=isolated_as_missing, - left=interval.left, - right=interval.right, - ): - alleles = np.full(len(var.alleles), missing_int8, dtype=np.int8) - for i, allele in enumerate(var.alleles): - if allele is not None: - if len(allele) != 1: - raise TypeError( - "Multi-letter allele or deletion detected at site {}".format( - var.site.id - ) - ) - try: - ascii_allele = allele.encode("ascii") - except UnicodeEncodeError: - raise TypeError( - "Non-ascii character in allele at site {}".format( - var.site.id - ) - ) - allele_int8 = ord(ascii_allele) - if allele_int8 == missing_int8: + + if want_missing and num_sites > 0: + ll_ts = self._ll_tree_sequence + anc_offsets = ll_ts.sites_ancestral_state_offset + anc_data = ll_ts.sites_ancestral_state + anc_slice = anc_offsets[start_site : stop_site + 1] + anc_lengths = np.diff(anc_slice) + if np.any(anc_lengths > 0): + anc_index = anc_slice[:-1][anc_lengths > 0] + if np.any(anc_data[anc_index] == missing_int8): + raise ValueError( + "missing_data_character must differ from existing allele states" + ) + mut_sites = ll_ts.mutations_site + if mut_sites.size > 0: + mut_offsets = ll_ts.mutations_derived_state_offset + mut_lengths = np.diff(mut_offsets) + mask = (mut_sites >= start_site) & (mut_sites < stop_site) + valid = mask & (mut_lengths > 0) + if np.any(valid): + mut_start = mut_offsets[:-1][valid] + derived_chars = ll_ts.mutations_derived_state[mut_start] + if np.any(derived_chars == missing_int8): raise ValueError( - "The missing data character '{}' clashes with an " - "existing allele at site {}".format( - missing_data_character, var.site.id - ) + "missing_data_character must differ from existing allele " + "states" ) - alleles[i] = allele_int8 - H[:, var.site.id - start_site] = alleles[var.genotypes] + + if samples is None: + sample_nodes = self.samples() + else: + sample_nodes = np.array(samples, dtype=np.int64) + num_samples = len(sample_nodes) + + if want_missing and samples is not None and num_samples > 0: + flags = self.nodes_flags[sample_nodes] + if np.any((flags & NODE_IS_SAMPLE) == 0): + raise exceptions.LibraryError( + "Cannot generate genotypes for non-samples when isolated nodes " + "are considered as missing. (TSK_ERR_MUST_IMPUTE_NON_SAMPLES)" + ) + + H = np.empty((num_samples, num_sites), dtype=np.int8) + if num_samples == 0 or num_sites == 0: + return H, (start_site, stop_site - 1) + + # FIXME! The low-level code doesn't support isolated_as_missing + # yet so we do this ugly thing of using the variants code to find + # sites with missing data + missing_mask = None + if want_missing: + for var in self.variants( + samples=samples, + isolated_as_missing=isolated_as_missing, + left=interval.left, + right=interval.right, + copy=False, + ): + if not var.has_missing_data: + continue + if missing_mask is None: + missing_mask = np.zeros((num_samples, num_sites), dtype=bool) + genotypes = np.asarray(var.genotypes, dtype=np.int32) + missing_mask[:, var.site.id - start_site] = ( + genotypes == tskit.MISSING_DATA + ) + + try: + hap = _tskit.Haplotype( + self._ll_tree_sequence, + int(start_site), + int(stop_site), + ) + except exceptions.LibraryError as err: + if "TSK_ERR_UNSUPPORTED_OPERATION" in str(err): + raise TypeError(str(err)) from err + if "TSK_ERR_BAD_PARAM_VALUE" in str(err): + raise ValueError(str(err)) from err + raise + + for row, node in enumerate(sample_nodes): + data = hap.decode(int(node)) + H[row, :] = np.frombuffer(data, dtype=np.int8, count=num_sites) + + if missing_mask is not None: + H[missing_mask] = missing_int8 + return H, (start_site, stop_site - 1) def haplotypes( @@ -5702,9 +5763,10 @@ def alignments( missing_data_character = ( "N" if missing_data_character is None else missing_data_character ) + if len(missing_data_character) != 1: + raise ValueError("missing_data_character must be length 1") L = interval.span - a = np.empty(L, dtype=np.int8) if reference_sequence is None: if self.has_reference_sequence(): # This may be inefficient - see #1989. However, since we're @@ -5728,7 +5790,7 @@ def alignments( "The reference sequence ends before the requested stop position" ) ref_bytes = reference_sequence.encode("ascii") - a[:] = np.frombuffer(ref_bytes, dtype=np.int8) + a = np.frombuffer(ref_bytes, dtype=np.int8).copy() # To do this properly we'll have to detect the missing data as # part of a full implementation of alignments in C. The current @@ -5746,17 +5808,32 @@ def alignments( "The current implementation may also incorrectly identify an " "input tree sequence has having missing data." ) + if samples is not None: + samples = np.array(samples, dtype=np.int64) + if samples.size > 0: + flags = self.nodes_flags[samples] + if np.any((flags & NODE_IS_SAMPLE) == 0): + raise exceptions.LibraryError( + "Cannot generate genotypes for non-samples when isolated nodes " + "are considered as missing. (TSK_ERR_MUST_IMPUTE_NON_SAMPLES)" + ) H, (first_site_id, last_site_id) = self._haplotypes_array( interval=interval, + isolated_as_missing=False, missing_data_character=missing_data_character, samples=samples, ) - site_pos = self.sites_position.astype(np.int64)[ - first_site_id : last_site_id + 1 - ] - for h in H: - a[site_pos - interval.left] = h - yield a.tobytes().decode("ascii") + if first_site_id <= last_site_id: + site_pos = self.sites_position.astype(np.int64)[ + first_site_id : last_site_id + 1 + ] + else: + site_pos = np.array([], dtype=np.int64) + for hap in H: + a_copy = a.copy() + if site_pos.size > 0: + a_copy[site_pos - interval.left] = hap + yield a_copy.tobytes().decode("ascii") @property def individuals_population(self):