geneGATer.tl.learn_model#

geneGATer.tl.learn_model(adata, gene_list, model_type, loss, epochs=1000, lr=0.05, weight_decay=0.0005, n_rings=1, tissue=None, seed=1337, heads=1, norm=False, top=10, library_key=None, gene_ids='gene_ids', cluster_key='cluster', data_name='Adata Dataset', project='my_model')#

Learn a model with the given parameters to rank input genes list by importance.

Parameters
  • adata – The AnnData object.

  • gene_list – List of genes to rank, e.g. from getComGenes.

  • model_type – Type of model to use. (GAT, GAT_linear, GAT_linear_negbin, GAT_negbin)

  • loss – Loss function to use. (negbin, mse, poisson)

  • epochs (default: 1000) – Number of epochs to train the model.

  • lr (default: 0.05) – Learning rate for the optimizer.

  • weight_decay (default: 0.0005) – Weight decay for the optimizer.

  • n_rings (default: 1) – Number of rings to use for the spatial graph.

  • tissue (default: None) – Tissue to use for the spatial graph. (None if from one donor, if multiple donors, 1 is first donor, 2 second donor, etc., 0 is all donors)))

  • seed (default: 1337) – Seed for the random number generator.

  • heads (default: 1) – Number of heads to use for the GAT model. (currently not supported)

  • norm (default: False) – Normalize the data, yes or no.

  • top (default: 10) – Number of top k genes to plot extracted from the model.

  • library_key (default: None) – Key where donor names are saved in adata.obs.

  • gene_ids (default: 'gene_ids') – Key where gene names are saved in adata.var.

  • cluster_key (default: 'cluster') – Key where clustering should be saved in adata.obs.

  • data_name (default: 'Adata Dataset') – Name of the dataset.

  • project (default: 'my_model') – Name of the project, when uploaded to wandb.

  • compare_gene_list – List of genes to compare to, e.g. top k ranked genes are marked with an asterix if they are from this list.

Returns

model

The trained model.

data

The data splits used for training. (soon)