Logo
  • Home
  • structure
  • pre-training
  • usage example

example notebooks

  • scPRINT use case on BPH
    • table of contents:
      • 🎴 YOU NEED TO FIRST HAVE A POPULATED LAMINDB INSTANCE (see README) 🎴
    • Downloading and preprocessing
  • scPRINT use case on BPH (part 2, GN analysis)

documentation

  • model
  • tasks
  • cli
  • embedders
  • utils
scprint
  • example notebooks
  • scPRINT use case on BPH

scPRINT use case on BPH¶

In this use-case, also presented in Figure 5 of our manuscript, we perform an extensive analysis of a multi studies dataset of benign prostatic hyperplasia.

Our biological question is to check if there exist pre-cancerous cells that exhibits behaviors of mature cancer cells at this early stage of the disease.

In those cells, we want to know which genes might be implicated in cell state changes, and explore potentially novel targets in the treatment of prostate cancer and BPH.

We will start with a fresh datasets coming from the cellXgene database and representing 2 studies of BPH.

We will first explore these dataset to understand:

  • what are the cell types that are present in the data
  • what are the cell distributions (cell distributions? what are they?)
  • what sequencers were used, etc.

We also want to confirm existing target in prostate cancer through precancerous lesion analysis, and find potentially novel ones that would serve as less invasive BPH treatments than current ones.

Finally we want to know how these targets interacts and are involved in biological pathways.

We now showcase how to use scPRINT across its different functionalities to answer some of these questions.

table of contents:¶

  1. Downloading and preprocessing
  2. Embedding and annotations
  3. Annotation cleanup
  4. Clustering and differential expression
  5. Denoising and differential expression
  6. Gene network inference

In the notebook cancer_usecase_part2.ipynb you will see how to analyse cell type specific gene regulatory networks.

These notebooks can be found and run from docs/notebooks.

🎴 YOU NEED TO FIRST HAVE A POPULATED LAMINDB INSTANCE (see README) 🎴¶

In [1]:
Copied!
! lamin connect anonymous/testdb3
! lamin connect anonymous/testdb3
! calling anonymously, will miss private instances
→ connected lamindb: anonymous/testdb3
• to map a local dev directory, call: lamin settings set dev-dir .
In [1]:
Copied!
# ! lamin connect [MY_INSTANCE]

from huggingface_hub import hf_hub_download

from scprint import scPrint
from scdataloader import Preprocessor, utils
from scprint.tasks import GNInfer, Embedder, Denoiser, withknn

from scdataloader.utils import load_genes

from bengrn import BenGRN, get_sroy_gt, compute_genie3
from bengrn.base import train_classifier

from grnndata import utils as grnutils
from grnndata import read_h5ad

from anndata.utils import make_index_unique
from anndata import concat
import scanpy as sc

from matplotlib import pyplot as plt
import numpy as np
import lamindb as ln

%load_ext autoreload
%autoreload 2 

import torch

torch.set_float32_matmul_precision("medium")
# ! lamin connect [MY_INSTANCE] from huggingface_hub import hf_hub_download from scprint import scPrint from scdataloader import Preprocessor, utils from scprint.tasks import GNInfer, Embedder, Denoiser, withknn from scdataloader.utils import load_genes from bengrn import BenGRN, get_sroy_gt, compute_genie3 from bengrn.base import train_classifier from grnndata import utils as grnutils from grnndata import read_h5ad from anndata.utils import make_index_unique from anndata import concat import scanpy as sc from matplotlib import pyplot as plt import numpy as np import lamindb as ln %load_ext autoreload %autoreload 2 import torch torch.set_float32_matmul_precision("medium")
→ connected lamindb: anonymous/testdb3

Downloading and preprocessing¶

We now use lamindb to easily access cellxgene and download a dataset of normal and benign prostatic hyperplasia tissues.

data is available here https://cellxgene.cziscience.com/e/574e9f9e-f8b4-41ef-bf19-89a9964fd9c7.cxg/ .

We then use scDataloader's preprocessing method. This method is quite extensive and does a few things.. find our more about it on its documentation.

On our end we are using the preprocessor to make sure that the the gene expression that we have are raw counts and that we have enough information to use scPRINT (i.e., enough genes expressed and enough counts per cells across the dataset).

Finally, the preprocessor will also increase the size of the expression matrix to be a fixed set of genes defined by the latest version of ensemble.

In [2]:
Copied!
cx_dataset = (
    ln.DB(instance="laminlabs/cellxgene")
    .Collection.filter(version="2023-12-15", key="cellxgene-census")
    .one()
)
cx_dataset = ( ln.DB(instance="laminlabs/cellxgene") .Collection.filter(version="2023-12-15", key="cellxgene-census") .one() )
→ the database (2.2a3) is ahead of your installed lamindb package (2.1.1)
→ consider updating lamindb: pip install lamindb>=2.2
→ consider updating lamindb: pip install lamindb>=2.2
→ the database (2.2a3) is ahead of your installed lamindb package (2.1.1)
→ consider updating lamindb: pip install lamindb>=2.2
In [3]:
Copied!
prostate_adata = (
    cx_dataset.artifacts.filter(
        key="cell-census/2023-12-15/h5ads/574e9f9e-f8b4-41ef-bf19-89a9964fd9c7.h5ad"
    )
    .one()
    .load()
)
sc.pl.umap(prostate_adata)
prostate_adata = ( cx_dataset.artifacts.filter( key="cell-census/2023-12-15/h5ads/574e9f9e-f8b4-41ef-bf19-89a9964fd9c7.h5ad" ) .one() .load() ) sc.pl.umap(prostate_adata)
! run input wasn't tracked, call `ln.track()` and re-run
No description has been provided for this image
In [ ]:
Copied!
#from scdataloader.utils import populate_my_ontology
#populate_my_ontology()
#from scdataloader.utils import populate_my_ontology #populate_my_ontology()
... synchronizing df_all__cl__2025-12-17__CellType.parquet: 100.0%
... synchronizing df_vertebrates__ensembl__release-114__Organism.parquet: 100.0%
! ambiguous validation in Bionty for 1 record: 'sheep'
! ambiguous validation in Bionty for 1 record: 'sheep'
! ontology ID BFO:0000020 not found in DataFramequet: 100.0%
... synchronizing df_human__hancestro__2025-10-14__Ethnicity.parquet: 100.0%
... synchronizing df_all__efo__3.85.0__ExperimentalFactor.parquet: 100.0%
→ starting creation of 18326 ExperimentalFactor records in batches of 10000
→ starting creation of 16430 ExperimentalFactor_parents records in batches of 10000
! you are trying to create a record with name='unknown' but a record with similar name exists: 'fever of unknown origin'. Did you mean to load it?
... synchronizing df_all__uberon__2025-12-04__Tissue.parquet: 100.0%
→ starting creation of 15770 Tissue records in batches of 10000
→ starting creation of 42140 Tissue_parents records in batches of 10000
... synchronizing df_human__hsapdv__2025-01-23__DevelopmentalStage.parquet: 100.0%
... synchronizing df_mouse__mmusdv__2025-01-23__DevelopmentalStage.parquet: 100.0%
... synchronizing df_all__mondo__2026-01-06__Disease.parquet: 100.0%
→ starting creation of 30371 Disease records in batches of 10000
→ starting creation of 39555 Disease_parents records in batches of 10000
! you are trying to create a record with name='normal' but records with similar names exist: 'normocytic anemia', 'normal breast-like subtype of breast carcinoma', 'normokalemic periodic paralysis'. Did you mean to load one of them?
! you are trying to create a record with name='unknown' but records with similar names exist: 'unknown leukodystrophy', 'hemoglobin--variants for which the chain carrying the mutation 1S unknown or uncertain', 'sporadic adult-onset ataxia of unknown etiology'. Did you mean to load one of them?
... synchronizing df_mouse__ensembl__release-114__Gene.parquet: 100.0%
... synchronizing df_human__ensembl__release-114__Gene.parquet: 100.0%
In [4]:
Copied!
# preprocessing using scDataloader
prostate_adata.obs.drop(columns="is_primary_data", inplace=True)
preprocessor = Preprocessor(do_postp=False)
prostate_adata = preprocessor(prostate_adata)
# preprocessing using scDataloader prostate_adata.obs.drop(columns="is_primary_data", inplace=True) preprocessor = Preprocessor(do_postp=False) prostate_adata = preprocessor(prostate_adata)
Dropping layers:  KeysView(Layers with keys: )
checking raw counts
removed 0 non primary cells, 83451 renamining
filtered out 0 cells, 83451 renamining
Removed 586 genes not known to the ontology
Removed 0 duplicate genes
Added 58936 genes in the ontology but not present in the dataset
starting QC
Seeing 13064 outliers (15.65% of total dataset):
done
AnnData object with n_obs × n_vars = 83451 × 85720
    obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier'
    var: 'feature_biotype', 'feature_is_filtered', 'feature_name', 'feature_reference', 'uid', 'symbol', 'biotype', 'is_locked', 'branch_id', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'unseen_genes'

loading the model:¶

We now start to load a large version of scPRINT from a specific checkpoint. Please download the checkpoints following the instructions in the README.

Note that for any other notebooks presented in this repo, there might not be such a care on how we load the model. Refer to this part bellow to have a good understanding of the different parametrization of loading pretrained model weights.

In [5]:
Copied!
model_checkpoint_file = hf_hub_download(
    repo_id="jkobject/scPRINT", filename=f"medium-v1.5.ckpt"
)
# model_checkpoint_file = "../data/temp/o2uniqsx/epoch=18-step=133000.ckpt"
# "/pasteur/zeus/projets/p02/ml4ig_hot/Users/jkalfon/scprint_scale/o2uniqsx/checkpoints/last.ckpt"
#'../data/temp/o2uniqsx/epoch=18-step=133000.ckpt',
# "../../xzp23r4p.ckpt"
model_checkpoint_file
model_checkpoint_file = hf_hub_download( repo_id="jkobject/scPRINT", filename=f"medium-v1.5.ckpt" ) # model_checkpoint_file = "../data/temp/o2uniqsx/epoch=18-step=133000.ckpt" # "/pasteur/zeus/projets/p02/ml4ig_hot/Users/jkalfon/scprint_scale/o2uniqsx/checkpoints/last.ckpt" #'../data/temp/o2uniqsx/epoch=18-step=133000.ckpt', # "../../xzp23r4p.ckpt" model_checkpoint_file
Out[5]:
'/pasteur/appa/homes/jkalfon/.cache/huggingface/hub/models--jkobject--scPRINT/snapshots/d661d779e3614c77305e2fbe0a19312eda2794b8/medium-v1.5.ckpt'
In [6]:
Copied!
# make sure that you check if you have a GPU with flashattention or not (see README)
try:
    m = torch.load(model_checkpoint_file)
# if not use this instead since the model weights are by default mapped to GPU types
except RuntimeError:
    m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))

# again here by default the model was trained with flash attention, so if you do not have a GPU you will need to replace the attention mechanism with regular attention
transformer = "flash" if torch.cuda.is_available() else "normal"

# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer
if "prenorm" in m["hyper_parameters"]:
    m["hyper_parameters"].pop("prenorm")
    torch.save(m, model_checkpoint_file)
if "label_counts" in m["hyper_parameters"]:
    # you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model
    model = scPrint.load_from_checkpoint(
        model_checkpoint_file,
        precpt_gene_emb=None,
        classes=m["hyper_parameters"]["label_counts"],
        transformer=transformer,
    )
else:
    model = scPrint.load_from_checkpoint(
        model_checkpoint_file, precpt_gene_emb=None, transformer=transformer
    )

# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology
missing = set(model.genes) - set(load_genes(model.organisms).index)
if len(missing) > 0:
    print(
        "Warning: some genes missmatch exist between model and ontology: solving...",
    )
    model._rm_genes(missing)

# again if not on GPU you need to convert the model to float64
if not torch.cuda.is_available():
    model = model.to(torch.float32)

# you can perform your inference on float16 if you have a GPU, otherwise use float64
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
# make sure that you check if you have a GPU with flashattention or not (see README) try: m = torch.load(model_checkpoint_file) # if not use this instead since the model weights are by default mapped to GPU types except RuntimeError: m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) # again here by default the model was trained with flash attention, so if you do not have a GPU you will need to replace the attention mechanism with regular attention transformer = "flash" if torch.cuda.is_available() else "normal" # both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer if "prenorm" in m["hyper_parameters"]: m["hyper_parameters"].pop("prenorm") torch.save(m, model_checkpoint_file) if "label_counts" in m["hyper_parameters"]: # you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model model = scPrint.load_from_checkpoint( model_checkpoint_file, precpt_gene_emb=None, classes=m["hyper_parameters"]["label_counts"], transformer=transformer, ) else: model = scPrint.load_from_checkpoint( model_checkpoint_file, precpt_gene_emb=None, transformer=transformer ) # this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology missing = set(model.genes) - set(load_genes(model.organisms).index) if len(missing) > 0: print( "Warning: some genes missmatch exist between model and ontology: solving...", ) model._rm_genes(missing) # again if not on GPU you need to convert the model to float64 if not torch.cuda.is_available(): model = model.to(torch.float32) # you can perform your inference on float16 if you have a GPU, otherwise use float64 dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device model = model.to("cuda" if torch.cuda.is_available() else "cpu")
RuntimeError caught: scPrint is not attached to a `Trainer`.
Warning: some genes missmatch exist between model and ontology: solving...

Embedding and annotations¶

We will then use out Embedder class to embed the data and annotate the cells. These classes are how we parametrize and access the different functions of scPRINT. Find out more about its parameters in our documentation.

In [7]:
Copied!
embedder = Embedder(
    # can work on random genes or most variables etc..
    how="random expr",
    # number of genes to use
    max_len=4000,
    # the model is trained on a minibatch of 64 cells but you can choose whatever
    batch_size=32,
    # for the dataloading
    num_workers=8,
    # we will only use the cell type embedding here.
    pred_embedding=["cell_type_ontology_term_id"],
    # , "disease_ontology_term_id"],
    # we will now
    save_every=40_000,
    dtype=dtype,
)
embedder = Embedder( # can work on random genes or most variables etc.. how="random expr", # number of genes to use max_len=4000, # the model is trained on a minibatch of 64 cells but you can choose whatever batch_size=32, # for the dataloading num_workers=8, # we will only use the cell type embedding here. pred_embedding=["cell_type_ontology_term_id"], # , "disease_ontology_term_id"], # we will now save_every=40_000, dtype=dtype, )
In [9]:
Copied!
# create the embedding
prostate_adata, metrics = embedder(model, prostate_adata, cache=False)
# create the embedding prostate_adata, metrics = embedder(model, prostate_adata, cache=False)
  0%|          | 0/2608 [00:00<?, ?it/s]/local/scratch/tmp/tmptntrr2uq/main.c:4:10: fatal error: Python.h: No such file or directory
 #include <Python.h>
          ^~~~~~~~~~
compilation terminated.
  0%|          | 0/2608 [00:03<?, ?it/s]
---------------------------------------------------------------------------
CalledProcessError                        Traceback (most recent call last)
Cell In[9], line 2
      1 # create the embedding
----> 2 prostate_adata, metrics = embedder(model, prostate_adata, cache=False)

File ~/scPRINT/scprint/tasks/cell_emb.py:142, in Embedder.__call__(self, model, adata, cache)
    136     for batch in tqdm(dataloader):
    137         gene_pos, expression, depth = (
    138             batch["genes"].to(device),
    139             batch["x"].to(device),
    140             batch["depth"].to(device),
    141         )
--> 142         model._predict(
    143             gene_pos,
    144             expression,
    145             depth,
    146             predict_mode="none",
    147             pred_embedding=self.pred_embedding,
    148             get_gene_emb=self.get_gene_emb,
    149             max_size_in_mem=self.save_every,
    150         )
    151         torch.cuda.empty_cache()
    152 model.log_adata(name="predict_part_" + str(model.counter))

File ~/scPRINT/scprint/model/model.py:1487, in scPrint._predict(self, gene_pos, expression, depth, predict_mode, pred_embedding, get_attention_layer, depth_mult, keep_output, max_size_in_mem, get_gene_emb, metacell_token)
   1467 """
   1468 @see predict_step will save output of predict in multiple self variables
   1469 
   (...)   1484 
   1485 """
   1486 if predict_mode == "none":
-> 1487     output = self.forward(
   1488         gene_pos,
   1489         expression,
   1490         depth_mult=expression.sum(1),
   1491         req_depth=depth,
   1492         get_attention_layer=get_attention_layer,
   1493         do_class=True,
   1494         get_gene_emb=get_gene_emb,
   1495         metacell_token=metacell_token,
   1496     )
   1497     if len(get_attention_layer) > 0:
   1498         self.attn.add([i[:, :, :2, :] for i in output[1]], gene_pos)

File ~/scPRINT/scprint/model/model.py:749, in scPrint.forward(self, gene_pos, expression, mask, req_depth, timepoint, get_gene_emb, metacell_token, depth_mult, do_sample, do_mvc, do_class, get_attention_layer)
    747     cell_encoding = encoding[:, : self.cell_embs_count, :]
    748     encoding = encoding[:, self.cell_embs_count :, :]
--> 749 transformer_output = self.transformer(
    750     encoding,
    751     return_qkv=get_attention_layer,
    752     bias=bias if self.attn_bias != "none" else None,
    753     bias_layer=list(range(self.nlayers - 1)),
    754 )
    755 if len(get_attention_layer) > 0:
    756     transformer_output, qkvs = transformer_output

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/flashformer.py:191, in FlashTransformer.forward(self, hidden_states, x_kv, mask, return_qkv, bias, bias_layer)
    189 if self.cross_attn and x_kv is None and i % 2 == 0:
    190     continue
--> 191 hidden_states = block(
    192     hidden_states,
    193     x_kv if self.cross_attn and i % 2 == 0 else None,
    194     residual,
    195     return_qkv=(i in return_qkv),
    196     bias=bias if i in bias_layer else None,
    197 )
    198 if i in return_qkv:
    199     qkvs.append(hidden_states[-1])

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/block.py:198, in Block.forward(self, hidden_states, x_kv, residual, bias, src_mask, is_causal, src_key_padding_mask, mixer_subset, mixer_kwargs, return_qkv)
    196 if mixer_subset is not None:
    197     mixer_kwargs["mixer_subset"] = mixer_subset
--> 198 hidden_states = self.mixer(
    199     hidden_states,
    200     x_kv=x_kv,
    201     return_qkv=return_qkv,
    202     bias=bias,
    203     **mixer_kwargs,
    204 )
    205 if return_qkv:
    206     qkv = hidden_states[1]

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/mha.py:680, in MHA.forward(self, x, x_kv, key_padding_mask, cu_seqlens, max_seqlen, mixer_subset, inference_params, return_qkv, **kwargs)
    678 if inference_params is None:
    679     if not self.checkpointing:
--> 680         context = self.inner_attn(qkv, **kwargs)
    681     else:
    682         context = torch.utils.checkpoint.checkpoint(
    683             self.inner_attn, qkv, **kwargs
    684         )

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/mha.py:104, in FlashSelfAttention.forward(self, qkv, causal, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, bias, **kwargs)
     94         qkv = qkv.contiguous()
     95     # return triton_attention(
     96     #    qkv[:, :, 0],
     97     #    qkv[:, :, 1],
   (...)    102     # )
    103 else:
--> 104     return flash_attn_qkvpacked_func(
    105         qkv,
    106         bias,
    107         # self.drop.p if self.training else 0.0,
    108         causal,
    109         self.softmax_scale,
    110     )

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs)
    550 if not torch._C._are_functorch_transforms_active():
    551     # See NOTE: [functorch vjp and autograd interaction]
    552     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553     return super().apply(*args, **kwargs)  # type: ignore[misc]
    555 if not is_setup_ctx_defined:
    556     raise RuntimeError(
    557         "In order to use an autograd.Function with functorch transforms "
    558         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    559         "staticmethod. For more details, please see "
    560         "https://pytorch.org/docs/master/notes/extending.func.html"
    561     )

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/flashattention.py:1190, in FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale)
   1188 if qkv.stride(-1) != 1:
   1189     qkv = qkv.contiguous()
-> 1190 o, lse, ctx.softmax_scale = _flash_attn_forward(
   1191     qkv[:, :, 0],
   1192     qkv[:, :, 1],
   1193     qkv[:, :, 2],
   1194     bias=bias,
   1195     causal=causal,
   1196     softmax_scale=softmax_scale,
   1197 )
   1198 ctx.save_for_backward(qkv, o, lse, bias)
   1199 ctx.causal = causal

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/flashattention.py:999, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
    997 num_warps = 4 if d <= 64 else 8
    998 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
--> 999 _fwd_kernel[grid](
   1000     q,
   1001     k,
   1002     v,
   1003     bias,
   1004     o,
   1005     lse,
   1006     tmp,
   1007     softmax_scale,
   1008     q.stride(0),
   1009     q.stride(2),
   1010     q.stride(1),
   1011     k.stride(0),
   1012     k.stride(2),
   1013     k.stride(1),
   1014     v.stride(0),
   1015     v.stride(2),
   1016     v.stride(1),
   1017     *bias_strides,
   1018     o.stride(0),
   1019     o.stride(2),
   1020     o.stride(1),
   1021     nheads,
   1022     seqlen_q,
   1023     seqlen_k,
   1024     seqlen_q_rounded,
   1025     d,
   1026     seqlen_q // 32,
   1027     seqlen_k // 32,  # key for triton cache (limit number of compilations)
   1028     # Can't use kwargs here because triton autotune expects key to be args, not kwargs
   1029     # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
   1030     bias_type,
   1031     causal,
   1032     BLOCK_HEADDIM,
   1033     BLOCK_M=BLOCK,
   1034     BLOCK_N=BLOCK,
   1035     num_warps=num_warps,
   1036     num_stages=1,
   1037 )
   1038 return o, lse, softmax_scale

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/runtime/autotuner.py:305, in Heuristics.run(self, *args, **kwargs)
    303 for v, heur in self.values.items():
    304     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 305 return self.fn.run(*args, **kwargs)

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/runtime/jit.py:532, in JITFunction.run(self, *args, **kwargs)
    517     if self._call_hook(
    518             key,
    519             signature,
   (...)    528             configs,
    529     ):
    530         return None
--> 532     self.cache[device][key] = compile(
    533         self,
    534         signature=signature,
    535         device=device,
    536         constants=constants,
    537         num_warps=num_warps,
    538         num_ctas=num_ctas,
    539         num_stages=num_stages,
    540         enable_warp_specialization=enable_warp_specialization,
    541         enable_fp_fusion=enable_fp_fusion,
    542         extern_libs=extern_libs,
    543         configs=configs,
    544         debug=self.debug,
    545         device_type=device_type,
    546     )
    548 bin = self.cache[device][key]
    549 if not warmup:

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/compiler/compiler.py:614, in compile(fn, **kwargs)
    612 # cache manager
    613 if is_cuda:
--> 614     so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
    615 else:
    616     so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/compiler/make_launcher.py:37, in make_stub(name, signature, constants, ids, **kwargs)
     35 with open(src_path, "w") as f:
     36     f.write(src)
---> 37 so = _build(name, src_path, tmpdir)
     38 with open(so, "rb") as f:
     39     return so_cache_manager.put(f.read(), so_name, binary=True)

File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/common/build.py:106, in _build(name, src, srcdir)
    101     cc_cmd = [
    102         cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
    103         "-o", so
    104     ]
    105     cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
--> 106     ret = subprocess.check_call(cc_cmd)
    108 if ret == 0:
    109     return so

File /usr/lib64/python3.12/subprocess.py:413, in check_call(*popenargs, **kwargs)
    411     if cmd is None:
    412         cmd = popenargs[0]
--> 413     raise CalledProcessError(retcode, cmd)
    414 return 0

CalledProcessError: Command '['/usr/bin/gcc', '/local/scratch/tmp/tmptntrr2uq/main.c', '-O3', '-I/pasteur/appa/homes/jkalfon/scPRINT/scprint1/lib64/python3.12/site-packages/triton/common/../third_party/cuda/include', '-I/usr/include/python3.12', '-I/local/scratch/tmp/tmptntrr2uq', '-shared', '-fPIC', '-lcuda', '-o', '/local/scratch/tmp/tmptntrr2uq/_fwd_kernel.cpython-312-x86_64-linux-gnu.so', '-L/lib64', '-L/lib64']' returned non-zero exit status 1.
In [8]:
Copied!
prostate_adata
prostate_adata
Out[8]:
AnnData object with n_obs × n_vars = 83451 × 70704
    obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'synonyms', 'description', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'unseen_genes'

Annotation cleanup¶

scPRINT generates predictions over hundreds of possible labels for each cell.

It is often advised to "cleanup" the predictions, e.g. making sure to remove low frequency cells and misslabellings.

Here, we use the most straightforward approach which is to remove any annotations that appear a small number of times.

A better approach would be doing majority voting over cell clusters as it would aggregate and smoothout the predictions over multiple cells. it would also remove most of the low frequency mistakes in the predictions.

We will also have a look at the embeddings of scPRINT by plotting its UMAP visualization.

In [ ]:
Copied!
# cleaning up the cell types
prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"] = prostate_adata.obs[
    "conv_pred_cell_type_ontology_term_id"
].astype(str)
prostate_adata.obs.loc[
    ~prostate_adata.obs["conv_pred_cell_type_ontology_term_id"].isin(
        [
            k
            for k, v in prostate_adata.obs["conv_pred_cell_type_ontology_term_id"]
            .value_counts()
            .items()
            if v > 400
        ]
    ),
    "cleaned_pred_cell_type_ontology_term_id",
] = "other"
prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"].value_counts().plot.pie()
# cleaning up the cell types prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"] = prostate_adata.obs[ "conv_pred_cell_type_ontology_term_id" ].astype(str) prostate_adata.obs.loc[ ~prostate_adata.obs["conv_pred_cell_type_ontology_term_id"].isin( [ k for k, v in prostate_adata.obs["conv_pred_cell_type_ontology_term_id"] .value_counts() .items() if v > 400 ] ), "cleaned_pred_cell_type_ontology_term_id", ] = "other" prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"].value_counts().plot.pie()
In [ ]:
Copied!
# cleaning up the cell types
prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"] = prostate_adata.obs[
    "conv_pred_cell_type_ontology_term_id"
].astype(str)
prostate_adata.obs.loc[
    ~prostate_adata.obs["conv_pred_cell_type_ontology_term_id"].isin(
        [
            k
            for k, v in prostate_adata.obs["conv_pred_cell_type_ontology_term_id"]
            .value_counts()
            .items()
            if v > 400
        ]
    ),
    "cleaned_pred_cell_type_ontology_term_id",
] = "other"
prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"].value_counts().plot.pie()
# cleaning up the cell types prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"] = prostate_adata.obs[ "conv_pred_cell_type_ontology_term_id" ].astype(str) prostate_adata.obs.loc[ ~prostate_adata.obs["conv_pred_cell_type_ontology_term_id"].isin( [ k for k, v in prostate_adata.obs["conv_pred_cell_type_ontology_term_id"] .value_counts() .items() if v > 400 ] ), "cleaned_pred_cell_type_ontology_term_id", ] = "other" prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"].value_counts().plot.pie()
Out[ ]:
<Axes: ylabel='count'>
No description has been provided for this image
In [ ]:
Copied!
# cleaning up the diseases
prostate_adata.obs["cleaned_pred_disease_ontology_term_id"] = prostate_adata.obs[
    "conv_pred_disease_ontology_term_id"
].astype(str)
prostate_adata.obs.loc[
    ~prostate_adata.obs["conv_pred_disease_ontology_term_id"].isin(
        [
            k
            for k, v in prostate_adata.obs["conv_pred_disease_ontology_term_id"]
            .value_counts()
            .items()
            if v > 1000
        ]
    ),
    "cleaned_pred_disease_ontology_term_id",
] = "other"
prostate_adata.obs["cleaned_pred_disease_ontology_term_id"].value_counts().plot.pie()
# cleaning up the diseases prostate_adata.obs["cleaned_pred_disease_ontology_term_id"] = prostate_adata.obs[ "conv_pred_disease_ontology_term_id" ].astype(str) prostate_adata.obs.loc[ ~prostate_adata.obs["conv_pred_disease_ontology_term_id"].isin( [ k for k, v in prostate_adata.obs["conv_pred_disease_ontology_term_id"] .value_counts() .items() if v > 1000 ] ), "cleaned_pred_disease_ontology_term_id", ] = "other" prostate_adata.obs["cleaned_pred_disease_ontology_term_id"].value_counts().plot.pie()
Out[ ]:
<Axes: ylabel='count'>
No description has been provided for this image
In [13]:
Copied!
prostate_adata
prostate_adata
Out[13]:
AnnData object with n_obs × n_vars = 83451 × 70704
    obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id', 'sprint_leiden', 'cleaned_pred_cell_type_ontology_term_id', 'cleaned_pred_disease_ontology_term_id'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'synonyms', 'description', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'unseen_genes'
    obsm: 'X_scprint_umap', 'scprint_emb'
In [14]:
Copied!
# do louvain mutliple times
sc.pp.neighbors(prostate_adata, n_neighbors=10, use_rep="scprint_emb")
# using multiple resolutions can help spotting smaller clusters
sc.tl.louvain(prostate_adata, resolution=0.5, key_added="louvain_0.5")
sc.tl.louvain(prostate_adata, resolution=1.0, key_added="louvain_1.0")
# do umap
sc.tl.umap(prostate_adata, n_components=2, random_state=42)
# plot
# do louvain mutliple times sc.pp.neighbors(prostate_adata, n_neighbors=10, use_rep="scprint_emb") # using multiple resolutions can help spotting smaller clusters sc.tl.louvain(prostate_adata, resolution=0.5, key_added="louvain_0.5") sc.tl.louvain(prostate_adata, resolution=1.0, key_added="louvain_1.0") # do umap sc.tl.umap(prostate_adata, n_components=2, random_state=42) # plot
In [16]:
Copied!
prostate_adata
prostate_adata
Out[16]:
AnnData object with n_obs × n_vars = 83451 × 70704
    obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id', 'sprint_leiden', 'cleaned_pred_cell_type_ontology_term_id', 'cleaned_pred_disease_ontology_term_id', 'louvain_0.5', 'louvain_1.0'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'synonyms', 'description', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'unseen_genes', 'neighbors', 'louvain_0.5', 'louvain_1.0', 'umap', 'cleaned_pred_cell_type_ontology_term_id_colors', 'cleaned_pred_disease_ontology_term_id_colors', 'assay_colors', 'disease_colors', 'development_stage_colors'
    obsm: 'X_scprint_umap', 'scprint_emb', 'X_umap'
    obsp: 'distances', 'connectivities'
In [ ]:
Copied!
sc.pl.embedding(
    prostate_adata, basis="X_umap", color=["cleaned_pred_cell_type_ontology_term_id"]
)
sc.pl.embedding(
    prostate_adata, basis="X_umap", color=["cleaned_pred_disease_ontology_term_id"]
)
sc.pl.embedding(prostate_adata, basis="X_umap", color=["assay"])
sc.pl.embedding(prostate_adata, basis="X_umap", color=["disease"])
sc.pl.embedding(prostate_adata, basis="X_umap", color=["development_stage"])
sc.pl.embedding( prostate_adata, basis="X_umap", color=["cleaned_pred_cell_type_ontology_term_id"] ) sc.pl.embedding( prostate_adata, basis="X_umap", color=["cleaned_pred_disease_ontology_term_id"] ) sc.pl.embedding(prostate_adata, basis="X_umap", color=["assay"]) sc.pl.embedding(prostate_adata, basis="X_umap", color=["disease"]) sc.pl.embedding(prostate_adata, basis="X_umap", color=["development_stage"])
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [19]:
Copied!
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id.value_counts().head(20)
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id.value_counts().head(20)
Out[19]:
cleaned_pred_cell_type_ontology_term_id
basal epithelial cell of prostatic duct                             25801
urethra urothelial cell                                             11050
prostate gland microvascular endothelial cell                       10321
CD1c-positive myeloid dendritic cell                                 6427
other                                                                4567
fibroblast of connective tissue of nonglandular part of prostate     3995
lung pericyte                                                        3045
luminal cell of prostate epithelium                                  2960
IgG-negative class switched memory B cell                            2701
effector CD8-positive, alpha-beta T cell                             2499
acinar cell of salivary gland                                        1689
basophil                                                             1340
glomerular capillary endothelial cell                                1066
smooth muscle cell of prostate                                       1055
mast cell                                                             993
fibroblast of connective tissue of glandular part of prostate         858
effector memory CD8-positive, alpha-beta T cell                       775
CD4-positive, alpha-beta thymocyte                                    698
smooth muscle cell of the pulmonary artery                            611
effector CD4-positive, alpha-beta T cell                              506
Name: count, dtype: int64
In [20]:
Copied!
# we save for next time
prostate_adata.write_h5ad("../../data/temp/prostate_combined_latest.h5ad")
# we save for next time prostate_adata.write_h5ad("../../data/temp/prostate_combined_latest.h5ad")
In [6]:
Copied!
prostate_adata = sc.read_h5ad("../../data/temp/prostate_combined_latest.h5ad")
prostate_adata = sc.read_h5ad("../../data/temp/prostate_combined_latest.h5ad")

Clustering and differential expression¶

We will now cluster using the louvain algorithm on a kNN graph.

Once we detect a cluster of interest we will perform differential expression analysis on it. Taking as example some B-cell clusters, we will use scanpy's implementation of rank_gene_groups for our differential expression

In [ ]:
Copied!
# check clusters
sc.pl.embedding(
    prostate_adata,
    basis="X_umap",
    color="louvain_0.5",
    show=False,
    legend_loc="on data",
)
# check clusters sc.pl.embedding( prostate_adata, basis="X_umap", color="louvain_0.5", show=False, legend_loc="on data", )
Out[ ]:
<Axes: title={'center': 'louvain_0.5'}, xlabel='X_umap1', ylabel='X_umap2'>
No description has been provided for this image
In [ ]:
Copied!
# check cluster 9
i = 3
loc = prostate_adata.obs["louvain_0.5"] == str(i)
(
    prostate_adata.obs[loc].conv_pred_disease_ontology_term_id.value_counts().head(2),
    prostate_adata.obs[loc].conv_pred_cell_type_ontology_term_id.value_counts().head(),
)
# check cluster 9 i = 3 loc = prostate_adata.obs["louvain_0.5"] == str(i) ( prostate_adata.obs[loc].conv_pred_disease_ontology_term_id.value_counts().head(2), prostate_adata.obs[loc].conv_pred_cell_type_ontology_term_id.value_counts().head(), )
Out[ ]:
(conv_pred_disease_ontology_term_id
 normal                          3359
 benign prostatic hyperplasia    1441
 Name: count, dtype: int64,
 conv_pred_cell_type_ontology_term_id
 basal epithelial cell of prostatic duct    2785
 urethra urothelial cell                    1487
 luminal cell of prostate epithelium         260
 acinar cell of salivary gland               198
 nasal mucosa goblet cell                     26
 Name: count, dtype: int64)
In [ ]:
Copied!
# check cluster 11
i = 26
loc = prostate_adata.obs["louvain_0.5"] == str(i)
(
    prostate_adata.obs[loc]
    .cleaned_pred_disease_ontology_term_id.value_counts()
    .head(2),
    prostate_adata.obs[loc]
    .cleaned_pred_cell_type_ontology_term_id.value_counts()
    .head(5),
)
# check cluster 11 i = 26 loc = prostate_adata.obs["louvain_0.5"] == str(i) ( prostate_adata.obs[loc] .cleaned_pred_disease_ontology_term_id.value_counts() .head(2), prostate_adata.obs[loc] .cleaned_pred_cell_type_ontology_term_id.value_counts() .head(5), )
Out[ ]:
(cleaned_pred_disease_ontology_term_id
 benign prostatic hyperplasia    856
 other                             1
 Name: count, dtype: int64,
 cleaned_pred_cell_type_ontology_term_id
 IgG-negative class switched memory B cell    777
 basal epithelial cell of prostatic duct       49
 other                                         25
 effector CD8-positive, alpha-beta T cell       3
 glomerular capillary endothelial cell          2
 Name: count, dtype: int64)
In [ ]:
Copied!
prostate_adata.obs[
    prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id
    == "IgG-negative class switched memory B cell"
]["louvain_0.5"].value_counts()
prostate_adata.obs[ prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id == "IgG-negative class switched memory B cell" ]["louvain_0.5"].value_counts()
Out[ ]:
louvain_0.5
21    1202
26     777
27     708
8        5
10       4
2        2
9        1
15       1
12       1
3        0
0        0
1        0
11       0
5        0
13       0
14       0
16       0
6        0
7        0
4        0
19       0
18       0
17       0
20       0
23       0
22       0
25       0
24       0
28       0
29       0
30       0
31       0
32       0
33       0
34       0
35       0
Name: count, dtype: int64
In [ ]:
Copied!
# We have find a nice IgG-negative class switched memory B cell cluster. let's use it and define a clean annotation for a plot
loc = loc & (
    prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id
    == "IgG-negative class switched memory B cell"
)
prostate_adata.obs[loc].cleaned_pred_disease_ontology_term_id.value_counts().head(2)
# We have find a nice IgG-negative class switched memory B cell cluster. let's use it and define a clean annotation for a plot loc = loc & ( prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id == "IgG-negative class switched memory B cell" ) prostate_adata.obs[loc].cleaned_pred_disease_ontology_term_id.value_counts().head(2)
Out[ ]:
cleaned_pred_disease_ontology_term_id
benign prostatic hyperplasia    776
other                             1
Name: count, dtype: int64
In [ ]:
Copied!
# making a "focus" annotation for the B-cell to generate a nice plot of the B-cell cluster only
prostate_adata.obs["focus"] = "other"
prostate_adata.obs.loc[loc, "focus"] = "memory B cell"
prostate_adata.obs.loc[
    loc
    & (
        prostate_adata.obs["cleaned_pred_disease_ontology_term_id"]
        == "benign prostatic hyperplasia"
    ),
    "focus",
] = "BPH associated memory B cell"
prostate_adata.obs["focus"].value_counts()
# making a "focus" annotation for the B-cell to generate a nice plot of the B-cell cluster only prostate_adata.obs["focus"] = "other" prostate_adata.obs.loc[loc, "focus"] = "memory B cell" prostate_adata.obs.loc[ loc & ( prostate_adata.obs["cleaned_pred_disease_ontology_term_id"] == "benign prostatic hyperplasia" ), "focus", ] = "BPH associated memory B cell" prostate_adata.obs["focus"].value_counts()
Out[ ]:
focus
other                           82674
BPH associated memory B cell      776
memory B cell                       1
Name: count, dtype: int64
In [ ]:
Copied!
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import seaborn as sns

color = sns.color_palette()[1]
fig, ax = plt.subplots(figsize=(2, 2))
rect = patches.Rectangle((0, 0), 1, 1, facecolor=color)
ax.add_patch(rect)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")
plt.show()
import matplotlib.pyplot as plt import matplotlib.patches as patches import seaborn as sns color = sns.color_palette()[1] fig, ax = plt.subplots(figsize=(2, 2)) rect = patches.Rectangle((0, 0), 1, 1, facecolor=color) ax.add_patch(rect) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis("off") plt.show()
No description has been provided for this image
In [ ]:
Copied!
# looking at the B-cell cluster. We can see some normal and BPH-associated memory B-cells
sc.pl.embedding(
    prostate_adata[
        (prostate_adata.obs["louvain_1.0"] == str(i))
        & (prostate_adata.obsm["scprint_umap"][:, 0] > 4)
    ],
    basis="scprint_umap",
    color="focus",
    show=False,
    size=8,
    title="Switched memory B-cell cluster",
    legend_loc="right margin",
)
# looking at the B-cell cluster. We can see some normal and BPH-associated memory B-cells sc.pl.embedding( prostate_adata[ (prostate_adata.obs["louvain_1.0"] == str(i)) & (prostate_adata.obsm["scprint_umap"][:, 0] > 4) ], basis="scprint_umap", color="focus", show=False, size=8, title="Switched memory B-cell cluster", legend_loc="right margin", )
... storing 'focus' as categorical
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning
  color_vector = pd.Categorical(values.map(color_map))
Out[ ]:
<Axes: title={'center': 'focus'}, xlabel='scprint_umap_rot1', ylabel='scprint_umap_rot2'>
No description has been provided for this image
In [ ]:
Copied!
# now the diff expression between B-cells and the rest
sc.tl.rank_genes_groups(
    prostate_adata,
    groupby="cleaned_pred_cell_type_ontology_term_id",
    groups=["IgG-negative class switched memory B cell"],
    reference="other",
    method="t-test",
)
# Plot the most differentially expressed genes
sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol")
# super strong B cell markers
# now the diff expression between B-cells and the rest sc.tl.rank_genes_groups( prostate_adata, groupby="cleaned_pred_cell_type_ontology_term_id", groups=["IgG-negative class switched memory B cell"], reference="other", method="t-test", ) # Plot the most differentially expressed genes sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol") # super strong B cell markers
... storing 'focus' as categorical
WARNING: It seems you use rank_genes_groups on the raw count data. Please logarithmize your data before calling rank_genes_groups.
No description has been provided for this image

Denoising and differential expression¶

What we found out from our previous analysis is that there is not a lot of normal (i.e. healthy) B-cells in our cluster, most of them are BPH associated. In this case, if we wanted to compare BPH B-cells to normal B-cells we might be very underpowered...

Instead of going to look for some other dataset, let's use scPRINT to increase the depth of the expression profile of the cells, virtually adding more signal to our dataset.

We will use the Denoiser class (see more about the class in our documentation) in a similar way Trainer is used in pytorch lightning to denoise the expression profile of the cells.

We will then show the results of differential expression analysis before and after denoising.

In [ ]:
Copied!
# here we compare memory B-cell in BPH to normal memory B cells before denoising
sc.tl.rank_genes_groups(
    prostate_adata,
    groupby="focus",
    groups=["BPH associated memory B cell"],
    reference="memory B cell",
    method="t-test",
)
# Plot the most differentially expressed genes
sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol")
# super strong B cell markers
# here we compare memory B-cell in BPH to normal memory B cells before denoising sc.tl.rank_genes_groups( prostate_adata, groupby="focus", groups=["BPH associated memory B cell"], reference="memory B cell", method="t-test", ) # Plot the most differentially expressed genes sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol") # super strong B cell markers
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/pandas/core/indexes/base.py:3805, in Index.get_loc(self, key)
   3804 try:
-> 3805     return self._engine.get_loc(casted_key)
   3806 except KeyError as err:

File index.pyx:167, in pandas._libs.index.IndexEngine.get_loc()

File index.pyx:196, in pandas._libs.index.IndexEngine.get_loc()

File pandas/_libs/hashtable_class_helper.pxi:7081, in pandas._libs.hashtable.PyObjectHashTable.get_item()

File pandas/_libs/hashtable_class_helper.pxi:7089, in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 'focus'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Cell In[11], line 2
      1 # here we compare memory B-cell in BPH to normal memory B cells before denoising
----> 2 sc.tl.rank_genes_groups(prostate_adata, groupby='focus', groups=['BPH associated memory B cell'], reference='memory B cell', method='t-test')
      3 # Plot the most differentially expressed genes
      4 sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols='symbol')

File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/legacy_api_wrap/__init__.py:80, in legacy_api.<locals>.wrapper.<locals>.fn_compatible(*args_all, **kw)
     77 @wraps(fn)
     78 def fn_compatible(*args_all: P.args, **kw: P.kwargs) -> R:
     79     if len(args_all) <= n_positional:
---> 80         return fn(*args_all, **kw)
     82     args_pos: P.args
     83     args_pos, args_rest = args_all[:n_positional], args_all[n_positional:]

File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/tools/_rank_genes_groups.py:637, in rank_genes_groups(adata, groupby, mask_var, use_raw, groups, reference, n_genes, rankby_abs, pts, key_added, copy, method, corr_method, tie_correct, layer, **kwds)
    635     if reference != "rest" and reference not in set(groups_order):
    636         groups_order += [reference]
--> 637 if reference != "rest" and reference not in adata.obs[groupby].cat.categories:
    638     cats = adata.obs[groupby].cat.categories.tolist()
    639     raise ValueError(
    640         f"reference = {reference} needs to be one of groupby = {cats}."
    641     )

File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/pandas/core/frame.py:4102, in DataFrame.__getitem__(self, key)
   4100 if self.columns.nlevels > 1:
   4101     return self._getitem_multilevel(key)
-> 4102 indexer = self.columns.get_loc(key)
   4103 if is_integer(indexer):
   4104     indexer = [indexer]

File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/pandas/core/indexes/base.py:3812, in Index.get_loc(self, key)
   3807     if isinstance(casted_key, slice) or (
   3808         isinstance(casted_key, abc.Iterable)
   3809         and any(isinstance(x, slice) for x in casted_key)
   3810     ):
   3811         raise InvalidIndexError(key)
-> 3812     raise KeyError(key) from err
   3813 except TypeError:
   3814     # If we have a listlike key, _check_indexing_error will raise
   3815     #  InvalidIndexError. Otherwise we fall through and re-raise
   3816     #  the TypeError.
   3817     self._check_indexing_error(key)

KeyError: 'focus'
In [ ]:
Copied!
# here we compare memory B-cell in BPH to normal memory B cells before denoising
sc.tl.rank_genes_groups(
    prostate_adata,
    groupby="focus",
    groups=["BPH associated memory B cell"],
    reference="memory B cell",
    method="t-test",
)
# Plot the most differentially expressed genes
sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol")
# super strong B cell markers
# here we compare memory B-cell in BPH to normal memory B cells before denoising sc.tl.rank_genes_groups( prostate_adata, groupby="focus", groups=["BPH associated memory B cell"], reference="memory B cell", method="t-test", ) # Plot the most differentially expressed genes sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol") # super strong B cell markers
WARNING: It seems you use rank_genes_groups on the raw count data. Please logarithmize your data before calling rank_genes_groups.
No description has been provided for this image
In [ ]:
Copied!
# perform denoising
denoise = Denoiser(
    # could be on top differentially expressed genes or on random expressed genes in the cell (completed by random non expressed gene)
    how="random expr",
    # the size of the minibatch (need to fit in memory)
    batch_size=20,
    num_workers=8,
    # the number of genes to use
    max_len=5000,
    # the number of cells to use (here more than what we will use so we will use everything)
    max_cells=10_000,
    doplot=False,
    # how much do we want to increase the depth / counts of the cells (here, 10x)
    predict_depth_mult=7,
    dtype=dtype,
)
metrics, idx, denoised_adata = denoise(
    model, prostate_adata[prostate_adata.obs["focus"] != "other"]
)
# perform denoising denoise = Denoiser( # could be on top differentially expressed genes or on random expressed genes in the cell (completed by random non expressed gene) how="random expr", # the size of the minibatch (need to fit in memory) batch_size=20, num_workers=8, # the number of genes to use max_len=5000, # the number of cells to use (here more than what we will use so we will use everything) max_cells=10_000, doplot=False, # how much do we want to increase the depth / counts of the cells (here, 10x) predict_depth_mult=7, dtype=dtype, ) metrics, idx, denoised_adata = denoise( model, prostate_adata[prostate_adata.obs["focus"] != "other"] )
100%|██████████| 2/2 [00:58<00:00, 29.10s/it]
AnnData object with n_obs × n_vars = 32 × 44741
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb'
    layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
In [ ]:
Copied!
# The error occurs because we need to ensure the masks have compatible shapes
# For AnnData objects, we need to use the obs_names and var_names for proper indexing
row_mask = prostate_adata.obs_names[prostate_adata.obs["focus"] != "other"]
col_mask = prostate_adata.var_names[
    prostate_adata.var.index.isin(denoised_adata.var.index)
]
prostate_adata[row_mask, col_mask].X = denoised_adata.X[
    :, denoised_adata.var.index.isin(prostate_adata.var.index)
]
# The error occurs because we need to ensure the masks have compatible shapes # For AnnData objects, we need to use the obs_names and var_names for proper indexing row_mask = prostate_adata.obs_names[prostate_adata.obs["focus"] != "other"] col_mask = prostate_adata.var_names[ prostate_adata.var.index.isin(denoised_adata.var.index) ] prostate_adata[row_mask, col_mask].X = denoised_adata.X[ :, denoised_adata.var.index.isin(prostate_adata.var.index) ]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[93], line 5
      3 row_mask = prostate_adata.obs_names[prostate_adata.obs['focus'] != "other"]
      4 col_mask = prostate_adata.var_names[prostate_adata.var.index.isin(denoised_adata.var.index)]
----> 5 prostate_adata[row_mask, col_mask].X = denoised_adata.X[:, denoised_adata.var.index.isin(prostate_adata.var.index)]

File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/anndata/_core/anndata.py:650, in AnnData.X(self, value)
    643         warnings.warn(
    644             "Trying to set a dense array with a sparse array on a view."
    645             "Densifying the sparse array."
    646             "This may incur excessive memory usage",
    647             stacklevel=2,
    648         )
    649         value = value.toarray()
--> 650     self._adata_ref._X[oidx, vidx] = value
    651 else:
    652     self._X = value

File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_index.py:134, in IndexMixin.__setitem__(self, key, x)
    132     x = x.tocoo(copy=True)
    133     x.sum_duplicates()
--> 134     self._set_arrayXarray_sparse(i, j, x)
    135 else:
    136     # Make x and i into the same shape
    137     x = np.asarray(x, dtype=self.dtype)

File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_compressed.py:826, in _cs_matrix._set_arrayXarray_sparse(self, row, col, x)
    824 def _set_arrayXarray_sparse(self, row, col, x):
    825     # clear entries that will be overwritten
--> 826     self._zero_many(*self._swap((row, col)))
    828     M, N = row.shape  # matches col.shape
    829     broadcast_row = M != 1 and x.shape[0] == 1

File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_compressed.py:946, in _cs_matrix._zero_many(self, i, j)
    944 n_samples = len(i)
    945 offsets = np.empty(n_samples, dtype=self.indices.dtype)
--> 946 ret = csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
    947                          i, j, offsets)
    948 if ret == 1:
    949     # rinse and repeat
    950     self.sum_duplicates()

KeyboardInterrupt: 
In [83]:
Copied!
denoised_adata
denoised_adata
Out[83]:
AnnData object with n_obs × n_vars = 777 × 44756
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb'
    layers: 'scprint_mu', 'scprint_pi', 'scprint_theta'
In [ ]:
Copied!
# now what we are doing here it to complete the expression profile with the denoised values. this is not done by default for now
prostate_adata.X[
    prostate_adata.obs["focus"] != "other",
    prostate_adata.var.index.isin(denoised_adata.var.index),
] = denoised_adata.X[:, denoised_adata.var.index.isin(prostate_adata.var.index)]
# now what we are doing here it to complete the expression profile with the denoised values. this is not done by default for now prostate_adata.X[ prostate_adata.obs["focus"] != "other", prostate_adata.var.index.isin(denoised_adata.var.index), ] = denoised_adata.X[:, denoised_adata.var.index.isin(prostate_adata.var.index)]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[82], line 2
      1 # now what we are doing here it to complete the expression profile with the denoised values. this is not done by default for now
----> 2 prostate_adata.X[prostate_adata.obs['focus']!="other", prostate_adata.var.index.isin(denoised_adata.var.index)] = denoised_adata.X[:, denoised_adata.var.index.isin(prostate_adata.var.index)]

File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_index.py:115, in IndexMixin.__setitem__(self, key, x)
    112 else:
    113     col = np.atleast_1d(col)
--> 115 i, j = _broadcast_arrays(row, col)
    116 if i.shape != j.shape:
    117     raise IndexError('number of row and column indices differ')

File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_index.py:19, in _broadcast_arrays(a, b)
     10 def _broadcast_arrays(a, b):
     11     """
     12     Same as np.broadcast_arrays(a, b) but old writeability rules.
     13 
   (...)
     17     the old behavior.
     18     """
---> 19     x, y = np.broadcast_arrays(a, b)
     20     x.flags.writeable = a.flags.writeable
     21     y.flags.writeable = b.flags.writeable

File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/numpy/lib/stride_tricks.py:540, in broadcast_arrays(subok, *args)
    533 # nditer is not used here to avoid the limit of 32 arrays.
    534 # Otherwise, something like the following one-liner would suffice:
    535 # return np.nditer(args, flags=['multi_index', 'zerosize_ok'],
    536 #                  order='C').itviews
    538 args = [np.array(_m, copy=False, subok=subok) for _m in args]
--> 540 shape = _broadcast_shape(*args)
    542 if all(array.shape == shape for array in args):
    543     # Common case where nothing needs to be broadcasted.
    544     return args

File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/numpy/lib/stride_tricks.py:422, in _broadcast_shape(*args)
    417 """Returns the shape of the arrays that would result from broadcasting the
    418 supplied arrays against each other.
    419 """
    420 # use the old-iterator because np.nditer does not handle size 0 arrays
    421 # consistently
--> 422 b = np.broadcast(*args[:32])
    423 # unfortunately, it cannot handle 32 or more arguments directly
    424 for pos in range(32, len(args), 31):
    425     # ironically, np.broadcast does not properly handle np.broadcast
    426     # objects (it treats them as scalars)
    427     # use broadcasting to avoid allocating the full array

ValueError: shape mismatch: objects cannot be broadcast to a single shape.  Mismatch is between arg 0 with shape (777,) and arg 1 with shape (23142,).
In [ ]:
Copied!
# now we compare memory B-cell in BPH to normal memory B cells after denoising
sc.tl.rank_genes_groups(
    prostate_adata,
    groupby="focus",
    groups=["BPH associated memory B cell"],
    reference="memory B cell",
    method="t-test",
)
# Plot the most differentially expressed genes
sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol")
# super strong B cell markers
# now we compare memory B-cell in BPH to normal memory B cells after denoising sc.tl.rank_genes_groups( prostate_adata, groupby="focus", groups=["BPH associated memory B cell"], reference="memory B cell", method="t-test", ) # Plot the most differentially expressed genes sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol") # super strong B cell markers
No description has been provided for this image
In [9]:
Copied!
prostate_adata.write_h5ad("../../data/temp/prostate_combined_denoised.h5ad")
prostate_adata.write_h5ad("../../data/temp/prostate_combined_denoised.h5ad")
In [10]:
Copied!
prostate_adata = sc.read_h5ad("../../data/temp/prostate_combined_denoised.h5ad")
prostate_adata
prostate_adata = sc.read_h5ad("../../data/temp/prostate_combined_denoised.h5ad") prostate_adata
Out[10]:
AnnData object with n_obs × n_vars = 83451 × 70116
    obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'pred_cell_type_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id', 'sprint_leiden', 'cleaned_pred_cell_type_ontology_term_id', 'cleaned_pred_disease_ontology_term_id'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'description', 'synonyms', 'organism_id', 'public_source_id', 'created_by_id', 'mt', 'ribo', 'hb', 'organism', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'assay_colors', 'cleaned_pred_cell_type_ontology_term_id_colors', 'cleaned_pred_disease_ontology_term_id_colors', 'development_stage_colors', 'disease_colors', 'unseen_genes'
    obsm: 'X_pca', 'X_umap', 'scprint', 'scprint_umap'

Gene network inference¶

Finally we will use scPRINT to infer gene networks on another cell of interest, the fibroblasts, in both normal and BPH conditions.

We will use the GRNfer class to infer gene networks. (see the cancer_usecase_part2.ipynb for more details on how to analyse the gene networks.)

In [38]:
Copied!
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id.value_counts().head(20)
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id.value_counts().head(20)
Out[38]:
cleaned_pred_cell_type_ontology_term_id
basal epithelial cell of prostatic duct                             25740
prostate gland microvascular endothelial cell                       12165
urethra urothelial cell                                             10952
CD1c-positive myeloid dendritic cell                                 6898
other                                                                5869
aortic smooth muscle cell                                            3970
effector CD8-positive, alpha-beta T cell                             3925
pancreatic acinar cell                                               3711
luminal cell of prostate epithelium                                  3435
fibroblast of connective tissue of nonglandular part of prostate     3334
mucous neck cell                                                     3260
IgG-negative class switched memory B cell                            2932
basophil                                                             1862
fibroblast of connective tissue of glandular part of prostate        1815
pancreatic ductal cell                                               1651
CD4-positive, alpha-beta thymocyte                                   1516
club cell                                                            1226
smooth muscle cell of prostate                                       1000
effector memory CD8-positive, alpha-beta T cell                       991
peptic cell                                                           872
mature conventional dendritic cell                                    721
effector CD4-positive, alpha-beta T cell                              686
mast cell                                                             451
renal interstitial pericyte                                           440
retinal blood vessel endothelial cell                                 404
Name: count, dtype: int64
In [ ]:
Copied!
loc = (
    prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id
    == "fibroblast of connective tissue of glandular part of prostate"
)
prostate_adata.obs[loc]["cleaned_pred_disease_ontology_term_id"].value_counts()
loc = ( prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id == "fibroblast of connective tissue of glandular part of prostate" ) prostate_adata.obs[loc]["cleaned_pred_disease_ontology_term_id"].value_counts()
Out[ ]:
cleaned_pred_disease_ontology_term_id
benign prostatic hyperplasia    1482
normal                           673
other                              4
Name: count, dtype: int64
In [ ]:
Copied!
prostate_adata.obs[loc]["louvain_1.0"].value_counts().head(10)
prostate_adata.obs[loc]["louvain_1.0"].value_counts().head(10)
Out[ ]:
louvain_1.0
5     1458
4      671
17      10
8        8
7        7
0        3
9        2
6        0
3        0
1        0
10       0
11       0
12       0
13       0
14       0
15       0
16       0
2        0
18       0
Name: count, dtype: int64
In [ ]:
Copied!
loc = loc & (prostate_adata.obs["louvain_1.0"] == str(5))
prostate_adata.obs[loc]["cleaned_pred_disease_ontology_term_id"].value_counts()
loc = loc & (prostate_adata.obs["louvain_1.0"] == str(5)) prostate_adata.obs[loc]["cleaned_pred_disease_ontology_term_id"].value_counts()
Out[ ]:
cleaned_pred_disease_ontology_term_id
benign prostatic hyperplasia    790
normal                          664
other                             4
Name: count, dtype: int64
In [ ]:
Copied!
prostate_adata.obs["fibro"] = None
prostate_adata.obs.loc[loc, "fibro"] = "fibroblasts"
prostate_adata.obs.loc[
    loc
    & (
        prostate_adata.obs.cleaned_pred_disease_ontology_term_id
        == "benign prostatic hyperplasia"
    ),
    "fibro",
] = "BPH associated fibroblasts"
prostate_adata.obs["fibro"] = None prostate_adata.obs.loc[loc, "fibro"] = "fibroblasts" prostate_adata.obs.loc[ loc & ( prostate_adata.obs.cleaned_pred_disease_ontology_term_id == "benign prostatic hyperplasia" ), "fibro", ] = "BPH associated fibroblasts"
In [ ]:
Copied!
sc.pl.embedding(
    prostate_adata[
        prostate_adata.obs["louvain_1.0"].isin(["5", "4"])
        & (prostate_adata.obsm["scprint_umap"][:, 1] > 10)
    ],
    basis="scprint_umap_rot",
    color="fibro",
    show=False,
    size=8,
    title="Fibroblasts cluster",
    legend_loc="right margin",
)
sc.pl.embedding( prostate_adata[ prostate_adata.obs["louvain_1.0"].isin(["5", "4"]) & (prostate_adata.obsm["scprint_umap"][:, 1] > 10) ], basis="scprint_umap_rot", color="fibro", show=False, size=8, title="Fibroblasts cluster", legend_loc="right margin", )
... storing 'focus' as categorical
... storing 'fibro' as categorical
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning
  color_vector = pd.Categorical(values.map(color_map))
No description has been provided for this image
In [ ]:
Copied!
grn_inferer = GNInfer(
    # here we use the most variable genes across the fibroblasts vs the rest to find the num_genes genes we want to use to compute our gene network
    how="most var across",
    # how="random expr",
    # we will preprocess the attention matrix with softmax
    # this means that instead of being directly the QK^T matrix of the attention operation we apply softmax to the rows (see https://medium.com/@mayanksultania/transformers-101-tokens-attention-and-beyond-b080a900ca6c)
    preprocess="softmax",
    # we don't aggregate the different gene networks coming from the different heads over the different layers here, we will do it manually afterward
    head_agg="none",
    # we will not use any filtration here, but you can use "none" or "topk" for top k connections per genes or "thresh" for a defined threshold, and more
    filtration="none",
    # here if we generate the attention matrices by performing a task, like denoising or by just passing the expression profile through the model
    forward_mode="none",
    # the number of genes to use (here the 4000 most variable genes)
    num_genes=4000,
    # the max number of cell to use per cell type
    max_cells=300,
    doplot=False,
    batch_size=16,
    # the column in anndata the defines the cell type
    cell_type_col="fibro",
    dtype=dtype,
    # list of transformer layers to extract the Gene Network from
    layer=list(range(model.nlayers))[:],
)
grn_inferer = GNInfer( # here we use the most variable genes across the fibroblasts vs the rest to find the num_genes genes we want to use to compute our gene network how="most var across", # how="random expr", # we will preprocess the attention matrix with softmax # this means that instead of being directly the QK^T matrix of the attention operation we apply softmax to the rows (see https://medium.com/@mayanksultania/transformers-101-tokens-attention-and-beyond-b080a900ca6c) preprocess="softmax", # we don't aggregate the different gene networks coming from the different heads over the different layers here, we will do it manually afterward head_agg="none", # we will not use any filtration here, but you can use "none" or "topk" for top k connections per genes or "thresh" for a defined threshold, and more filtration="none", # here if we generate the attention matrices by performing a task, like denoising or by just passing the expression profile through the model forward_mode="none", # the number of genes to use (here the 4000 most variable genes) num_genes=4000, # the max number of cell to use per cell type max_cells=300, doplot=False, batch_size=16, # the column in anndata the defines the cell type cell_type_col="fibro", dtype=dtype, # list of transformer layers to extract the Gene Network from layer=list(range(model.nlayers))[:], )
In [ ]:
Copied!
# I was missing this from the model (not really necessary)
prostate_adata.obs["fibro"] = prostate_adata.obs["fibro"].astype(str)
prostate_adata.obs[
    prostate_adata.obs["fibro"] == "BPH associated fibroblasts"
].disease.value_counts()
# I was missing this from the model (not really necessary) prostate_adata.obs["fibro"] = prostate_adata.obs["fibro"].astype(str) prostate_adata.obs[ prostate_adata.obs["fibro"] == "BPH associated fibroblasts" ].disease.value_counts()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/local/scratch/tmp/ipykernel_3040762/465330363.py in ?()
      1 # I was missing this from the model (not really necessary)
      2 model.organisms# = ['NCBITaxon:9606', 'NCBITaxon:10090']
----> 3 prostate_adata.obs.fibro = prostate_adata.obs.fibro.astype(str)
      4 prostate_adata.obs[prostate_adata.obs.fibro=="BPH associated fibroblasts"].disease.value_counts()

~/test/test/lib/python3.10/site-packages/pandas/core/generic.py in ?(self, name)
   6314             and name not in self._accessors
   6315             and self._info_axis._can_hold_identifiers_and_holds_name(name)
   6316         ):
   6317             return self[name]
-> 6318         return object.__getattribute__(self, name)

AttributeError: 'DataFrame' object has no attribute 'fibro'

output¶

the output of the function is a new Anndata in the form of a GRnnData object see cantinilab/grnndata

the .varp['GRN'] attribute which contains the full Gene Network in this case (a fully connected directed weighted graph between all 4000 genes selected)

because we selected head_agg to "none" that means it actually is not of size 4000x4000 but 4000x4000x(n_layers*n_heads). by taking the mean over this last dimension with mean-aggregate the different gene networks coming from the different heads over the different layers into a single 4000x4000 matrix.

In [ ]:
Copied!
# compute GRNs on fibroblasts, we use all the atetetion layers
grn = grn_inferer(model, prostate_adata, cell_type="fibroblasts"))

# highlight differential links on genes that are expressed in both
grn.varp["all"] = grn.varp["GRN"].copy()
# now we aggregate the heads by taking their average
grn.varp["GRN"] = grn.varp["GRN"].mean(-1)
grn.write_h5ad("../../data/temp/prostate_fibro_grn_all.h5ad")
# compute GRNs on fibroblasts, we use all the atetetion layers grn = grn_inferer(model, prostate_adata, cell_type="fibroblasts")) # highlight differential links on genes that are expressed in both grn.varp["all"] = grn.varp["GRN"].copy() # now we aggregate the heads by taking their average grn.varp["GRN"] = grn.varp["GRN"].mean(-1) grn.write_h5ad("../../data/temp/prostate_fibro_grn_all.h5ad")
In [ ]:
Copied!
# same on the BPH associated fibroblasts
# I wanted to use only the ones that the labellers had defined as coing from BPH
prostate_adata.obs.loc[
    (prostate_adata.obs.fibro == "BPH associated fibroblasts")
    & (prostate_adata.obs.disease == "benign prostatic hyperplasia"),
    "fibro",
] = "true BPH associated fibroblasts"

grn_c = grn_inferer(
    model, prostate_adata, cell_type="true cancer associated fibroblasts"
)
# highlight differential links on genes that are expressed in both
grn_c.varp["all"] = grn_c.varp["GRN"].copy()
grn_c.varp["GRN"] = grn_c.varp["GRN"].mean(-1)
grn_c.write_h5ad("../../data/temp/prostate_BPH_fibro_grn_all.h5ad")
# same on the BPH associated fibroblasts # I wanted to use only the ones that the labellers had defined as coing from BPH prostate_adata.obs.loc[ (prostate_adata.obs.fibro == "BPH associated fibroblasts") & (prostate_adata.obs.disease == "benign prostatic hyperplasia"), "fibro", ] = "true BPH associated fibroblasts" grn_c = grn_inferer( model, prostate_adata, cell_type="true cancer associated fibroblasts" ) # highlight differential links on genes that are expressed in both grn_c.varp["all"] = grn_c.varp["GRN"].copy() grn_c.varp["GRN"] = grn_c.varp["GRN"].mean(-1) grn_c.write_h5ad("../../data/temp/prostate_BPH_fibro_grn_all.h5ad")
Previous Next

Built with MkDocs using a theme provided by Read the Docs.
« Previous Next »