diff --git a/performance/compare.py b/performance/compare.py index dd7c3f8..77da636 100644 --- a/performance/compare.py +++ b/performance/compare.py @@ -27,8 +27,8 @@ def run_vcztools(command: str, dataset_name: str): if __name__ == "__main__": commands = [ - ("view", "sim_10k"), - ("view", "chr22"), + ("view -H", "sim_10k"), + ("view -H", "chr22"), ("view -s tsk_7068,tsk_8769,tsk_8820", "sim_10k"), (r"query -f '%CHROM %POS %REF %ALT{0}\n'", "sim_10k"), (r"query -f '%CHROM:%POS\n' -i 'POS=49887394 | POS=50816415'", "sim_10k"), diff --git a/pyproject.toml b/pyproject.toml index 9264da1..8c70e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dev = [ "cyvcf2", "obstore", "pytest", + "pytest-asyncio", "pytest-cov", "msprime", "setuptools", diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index da79882..991c885 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -4,7 +4,12 @@ import pytest import zarr -from vcztools.retrieval import variant_chunk_iter, variant_iter +from vcztools.retrieval import ( + AsyncVariantChunkReader, + async_variant_chunk_iter, + variant_chunk_iter, + variant_iter, +) from vcztools.samples import parse_samples from .utils import vcz_path_cache @@ -35,6 +40,31 @@ def test_variant_chunk_iter(): nt.assert_array_equal(chunk_data["call_mask"], [[True, False], [False, False]]) +@pytest.mark.asyncio() +async def test_variant_chunk_iter_async(): + from zarr.api.asynchronous import open_group + + original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz" + vcz = vcz_path_cache(original) + root = await open_group(vcz, mode="r") + + sample_id_arr = await root.getitem("sample_id") + sample_id = await sample_id_arr.getitem(slice(None)) + _, samples_selection = parse_samples("NA00002,NA00003", sample_id) + + chunk_reader = await AsyncVariantChunkReader.create(root) + + chunk_data = await chunk_reader.get_chunk_data( + 0, samples_selection=samples_selection + ) + print(chunk_data["call_DP"]) + + async for chunk_data in async_variant_chunk_iter( + root, samples_selection=samples_selection + ): + print(chunk_data["call_DP"]) + + def test_variant_chunk_iter_empty_fields(): original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz" vcz = vcz_path_cache(original) diff --git a/vcztools/query.py b/vcztools/query.py index 7cac33d..979c1fc 100644 --- a/vcztools/query.py +++ b/vcztools/query.py @@ -1,3 +1,4 @@ +import asyncio import functools import itertools import math @@ -5,10 +6,16 @@ import numpy as np import pyparsing as pp +from zarr.api.asynchronous import open_group -from vcztools import constants, retrieval +from vcztools import constants, retrieval_async from vcztools.samples import parse_samples -from vcztools.utils import missing, open_zarr, vcf_name_to_vcz_names +from vcztools.utils import ( + async_get_zarr_array, + missing, + open_zarr, + vcf_name_to_vcz_names, +) def list_samples(vcz_path, output, zarr_backend_storage=None): @@ -307,21 +314,52 @@ def write_query( disable_automatic_newline: bool = False, zarr_backend_storage: str | None = None, ): - root = open_zarr(vcz, mode="r", zarr_backend_storage=zarr_backend_storage) + asyncio.run( + async_write_query( + vcz, + output, + query_format=query_format, + regions=regions, + targets=targets, + samples=samples, + force_samples=force_samples, + include=include, + exclude=exclude, + disable_automatic_newline=disable_automatic_newline, + zarr_backend_storage=zarr_backend_storage, + ) + ) + + +async def async_write_query( + vcz, + output, + *, + query_format: str, + regions=None, + targets=None, + samples=None, + force_samples: bool = False, + include: str | None = None, + exclude: str | None = None, + disable_automatic_newline: bool = False, + zarr_backend_storage: str | None = None, +): + root = await open_group(vcz, mode="r") - all_samples = root["sample_id"][:] + all_samples = await async_get_zarr_array(root, "sample_id") sample_ids, samples_selection = parse_samples( samples, all_samples, force_samples=force_samples ) - contigs = root["contig_id"][:] - filters = root["filter_id"][:] + contigs = await async_get_zarr_array(root, "contig_id") + filters = await async_get_zarr_array(root, "filter_id") if "\\n" not in query_format and not disable_automatic_newline: query_format = query_format + "\\n" generator = QueryFormatGenerator(query_format, sample_ids, contigs, filters) - for chunk_data in retrieval.variant_chunk_iter( + async for chunk_data in retrieval_async.async_variant_chunk_iter( root, regions=regions, targets=targets, diff --git a/vcztools/retrieval_async.py b/vcztools/retrieval_async.py new file mode 100644 index 0000000..8ff24cb --- /dev/null +++ b/vcztools/retrieval_async.py @@ -0,0 +1,223 @@ +import asyncio + +import numpy as np + +from vcztools import filter as filter_mod +from vcztools.regions import ( + parse_regions, + parse_targets, + regions_to_chunk_indexes, + regions_to_selection, +) +from vcztools.utils import ( + _as_fixed_length_unicode, + async_get_zarr_array, + get_block_selection, +) + + +# NOTE: this class is just a skeleton for now. The idea is that this +# will provide readahead, caching etc, and will be the central location +# for fetching bulk Zarr data. +class AsyncVariantChunkReader: + def __init__(self, root, arrays, num_chunks): + self.root = root + self.arrays = arrays + self.num_chunks = num_chunks + + @classmethod + async def create(cls, root, *, fields=None): + if fields is None: + fields = [ + key + async for key in root.keys() + if key.startswith("variant_") or key.startswith("call_") + ] + + arrays = [root.getitem(key) for key in fields] + arrays = dict(zip(fields, await asyncio.gather(*arrays))) + + num_chunks = next(iter(arrays.values())).cdata_shape[0] + + return cls(root, arrays, num_chunks) + + async def getitem(self, chunk): + chunk_data = [ + get_block_selection(array, chunk) for array in self.arrays.values() + ] + chunk_data = await asyncio.gather(*chunk_data) + return dict(zip(self.arrays.keys(), chunk_data)) + + async def get_chunk_data(self, chunk, mask=None, samples_selection=None): + def get_vchunk_array(zarray, v_chunk, samples_selection=None): + v_chunksize = zarray.chunks[0] + start = v_chunksize * v_chunk + end = v_chunksize * (v_chunk + 1) + if samples_selection is None: + result = zarray.getitem(slice(start, end)) + else: + result = zarray.get_orthogonal_selection( + (slice(start, end), samples_selection) + ) + return result + + num_samples = len(samples_selection) if samples_selection is not None else 0 + chunk_data = [ + get_vchunk_array( + array, + chunk, + samples_selection=samples_selection + if (key.startswith("call_") and num_samples > 0) + else None, + ) + for key, array in self.arrays.items() + ] + chunk_data = await asyncio.gather(*chunk_data) + if mask is not None: + chunk_data = [arr[mask] for arr in chunk_data] + return dict(zip(self.arrays.keys(), chunk_data)) + + +async def async_variant_chunk_index_iter(root, regions=None, targets=None): + pos = await root.getitem("variant_position") + + if regions is None and targets is None: + num_chunks = pos.cdata_shape[0] + # no regions or targets selected + for v_chunk in range(num_chunks): + v_mask_chunk = None + yield v_chunk, v_mask_chunk + else: + contigs = await async_get_zarr_array(root, "contig_id") + contigs_u = _as_fixed_length_unicode(contigs).tolist() + regions_pyranges = parse_regions(regions, contigs_u) + targets_pyranges, complement = parse_targets(targets, contigs_u) + + # Use the region index to find the chunks that overlap specfied regions or + # targets + region_index = await async_get_zarr_array(root, "region_index") + chunk_indexes = regions_to_chunk_indexes( + regions_pyranges, + targets_pyranges, + complement, + region_index, + ) + + if len(chunk_indexes) == 0: + # no chunks - no variants to write + return + + # Then only load required variant_contig/position chunks + region_variant_contig_arr = await root.getitem("variant_contig") + region_variant_position_arr = await root.getitem("variant_position") + region_variant_length_arr = await root.getitem("variant_length") + for chunk_index in chunk_indexes: + # TODO: get all three concurrently + region_variant_contig = await get_block_selection( + region_variant_contig_arr, chunk_index + ) + region_variant_position = await get_block_selection( + region_variant_position_arr, chunk_index + ) + region_variant_length = await get_block_selection( + region_variant_length_arr, chunk_index + ) + + # Find the variant selection for the chunk + variant_selection = regions_to_selection( + regions_pyranges, + targets_pyranges, + complement, + region_variant_contig, + region_variant_position, + region_variant_length, + ) + variant_mask = np.zeros(region_variant_position.shape[0], dtype=bool) + variant_mask[variant_selection] = 1 + + yield chunk_index, variant_mask + + +async def async_variant_chunk_index_iter_with_filtering( + root, + *, + regions=None, + targets=None, + include: str | None = None, + exclude: str | None = None, +): + """Iterate over variant chunk indexes that overlap the given regions or targets + and which match the include/exclude filter expression. + + Returns tuples of variant chunk indexes and (optional) variant masks. + + A variant mask of None indicates that all the variants in the chunk are included. + """ + + field_names = set([key async for key in root.keys()]) + filter_expr = filter_mod.FilterExpression( + field_names=field_names, include=include, exclude=exclude + ) + if filter_expr.parse_result is None: + filter_expr = None + else: + filter_fields = list(filter_expr.referenced_fields) + filter_fields_reader = await AsyncVariantChunkReader.create( + root, fields=filter_fields + ) + + async for v_chunk, v_mask_chunk in async_variant_chunk_index_iter( + root, regions, targets + ): + if filter_expr is not None: + chunk_data = await filter_fields_reader.getitem(v_chunk) + v_mask_chunk_filter = filter_expr.evaluate(chunk_data) + if v_mask_chunk is None: + v_mask_chunk = v_mask_chunk_filter + else: + if v_mask_chunk_filter.ndim == 2: + v_mask_chunk = np.expand_dims(v_mask_chunk, axis=1) + v_mask_chunk = np.logical_and(v_mask_chunk, v_mask_chunk_filter) + if v_mask_chunk is None or np.any(v_mask_chunk): + yield v_chunk, v_mask_chunk + + +async def async_variant_chunk_iter( + root, + *, + fields: list[str] | None = None, + regions=None, + targets=None, + include: str | None = None, + exclude: str | None = None, + samples_selection=None, +): + if fields is not None and len(fields) == 0: + return # empty iterator + query_fields_reader = await AsyncVariantChunkReader.create(root, fields=fields) + + async for v_chunk, v_mask_chunk in async_variant_chunk_index_iter_with_filtering( + root, + regions=regions, + targets=targets, + include=include, + exclude=exclude, + ): + # The variants_selection is used to subset variant chunks along + # the variants dimension. + # The call_mask is returned to the client to indicate which samples + # matched (for each variant) in the case of per-sample filtering. + if v_mask_chunk is None or v_mask_chunk.ndim == 1: + variants_selection = v_mask_chunk + call_mask = None + else: + variants_selection = np.any(v_mask_chunk, axis=1) + call_mask = v_mask_chunk[variants_selection] + if samples_selection is not None: + call_mask = call_mask[:, samples_selection] + chunk_data = await query_fields_reader.get_chunk_data( + v_chunk, variants_selection, samples_selection=samples_selection + ) + if call_mask is not None: + chunk_data["call_mask"] = call_mask + yield chunk_data diff --git a/vcztools/utils.py b/vcztools/utils.py index 7f959e3..382e1a2 100644 --- a/vcztools/utils.py +++ b/vcztools/utils.py @@ -117,6 +117,33 @@ def vcf_name_to_vcz_names(vcz_names: set[str], vcf_name: str) -> list[str]: return matches +async def async_get_zarr_array(group, key): + arr = await group.getitem(key) + return await arr.getitem(slice(None)) + + +# TODO: contribute upstream +async def get_block_selection( + async_array, + selection, + *, + out=None, + fields=None, + prototype=None, +): + from zarr.core.buffer import default_buffer_prototype + from zarr.core.indexing import BlockIndexer + + if prototype is None: + prototype = default_buffer_prototype() + indexer = BlockIndexer( + selection, async_array.shape, async_array.metadata.chunk_grid + ) + return await async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + + # See https://numpy.org/devdocs/user/basics.strings.html#casting-to-and-from-fixed-width-strings diff --git a/vcztools/vcf_writer.py b/vcztools/vcf_writer.py index 8ced633..7841a29 100644 --- a/vcztools/vcf_writer.py +++ b/vcztools/vcf_writer.py @@ -1,19 +1,21 @@ +import asyncio import io import logging import sys from datetime import datetime import numpy as np +from zarr.api.asynchronous import open_group from vcztools.samples import parse_samples from vcztools.utils import ( _as_fixed_length_string, _as_fixed_length_unicode, + async_get_zarr_array, open_file_like, - open_zarr, ) -from . import _vcztools, constants, retrieval +from . import _vcztools, constants, retrieval_async from . import filter as filter_mod from .constants import FLOAT32_MISSING, RESERVED_VARIABLE_NAMES @@ -91,7 +93,44 @@ def write_vcf( exclude: str | None = None, zarr_backend_storage: str | None = None, ) -> None: - root = open_zarr(vcz, mode="r", zarr_backend_storage=zarr_backend_storage) + asyncio.run( + async_write_vcf( + vcz, + output, + header_only=header_only, + no_header=no_header, + no_version=no_version, + regions=regions, + targets=targets, + no_update=no_update, + samples=samples, + force_samples=force_samples, + drop_genotypes=drop_genotypes, + include=include, + exclude=exclude, + zarr_backend_storage=zarr_backend_storage, + ) + ) + + +async def async_write_vcf( + vcz, + output, + *, + header_only: bool = False, + no_header: bool = False, + no_version: bool = False, + regions=None, + targets=None, + no_update=None, + samples=None, + force_samples: bool = False, + drop_genotypes: bool = False, + include: str | None = None, + exclude: str | None = None, + zarr_backend_storage: str | None = None, +) -> None: + root = await open_group(vcz, mode="r") with open_file_like(output) as output: if samples and drop_genotypes: @@ -100,14 +139,15 @@ def write_vcf( sample_ids = [] samples_selection = np.array([]) else: - all_samples = root["sample_id"][:] + all_samples = await async_get_zarr_array(root, "sample_id") sample_ids, samples_selection = parse_samples( samples, all_samples, force_samples=force_samples ) # Need to try parsing filter expressions before writing header + field_names = set([key async for key in root.keys()]) filter_mod.FilterExpression( - field_names=set(root), include=include, exclude=exclude + field_names=field_names, include=include, exclude=exclude ) if not no_header: @@ -123,10 +163,10 @@ def write_vcf( if header_only: return - contigs = _as_fixed_length_string(root["contig_id"][:]) - filters = get_filter_ids(root) + contigs = _as_fixed_length_string(await async_get_zarr_array(root, "contig_id")) + filters = _as_fixed_length_string(await async_get_zarr_array(root, "filter_id")) - for chunk_data in retrieval.variant_chunk_iter( + async for chunk_data in retrieval_async.async_variant_chunk_iter( root, regions=regions, targets=targets,