Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions performance/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def run_vcztools(command: str, dataset_name: str):

if __name__ == "__main__":
commands = [
("view", "sim_10k"),
("view", "chr22"),
("view -H", "sim_10k"),
("view -H", "chr22"),
("view -s tsk_7068,tsk_8769,tsk_8820", "sim_10k"),
(r"query -f '%CHROM %POS %REF %ALT{0}\n'", "sim_10k"),
(r"query -f '%CHROM:%POS\n' -i 'POS=49887394 | POS=50816415'", "sim_10k"),
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dev = [
"cyvcf2",
"obstore",
"pytest",
"pytest-asyncio",
"pytest-cov",
"msprime",
"setuptools",
Expand Down
32 changes: 31 additions & 1 deletion tests/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import pytest
import zarr

from vcztools.retrieval import variant_chunk_iter, variant_iter
from vcztools.retrieval import (
AsyncVariantChunkReader,
async_variant_chunk_iter,
variant_chunk_iter,
variant_iter,
)
from vcztools.samples import parse_samples

from .utils import vcz_path_cache
Expand Down Expand Up @@ -35,6 +40,31 @@ def test_variant_chunk_iter():
nt.assert_array_equal(chunk_data["call_mask"], [[True, False], [False, False]])


@pytest.mark.asyncio()
async def test_variant_chunk_iter_async():
from zarr.api.asynchronous import open_group

original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz"
vcz = vcz_path_cache(original)
root = await open_group(vcz, mode="r")

sample_id_arr = await root.getitem("sample_id")
sample_id = await sample_id_arr.getitem(slice(None))
_, samples_selection = parse_samples("NA00002,NA00003", sample_id)

chunk_reader = await AsyncVariantChunkReader.create(root)

chunk_data = await chunk_reader.get_chunk_data(
0, samples_selection=samples_selection
)
print(chunk_data["call_DP"])

async for chunk_data in async_variant_chunk_iter(
root, samples_selection=samples_selection
):
print(chunk_data["call_DP"])


def test_variant_chunk_iter_empty_fields():
original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz"
vcz = vcz_path_cache(original)
Expand Down
52 changes: 45 additions & 7 deletions vcztools/query.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import asyncio
import functools
import itertools
import math
from collections.abc import Callable

import numpy as np
import pyparsing as pp
from zarr.api.asynchronous import open_group

from vcztools import constants, retrieval
from vcztools import constants, retrieval_async
from vcztools.samples import parse_samples
from vcztools.utils import missing, open_zarr, vcf_name_to_vcz_names
from vcztools.utils import (
async_get_zarr_array,
missing,
open_zarr,
vcf_name_to_vcz_names,
)


def list_samples(vcz_path, output, zarr_backend_storage=None):
Expand Down Expand Up @@ -307,21 +314,52 @@ def write_query(
disable_automatic_newline: bool = False,
zarr_backend_storage: str | None = None,
):
root = open_zarr(vcz, mode="r", zarr_backend_storage=zarr_backend_storage)
asyncio.run(
async_write_query(
vcz,
output,
query_format=query_format,
regions=regions,
targets=targets,
samples=samples,
force_samples=force_samples,
include=include,
exclude=exclude,
disable_automatic_newline=disable_automatic_newline,
zarr_backend_storage=zarr_backend_storage,
)
)


async def async_write_query(
vcz,
output,
*,
query_format: str,
regions=None,
targets=None,
samples=None,
force_samples: bool = False,
include: str | None = None,
exclude: str | None = None,
disable_automatic_newline: bool = False,
zarr_backend_storage: str | None = None,
):
root = await open_group(vcz, mode="r")

all_samples = root["sample_id"][:]
all_samples = await async_get_zarr_array(root, "sample_id")
sample_ids, samples_selection = parse_samples(
samples, all_samples, force_samples=force_samples
)
contigs = root["contig_id"][:]
filters = root["filter_id"][:]
contigs = await async_get_zarr_array(root, "contig_id")
filters = await async_get_zarr_array(root, "filter_id")

if "\\n" not in query_format and not disable_automatic_newline:
query_format = query_format + "\\n"

generator = QueryFormatGenerator(query_format, sample_ids, contigs, filters)

for chunk_data in retrieval.variant_chunk_iter(
async for chunk_data in retrieval_async.async_variant_chunk_iter(
root,
regions=regions,
targets=targets,
Expand Down
223 changes: 223 additions & 0 deletions vcztools/retrieval_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import asyncio

import numpy as np

from vcztools import filter as filter_mod
from vcztools.regions import (
parse_regions,
parse_targets,
regions_to_chunk_indexes,
regions_to_selection,
)
from vcztools.utils import (
_as_fixed_length_unicode,
async_get_zarr_array,
get_block_selection,
)


# NOTE: this class is just a skeleton for now. The idea is that this
# will provide readahead, caching etc, and will be the central location
# for fetching bulk Zarr data.
class AsyncVariantChunkReader:
def __init__(self, root, arrays, num_chunks):
self.root = root
self.arrays = arrays
self.num_chunks = num_chunks

@classmethod
async def create(cls, root, *, fields=None):
if fields is None:
fields = [
key
async for key in root.keys()
if key.startswith("variant_") or key.startswith("call_")
]

arrays = [root.getitem(key) for key in fields]
arrays = dict(zip(fields, await asyncio.gather(*arrays)))

num_chunks = next(iter(arrays.values())).cdata_shape[0]

return cls(root, arrays, num_chunks)

async def getitem(self, chunk):
chunk_data = [
get_block_selection(array, chunk) for array in self.arrays.values()
]
chunk_data = await asyncio.gather(*chunk_data)
return dict(zip(self.arrays.keys(), chunk_data))

async def get_chunk_data(self, chunk, mask=None, samples_selection=None):
def get_vchunk_array(zarray, v_chunk, samples_selection=None):
v_chunksize = zarray.chunks[0]
start = v_chunksize * v_chunk
end = v_chunksize * (v_chunk + 1)
if samples_selection is None:
result = zarray.getitem(slice(start, end))
else:
result = zarray.get_orthogonal_selection(
(slice(start, end), samples_selection)
)
return result

num_samples = len(samples_selection) if samples_selection is not None else 0
chunk_data = [
get_vchunk_array(
array,
chunk,
samples_selection=samples_selection
if (key.startswith("call_") and num_samples > 0)
else None,
)
for key, array in self.arrays.items()
]
chunk_data = await asyncio.gather(*chunk_data)
if mask is not None:
chunk_data = [arr[mask] for arr in chunk_data]
return dict(zip(self.arrays.keys(), chunk_data))


async def async_variant_chunk_index_iter(root, regions=None, targets=None):
pos = await root.getitem("variant_position")

if regions is None and targets is None:
num_chunks = pos.cdata_shape[0]
# no regions or targets selected
for v_chunk in range(num_chunks):
v_mask_chunk = None
yield v_chunk, v_mask_chunk
else:
contigs = await async_get_zarr_array(root, "contig_id")
contigs_u = _as_fixed_length_unicode(contigs).tolist()
regions_pyranges = parse_regions(regions, contigs_u)
targets_pyranges, complement = parse_targets(targets, contigs_u)

# Use the region index to find the chunks that overlap specfied regions or
# targets
region_index = await async_get_zarr_array(root, "region_index")
chunk_indexes = regions_to_chunk_indexes(
regions_pyranges,
targets_pyranges,
complement,
region_index,
)

if len(chunk_indexes) == 0:
# no chunks - no variants to write
return

# Then only load required variant_contig/position chunks
region_variant_contig_arr = await root.getitem("variant_contig")
region_variant_position_arr = await root.getitem("variant_position")
region_variant_length_arr = await root.getitem("variant_length")
for chunk_index in chunk_indexes:
# TODO: get all three concurrently
region_variant_contig = await get_block_selection(
region_variant_contig_arr, chunk_index
)
region_variant_position = await get_block_selection(
region_variant_position_arr, chunk_index
)
region_variant_length = await get_block_selection(
region_variant_length_arr, chunk_index
)

# Find the variant selection for the chunk
variant_selection = regions_to_selection(
regions_pyranges,
targets_pyranges,
complement,
region_variant_contig,
region_variant_position,
region_variant_length,
)
variant_mask = np.zeros(region_variant_position.shape[0], dtype=bool)
variant_mask[variant_selection] = 1

yield chunk_index, variant_mask


async def async_variant_chunk_index_iter_with_filtering(
root,
*,
regions=None,
targets=None,
include: str | None = None,
exclude: str | None = None,
):
"""Iterate over variant chunk indexes that overlap the given regions or targets
and which match the include/exclude filter expression.

Returns tuples of variant chunk indexes and (optional) variant masks.

A variant mask of None indicates that all the variants in the chunk are included.
"""

field_names = set([key async for key in root.keys()])
filter_expr = filter_mod.FilterExpression(
field_names=field_names, include=include, exclude=exclude
)
if filter_expr.parse_result is None:
filter_expr = None
else:
filter_fields = list(filter_expr.referenced_fields)
filter_fields_reader = await AsyncVariantChunkReader.create(
root, fields=filter_fields
)

async for v_chunk, v_mask_chunk in async_variant_chunk_index_iter(
root, regions, targets
):
if filter_expr is not None:
chunk_data = await filter_fields_reader.getitem(v_chunk)
v_mask_chunk_filter = filter_expr.evaluate(chunk_data)
if v_mask_chunk is None:
v_mask_chunk = v_mask_chunk_filter
else:
if v_mask_chunk_filter.ndim == 2:
v_mask_chunk = np.expand_dims(v_mask_chunk, axis=1)
v_mask_chunk = np.logical_and(v_mask_chunk, v_mask_chunk_filter)
if v_mask_chunk is None or np.any(v_mask_chunk):
yield v_chunk, v_mask_chunk


async def async_variant_chunk_iter(
root,
*,
fields: list[str] | None = None,
regions=None,
targets=None,
include: str | None = None,
exclude: str | None = None,
samples_selection=None,
):
if fields is not None and len(fields) == 0:
return # empty iterator
query_fields_reader = await AsyncVariantChunkReader.create(root, fields=fields)

async for v_chunk, v_mask_chunk in async_variant_chunk_index_iter_with_filtering(
root,
regions=regions,
targets=targets,
include=include,
exclude=exclude,
):
# The variants_selection is used to subset variant chunks along
# the variants dimension.
# The call_mask is returned to the client to indicate which samples
# matched (for each variant) in the case of per-sample filtering.
if v_mask_chunk is None or v_mask_chunk.ndim == 1:
variants_selection = v_mask_chunk
call_mask = None
else:
variants_selection = np.any(v_mask_chunk, axis=1)
call_mask = v_mask_chunk[variants_selection]
if samples_selection is not None:
call_mask = call_mask[:, samples_selection]
chunk_data = await query_fields_reader.get_chunk_data(
v_chunk, variants_selection, samples_selection=samples_selection
)
if call_mask is not None:
chunk_data["call_mask"] = call_mask
yield chunk_data
Loading