Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 95 additions & 62 deletions ALLCools/count_matrix/dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import pathlib
import subprocess
import tempfile
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import lru_cache
from shutil import rmtree

import numpy as np
import pandas as pd
import pybedtools
import pysam
import xarray as xr
import zarr
import zarr.convenience
import zarr.creation
import zarr.hierarchy
import zarr.storage
from numcodecs import blosc
from scipy import stats

from ALLCools.utilities import parse_chrom_size, parse_mc_pattern
Expand Down Expand Up @@ -68,11 +74,9 @@ def summary(self):
return mc_type_data


def _determine_datasets(regions, quantifiers, chrom_size_path, tmp_dir):
def _determine_datasets(regions, quantifiers, chrom_size_path):
"""Determine datasets for each region."""
tmp_dir = pathlib.Path(tmp_dir).absolute()
tmp_dir.mkdir(exist_ok=True, parents=True)

tmpdir = tempfile.mkdtemp()
chrom_sizes = parse_chrom_size(chrom_size_path)
datasets = {}
for pair in regions:
Expand All @@ -94,7 +98,7 @@ def _determine_datasets(regions, quantifiers, chrom_size_path, tmp_dir):
"do not have index in its fourth column, adding it automatically. "
"If this is not desired, add a fourth column containing UNIQUE IDs to the BED file.",
)
region_bed_df[name] = [f"{name}_{i}" for i in range(region_bed_df.shape[0])]
region_bed_df[name] = (f"{name}_{i}" for i in range(region_bed_df.shape[0]))
# check if name is unique()
if region_bed_df.iloc[:, 3].duplicated().sum() > 0:
raise ValueError(f"Region IDs in {region_path} (fourth column) are not unique.")
Expand Down Expand Up @@ -122,7 +126,7 @@ def _id(i, c=chrom):

except ValueError:
raise ValueError(f"Can not understand region specification {region_path}")
region_path = f"{tmp_dir}/{name}.regions.csv"
region_path = f"{tmpdir}/{name}.regions.csv"
region_bed_df.to_csv(region_path)
datasets[name] = {"regions": region_path, "quant": []}

Expand Down Expand Up @@ -152,7 +156,7 @@ def _id(i, c=chrom):
if quant_type not in ALLOW_QUANT_TYPES:
raise ValueError(f"QUANT_TYPE need to be in {ALLOW_QUANT_TYPES}, got {quant_type} in {quantifier}.")
datasets[name]["quant"].append(_Quant(mc_types=mc_types, quant_type=quant_type, kwargs=kwargs))
return datasets
return datasets, tmpdir


def _count_single_region_set(allc_table, region_config, obs_dim, region_dim):
Expand Down Expand Up @@ -208,15 +212,13 @@ def _calculate_pv(data, reverse_value, obs_dim, var_dim, cutoff=0.9):


def _count_single_zarr(
allc_table, region_config, obs_dim, region_dim, output_path, obs_dim_dtype, count_dtype="uint32"
allc_table, region_config, obs_dim, obs_dim_dtype, region_dim, chunk_start, regiongroup, count_dtype="uint32"
):
"""Process single region set and its quantifiers."""
# count all ALLC and mC types that's needed for quantifiers if this region_dim
count_ds = _count_single_region_set(
allc_table=allc_table, region_config=region_config, obs_dim=obs_dim, region_dim=region_dim
)

total_ds = {}
# deal with count quantifiers
count_mc_types = []
for quant in region_config["quant"]:
Expand All @@ -227,8 +229,9 @@ def _count_single_zarr(
count_da = count_ds.sel(mc_type=count_mc_types)[f"{region_dim}_da"]
max_int = np.iinfo(count_dtype).max
count_da = xr.where(count_da > max_int, max_int, count_da)
total_ds[f"{region_dim}_da"] = count_da.astype(count_dtype)

regiongroup[f"{region_dim}_da"][chunk_start : chunk_start + allc_table.index.size, :, :, :] = count_da.astype(
count_dtype
).data
# deal with hypo-score, hyper-score quantifiers
for quant in region_config["quant"]:
if quant.quant_type == "hypo-score":
Expand All @@ -240,7 +243,9 @@ def _count_single_zarr(
var_dim=region_dim,
**quant.kwargs,
)
total_ds[f"{region_dim}_da_{mc_type}-hypo-score"] = data
regiongroup[f"{region_dim}_da_{mc_type}-hypo-score"][
chunk_start : chunk_start + allc_table.index.size, :
] = data.data
elif quant.quant_type == "hyper-score":
for mc_type in quant.mc_types:
data = _calculate_pv(
Expand All @@ -250,11 +255,13 @@ def _count_single_zarr(
var_dim=region_dim,
**quant.kwargs,
)
total_ds[f"{region_dim}_da_{mc_type}-hyper-score"] = data
total_ds = xr.Dataset(total_ds)
total_ds.coords[obs_dim] = total_ds.coords[obs_dim].astype(obs_dim_dtype)
total_ds.to_zarr(output_path, mode="w")
return output_path
regiongroup[f"{region_dim}_da_{mc_type}-hyper-score"][
chunk_start : chunk_start + allc_table.index.size, :
] = data.data
regiongroup[obs_dim][chunk_start : chunk_start + allc_table.index.size] = (
count_ds.coords[obs_dim].astype(obs_dim_dtype).data
)
return True


@doc_params(
Expand Down Expand Up @@ -311,68 +318,92 @@ def generate_dataset(

# prepare regions and determine quantifiers
pathlib.Path(output_path).mkdir(exist_ok=True)
tmp_dir = f"{output_path}_tmp"
datasets = _determine_datasets(regions, quantifiers, chrom_size_path, tmp_dir)

z = zarr.storage.DirectoryStore(path=output_path)
root = zarr.hierarchy.group(store=z, overwrite=True)
datasets, tmpdir = _determine_datasets(regions, quantifiers, chrom_size_path)
# copy chrom_size_path to output_path
subprocess.run(["cp", "-f", chrom_size_path, f"{output_path}/chrom_sizes.txt"], check=True)

chunk_records = defaultdict(dict)
rgs = {}
for region_dim, region_config in datasets.items():
rgs[region_dim] = root.create_group(region_dim)
# save region coords to the ds
bed = pd.read_csv(f"{tmpdir}/{region_dim}.regions.csv", index_col=0)
bed.columns = [f"{region_dim}_chrom", f"{region_dim}_start", f"{region_dim}_end"]
bed.index.name = region_dim
region_size = bed.index.size
# append region bed to the saved ds
ds = xr.Dataset()
for col, data in bed.items():
ds.coords[col] = data
ds.coords[region_dim] = bed.index.values
# change object dtype to string
for k in ds.coords.keys():
if ds.coords[k].dtype == "O":
ds.coords[k] = ds.coords[k].astype(str)
ds.to_zarr(f"{output_path}/{region_dim}", mode="w", consolidated=False)
dsobs = rgs[region_dim].empty(
name=obs_dim, shape=allc_table.index.size, chunks=(chunk_size), dtype=f"<U{max_length}"
)
dsobs.attrs["_ARRAY_DIMENSIONS"] = [obs_dim]
count_mc_types = []
for quant in region_config["quant"]:
if quant.quant_type == "count":
count_mc_types += quant.mc_types
count_mc_types = list(set(count_mc_types))
if len(count_mc_types) > 0:
DA = rgs[region_dim].empty(
name=f"{region_dim}_da",
shape=(n_sample, region_size, len(count_mc_types), 2),
chunks=(chunk_size, region_size, len(count_mc_types), 2),
dtype="uint32",
)
DA.attrs["_ARRAY_DIMENSIONS"] = [obs_dim, region_dim, "mc_type", "count_type"]
count = rgs[region_dim].array(name="count_type", data=(["mc", "cov"]), dtype="<U3")
count.attrs["_ARRAY_DIMENSIONS"] = ["count_type"]
mc = rgs[region_dim].array(name="mc_type", data=count_mc_types, dtype="<U3")
mc.attrs["_ARRAY_DIMENSIONS"] = ["mc_type"]
# deal with hypo-score, hyper-score quantifiers
for quant in region_config["quant"]:
if quant.quant_type == "hypo-score":
for mc_type in quant.mc_types:
hypo = rgs[region_dim].empty(
name=f"{region_dim}_da_{mc_type}-hypo-score",
shape=(allc_table.size, region_size),
chunks=(chunk_size, region_size),
dtype="float16",
)
hypo.attrs["_ARRAY_DIMENSIONS"] = [obs_dim, region_dim]
elif quant.quant_type == "hyper-score":
for mc_type in quant.mc_types:
hyper = rgs[region_dim].empty(
name=f"{region_dim}_da_{mc_type}-hyper-score",
shape=(allc_table.size, region_size),
chunks=(chunk_size, region_size),
dtype="float16",
)
hyper.attrs["_ARRAY_DIMENSIONS"] = [obs_dim, region_dim]
blosc.use_threads = False
with ProcessPoolExecutor(cpu) as exe:
futures = {}
# parallel on allc chunks and region_sets levels
for i, chunk_start in enumerate(range(0, n_sample, chunk_size)):
allc_chunk = allc_table[chunk_start : chunk_start + chunk_size]
for region_dim, region_config in datasets.items():
chunk_path = f"{tmp_dir}/chunk_{region_dim}_{chunk_start}.zarr"
f = exe.submit(
_count_single_zarr,
allc_table=allc_chunk,
region_config=region_config,
obs_dim=obs_dim,
region_dim=region_dim,
output_path=chunk_path,
obs_dim_dtype=obs_dim_dtype,
region_dim=region_dim,
chunk_start=chunk_start,
regiongroup=rgs[region_dim],
)
futures[f] = (region_dim, i)

for f in as_completed(futures):
region_dim, i = futures[f]
chunk_path = f.result()
print(f"Chunk {i} of {region_dim} returned")
chunk_records[region_dim][i] = chunk_path

for region_dim, chunks in chunk_records.items():
# write chunk in order
chunk_paths = pd.Series(chunks).sort_index().tolist()
for i, chunk_path in enumerate(chunk_paths):
ds = xr.open_zarr(chunk_path).load()
# dump chunk to final place
if i == 0:
# first chunk
ds.to_zarr(f"{output_path}/{region_dim}", mode="w")
else:
# append
ds.to_zarr(f"{output_path}/{region_dim}", append_dim=obs_dim)
rmtree(chunk_path)

# save region coords to the ds
bed = pd.read_csv(f"{tmp_dir}/{region_dim}.regions.csv", index_col=0)
bed.columns = [f"{region_dim}_chrom", f"{region_dim}_start", f"{region_dim}_end"]
bed.index.name = region_dim
# append region bed to the saved ds
ds = xr.Dataset()
for col, data in bed.items():
ds.coords[col] = data
# change object dtype to string
for k in ds.coords.keys():
if ds.coords[k].dtype == "O":
ds.coords[k] = ds.coords[k].astype(str)
ds.to_zarr(f"{output_path}/{region_dim}", mode="a")

# delete tmp
rmtree(tmp_dir)

blosc.use_threads = None
from ..mcds.utilities import update_dataset_config

update_dataset_config(
Expand All @@ -383,4 +414,6 @@ def generate_dataset(
"ds_sample_dim": {region_dim: obs_dim for region_dim in datasets.keys()},
},
)
for region_dim in datasets.keys():
zarr.convenience.consolidate_metadata(f"{output_path}/{region_dim}")
return output_path
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies:
- statsmodels
- xarray
- yaml
- zarr
- zarr < 3
- pip:
- papermill
- imblearn
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
'seaborn',
"xarray",
"pyyaml",
"zarr < 3"
]

[project.optional-dependencies]
Expand Down