Skip to content

Commit 6286fd6

Browse files
committed
Add tskit CLI
1 parent 6680178 commit 6286fd6

File tree

7 files changed

+228
-4
lines changed

7 files changed

+228
-4
lines changed

bio2zarr/__main__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def bio2zarr():
1717
bio2zarr.add_command(cli.vcf2zarr_main)
1818
bio2zarr.add_command(cli.plink2zarr)
1919
bio2zarr.add_command(cli.vcfpartition)
20+
bio2zarr.add_command(cli.tskit2zarr)
2021

2122
if __name__ == "__main__":
2223
bio2zarr()

bio2zarr/cli.py

+50
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tabulate
1010

1111
from . import plink, provenance, vcf_utils
12+
from . import tskit as tskit_mod
1213
from . import vcf as vcf_mod
1314

1415
logger = logging.getLogger(__name__)
@@ -630,3 +631,52 @@ def vcfpartition(vcfs, verbose, num_partitions, partition_size):
630631
)
631632
for region in regions:
632633
click.echo(f"{region}\t{vcf_path}")
634+
635+
636+
@click.command(name="convert")
637+
@click.argument("ts_path", type=click.Path(exists=True))
638+
@click.argument("zarr_path", type=click.Path())
639+
@click.option("--contig-id", type=str, help="Contig/chromosome ID (default: '1')")
640+
@click.option(
641+
"--isolated-as-missing", is_flag=True, help="Treat isolated nodes as missing"
642+
)
643+
@variants_chunk_size
644+
@samples_chunk_size
645+
@verbose
646+
@progress
647+
@worker_processes
648+
@force
649+
def convert_tskit(
650+
ts_path,
651+
zarr_path,
652+
contig_id,
653+
isolated_as_missing,
654+
variants_chunk_size,
655+
samples_chunk_size,
656+
verbose,
657+
progress,
658+
worker_processes,
659+
force,
660+
):
661+
setup_logging(verbose)
662+
check_overwrite_dir(zarr_path, force)
663+
664+
tskit_mod.convert(
665+
ts_path,
666+
zarr_path,
667+
contig_id=contig_id,
668+
isolated_as_missing=isolated_as_missing,
669+
variants_chunk_size=variants_chunk_size,
670+
samples_chunk_size=samples_chunk_size,
671+
worker_processes=worker_processes,
672+
show_progress=progress,
673+
)
674+
675+
676+
@version
677+
@click.group()
678+
def tskit2zarr():
679+
pass
680+
681+
682+
tskit2zarr.add_command(convert_tskit)

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ dependencies = [
2525
# colouredlogs pulls in humanfriendly",
2626
"cyvcf2",
2727
"bed_reader",
28+
# TODO Using dev version of tskit for CI, FIXME before release
29+
"tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python",
2830
]
2931
requires-python = ">=3.10"
3032
classifiers = [
@@ -51,6 +53,7 @@ documentation = "https://sgkit-dev.github.io/bio2zarr/"
5153
[project.scripts]
5254
vcf2zarr = "bio2zarr.cli:vcf2zarr_main"
5355
vcfpartition = "bio2zarr.cli:vcfpartition"
56+
tskit2zarr = "bio2zarr.cli:tskit2zarr_main"
5457

5558
[project.optional-dependencies]
5659
dev = [

tests/data/ts/example.trees

9.92 KB
Binary file not shown.

tests/test_cli.py

+146-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@
6161
local_alleles=False,
6262
)
6363

64+
DEFAULT_TSKIT_CONVERT_ARGS = dict(
65+
contig_id=None,
66+
isolated_as_missing=False,
67+
variants_chunk_size=None,
68+
samples_chunk_size=None,
69+
show_progress=True,
70+
worker_processes=1,
71+
)
72+
6473
DEFAULT_PLINK_CONVERT_ARGS = dict(
6574
variants_chunk_size=None,
6675
samples_chunk_size=None,
@@ -635,6 +644,116 @@ def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response
635644
(self.vcf_path,), str(zarr_path), **DEFAULT_CONVERT_ARGS
636645
)
637646

647+
@pytest.mark.parametrize(("progress", "flag"), [(True, "-P"), (False, "-Q")])
648+
@mock.patch("bio2zarr.tskit.convert")
649+
def test_convert_tskit(self, mocked, tmp_path, progress, flag):
650+
ts_path = "tests/data/ts/example.trees"
651+
zarr_path = tmp_path / "zarr"
652+
runner = ct.CliRunner()
653+
result = runner.invoke(
654+
cli.tskit2zarr,
655+
f"convert {ts_path} {zarr_path} {flag}",
656+
catch_exceptions=False,
657+
)
658+
assert result.exit_code == 0
659+
assert len(result.stdout) == 0
660+
assert len(result.stderr) == 0
661+
args = dict(DEFAULT_TSKIT_CONVERT_ARGS)
662+
args["show_progress"] = progress
663+
mocked.assert_called_once_with(
664+
ts_path,
665+
str(zarr_path),
666+
**args,
667+
)
668+
669+
@pytest.mark.parametrize("response", ["y", "Y", "yes"])
670+
@mock.patch("bio2zarr.tskit.convert")
671+
def test_tskit_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response):
672+
ts_path = "tests/data/ts/example.trees"
673+
zarr_path = tmp_path / "zarr"
674+
zarr_path.mkdir()
675+
runner = ct.CliRunner()
676+
result = runner.invoke(
677+
cli.tskit2zarr,
678+
f"convert {ts_path} {zarr_path}",
679+
catch_exceptions=False,
680+
input=response,
681+
)
682+
assert result.exit_code == 0
683+
assert f"Do you want to overwrite {zarr_path}" in result.stdout
684+
assert len(result.stderr) == 0
685+
mocked.assert_called_once_with(
686+
ts_path,
687+
str(zarr_path),
688+
**DEFAULT_TSKIT_CONVERT_ARGS,
689+
)
690+
691+
@pytest.mark.parametrize("response", ["n", "N", "No"])
692+
@mock.patch("bio2zarr.tskit.convert")
693+
def test_tskit_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response):
694+
ts_path = "tests/data/ts/example.trees"
695+
zarr_path = tmp_path / "zarr"
696+
zarr_path.mkdir()
697+
runner = ct.CliRunner()
698+
result = runner.invoke(
699+
cli.tskit2zarr,
700+
f"convert {ts_path} {zarr_path}",
701+
catch_exceptions=False,
702+
input=response,
703+
)
704+
assert result.exit_code == 1
705+
assert "Aborted" in result.stderr
706+
mocked.assert_not_called()
707+
708+
@pytest.mark.parametrize("force_arg", ["-f", "--force"])
709+
@mock.patch("bio2zarr.tskit.convert")
710+
def test_tskit_convert_overwrite_zarr_force(self, mocked, tmp_path, force_arg):
711+
ts_path = "tests/data/ts/example.trees"
712+
zarr_path = tmp_path / "zarr"
713+
zarr_path.mkdir()
714+
runner = ct.CliRunner()
715+
result = runner.invoke(
716+
cli.tskit2zarr,
717+
f"convert {ts_path} {zarr_path} {force_arg}",
718+
catch_exceptions=False,
719+
)
720+
assert result.exit_code == 0
721+
assert len(result.stdout) == 0
722+
assert len(result.stderr) == 0
723+
mocked.assert_called_once_with(
724+
ts_path,
725+
str(zarr_path),
726+
**DEFAULT_TSKIT_CONVERT_ARGS,
727+
)
728+
729+
@mock.patch("bio2zarr.tskit.convert")
730+
def test_tskit_convert_with_options(self, mocked, tmp_path):
731+
ts_path = "tests/data/ts/example.trees"
732+
zarr_path = tmp_path / "zarr"
733+
runner = ct.CliRunner()
734+
result = runner.invoke(
735+
cli.tskit2zarr,
736+
f"convert {ts_path} {zarr_path} --contig-id chr1 "
737+
"--isolated-as-missing -l 100 -w 50 -p 4",
738+
catch_exceptions=False,
739+
)
740+
assert result.exit_code == 0
741+
assert len(result.stdout) == 0
742+
assert len(result.stderr) == 0
743+
744+
expected_args = dict(DEFAULT_TSKIT_CONVERT_ARGS)
745+
expected_args["contig_id"] = "chr1"
746+
expected_args["isolated_as_missing"] = True
747+
expected_args["variants_chunk_size"] = 100
748+
expected_args["samples_chunk_size"] = 50
749+
expected_args["worker_processes"] = 4
750+
751+
mocked.assert_called_once_with(
752+
ts_path,
753+
str(zarr_path),
754+
**expected_args,
755+
)
756+
638757

639758
class TestVcfEndToEnd:
640759
vcf_path = "tests/data/vcf/sample.vcf.gz"
@@ -908,10 +1027,36 @@ def test_part_size_multiple_vcfs(self):
9081027

9091028

9101029
@pytest.mark.parametrize(
911-
"cmd", [main.bio2zarr, cli.vcf2zarr_main, cli.plink2zarr, cli.vcfpartition]
1030+
"cmd",
1031+
[
1032+
main.bio2zarr,
1033+
cli.vcf2zarr_main,
1034+
cli.plink2zarr,
1035+
cli.vcfpartition,
1036+
cli.tskit2zarr,
1037+
],
9121038
)
9131039
def test_version(cmd):
9141040
runner = ct.CliRunner()
9151041
result = runner.invoke(cmd, ["--version"], catch_exceptions=False)
9161042
s = f"version {provenance.__version__}\n"
9171043
assert result.stdout.endswith(s)
1044+
1045+
1046+
class TestTskitEndToEnd:
1047+
def test_convert(self, tmp_path):
1048+
ts_path = "tests/data/ts/example.trees"
1049+
zarr_path = tmp_path / "zarr"
1050+
runner = ct.CliRunner()
1051+
result = runner.invoke(
1052+
cli.tskit2zarr,
1053+
f"convert {ts_path} {zarr_path}",
1054+
catch_exceptions=False,
1055+
)
1056+
assert result.exit_code == 0
1057+
result = runner.invoke(
1058+
cli.vcf2zarr_main, f"inspect {zarr_path}", catch_exceptions=False
1059+
)
1060+
assert result.exit_code == 0
1061+
# Arbitrary check
1062+
assert "variant_position" in result.stdout

tests/test_core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_examples(self, chunk_size, size, start, stop):
237237
# It works in CI on Linux, but it'll probably break at some point.
238238
# It's also necessary to update these numbers each time a new data
239239
# file gets added
240-
("tests/data", 5030777),
240+
("tests/data", 5045029),
241241
("tests/data/vcf", 5018640),
242242
("tests/data/vcf/sample.vcf.gz", 1089),
243243
],

tests/test_ts.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -40,29 +40,54 @@ def test_simple_tree_sequence(self, tmp_path):
4040
tmp_path / "test.trees", zarr_path, ind_nodes, show_progress=False
4141
)
4242
zroot = zarr.open(zarr_path, mode="r")
43-
assert zroot["variant_position"].shape == (3,)
44-
assert list(zroot["variant_position"][:]) == [10, 20, 30]
43+
pos = zroot["variant_position"][:]
44+
assert pos.shape == (3,)
45+
assert pos.dtype == np.int8
46+
assert np.array_equal(pos, [10, 20, 30])
4547

4648
alleles = zroot["variant_allele"][:]
49+
assert alleles.shape == (3, 2)
50+
assert alleles.dtype == "O"
4751
assert np.array_equal(alleles, [["A", "T"], ["C", "G"], ["G", "A"]])
4852

4953
genotypes = zroot["call_genotype"][:]
54+
assert genotypes.shape == (3, 2, 2)
55+
assert genotypes.dtype == np.int8
5056
assert np.array_equal(
5157
genotypes, [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]]
5258
)
5359

5460
phased = zroot["call_genotype_phased"][:]
61+
assert phased.shape == (3, 2)
62+
assert phased.dtype == np.bool
5563
assert np.all(phased)
5664

5765
contigs = zroot["contig_id"][:]
66+
assert contigs.shape == (1,)
67+
assert contigs.dtype == "O"
5868
assert np.array_equal(contigs, ["1"])
5969

6070
contig = zroot["variant_contig"][:]
71+
assert contig.shape == (3,)
72+
assert contig.dtype == np.int8
6173
assert np.array_equal(contig, [0, 0, 0])
6274

6375
samples = zroot["sample_id"][:]
76+
assert samples.shape == (2,)
77+
assert samples.dtype == "O"
6478
assert np.array_equal(samples, ["tsk_0", "tsk_1"])
6579

80+
assert set(zroot.array_keys()) == {
81+
"variant_position",
82+
"variant_allele",
83+
"call_genotype",
84+
"call_genotype_phased",
85+
"call_genotype_mask",
86+
"contig_id",
87+
"variant_contig",
88+
"sample_id",
89+
}
90+
6691

6792
class TestTskitFormat:
6893
"""Unit tests for TskitFormat without using full conversion."""

0 commit comments

Comments
 (0)