Skip to content

Add tskit CLI #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bio2zarr/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def bio2zarr():
bio2zarr.add_command(cli.vcf2zarr_main)
bio2zarr.add_command(cli.plink2zarr)
bio2zarr.add_command(cli.vcfpartition)
bio2zarr.add_command(cli.tskit2zarr)

if __name__ == "__main__":
bio2zarr()
50 changes: 50 additions & 0 deletions bio2zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tabulate

from . import plink, provenance, vcf_utils
from . import tskit as tskit_mod
from . import vcf as vcf_mod

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -630,3 +631,52 @@ def vcfpartition(vcfs, verbose, num_partitions, partition_size):
)
for region in regions:
click.echo(f"{region}\t{vcf_path}")


@click.command(name="convert")
@click.argument("ts_path", type=click.Path(exists=True))
@click.argument("zarr_path", type=click.Path())
@click.option("--contig-id", type=str, help="Contig/chromosome ID (default: '1')")
@click.option(
"--isolated-as-missing", is_flag=True, help="Treat isolated nodes as missing"
)
@variants_chunk_size
@samples_chunk_size
@verbose
@progress
@worker_processes
@force
def convert_tskit(
ts_path,
zarr_path,
contig_id,
isolated_as_missing,
variants_chunk_size,
samples_chunk_size,
verbose,
progress,
worker_processes,
force,
):
setup_logging(verbose)
check_overwrite_dir(zarr_path, force)

tskit_mod.convert(
ts_path,
zarr_path,
contig_id=contig_id,
isolated_as_missing=isolated_as_missing,
variants_chunk_size=variants_chunk_size,
samples_chunk_size=samples_chunk_size,
worker_processes=worker_processes,
show_progress=progress,
)


@version
@click.group()
def tskit2zarr():
pass


tskit2zarr.add_command(convert_tskit)
9 changes: 6 additions & 3 deletions bio2zarr/tskit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TskitFormat(vcz.Source):
def __init__(
self,
ts_path,
individuals_nodes,
individuals_nodes=None,
sample_ids=None,
contig_id=None,
isolated_as_missing=False,
Expand All @@ -25,6 +25,9 @@ def __init__(

self.positions = self.ts.sites_position

if individuals_nodes is None:
individuals_nodes = self.ts.individuals_nodes

self._num_samples = individuals_nodes.shape[0]
if self._num_samples < 1:
raise ValueError("individuals_nodes must have at least one sample")
Expand Down Expand Up @@ -213,8 +216,8 @@ def generate_schema(
def convert(
ts_path,
zarr_path,
individuals_nodes,
*,
individuals_nodes=None,
sample_ids=None,
contig_id=None,
isolated_as_missing=False,
Expand All @@ -225,7 +228,7 @@ def convert(
):
tskit_format = TskitFormat(
ts_path,
individuals_nodes,
individuals_nodes=individuals_nodes,
sample_ids=sample_ids,
contig_id=contig_id,
isolated_as_missing=isolated_as_missing,
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ dependencies = [
# colouredlogs pulls in humanfriendly",
"cyvcf2",
"bed_reader",
# TODO Using dev version of tskit for CI, FIXME before release
"tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python",
]
requires-python = ">=3.10"
classifiers = [
Expand All @@ -51,6 +53,7 @@ documentation = "https://sgkit-dev.github.io/bio2zarr/"
[project.scripts]
vcf2zarr = "bio2zarr.cli:vcf2zarr_main"
vcfpartition = "bio2zarr.cli:vcfpartition"
tskit2zarr = "bio2zarr.cli:tskit2zarr_main"

[project.optional-dependencies]
dev = [
Expand Down
Binary file added tests/data/ts/example.trees
Binary file not shown.
147 changes: 146 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@
local_alleles=False,
)

DEFAULT_TSKIT_CONVERT_ARGS = dict(
contig_id=None,
isolated_as_missing=False,
variants_chunk_size=None,
samples_chunk_size=None,
show_progress=True,
worker_processes=1,
)

DEFAULT_PLINK_CONVERT_ARGS = dict(
variants_chunk_size=None,
samples_chunk_size=None,
Expand Down Expand Up @@ -635,6 +644,116 @@ def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response
(self.vcf_path,), str(zarr_path), **DEFAULT_CONVERT_ARGS
)

@pytest.mark.parametrize(("progress", "flag"), [(True, "-P"), (False, "-Q")])
@mock.patch("bio2zarr.tskit.convert")
def test_convert_tskit(self, mocked, tmp_path, progress, flag):
ts_path = "tests/data/ts/example.trees"
zarr_path = tmp_path / "zarr"
runner = ct.CliRunner()
result = runner.invoke(
cli.tskit2zarr,
f"convert {ts_path} {zarr_path} {flag}",
catch_exceptions=False,
)
assert result.exit_code == 0
assert len(result.stdout) == 0
assert len(result.stderr) == 0
args = dict(DEFAULT_TSKIT_CONVERT_ARGS)
args["show_progress"] = progress
mocked.assert_called_once_with(
ts_path,
str(zarr_path),
**args,
)

@pytest.mark.parametrize("response", ["y", "Y", "yes"])
@mock.patch("bio2zarr.tskit.convert")
def test_tskit_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response):
ts_path = "tests/data/ts/example.trees"
zarr_path = tmp_path / "zarr"
zarr_path.mkdir()
runner = ct.CliRunner()
result = runner.invoke(
cli.tskit2zarr,
f"convert {ts_path} {zarr_path}",
catch_exceptions=False,
input=response,
)
assert result.exit_code == 0
assert f"Do you want to overwrite {zarr_path}" in result.stdout
assert len(result.stderr) == 0
mocked.assert_called_once_with(
ts_path,
str(zarr_path),
**DEFAULT_TSKIT_CONVERT_ARGS,
)

@pytest.mark.parametrize("response", ["n", "N", "No"])
@mock.patch("bio2zarr.tskit.convert")
def test_tskit_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response):
ts_path = "tests/data/ts/example.trees"
zarr_path = tmp_path / "zarr"
zarr_path.mkdir()
runner = ct.CliRunner()
result = runner.invoke(
cli.tskit2zarr,
f"convert {ts_path} {zarr_path}",
catch_exceptions=False,
input=response,
)
assert result.exit_code == 1
assert "Aborted" in result.stderr
mocked.assert_not_called()

@pytest.mark.parametrize("force_arg", ["-f", "--force"])
@mock.patch("bio2zarr.tskit.convert")
def test_tskit_convert_overwrite_zarr_force(self, mocked, tmp_path, force_arg):
ts_path = "tests/data/ts/example.trees"
zarr_path = tmp_path / "zarr"
zarr_path.mkdir()
runner = ct.CliRunner()
result = runner.invoke(
cli.tskit2zarr,
f"convert {ts_path} {zarr_path} {force_arg}",
catch_exceptions=False,
)
assert result.exit_code == 0
assert len(result.stdout) == 0
assert len(result.stderr) == 0
mocked.assert_called_once_with(
ts_path,
str(zarr_path),
**DEFAULT_TSKIT_CONVERT_ARGS,
)

@mock.patch("bio2zarr.tskit.convert")
def test_tskit_convert_with_options(self, mocked, tmp_path):
ts_path = "tests/data/ts/example.trees"
zarr_path = tmp_path / "zarr"
runner = ct.CliRunner()
result = runner.invoke(
cli.tskit2zarr,
f"convert {ts_path} {zarr_path} --contig-id chr1 "
"--isolated-as-missing -l 100 -w 50 -p 4",
catch_exceptions=False,
)
assert result.exit_code == 0
assert len(result.stdout) == 0
assert len(result.stderr) == 0

expected_args = dict(DEFAULT_TSKIT_CONVERT_ARGS)
expected_args["contig_id"] = "chr1"
expected_args["isolated_as_missing"] = True
expected_args["variants_chunk_size"] = 100
expected_args["samples_chunk_size"] = 50
expected_args["worker_processes"] = 4

mocked.assert_called_once_with(
ts_path,
str(zarr_path),
**expected_args,
)


class TestVcfEndToEnd:
vcf_path = "tests/data/vcf/sample.vcf.gz"
Expand Down Expand Up @@ -908,10 +1027,36 @@ def test_part_size_multiple_vcfs(self):


@pytest.mark.parametrize(
"cmd", [main.bio2zarr, cli.vcf2zarr_main, cli.plink2zarr, cli.vcfpartition]
"cmd",
[
main.bio2zarr,
cli.vcf2zarr_main,
cli.plink2zarr,
cli.vcfpartition,
cli.tskit2zarr,
],
)
def test_version(cmd):
runner = ct.CliRunner()
result = runner.invoke(cmd, ["--version"], catch_exceptions=False)
s = f"version {provenance.__version__}\n"
assert result.stdout.endswith(s)


class TestTskitEndToEnd:
def test_convert(self, tmp_path):
ts_path = "tests/data/ts/example.trees"
zarr_path = tmp_path / "zarr"
runner = ct.CliRunner()
result = runner.invoke(
cli.tskit2zarr,
f"convert {ts_path} {zarr_path}",
catch_exceptions=False,
)
assert result.exit_code == 0
result = runner.invoke(
cli.vcf2zarr_main, f"inspect {zarr_path}", catch_exceptions=False
)
assert result.exit_code == 0
# Arbitrary check
assert "variant_position" in result.stdout
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_examples(self, chunk_size, size, start, stop):
# It works in CI on Linux, but it'll probably break at some point.
# It's also necessary to update these numbers each time a new data
# file gets added
("tests/data", 5030777),
("tests/data", 5045029),
("tests/data/vcf", 5018640),
("tests/data/vcf/sample.vcf.gz", 1089),
],
Expand Down
Loading