geneGATer.tl.learn_model#
- geneGATer.tl.learn_model(adata, gene_list, model_type, loss, epochs=1000, lr=0.05, weight_decay=0.0005, n_rings=1, tissue=None, seed=1337, heads=1, norm=False, top=10, library_key=None, gene_ids='gene_ids', cluster_key='cluster', data_name='Adata Dataset', project='my_model')#
Learn a model with the given parameters to rank input genes list by importance.
- Parameters
adata – The AnnData object.
gene_list – List of genes to rank, e.g. from getComGenes.
model_type – Type of model to use. (GAT, GAT_linear, GAT_linear_negbin, GAT_negbin)
loss – Loss function to use. (negbin, mse, poisson)
epochs (default:
1000) – Number of epochs to train the model.lr (default:
0.05) – Learning rate for the optimizer.weight_decay (default:
0.0005) – Weight decay for the optimizer.n_rings (default:
1) – Number of rings to use for the spatial graph.tissue (default:
None) – Tissue to use for the spatial graph. (None if from one donor, if multiple donors, 1 is first donor, 2 second donor, etc., 0 is all donors)))seed (default:
1337) – Seed for the random number generator.heads (default:
1) – Number of heads to use for the GAT model. (currently not supported)norm (default:
False) – Normalize the data, yes or no.top (default:
10) – Number of top k genes to plot extracted from the model.library_key (default:
None) – Key where donor names are saved in adata.obs.gene_ids (default:
'gene_ids') – Key where gene names are saved in adata.var.cluster_key (default:
'cluster') – Key where clustering should be saved in adata.obs.data_name (default:
'Adata Dataset') – Name of the dataset.project (default:
'my_model') – Name of the project, when uploaded to wandb.compare_gene_list – List of genes to compare to, e.g. top k ranked genes are marked with an asterix if they are from this list.
- Returns
- model
The trained model.
- data
The data splits used for training. (soon)