代码拉取完成,页面将自动刷新
from itertools import chain
import os
from pathlib import Path
import dill
import anndata as ad
import networkx as nx
import scanpy as sc
import numpy as np
import pandas as pd
from matplotlib import rcParams
import mindspore.context as context
import multiprocessing as mp
from models.scglue import SCGLUEModel
mp.set_start_method('spawn')
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0, max_device_memory="16GB")
ANNDATA_KEY = "__scglue__"
def set_publication_params() -> None:
r"""
Set publication-level figure parameters
"""
sc.set_figure_params(
scanpy=True, dpi_save=600, vector_friendly=True, format="pdf",
facecolor=(1.0, 1.0, 1.0, 0.0), transparent=False
)
rcParams["savefig.bbox"] = "tight"
def configure_dataset(adata, prob_model, use_highly_variable=True, use_layer=None, use_rep=None, use_batch=None, use_cell_type=None, use_dsc_weight=None, use_obs_names=False):
if ANNDATA_KEY in adata.uns:
configure_dataset.logger.warning(
"`configure_dataset` has already been called. "
"Previous configuration will be overwritten!"
)
data_config = {}
data_config["prob_model"] = prob_model
if use_highly_variable:
if "highly_variable" not in adata.var:
raise ValueError("Please mark highly variable features first!")
data_config["use_highly_variable"] = True
data_config["features"] = adata.var.query("highly_variable").index.to_numpy().tolist()
else:
data_config["use_highly_variable"] = False
data_config["features"] = adata.var_names.to_numpy().tolist()
if use_layer:
if use_layer not in adata.layers:
raise ValueError("Invalid `use_layer`!")
data_config["use_layer"] = use_layer
else:
data_config["use_layer"] = None
if use_rep:
if use_rep not in adata.obsm:
raise ValueError("Invalid `use_rep`!")
data_config["use_rep"] = use_rep
data_config["rep_dim"] = adata.obsm[use_rep].shape[1]
else:
data_config["use_rep"] = None
data_config["rep_dim"] = None
if use_batch:
if use_batch not in adata.obs:
raise ValueError("Invalid `use_batch`!")
data_config["use_batch"] = use_batch
data_config["batches"] = pd.Index(
adata.obs[use_batch]
).dropna().drop_duplicates().sort_values().to_numpy() # AnnData does not support saving pd.Index in uns
else:
data_config["use_batch"] = None
data_config["batches"] = None
if use_cell_type:
if use_cell_type not in adata.obs:
raise ValueError("Invalid `use_cell_type`!")
data_config["use_cell_type"] = use_cell_type
data_config["cell_types"] = pd.Index(
adata.obs[use_cell_type]
).dropna().drop_duplicates().sort_values().to_numpy() # AnnData does not support saving pd.Index in uns
else:
data_config["use_cell_type"] = None
data_config["cell_types"] = None
if use_dsc_weight:
if use_dsc_weight not in adata.obs:
raise ValueError("Invalid `use_dsc_weight`!")
data_config["use_dsc_weight"] = use_dsc_weight
else:
data_config["use_dsc_weight"] = None
data_config["use_obs_names"] = use_obs_names
adata.uns[ANNDATA_KEY] = data_config
def load_model(fname):
fname = Path(fname)
with fname.open("rb") as f:
model = dill.load(f)
return model
def fit_SCGLUE(adatas, graph, model=SCGLUEModel, init_kws=None, compile_kws=None, fit_kws=None, balance_kws=None, infer=False):
print("start fit scglue")
init_kws = init_kws or {}
compile_kws = compile_kws or {}
fit_kws = fit_kws or {}
balance_kws = balance_kws or {}
print("Pretraining SCGLUE model...")
pretrain_init_kws = init_kws.copy()
pretrain_init_kws.update({"shared_batches":False})
pretrain_fit_kws = fit_kws.copy()
# pretrain_fit_kws.update({"align_burnin": np.inf, "safe_burnin": False})
pretrain_fit_kws.update({"safe_burnin": False})
if "directory" in pretrain_fit_kws:
pretrain_fit_kws["directory"] = os.path.join(pretrain_fit_kws["directory"], "pretrain")
pretrain = model(adatas, sorted(graph.nodes), **pretrain_init_kws)
if infer:
return pretrain
print("start compile")
pretrain.compile(**compile_kws)
print("start fit")
pretrain.fit(adatas, graph, **pretrain_fit_kws)
return pretrain
if __name__ == "__main__":
set_publication_params()
rcParams["figure.figsize"] = (4, 4)
rna = ad.read_h5ad("data/rna-pp.h5ad")
atac = ad.read_h5ad("data/atac-pp.h5ad")
guidance = nx.read_graphml("data/guidance.graphml.gz")
print("successfully read the data")
configure_dataset(rna, "NB", use_highly_variable=True, use_layer="counts", use_rep="X_pca")
configure_dataset(atac, "NB", use_highly_variable=True, use_rep="X_lsi")
print("successfully configure dataset")
guidance_hvf = guidance.subgraph(chain(rna.var.query("highly_variable").index, atac.var.query("highly_variable").index)).copy()
glue = fit_SCGLUE({"rna": rna, "atac": atac}, guidance_hvf, fit_kws={"directory": "glue"})
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。