import os,sys
from typing import List, Dict, Optional, Sequence, Tuple, Any
from concurrent.futures import ThreadPoolExecutor,as_completed
import anndata
import numpy as np
import pandas as pd
import scipy.sparse as sp
import glob
import json
import time
from loguru import logger as logger
logger.remove()
logger.add(sys.stderr, level="DEBUG")
[docs]
class AnnDataCollection:
"""Wrapper representing a merged collection of AnnData files.
Example:
ds = AnnDataCollection.from_files(["a.h5ad", "b.h5ad"], out_path="merged.h5ad")
view = ds[[True]*100, ["GeneA", "GeneB"]]
adata = view.to_memory()
"""
def __init__(self, adata_paths: List[str],
adata: Optional[anndata.AnnData] = None,
source_info: Optional[List[Dict]] = None):
"""Initialize an `AnnDataCollection` wrapper.
Parameters
- `adata_paths`: list of filesystem paths to the individual `.h5ad` files.
- `adata`: optional `anndata.AnnData` that contains merged metadata
(typically `obs`, `var`, and `obsm`) for the combined dataset. When
present this `AnnData` usually contains an empty or placeholder
`X` matrix because the heavy expression matrices remain stored in
the original files and are loaded on demand by `AnnSubset.to_memory()`.
If `ann` is `None`, the wrapper does not have merged metadata yet.
- `source_info`: optional list of dictionaries with per-source
metadata (e.g. path, n_obs, var_names). Populated by
`AnnDataCollection.from_files()` when creating a merged dataset.
"""
self.adata_paths = list(adata_paths)
self.adata = adata
self.source_info = source_info or []
[docs]
@classmethod
def from_files(cls, paths: Sequence[str],
out_path: Optional[str] = None,
metadata_path: Optional[str] = None) -> "AnnDataCollection":
"""Create a merged AnnDataCollection from existing `.h5ad` files.
This will merge `obs` (stacked) and `var` (union by var_names).
`X` is not merged; the saved on-disk AnnData will contain an empty
sparse matrix of shape (n_obs_total, n_vars_total).
"""
# `paths` should be an iterable of paths to `.h5ad` files, all .h5ad files must
# have the same vars. When
# `out_path` is provided the merged AnnData (metadata-only) will be
# written to that file. The returned `AnnDataCollection` keeps the list of
# original file paths and stores per-source info in `source_info`.
# Allow `paths` to be either:
# - a single wildcard string: "/path/*.h5ad"
# - a sequence of explicit paths
# - a sequence with one wildcard string
if isinstance(paths, str):
paths = sorted(glob.glob(os.path.expanduser(paths)))
else:
# convert to list of str
paths = [str(p) for p in paths]
if not paths:
raise FileNotFoundError("No .h5ad files found for given paths")
obs_list = []
# preserve var order as seen across files for faster subsetting later
# var_list: List[str] = []
# _seen_vars = set()
per_source_info: List[Dict[str, Any]] = []
var_names=None
# First pass: gather var union, obs frames and per-source info.
# Read in backed mode so we don't load `X` into memory.
for i, p in enumerate(paths):
a = anndata.read_h5ad(p, backed="r")
try:
obs = a.obs.copy()
obs["_orig_obs_name"] = a.obs_names.astype(str)
obs["_source_idx"] = i
obs_list.append(obs)
# var_list = var_indexes.union(a.var_names)
# append vars in source order if not seen before
# for v in a.var_names:
# if v not in _seen_vars:
# var_list.append(str(v))
# _seen_vars.add(v)
if var_names is None:
var_names=list(a.var_names)
assert var_names == list(a.var_names), f"Var names in file {p} differ from previous files; this may lead to incorrect merging."
per_source_info.append({
"path": os.path.abspath(p),
"n_obs": a.n_obs,
"n_vars": a.n_vars,
})
finally:
try:
a.file.close()
except Exception:
try:
# fallback for older anndata versions
a._file.close()
except Exception:
pass
# Concatenate obs
merged_obs = pd.concat(obs_list)
merged_obs.index = merged_obs.index.map(str)
# If metadata_path provided, read metadata and filter merged_obs to only those cells.
if metadata_path is not None:
sep="\t" if metadata_path.endswith(".tsv") or metadata_path.endswith(".tsv.gz") or metadata_path.endswith(".txt") else ','
metadata_path=os.path.expanduser(metadata_path)
meta = pd.read_csv(metadata_path, index_col=0, sep=sep)
meta_ids = set(meta.index.astype(str).tolist())
# keep_ids= list(set(merged_obs.index.astype(str)).intersection(meta_ids))
keep_ids=[cell for cell in merged_obs.index.astype(str).tolist() if cell in meta_ids]
merged_obs=merged_obs.loc[keep_ids]
for col in meta.columns:
if col not in merged_obs.columns:
merged_obs[col] = meta.loc[merged_obs.index.tolist(),col].tolist()
# Build merged var DataFrame (union of var names)
merged_var = pd.DataFrame(index=pd.Index(var_names)) # var_list
# recompute total observations after optional metadata filtering
# Create an empty sparse X with proper shape (we don't merge expression matrices here)
X_empty = sp.csr_matrix((merged_obs.shape[0], len(merged_var)), dtype=np.float32)
adata = anndata.AnnData(X_empty, obs=merged_obs, var=merged_var)
# Note: we intentionally do not merge `obsm` arrays. If per-source
# `obsm` data is needed, read individual files directly from
# `individual_adata_paths` stored in `merged.uns`.
# store metadata (short keys)
adata.uns["src_paths"] = [os.path.abspath(p) for p in paths]
# JSON-serialize per-source info so HDF5 can store it safely
try:
adata.uns["src_info"] = json.dumps(per_source_info)
except Exception:
# fallback: store as a list of strings
adata.uns["src_info"] = [str(x) for x in per_source_info]
if metadata_path is not None:
adata.uns["metadata_path"] = os.path.abspath(metadata_path)
instance = cls(list(paths), adata=adata, source_info=per_source_info)
# optionally save to out_path
if out_path:
adata.write_h5ad(out_path)
return instance
[docs]
@classmethod
def read(cls, path: str) -> "AnnDataCollection":
"""Load a merged `.h5ad` file created by `from_files` and return an
`AnnDataCollection`.
Expects `src_paths` (or `individual_adata_paths`) in `adata.uns` and
`src_info` which may be JSON-serialized or a list.
"""
path = os.path.expanduser(path)
adata = anndata.read_h5ad(path)
# Basic validation: ensure this AnnData was produced by `from_files`.
has_src_uns = ("src_paths" in adata.uns) or ("individual_adata_paths" in adata.uns)
has_obs_markers = ("_source_idx" in adata.obs.columns) or ("_orig_obs_name" in adata.obs.columns)
if not (has_src_uns and has_obs_markers):
raise ValueError(f"File {path} does not appear to be an AnnDataCollection (missing src_paths/src_info or obs markers)")
# Avoid relying on truthiness of containers (e.g., numpy arrays)
sp_raw = adata.uns.get("src_paths", None)
if sp_raw is None:
sp_raw = adata.uns.get("individual_adata_paths", None)
if sp_raw is None:
src_paths = []
else:
# normalize to a list of strings regardless of input type
if isinstance(sp_raw, (list, tuple, pd.Index)):
src_paths = [str(p) for p in sp_raw]
elif isinstance(sp_raw, np.ndarray):
src_paths = [str(p) for p in sp_raw.tolist()]
else:
src_paths = [str(sp_raw)]
raw_info = adata.uns.get("src_info", None)
source_info: List[Dict[str, Any]] = []
if raw_info is None:
source_info = []
elif isinstance(raw_info, str):
try:
source_info = json.loads(raw_info)
except Exception:
source_info = [raw_info]
elif isinstance(raw_info, list):
source_info = raw_info
else:
source_info = [raw_info]
return cls(src_paths, adata=adata, source_info=source_info)
def __len__(self) -> int:
return self.adata.n_obs if self.adata is not None else 0
def __getitem__(self, key: Tuple[Any, Any]) -> "AnnDataView":
"""Support `adatas[cells, genes]` slicing.
`cells` and `genes` may be:
- slice, list of booleans, list of indices, list of labels
Returns an `AnnDataView` with a `to_memory()` method that will load X.
"""
if not isinstance(key, tuple) or len(key) != 2:
raise IndexError("Indexing must be adatas[cells, genes]")
cells, genes = key
obs_idx = self._parse_obs_indexer(cells)
var_names = self._parse_var_indexer(genes)
return AnnDataView(self, obs_idx, var_names)
[docs]
def update_paths(self, mapping: Optional[Dict[str, str]] = None,
search_dirs: Optional[Sequence[str]] = None,
recursive: bool = True) -> Dict[str, Optional[str]]:
"""Update stored source paths after files have been moved.
Parameters
- `mapping`: optional dict mapping old_path -> new_path. If provided,
entries in `adata_paths` will be replaced according to this map.
- `search_dirs`: optional sequence of directories to search for moved
files by matching basenames (first match wins). Used when `mapping`
is not provided or doesn't contain an entry for a given source.
- `recursive`: if True, search directories recursively when using
`search_dirs`.
Returns a dict mapping the original absolute path -> new absolute
path (or `None` if not found).
"""
result: Dict[str, Optional[str]] = {}
norm_map: Dict[str, str] = {}
if mapping:
for k, v in mapping.items():
try:
norm_map[os.path.abspath(k)] = os.path.abspath(v)
except Exception:
norm_map[str(k)] = str(v)
search_dirs = [os.path.expanduser(d) for d in (search_dirs or [])]
for i, orig in enumerate(list(self.adata_paths)):
orig_abs = os.path.abspath(orig)
new_path: Optional[str] = None
# 1) explicit mapping provided by user
if orig_abs in norm_map:
cand = norm_map[orig_abs]
if os.path.exists(cand):
new_path = cand
# 2) search by basename in provided search_dirs
if new_path is None and search_dirs:
basename = os.path.basename(orig_abs)
for d in search_dirs:
if recursive:
pattern = os.path.join(d, "**", basename)
else:
pattern = os.path.join(d, basename)
matches = sorted(glob.glob(pattern, recursive=recursive))
if matches:
new_path = os.path.abspath(matches[0])
break
# 3) if original still exists, keep it
if new_path is None and os.path.exists(orig_abs):
new_path = orig_abs
result[orig_abs] = new_path
if new_path:
# update internal list
self.adata_paths[i] = new_path
# update adata.uns if present (accept old key as fallback)
if self.adata is not None:
try:
paths_list = list(self.adata.uns.get("src_paths", self.adata.uns.get("individual_adata_paths", [])))
if i < len(paths_list):
paths_list[i] = new_path
else:
# ensure list is long enough
while len(paths_list) < i:
paths_list.append("")
paths_list.append(new_path)
self.adata.uns["src_paths"] = [os.path.abspath(p) for p in paths_list]
except Exception:
# best-effort; don't raise on metadata update
pass
# update source_info if present
if i < len(self.source_info):
try:
self.source_info[i]["path"] = new_path
except Exception:
pass
# re-serialize per-source info into uns for safe HDF5 storage
if self.adata is not None and self.source_info:
try:
self.adata.uns["src_info"] = json.dumps(self.source_info)
except Exception:
self.adata.uns["src_info"] = [str(x) for x in self.source_info]
return result
def _parse_obs_indexer(self, idx) -> np.ndarray:
n = self.adata.n_obs
if idx is None:
return np.arange(n)
if isinstance(idx, slice):
return np.arange(n)[idx]
if isinstance(idx, (list, np.ndarray, pd.Series)):
arr = np.asarray(idx)
if arr.dtype == bool:
return np.nonzero(arr)[0] # get indexes where True
# If integer dtype, treat as indices. Otherwise treat as labels.
if np.issubdtype(arr.dtype, np.integer):
return arr.astype(int)
# treat as obs names (strings or object)
return np.array(self.adata.obs_names.get_indexer(arr.astype(str)), dtype=int)
raise IndexError("Unsupported obs indexer")
def _parse_var_indexer(self, idx) -> List[str]:
if idx is None: # idx==slice(None)
return list(self.adata.var_names)
if isinstance(idx, slice):
return list(self.adata.var_names[idx])
if isinstance(idx, (list, np.ndarray, pd.Series)):
arr = np.asarray(idx)
if arr.dtype == bool:
return list(np.asarray(self.adata.var_names)[arr])
# If integer dtype, interpret as indices into var_names
if np.issubdtype(arr.dtype, np.integer):
return list(np.asarray(self.adata.var_names)[arr.astype(int)])
# Otherwise treat as labels (string/object/unicode)
return list(arr.astype(str))
if isinstance(idx, str):
return [idx]
raise IndexError("Unsupported var indexer")
[docs]
def is_annadatacollection(path: str) -> bool:
"""Return True if `path` points to an AnnData file produced by
`AnnDataCollection.from_files`.
Heuristic: the AnnData must have `uns` entries `src_paths` or
`individual_adata_paths` and its `obs` must contain either
`_source_idx` or `_orig_obs_name` markers.
"""
if path is None:
return False
path = os.path.expanduser(path)
try:
ad = anndata.read_h5ad(path, backed="r")
except Exception:
return False
try:
has_src_uns = ("src_paths" in ad.uns) or ("individual_adata_paths" in ad.uns)
# obs may be empty; check columns safely
has_obs_markers = False
try:
has_obs_markers = ("_source_idx" in ad.obs.columns) or ("_orig_obs_name" in ad.obs.columns)
except Exception:
has_obs_markers = False
return bool(has_src_uns and has_obs_markers)
finally:
try:
ad.file.close()
except Exception:
try:
ad._file.close()
except Exception:
pass
[docs]
class AnnDataView:
"""Helper representing a (cells, genes) subset of an `AnnDataCollection`.
Call `to_memory()` to read and assemble the actual `AnnData` with `X`.
"""
def __init__(self, dataset: AnnDataCollection, obs_idx: np.ndarray=None, var_names: List[str]=None):
self.dataset = dataset
self.obs_idx = np.asarray(obs_idx, dtype=int) if obs_idx is not None else None # np.arange(dataset.adata.n_obs)
self.var_names = list(var_names) if var_names is not None else None #list(dataset.adata.var_names)
[docs]
def to_memory(self, thread=8) -> anndata.AnnData:
"""Load selected `X` chunks from underlying files and assemble AnnData.
Returns a new `anndata.AnnData` with concatenated `X` for the selected
cells and genes.
"""
ds = self.dataset
merged = ds.adata
# Build target obs and var
if not self.obs_idx is None:
sel_obs = merged.obs.iloc[self.obs_idx].copy()
else:
sel_obs = merged.obs.copy()
if not self.var_names is None:
sel_var = merged.var.loc[self.var_names].copy()
else:
sel_var = merged.var.copy()
n_rows = len(sel_obs)
n_cols = len(sel_var)
# Prepare containers for COO assembly
rows_all: List[int] = []
cols_all: List[int] = []
data_all: List[float] = []
# Map merged obs rows to sources and cache some locals for speed
src_indices = sel_obs["_source_idx"].values
sel_orig_obs = sel_obs["_orig_obs_name"].to_numpy() # cell_ids in original source files, aligned with sel_obs rows
paths = ds.adata_paths
var_names = sel_var.index.astype(str).tolist()
# build mapping dict for fast lookup
mapping = {v: idx for idx, v in enumerate(merged.var_names)} # type: ignore
src_gene_pos_all = np.array([mapping.get(v, -1) for v in var_names], dtype=int)
# use precomputed mapping for this source
present_mask = src_gene_pos_all >= 0
if not np.any(present_mask):
return None
src_col_pos_in_global = np.nonzero(present_mask)[0]
src_gene_pos = src_gene_pos_all[present_mask]
def _process_source(src_idx: int, path: str):
mask = src_indices == src_idx # indices of adata in adata_paths
if not np.any(mask):
return None
global_rows = np.nonzero(mask)[0] # positions when True, all cells in sel_obs that come from this source
orig_obs_names = sel_orig_obs[mask] # cell_ids in the current adata file, aligned with global_rows positions in sel_obs
a = anndata.read_h5ad(path, backed="r")
try:
# src_row_pos = np.asarray(a.obs_names.get_indexer(orig_obs_names), dtype=int)
# if (src_row_pos < 0).any():
# missing = orig_obs_names[src_row_pos < 0]
# raise KeyError(f"Some requested obs names not found in {path}: {missing}")
rows_ann = a[orig_obs_names,:]
X_rows = rows_ann.X
if sp.issparse(X_rows):
X_rows = X_rows.tocsr()
else:
X_rows = sp.csr_matrix(X_rows)
X_sub = X_rows[:, src_gene_pos]
if not sp.issparse(X_sub):
X_sub = sp.csr_matrix(X_sub)
X_coo = X_sub.tocoo()
if X_coo.nnz == 0:
return None
global_rows_arr = np.asarray(global_rows)
src_col_pos_arr = np.asarray(src_col_pos_in_global)
rows_mapped = global_rows_arr[X_coo.row].astype(np.int32)
cols_mapped = src_col_pos_arr[X_coo.col].astype(np.int32)
data_arr = X_coo.data.astype(np.float32)
return rows_mapped, cols_mapped, data_arr
finally:
try:
a.file.close()
except Exception:
try:
a._file.close()
except Exception:
pass
# Process sources in parallel to utilize multiple cores / IO
max_workers = min(thread, max(1, len(paths)))
with ThreadPoolExecutor(max_workers=max_workers) as exe:
futures = {exe.submit(_process_source, i, p): (i, p) for i, p in enumerate(paths)}
for fut in as_completed(futures):
try:
res = fut.result()
except Exception:
logger.error("Error processing source %s", futures[fut][1])
continue
if res is None:
continue
rows_mapped, cols_mapped, data = res
rows_all.append(rows_mapped)
cols_all.append(cols_mapped)
data_all.append(data)
if len(data_all) == 0:
X_final = sp.csr_matrix((n_rows, n_cols), dtype=np.float32)
else:
rows_cat = np.concatenate(rows_all).astype(int)
cols_cat = np.concatenate(cols_all).astype(int)
data_cat = np.concatenate(data_all).astype(np.float32)
X_final = sp.coo_matrix((data_cat, (rows_cat, cols_cat)), shape=(n_rows, n_cols)).tocsr()
out = anndata.AnnData(X_final, obs=sel_obs, var=sel_var)
return out
[docs]
def test_subset():
os.chdir(os.path.expanduser("/home/x-wding2/Projects/mouse_dev/testAnnDataCollection"))
adata_path="/anvil/projects/x-mcb130189/Wubin/mouse_dev/adata/100kb/*-CGN.h5ad"
reference_path="/anvil/projects/x-mcb130189/Wubin/mouse_dev/adata/mouse_dev.100kb-CGN.h5ad"
metadata_path=os.path.expanduser("~/Projects/mouse_dev/metadata/metadata.filtered.tsv.gz")
ds = AnnDataCollection.from_files(adata_path, out_path="mouse_dev.100kb-CGN.h5ad",
metadata_path=metadata_path)
ref = anndata.read_h5ad(reference_path, backed="r")
use_cells=ref.obs.sample(5000).index.tolist()
start = time.perf_counter()
ref_data=ref[use_cells,:].to_df()
end = time.perf_counter()
ref.file.close()
logger.info(f"Reference read time for 5000 cells: {end - start:.2f} seconds")
start= time.perf_counter()
data=ds[use_cells,:].to_memory(thread=10).to_df()
end = time.perf_counter()
logger.info(f"AnnDataCollection read time for 5000 cells: {end - start:.2f} seconds")
ref_data.index.name='cell'
assert ref_data.equals(data)
pd.testing.assert_frame_equal(ref_data, data, check_dtype=True, rtol=1e-6, atol=0)
# subset vars
ref = anndata.read_h5ad(reference_path, backed="r")
use_vars=ref.var.sample(50).index.tolist()
start = time.perf_counter()
ref_data=ref[:,use_vars].to_df()
end = time.perf_counter()
logger.info(f"Reference read time for 50 vars: {end - start:.2f} seconds")
ref.file.close()
start = time.perf_counter()
data=ds[:,use_vars].to_memory(thread=10).to_df()
end = time.perf_counter()
logger.info(f"AnnDataCollection read time for 50 vars: {end - start:.2f} seconds")
ref_data.index.name='cell'
data=data.loc[ref_data.index.tolist()]
assert ref_data.equals(data)
pd.testing.assert_frame_equal(ref_data, data, check_dtype=True, rtol=1e-6, atol=0)
[docs]
def main():
import fire
fire.core.Display = lambda lines, out: print(*lines, file=out) # type: ignore
fire.Fire()
if __name__ == "__main__":
main()