加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 5.25 KB
一键复制 编辑 原始数据 按行查看 历史
from itertools import chain
import os
from pathlib import Path
import dill
import anndata as ad
import networkx as nx
import scanpy as sc
import numpy as np
import pandas as pd
from matplotlib import rcParams
import mindspore.context as context
import multiprocessing as mp
from models.scglue import SCGLUEModel
mp.set_start_method('spawn')
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0, max_device_memory="16GB")
ANNDATA_KEY = "__scglue__"
def set_publication_params() -> None:
r"""
Set publication-level figure parameters
"""
sc.set_figure_params(
scanpy=True, dpi_save=600, vector_friendly=True, format="pdf",
facecolor=(1.0, 1.0, 1.0, 0.0), transparent=False
)
rcParams["savefig.bbox"] = "tight"
def configure_dataset(adata, prob_model, use_highly_variable=True, use_layer=None, use_rep=None, use_batch=None, use_cell_type=None, use_dsc_weight=None, use_obs_names=False):
if ANNDATA_KEY in adata.uns:
configure_dataset.logger.warning(
"`configure_dataset` has already been called. "
"Previous configuration will be overwritten!"
)
data_config = {}
data_config["prob_model"] = prob_model
if use_highly_variable:
if "highly_variable" not in adata.var:
raise ValueError("Please mark highly variable features first!")
data_config["use_highly_variable"] = True
data_config["features"] = adata.var.query("highly_variable").index.to_numpy().tolist()
else:
data_config["use_highly_variable"] = False
data_config["features"] = adata.var_names.to_numpy().tolist()
if use_layer:
if use_layer not in adata.layers:
raise ValueError("Invalid `use_layer`!")
data_config["use_layer"] = use_layer
else:
data_config["use_layer"] = None
if use_rep:
if use_rep not in adata.obsm:
raise ValueError("Invalid `use_rep`!")
data_config["use_rep"] = use_rep
data_config["rep_dim"] = adata.obsm[use_rep].shape[1]
else:
data_config["use_rep"] = None
data_config["rep_dim"] = None
if use_batch:
if use_batch not in adata.obs:
raise ValueError("Invalid `use_batch`!")
data_config["use_batch"] = use_batch
data_config["batches"] = pd.Index(
adata.obs[use_batch]
).dropna().drop_duplicates().sort_values().to_numpy() # AnnData does not support saving pd.Index in uns
else:
data_config["use_batch"] = None
data_config["batches"] = None
if use_cell_type:
if use_cell_type not in adata.obs:
raise ValueError("Invalid `use_cell_type`!")
data_config["use_cell_type"] = use_cell_type
data_config["cell_types"] = pd.Index(
adata.obs[use_cell_type]
).dropna().drop_duplicates().sort_values().to_numpy() # AnnData does not support saving pd.Index in uns
else:
data_config["use_cell_type"] = None
data_config["cell_types"] = None
if use_dsc_weight:
if use_dsc_weight not in adata.obs:
raise ValueError("Invalid `use_dsc_weight`!")
data_config["use_dsc_weight"] = use_dsc_weight
else:
data_config["use_dsc_weight"] = None
data_config["use_obs_names"] = use_obs_names
adata.uns[ANNDATA_KEY] = data_config
def load_model(fname):
fname = Path(fname)
with fname.open("rb") as f:
model = dill.load(f)
return model
def fit_SCGLUE(adatas, graph, model=SCGLUEModel, init_kws=None, compile_kws=None, fit_kws=None, balance_kws=None, infer=False):
print("start fit scglue")
init_kws = init_kws or {}
compile_kws = compile_kws or {}
fit_kws = fit_kws or {}
balance_kws = balance_kws or {}
print("Pretraining SCGLUE model...")
pretrain_init_kws = init_kws.copy()
pretrain_init_kws.update({"shared_batches":False})
pretrain_fit_kws = fit_kws.copy()
# pretrain_fit_kws.update({"align_burnin": np.inf, "safe_burnin": False})
pretrain_fit_kws.update({"safe_burnin": False})
if "directory" in pretrain_fit_kws:
pretrain_fit_kws["directory"] = os.path.join(pretrain_fit_kws["directory"], "pretrain")
pretrain = model(adatas, sorted(graph.nodes), **pretrain_init_kws)
if infer:
return pretrain
print("start compile")
pretrain.compile(**compile_kws)
print("start fit")
pretrain.fit(adatas, graph, **pretrain_fit_kws)
return pretrain
if __name__ == "__main__":
set_publication_params()
rcParams["figure.figsize"] = (4, 4)
rna = ad.read_h5ad("data/rna-pp.h5ad")
atac = ad.read_h5ad("data/atac-pp.h5ad")
guidance = nx.read_graphml("data/guidance.graphml.gz")
print("successfully read the data")
configure_dataset(rna, "NB", use_highly_variable=True, use_layer="counts", use_rep="X_pca")
configure_dataset(atac, "NB", use_highly_variable=True, use_rep="X_lsi")
print("successfully configure dataset")
guidance_hvf = guidance.subgraph(chain(rna.var.query("highly_variable").index, atac.var.query("highly_variable").index)).copy()
glue = fit_SCGLUE({"rna": rna, "atac": atac}, guidance_hvf, fit_kws={"directory": "glue"})
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化