-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathscript.py
More file actions
56 lines (48 loc) · 1.59 KB
/
script.py
File metadata and controls
56 lines (48 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import anndata as ad
from fadvi import FADVI
## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input': 'resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad',
'output': 'output.h5ad',
'n_hvg': 2000,
'n_latent_l': 30,
'n_latent_b': 30,
'n_layers': 2,
'max_epochs': 30
}
meta = {
'name': 'fadvi'
}
## VIASH END
print('Reading input files', flush=True)
adata = ad.read_h5ad(par['input'])
if par["n_hvg"]:
print(f"Select top {par['n_hvg']} high variable genes", flush=True)
idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]]
adata = adata[:, idx].copy()
print('Preprocess data', flush=True)
FADVI.setup_anndata(adata, batch_key="batch",labels_key="cell_type",
unlabeled_category='Unknown', layer='counts')
model = FADVI(adata, n_latent_l=par["n_latent_l"],
n_latent_b=par["n_latent_b"],
n_layers=par["n_layers"])
print('Train model', flush=True)
model.train(max_epochs=par["max_epochs"])
print('Generate predictions', flush=True)
# ... generate predictions ...
print("Write output AnnData to file", flush=True)
output = ad.AnnData(
obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": model.get_latent_representation(),
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
output.write_h5ad(par['output'], compression='gzip')