加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
training.py 2.07 KB
一键复制 编辑 原始数据 按行查看 历史
Louison5 提交于 2023-07-12 14:35 . Update format data
# from main import set_publication_params, configure_dataset, fit_SCGLUE
import scglue
import matplotlib.pyplot as plt
import scanpy as sc
import pandas as pd
from matplotlib import rcParams
import anndata as ad
from pathlib import Path
import networkx as nx
from itertools import chain
import mindspore.context as context
import multiprocessing as mp
# mp.set_start_method('spawn')
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ANNDATA_KEY = "__scglue__"
scglue.plot.set_publication_params()
rcParams["figure.figsize"] = (4, 4)
rna = ad.read_h5ad("data/rna-pp.h5ad")
atac = ad.read_h5ad("data/atac-pp.h5ad")
guidance = nx.read_graphml("data/prior.graphml.gz")
print("successfully read the data")
scglue.models.configure_dataset(
rna, "NB", use_highly_variable=True, use_layer="counts", use_rep="X_pca")
scglue.models.configure_dataset(
atac, "NB", use_highly_variable=True, use_rep="X_lsi")
print("successfully configure dataset")
guidance_hvf = guidance.subgraph(chain(rna.var.query(
"highly_variable").index, atac.var.query("highly_variable").index)).copy()
glue = scglue.models.fit_SCGLUE({"rna": rna, "atac": atac}, guidance_hvf, fit_kws={
"directory": "glue", },) # "val_split": 1e-5, "max_epochs": 50})# ,compile_kws={"lr":5e-3})
# glue.load("glue/pretrain/checkpoint185.ckpt")
rna.obsm["X_glue"] = glue.encode_data("rna", rna)
atac.obsm["X_glue"] = glue.encode_data("atac", atac)
combined = ad.concat([rna, atac])
sc.pp.neighbors(combined, use_rep="X_glue", metric="cosine")
sc.tl.umap(combined)
sc.pl.umap(combined, color=["cell_type", "domain"], wspace=0.65)
plt.savefig("cluster.png")
graph = guidance_hvf
feature_embeddings = glue.encode_graph(graph)
feature_embeddings = pd.DataFrame(feature_embeddings, index=glue.vertices)
print(feature_embeddings.iloc[:5, :5])
rna.varm["X_glue"] = feature_embeddings.reindex(rna.var_names).to_numpy()
atac.varm["X_glue"] = feature_embeddings.reindex(atac.var_names).to_numpy()
rna.write_h5ad("rna-emb.h5ad")
atac.write_h5ad("atac-emb.h5ad")
nx.write_graphml(guidance_hvf, "guidance-hvf.graphml.gz")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化