scPRINT use case on BPH¶
In this use-case, also presented in Figure 5 of our manuscript, we perform an extensive analysis of a multi studies dataset of benign prostatic hyperplasia.
Our biological question is to check if there exist pre-cancerous cells that exhibits behaviors of mature cancer cells at this early stage of the disease.
In those cells, we want to know which genes might be implicated in cell state changes, and explore potentially novel targets in the treatment of prostate cancer and BPH.
We will start with a fresh datasets coming from the cellXgene database and representing 2 studies of BPH.
We will first explore these dataset to understand:
- what are the cell types that are present in the data
- what are the cell distributions (cell distributions? what are they?)
- what sequencers were used, etc.
We also want to confirm existing target in prostate cancer through precancerous lesion analysis, and find potentially novel ones that would serve as less invasive BPH treatments than current ones.
Finally we want to know how these targets interacts and are involved in biological pathways.
We now showcase how to use scPRINT across its different functionalities to answer some of these questions.
table of contents:¶
- Downloading and preprocessing
- Embedding and annotations
- Annotation cleanup
- Clustering and differential expression
- Denoising and differential expression
- Gene network inference
In the notebook cancer_usecase_part2.ipynb you will see how to analyse cell type specific gene regulatory networks.
These notebooks can be found and run from docs/notebooks.
🎴 YOU NEED TO FIRST HAVE A POPULATED LAMINDB INSTANCE (see README) 🎴¶
! lamin connect anonymous/testdb3
! calling anonymously, will miss private instances → connected lamindb: anonymous/testdb3 • to map a local dev directory, call: lamin settings set dev-dir .
# ! lamin connect [MY_INSTANCE]
from huggingface_hub import hf_hub_download
from scprint import scPrint
from scdataloader import Preprocessor, utils
from scprint.tasks import GNInfer, Embedder, Denoiser, withknn
from scdataloader.utils import load_genes
from bengrn import BenGRN, get_sroy_gt, compute_genie3
from bengrn.base import train_classifier
from grnndata import utils as grnutils
from grnndata import read_h5ad
from anndata.utils import make_index_unique
from anndata import concat
import scanpy as sc
from matplotlib import pyplot as plt
import numpy as np
import lamindb as ln
%load_ext autoreload
%autoreload 2
import torch
torch.set_float32_matmul_precision("medium")
→ connected lamindb: anonymous/testdb3
Downloading and preprocessing¶
We now use lamindb to easily access cellxgene and download a dataset of normal and benign prostatic hyperplasia tissues.
data is available here https://cellxgene.cziscience.com/e/574e9f9e-f8b4-41ef-bf19-89a9964fd9c7.cxg/ .
We then use scDataloader's preprocessing method. This method is quite extensive and does a few things.. find our more about it on its documentation.
On our end we are using the preprocessor to make sure that the the gene expression that we have are raw counts and that we have enough information to use scPRINT (i.e., enough genes expressed and enough counts per cells across the dataset).
Finally, the preprocessor will also increase the size of the expression matrix to be a fixed set of genes defined by the latest version of ensemble.
cx_dataset = (
ln.DB(instance="laminlabs/cellxgene")
.Collection.filter(version="2023-12-15", key="cellxgene-census")
.one()
)
→ the database (2.2a3) is ahead of your installed lamindb package (2.1.1) → consider updating lamindb: pip install lamindb>=2.2 → consider updating lamindb: pip install lamindb>=2.2 → the database (2.2a3) is ahead of your installed lamindb package (2.1.1) → consider updating lamindb: pip install lamindb>=2.2
prostate_adata = (
cx_dataset.artifacts.filter(
key="cell-census/2023-12-15/h5ads/574e9f9e-f8b4-41ef-bf19-89a9964fd9c7.h5ad"
)
.one()
.load()
)
sc.pl.umap(prostate_adata)
! run input wasn't tracked, call `ln.track()` and re-run
#from scdataloader.utils import populate_my_ontology
#populate_my_ontology()
... synchronizing df_all__cl__2025-12-17__CellType.parquet: 100.0% ... synchronizing df_vertebrates__ensembl__release-114__Organism.parquet: 100.0% ! ambiguous validation in Bionty for 1 record: 'sheep' ! ambiguous validation in Bionty for 1 record: 'sheep' ! ontology ID BFO:0000020 not found in DataFramequet: 100.0% ... synchronizing df_human__hancestro__2025-10-14__Ethnicity.parquet: 100.0% ... synchronizing df_all__efo__3.85.0__ExperimentalFactor.parquet: 100.0% → starting creation of 18326 ExperimentalFactor records in batches of 10000 → starting creation of 16430 ExperimentalFactor_parents records in batches of 10000 ! you are trying to create a record with name='unknown' but a record with similar name exists: 'fever of unknown origin'. Did you mean to load it? ... synchronizing df_all__uberon__2025-12-04__Tissue.parquet: 100.0% → starting creation of 15770 Tissue records in batches of 10000 → starting creation of 42140 Tissue_parents records in batches of 10000 ... synchronizing df_human__hsapdv__2025-01-23__DevelopmentalStage.parquet: 100.0% ... synchronizing df_mouse__mmusdv__2025-01-23__DevelopmentalStage.parquet: 100.0% ... synchronizing df_all__mondo__2026-01-06__Disease.parquet: 100.0% → starting creation of 30371 Disease records in batches of 10000 → starting creation of 39555 Disease_parents records in batches of 10000 ! you are trying to create a record with name='normal' but records with similar names exist: 'normocytic anemia', 'normal breast-like subtype of breast carcinoma', 'normokalemic periodic paralysis'. Did you mean to load one of them? ! you are trying to create a record with name='unknown' but records with similar names exist: 'unknown leukodystrophy', 'hemoglobin--variants for which the chain carrying the mutation 1S unknown or uncertain', 'sporadic adult-onset ataxia of unknown etiology'. Did you mean to load one of them? ... synchronizing df_mouse__ensembl__release-114__Gene.parquet: 100.0% ... synchronizing df_human__ensembl__release-114__Gene.parquet: 100.0%
# preprocessing using scDataloader
prostate_adata.obs.drop(columns="is_primary_data", inplace=True)
preprocessor = Preprocessor(do_postp=False)
prostate_adata = preprocessor(prostate_adata)
Dropping layers: KeysView(Layers with keys: )
checking raw counts
removed 0 non primary cells, 83451 renamining
filtered out 0 cells, 83451 renamining
Removed 586 genes not known to the ontology
Removed 0 duplicate genes
Added 58936 genes in the ontology but not present in the dataset
starting QC
Seeing 13064 outliers (15.65% of total dataset):
done
AnnData object with n_obs × n_vars = 83451 × 85720
obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier'
var: 'feature_biotype', 'feature_is_filtered', 'feature_name', 'feature_reference', 'uid', 'symbol', 'biotype', 'is_locked', 'branch_id', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
uns: 'unseen_genes'
loading the model:¶
We now start to load a large version of scPRINT from a specific checkpoint. Please download the checkpoints following the instructions in the README.
Note that for any other notebooks presented in this repo, there might not be such a care on how we load the model. Refer to this part bellow to have a good understanding of the different parametrization of loading pretrained model weights.
model_checkpoint_file = hf_hub_download(
repo_id="jkobject/scPRINT", filename=f"medium-v1.5.ckpt"
)
# model_checkpoint_file = "../data/temp/o2uniqsx/epoch=18-step=133000.ckpt"
# "/pasteur/zeus/projets/p02/ml4ig_hot/Users/jkalfon/scprint_scale/o2uniqsx/checkpoints/last.ckpt"
#'../data/temp/o2uniqsx/epoch=18-step=133000.ckpt',
# "../../xzp23r4p.ckpt"
model_checkpoint_file
'/pasteur/appa/homes/jkalfon/.cache/huggingface/hub/models--jkobject--scPRINT/snapshots/d661d779e3614c77305e2fbe0a19312eda2794b8/medium-v1.5.ckpt'
# make sure that you check if you have a GPU with flashattention or not (see README)
try:
m = torch.load(model_checkpoint_file)
# if not use this instead since the model weights are by default mapped to GPU types
except RuntimeError:
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
# again here by default the model was trained with flash attention, so if you do not have a GPU you will need to replace the attention mechanism with regular attention
transformer = "flash" if torch.cuda.is_available() else "normal"
# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer
if "prenorm" in m["hyper_parameters"]:
m["hyper_parameters"].pop("prenorm")
torch.save(m, model_checkpoint_file)
if "label_counts" in m["hyper_parameters"]:
# you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
precpt_gene_emb=None,
classes=m["hyper_parameters"]["label_counts"],
transformer=transformer,
)
else:
model = scPrint.load_from_checkpoint(
model_checkpoint_file, precpt_gene_emb=None, transformer=transformer
)
# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology
missing = set(model.genes) - set(load_genes(model.organisms).index)
if len(missing) > 0:
print(
"Warning: some genes missmatch exist between model and ontology: solving...",
)
model._rm_genes(missing)
# again if not on GPU you need to convert the model to float64
if not torch.cuda.is_available():
model = model.to(torch.float32)
# you can perform your inference on float16 if you have a GPU, otherwise use float64
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
RuntimeError caught: scPrint is not attached to a `Trainer`. Warning: some genes missmatch exist between model and ontology: solving...
Embedding and annotations¶
We will then use out Embedder class to embed the data and annotate the cells.
These classes are how we parametrize and access the different functions of
scPRINT. Find out more about its parameters in our
documentation.
embedder = Embedder(
# can work on random genes or most variables etc..
how="random expr",
# number of genes to use
max_len=4000,
# the model is trained on a minibatch of 64 cells but you can choose whatever
batch_size=32,
# for the dataloading
num_workers=8,
# we will only use the cell type embedding here.
pred_embedding=["cell_type_ontology_term_id"],
# , "disease_ontology_term_id"],
# we will now
save_every=40_000,
dtype=dtype,
)
# create the embedding
prostate_adata, metrics = embedder(model, prostate_adata, cache=False)
0%| | 0/2608 [00:00<?, ?it/s]/local/scratch/tmp/tmptntrr2uq/main.c:4:10: fatal error: Python.h: No such file or directory
#include <Python.h>
^~~~~~~~~~
compilation terminated.
0%| | 0/2608 [00:03<?, ?it/s]
--------------------------------------------------------------------------- CalledProcessError Traceback (most recent call last) Cell In[9], line 2 1 # create the embedding ----> 2 prostate_adata, metrics = embedder(model, prostate_adata, cache=False) File ~/scPRINT/scprint/tasks/cell_emb.py:142, in Embedder.__call__(self, model, adata, cache) 136 for batch in tqdm(dataloader): 137 gene_pos, expression, depth = ( 138 batch["genes"].to(device), 139 batch["x"].to(device), 140 batch["depth"].to(device), 141 ) --> 142 model._predict( 143 gene_pos, 144 expression, 145 depth, 146 predict_mode="none", 147 pred_embedding=self.pred_embedding, 148 get_gene_emb=self.get_gene_emb, 149 max_size_in_mem=self.save_every, 150 ) 151 torch.cuda.empty_cache() 152 model.log_adata(name="predict_part_" + str(model.counter)) File ~/scPRINT/scprint/model/model.py:1487, in scPrint._predict(self, gene_pos, expression, depth, predict_mode, pred_embedding, get_attention_layer, depth_mult, keep_output, max_size_in_mem, get_gene_emb, metacell_token) 1467 """ 1468 @see predict_step will save output of predict in multiple self variables 1469 (...) 1484 1485 """ 1486 if predict_mode == "none": -> 1487 output = self.forward( 1488 gene_pos, 1489 expression, 1490 depth_mult=expression.sum(1), 1491 req_depth=depth, 1492 get_attention_layer=get_attention_layer, 1493 do_class=True, 1494 get_gene_emb=get_gene_emb, 1495 metacell_token=metacell_token, 1496 ) 1497 if len(get_attention_layer) > 0: 1498 self.attn.add([i[:, :, :2, :] for i in output[1]], gene_pos) File ~/scPRINT/scprint/model/model.py:749, in scPrint.forward(self, gene_pos, expression, mask, req_depth, timepoint, get_gene_emb, metacell_token, depth_mult, do_sample, do_mvc, do_class, get_attention_layer) 747 cell_encoding = encoding[:, : self.cell_embs_count, :] 748 encoding = encoding[:, self.cell_embs_count :, :] --> 749 transformer_output = self.transformer( 750 encoding, 751 return_qkv=get_attention_layer, 752 bias=bias if self.attn_bias != "none" else None, 753 bias_layer=list(range(self.nlayers - 1)), 754 ) 755 if len(get_attention_layer) > 0: 756 transformer_output, qkvs = transformer_output File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/flashformer.py:191, in FlashTransformer.forward(self, hidden_states, x_kv, mask, return_qkv, bias, bias_layer) 189 if self.cross_attn and x_kv is None and i % 2 == 0: 190 continue --> 191 hidden_states = block( 192 hidden_states, 193 x_kv if self.cross_attn and i % 2 == 0 else None, 194 residual, 195 return_qkv=(i in return_qkv), 196 bias=bias if i in bias_layer else None, 197 ) 198 if i in return_qkv: 199 qkvs.append(hidden_states[-1]) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/block.py:198, in Block.forward(self, hidden_states, x_kv, residual, bias, src_mask, is_causal, src_key_padding_mask, mixer_subset, mixer_kwargs, return_qkv) 196 if mixer_subset is not None: 197 mixer_kwargs["mixer_subset"] = mixer_subset --> 198 hidden_states = self.mixer( 199 hidden_states, 200 x_kv=x_kv, 201 return_qkv=return_qkv, 202 bias=bias, 203 **mixer_kwargs, 204 ) 205 if return_qkv: 206 qkv = hidden_states[1] File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/mha.py:680, in MHA.forward(self, x, x_kv, key_padding_mask, cu_seqlens, max_seqlen, mixer_subset, inference_params, return_qkv, **kwargs) 678 if inference_params is None: 679 if not self.checkpointing: --> 680 context = self.inner_attn(qkv, **kwargs) 681 else: 682 context = torch.utils.checkpoint.checkpoint( 683 self.inner_attn, qkv, **kwargs 684 ) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/mha.py:104, in FlashSelfAttention.forward(self, qkv, causal, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, bias, **kwargs) 94 qkv = qkv.contiguous() 95 # return triton_attention( 96 # qkv[:, :, 0], 97 # qkv[:, :, 1], (...) 102 # ) 103 else: --> 104 return flash_attn_qkvpacked_func( 105 qkv, 106 bias, 107 # self.drop.p if self.training else 0.0, 108 causal, 109 self.softmax_scale, 110 ) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs) 550 if not torch._C._are_functorch_transforms_active(): 551 # See NOTE: [functorch vjp and autograd interaction] 552 args = _functorch.utils.unwrap_dead_wrappers(args) --> 553 return super().apply(*args, **kwargs) # type: ignore[misc] 555 if not is_setup_ctx_defined: 556 raise RuntimeError( 557 "In order to use an autograd.Function with functorch transforms " 558 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 559 "staticmethod. For more details, please see " 560 "https://pytorch.org/docs/master/notes/extending.func.html" 561 ) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/flashattention.py:1190, in FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale) 1188 if qkv.stride(-1) != 1: 1189 qkv = qkv.contiguous() -> 1190 o, lse, ctx.softmax_scale = _flash_attn_forward( 1191 qkv[:, :, 0], 1192 qkv[:, :, 1], 1193 qkv[:, :, 2], 1194 bias=bias, 1195 causal=causal, 1196 softmax_scale=softmax_scale, 1197 ) 1198 ctx.save_for_backward(qkv, o, lse, bias) 1199 ctx.causal = causal File ~/scPRINT/scprint1/lib64/python3.12/site-packages/simpler_flash/flashattention.py:999, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale) 997 num_warps = 4 if d <= 64 else 8 998 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) --> 999 _fwd_kernel[grid]( 1000 q, 1001 k, 1002 v, 1003 bias, 1004 o, 1005 lse, 1006 tmp, 1007 softmax_scale, 1008 q.stride(0), 1009 q.stride(2), 1010 q.stride(1), 1011 k.stride(0), 1012 k.stride(2), 1013 k.stride(1), 1014 v.stride(0), 1015 v.stride(2), 1016 v.stride(1), 1017 *bias_strides, 1018 o.stride(0), 1019 o.stride(2), 1020 o.stride(1), 1021 nheads, 1022 seqlen_q, 1023 seqlen_k, 1024 seqlen_q_rounded, 1025 d, 1026 seqlen_q // 32, 1027 seqlen_k // 32, # key for triton cache (limit number of compilations) 1028 # Can't use kwargs here because triton autotune expects key to be args, not kwargs 1029 # IS_CAUSAL=causal, BLOCK_HEADDIM=d, 1030 bias_type, 1031 causal, 1032 BLOCK_HEADDIM, 1033 BLOCK_M=BLOCK, 1034 BLOCK_N=BLOCK, 1035 num_warps=num_warps, 1036 num_stages=1, 1037 ) 1038 return o, lse, softmax_scale File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/runtime/autotuner.py:305, in Heuristics.run(self, *args, **kwargs) 303 for v, heur in self.values.items(): 304 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) --> 305 return self.fn.run(*args, **kwargs) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/runtime/jit.py:532, in JITFunction.run(self, *args, **kwargs) 517 if self._call_hook( 518 key, 519 signature, (...) 528 configs, 529 ): 530 return None --> 532 self.cache[device][key] = compile( 533 self, 534 signature=signature, 535 device=device, 536 constants=constants, 537 num_warps=num_warps, 538 num_ctas=num_ctas, 539 num_stages=num_stages, 540 enable_warp_specialization=enable_warp_specialization, 541 enable_fp_fusion=enable_fp_fusion, 542 extern_libs=extern_libs, 543 configs=configs, 544 debug=self.debug, 545 device_type=device_type, 546 ) 548 bin = self.cache[device][key] 549 if not warmup: File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/compiler/compiler.py:614, in compile(fn, **kwargs) 612 # cache manager 613 if is_cuda: --> 614 so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) 615 else: 616 so_path = _device_backend.make_launcher_stub(name, signature, constants, ids) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/compiler/make_launcher.py:37, in make_stub(name, signature, constants, ids, **kwargs) 35 with open(src_path, "w") as f: 36 f.write(src) ---> 37 so = _build(name, src_path, tmpdir) 38 with open(so, "rb") as f: 39 return so_cache_manager.put(f.read(), so_name, binary=True) File ~/scPRINT/scprint1/lib64/python3.12/site-packages/triton/common/build.py:106, in _build(name, src, srcdir) 101 cc_cmd = [ 102 cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", 103 "-o", so 104 ] 105 cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs] --> 106 ret = subprocess.check_call(cc_cmd) 108 if ret == 0: 109 return so File /usr/lib64/python3.12/subprocess.py:413, in check_call(*popenargs, **kwargs) 411 if cmd is None: 412 cmd = popenargs[0] --> 413 raise CalledProcessError(retcode, cmd) 414 return 0 CalledProcessError: Command '['/usr/bin/gcc', '/local/scratch/tmp/tmptntrr2uq/main.c', '-O3', '-I/pasteur/appa/homes/jkalfon/scPRINT/scprint1/lib64/python3.12/site-packages/triton/common/../third_party/cuda/include', '-I/usr/include/python3.12', '-I/local/scratch/tmp/tmptntrr2uq', '-shared', '-fPIC', '-lcuda', '-o', '/local/scratch/tmp/tmptntrr2uq/_fwd_kernel.cpython-312-x86_64-linux-gnu.so', '-L/lib64', '-L/lib64']' returned non-zero exit status 1.
prostate_adata
AnnData object with n_obs × n_vars = 83451 × 70704
obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier'
var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'synonyms', 'description', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
uns: 'unseen_genes'
Annotation cleanup¶
scPRINT generates predictions over hundreds of possible labels for each cell.
It is often advised to "cleanup" the predictions, e.g. making sure to remove low frequency cells and misslabellings.
Here, we use the most straightforward approach which is to remove any annotations that appear a small number of times.
A better approach would be doing majority voting over cell clusters as it would aggregate and smoothout the predictions over multiple cells. it would also remove most of the low frequency mistakes in the predictions.
We will also have a look at the embeddings of scPRINT by plotting its UMAP
visualization.
# cleaning up the cell types
prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"] = prostate_adata.obs[
"conv_pred_cell_type_ontology_term_id"
].astype(str)
prostate_adata.obs.loc[
~prostate_adata.obs["conv_pred_cell_type_ontology_term_id"].isin(
[
k
for k, v in prostate_adata.obs["conv_pred_cell_type_ontology_term_id"]
.value_counts()
.items()
if v > 400
]
),
"cleaned_pred_cell_type_ontology_term_id",
] = "other"
prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"].value_counts().plot.pie()
# cleaning up the cell types
prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"] = prostate_adata.obs[
"conv_pred_cell_type_ontology_term_id"
].astype(str)
prostate_adata.obs.loc[
~prostate_adata.obs["conv_pred_cell_type_ontology_term_id"].isin(
[
k
for k, v in prostate_adata.obs["conv_pred_cell_type_ontology_term_id"]
.value_counts()
.items()
if v > 400
]
),
"cleaned_pred_cell_type_ontology_term_id",
] = "other"
prostate_adata.obs["cleaned_pred_cell_type_ontology_term_id"].value_counts().plot.pie()
<Axes: ylabel='count'>
# cleaning up the diseases
prostate_adata.obs["cleaned_pred_disease_ontology_term_id"] = prostate_adata.obs[
"conv_pred_disease_ontology_term_id"
].astype(str)
prostate_adata.obs.loc[
~prostate_adata.obs["conv_pred_disease_ontology_term_id"].isin(
[
k
for k, v in prostate_adata.obs["conv_pred_disease_ontology_term_id"]
.value_counts()
.items()
if v > 1000
]
),
"cleaned_pred_disease_ontology_term_id",
] = "other"
prostate_adata.obs["cleaned_pred_disease_ontology_term_id"].value_counts().plot.pie()
<Axes: ylabel='count'>
prostate_adata
AnnData object with n_obs × n_vars = 83451 × 70704
obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id', 'sprint_leiden', 'cleaned_pred_cell_type_ontology_term_id', 'cleaned_pred_disease_ontology_term_id'
var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'synonyms', 'description', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
uns: 'unseen_genes'
obsm: 'X_scprint_umap', 'scprint_emb'
# do louvain mutliple times
sc.pp.neighbors(prostate_adata, n_neighbors=10, use_rep="scprint_emb")
# using multiple resolutions can help spotting smaller clusters
sc.tl.louvain(prostate_adata, resolution=0.5, key_added="louvain_0.5")
sc.tl.louvain(prostate_adata, resolution=1.0, key_added="louvain_1.0")
# do umap
sc.tl.umap(prostate_adata, n_components=2, random_state=42)
# plot
prostate_adata
AnnData object with n_obs × n_vars = 83451 × 70704
obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id', 'sprint_leiden', 'cleaned_pred_cell_type_ontology_term_id', 'cleaned_pred_disease_ontology_term_id', 'louvain_0.5', 'louvain_1.0'
var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'synonyms', 'description', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
uns: 'unseen_genes', 'neighbors', 'louvain_0.5', 'louvain_1.0', 'umap', 'cleaned_pred_cell_type_ontology_term_id_colors', 'cleaned_pred_disease_ontology_term_id_colors', 'assay_colors', 'disease_colors', 'development_stage_colors'
obsm: 'X_scprint_umap', 'scprint_emb', 'X_umap'
obsp: 'distances', 'connectivities'
sc.pl.embedding(
prostate_adata, basis="X_umap", color=["cleaned_pred_cell_type_ontology_term_id"]
)
sc.pl.embedding(
prostate_adata, basis="X_umap", color=["cleaned_pred_disease_ontology_term_id"]
)
sc.pl.embedding(prostate_adata, basis="X_umap", color=["assay"])
sc.pl.embedding(prostate_adata, basis="X_umap", color=["disease"])
sc.pl.embedding(prostate_adata, basis="X_umap", color=["development_stage"])
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id.value_counts().head(20)
cleaned_pred_cell_type_ontology_term_id basal epithelial cell of prostatic duct 25801 urethra urothelial cell 11050 prostate gland microvascular endothelial cell 10321 CD1c-positive myeloid dendritic cell 6427 other 4567 fibroblast of connective tissue of nonglandular part of prostate 3995 lung pericyte 3045 luminal cell of prostate epithelium 2960 IgG-negative class switched memory B cell 2701 effector CD8-positive, alpha-beta T cell 2499 acinar cell of salivary gland 1689 basophil 1340 glomerular capillary endothelial cell 1066 smooth muscle cell of prostate 1055 mast cell 993 fibroblast of connective tissue of glandular part of prostate 858 effector memory CD8-positive, alpha-beta T cell 775 CD4-positive, alpha-beta thymocyte 698 smooth muscle cell of the pulmonary artery 611 effector CD4-positive, alpha-beta T cell 506 Name: count, dtype: int64
# we save for next time
prostate_adata.write_h5ad("../../data/temp/prostate_combined_latest.h5ad")
prostate_adata = sc.read_h5ad("../../data/temp/prostate_combined_latest.h5ad")
Clustering and differential expression¶
We will now cluster using the louvain algorithm on a kNN graph.
Once we detect a cluster of interest we will perform differential expression analysis on it. Taking as example some B-cell clusters, we will use scanpy's implementation of rank_gene_groups for our differential expression
# check clusters
sc.pl.embedding(
prostate_adata,
basis="X_umap",
color="louvain_0.5",
show=False,
legend_loc="on data",
)
<Axes: title={'center': 'louvain_0.5'}, xlabel='X_umap1', ylabel='X_umap2'>
# check cluster 9
i = 3
loc = prostate_adata.obs["louvain_0.5"] == str(i)
(
prostate_adata.obs[loc].conv_pred_disease_ontology_term_id.value_counts().head(2),
prostate_adata.obs[loc].conv_pred_cell_type_ontology_term_id.value_counts().head(),
)
(conv_pred_disease_ontology_term_id normal 3359 benign prostatic hyperplasia 1441 Name: count, dtype: int64, conv_pred_cell_type_ontology_term_id basal epithelial cell of prostatic duct 2785 urethra urothelial cell 1487 luminal cell of prostate epithelium 260 acinar cell of salivary gland 198 nasal mucosa goblet cell 26 Name: count, dtype: int64)
# check cluster 11
i = 26
loc = prostate_adata.obs["louvain_0.5"] == str(i)
(
prostate_adata.obs[loc]
.cleaned_pred_disease_ontology_term_id.value_counts()
.head(2),
prostate_adata.obs[loc]
.cleaned_pred_cell_type_ontology_term_id.value_counts()
.head(5),
)
(cleaned_pred_disease_ontology_term_id benign prostatic hyperplasia 856 other 1 Name: count, dtype: int64, cleaned_pred_cell_type_ontology_term_id IgG-negative class switched memory B cell 777 basal epithelial cell of prostatic duct 49 other 25 effector CD8-positive, alpha-beta T cell 3 glomerular capillary endothelial cell 2 Name: count, dtype: int64)
prostate_adata.obs[
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id
== "IgG-negative class switched memory B cell"
]["louvain_0.5"].value_counts()
louvain_0.5 21 1202 26 777 27 708 8 5 10 4 2 2 9 1 15 1 12 1 3 0 0 0 1 0 11 0 5 0 13 0 14 0 16 0 6 0 7 0 4 0 19 0 18 0 17 0 20 0 23 0 22 0 25 0 24 0 28 0 29 0 30 0 31 0 32 0 33 0 34 0 35 0 Name: count, dtype: int64
# We have find a nice IgG-negative class switched memory B cell cluster. let's use it and define a clean annotation for a plot
loc = loc & (
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id
== "IgG-negative class switched memory B cell"
)
prostate_adata.obs[loc].cleaned_pred_disease_ontology_term_id.value_counts().head(2)
cleaned_pred_disease_ontology_term_id benign prostatic hyperplasia 776 other 1 Name: count, dtype: int64
# making a "focus" annotation for the B-cell to generate a nice plot of the B-cell cluster only
prostate_adata.obs["focus"] = "other"
prostate_adata.obs.loc[loc, "focus"] = "memory B cell"
prostate_adata.obs.loc[
loc
& (
prostate_adata.obs["cleaned_pred_disease_ontology_term_id"]
== "benign prostatic hyperplasia"
),
"focus",
] = "BPH associated memory B cell"
prostate_adata.obs["focus"].value_counts()
focus other 82674 BPH associated memory B cell 776 memory B cell 1 Name: count, dtype: int64
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
color = sns.color_palette()[1]
fig, ax = plt.subplots(figsize=(2, 2))
rect = patches.Rectangle((0, 0), 1, 1, facecolor=color)
ax.add_patch(rect)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")
plt.show()
# looking at the B-cell cluster. We can see some normal and BPH-associated memory B-cells
sc.pl.embedding(
prostate_adata[
(prostate_adata.obs["louvain_1.0"] == str(i))
& (prostate_adata.obsm["scprint_umap"][:, 0] > 4)
],
basis="scprint_umap",
color="focus",
show=False,
size=8,
title="Switched memory B-cell cluster",
legend_loc="right margin",
)
... storing 'focus' as categorical /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map))
<Axes: title={'center': 'focus'}, xlabel='scprint_umap_rot1', ylabel='scprint_umap_rot2'>
# now the diff expression between B-cells and the rest
sc.tl.rank_genes_groups(
prostate_adata,
groupby="cleaned_pred_cell_type_ontology_term_id",
groups=["IgG-negative class switched memory B cell"],
reference="other",
method="t-test",
)
# Plot the most differentially expressed genes
sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol")
# super strong B cell markers
... storing 'focus' as categorical
WARNING: It seems you use rank_genes_groups on the raw count data. Please logarithmize your data before calling rank_genes_groups.
Denoising and differential expression¶
What we found out from our previous analysis is that there is not a lot of normal (i.e. healthy) B-cells in our cluster, most of them are BPH associated. In this case, if we wanted to compare BPH B-cells to normal B-cells we might be very underpowered...
Instead of going to look for some other dataset, let's use scPRINT to increase
the depth of the expression profile of the cells, virtually adding more signal
to our dataset.
We will use the Denoiser class (see more about the class in our
documentation) in a similar way Trainer
is used in pytorch lightning to denoise the expression profile of the cells.
We will then show the results of differential expression analysis before and after denoising.
# here we compare memory B-cell in BPH to normal memory B cells before denoising
sc.tl.rank_genes_groups(
prostate_adata,
groupby="focus",
groups=["BPH associated memory B cell"],
reference="memory B cell",
method="t-test",
)
# Plot the most differentially expressed genes
sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol")
# super strong B cell markers
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/pandas/core/indexes/base.py:3805, in Index.get_loc(self, key) 3804 try: -> 3805 return self._engine.get_loc(casted_key) 3806 except KeyError as err: File index.pyx:167, in pandas._libs.index.IndexEngine.get_loc() File index.pyx:196, in pandas._libs.index.IndexEngine.get_loc() File pandas/_libs/hashtable_class_helper.pxi:7081, in pandas._libs.hashtable.PyObjectHashTable.get_item() File pandas/_libs/hashtable_class_helper.pxi:7089, in pandas._libs.hashtable.PyObjectHashTable.get_item() KeyError: 'focus' The above exception was the direct cause of the following exception: KeyError Traceback (most recent call last) Cell In[11], line 2 1 # here we compare memory B-cell in BPH to normal memory B cells before denoising ----> 2 sc.tl.rank_genes_groups(prostate_adata, groupby='focus', groups=['BPH associated memory B cell'], reference='memory B cell', method='t-test') 3 # Plot the most differentially expressed genes 4 sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols='symbol') File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/legacy_api_wrap/__init__.py:80, in legacy_api.<locals>.wrapper.<locals>.fn_compatible(*args_all, **kw) 77 @wraps(fn) 78 def fn_compatible(*args_all: P.args, **kw: P.kwargs) -> R: 79 if len(args_all) <= n_positional: ---> 80 return fn(*args_all, **kw) 82 args_pos: P.args 83 args_pos, args_rest = args_all[:n_positional], args_all[n_positional:] File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/tools/_rank_genes_groups.py:637, in rank_genes_groups(adata, groupby, mask_var, use_raw, groups, reference, n_genes, rankby_abs, pts, key_added, copy, method, corr_method, tie_correct, layer, **kwds) 635 if reference != "rest" and reference not in set(groups_order): 636 groups_order += [reference] --> 637 if reference != "rest" and reference not in adata.obs[groupby].cat.categories: 638 cats = adata.obs[groupby].cat.categories.tolist() 639 raise ValueError( 640 f"reference = {reference} needs to be one of groupby = {cats}." 641 ) File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/pandas/core/frame.py:4102, in DataFrame.__getitem__(self, key) 4100 if self.columns.nlevels > 1: 4101 return self._getitem_multilevel(key) -> 4102 indexer = self.columns.get_loc(key) 4103 if is_integer(indexer): 4104 indexer = [indexer] File ~/miniconda3/envs/scprint/lib/python3.10/site-packages/pandas/core/indexes/base.py:3812, in Index.get_loc(self, key) 3807 if isinstance(casted_key, slice) or ( 3808 isinstance(casted_key, abc.Iterable) 3809 and any(isinstance(x, slice) for x in casted_key) 3810 ): 3811 raise InvalidIndexError(key) -> 3812 raise KeyError(key) from err 3813 except TypeError: 3814 # If we have a listlike key, _check_indexing_error will raise 3815 # InvalidIndexError. Otherwise we fall through and re-raise 3816 # the TypeError. 3817 self._check_indexing_error(key) KeyError: 'focus'
# here we compare memory B-cell in BPH to normal memory B cells before denoising
sc.tl.rank_genes_groups(
prostate_adata,
groupby="focus",
groups=["BPH associated memory B cell"],
reference="memory B cell",
method="t-test",
)
# Plot the most differentially expressed genes
sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol")
# super strong B cell markers
WARNING: It seems you use rank_genes_groups on the raw count data. Please logarithmize your data before calling rank_genes_groups.
# perform denoising
denoise = Denoiser(
# could be on top differentially expressed genes or on random expressed genes in the cell (completed by random non expressed gene)
how="random expr",
# the size of the minibatch (need to fit in memory)
batch_size=20,
num_workers=8,
# the number of genes to use
max_len=5000,
# the number of cells to use (here more than what we will use so we will use everything)
max_cells=10_000,
doplot=False,
# how much do we want to increase the depth / counts of the cells (here, 10x)
predict_depth_mult=7,
dtype=dtype,
)
metrics, idx, denoised_adata = denoise(
model, prostate_adata[prostate_adata.obs["focus"] != "other"]
)
100%|██████████| 2/2 [00:58<00:00, 29.10s/it]
AnnData object with n_obs × n_vars = 32 × 44741
obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
obsm: 'scprint_emb'
layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
# The error occurs because we need to ensure the masks have compatible shapes
# For AnnData objects, we need to use the obs_names and var_names for proper indexing
row_mask = prostate_adata.obs_names[prostate_adata.obs["focus"] != "other"]
col_mask = prostate_adata.var_names[
prostate_adata.var.index.isin(denoised_adata.var.index)
]
prostate_adata[row_mask, col_mask].X = denoised_adata.X[
:, denoised_adata.var.index.isin(prostate_adata.var.index)
]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[93], line 5 3 row_mask = prostate_adata.obs_names[prostate_adata.obs['focus'] != "other"] 4 col_mask = prostate_adata.var_names[prostate_adata.var.index.isin(denoised_adata.var.index)] ----> 5 prostate_adata[row_mask, col_mask].X = denoised_adata.X[:, denoised_adata.var.index.isin(prostate_adata.var.index)] File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/anndata/_core/anndata.py:650, in AnnData.X(self, value) 643 warnings.warn( 644 "Trying to set a dense array with a sparse array on a view." 645 "Densifying the sparse array." 646 "This may incur excessive memory usage", 647 stacklevel=2, 648 ) 649 value = value.toarray() --> 650 self._adata_ref._X[oidx, vidx] = value 651 else: 652 self._X = value File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_index.py:134, in IndexMixin.__setitem__(self, key, x) 132 x = x.tocoo(copy=True) 133 x.sum_duplicates() --> 134 self._set_arrayXarray_sparse(i, j, x) 135 else: 136 # Make x and i into the same shape 137 x = np.asarray(x, dtype=self.dtype) File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_compressed.py:826, in _cs_matrix._set_arrayXarray_sparse(self, row, col, x) 824 def _set_arrayXarray_sparse(self, row, col, x): 825 # clear entries that will be overwritten --> 826 self._zero_many(*self._swap((row, col))) 828 M, N = row.shape # matches col.shape 829 broadcast_row = M != 1 and x.shape[0] == 1 File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_compressed.py:946, in _cs_matrix._zero_many(self, i, j) 944 n_samples = len(i) 945 offsets = np.empty(n_samples, dtype=self.indices.dtype) --> 946 ret = csr_sample_offsets(M, N, self.indptr, self.indices, n_samples, 947 i, j, offsets) 948 if ret == 1: 949 # rinse and repeat 950 self.sum_duplicates() KeyboardInterrupt:
denoised_adata
AnnData object with n_obs × n_vars = 777 × 44756
obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
obsm: 'scprint_emb'
layers: 'scprint_mu', 'scprint_pi', 'scprint_theta'
# now what we are doing here it to complete the expression profile with the denoised values. this is not done by default for now
prostate_adata.X[
prostate_adata.obs["focus"] != "other",
prostate_adata.var.index.isin(denoised_adata.var.index),
] = denoised_adata.X[:, denoised_adata.var.index.isin(prostate_adata.var.index)]
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[82], line 2 1 # now what we are doing here it to complete the expression profile with the denoised values. this is not done by default for now ----> 2 prostate_adata.X[prostate_adata.obs['focus']!="other", prostate_adata.var.index.isin(denoised_adata.var.index)] = denoised_adata.X[:, denoised_adata.var.index.isin(prostate_adata.var.index)] File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_index.py:115, in IndexMixin.__setitem__(self, key, x) 112 else: 113 col = np.atleast_1d(col) --> 115 i, j = _broadcast_arrays(row, col) 116 if i.shape != j.shape: 117 raise IndexError('number of row and column indices differ') File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/scipy/sparse/_index.py:19, in _broadcast_arrays(a, b) 10 def _broadcast_arrays(a, b): 11 """ 12 Same as np.broadcast_arrays(a, b) but old writeability rules. 13 (...) 17 the old behavior. 18 """ ---> 19 x, y = np.broadcast_arrays(a, b) 20 x.flags.writeable = a.flags.writeable 21 y.flags.writeable = b.flags.writeable File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/numpy/lib/stride_tricks.py:540, in broadcast_arrays(subok, *args) 533 # nditer is not used here to avoid the limit of 32 arrays. 534 # Otherwise, something like the following one-liner would suffice: 535 # return np.nditer(args, flags=['multi_index', 'zerosize_ok'], 536 # order='C').itviews 538 args = [np.array(_m, copy=False, subok=subok) for _m in args] --> 540 shape = _broadcast_shape(*args) 542 if all(array.shape == shape for array in args): 543 # Common case where nothing needs to be broadcasted. 544 return args File ~/miniconda3/envs/scprnt2/lib/python3.10/site-packages/numpy/lib/stride_tricks.py:422, in _broadcast_shape(*args) 417 """Returns the shape of the arrays that would result from broadcasting the 418 supplied arrays against each other. 419 """ 420 # use the old-iterator because np.nditer does not handle size 0 arrays 421 # consistently --> 422 b = np.broadcast(*args[:32]) 423 # unfortunately, it cannot handle 32 or more arguments directly 424 for pos in range(32, len(args), 31): 425 # ironically, np.broadcast does not properly handle np.broadcast 426 # objects (it treats them as scalars) 427 # use broadcasting to avoid allocating the full array ValueError: shape mismatch: objects cannot be broadcast to a single shape. Mismatch is between arg 0 with shape (777,) and arg 1 with shape (23142,).
# now we compare memory B-cell in BPH to normal memory B cells after denoising
sc.tl.rank_genes_groups(
prostate_adata,
groupby="focus",
groups=["BPH associated memory B cell"],
reference="memory B cell",
method="t-test",
)
# Plot the most differentially expressed genes
sc.pl.rank_genes_groups(prostate_adata, n_genes=25, sharey=False, gene_symbols="symbol")
# super strong B cell markers
prostate_adata.write_h5ad("../../data/temp/prostate_combined_denoised.h5ad")
prostate_adata = sc.read_h5ad("../../data/temp/prostate_combined_denoised.h5ad")
prostate_adata
AnnData object with n_obs × n_vars = 83451 × 70116
obs: 'Sample', 'Lineage', 'Population', 'resolution_0.1', 'resolution_0.2', 'resolution_0.3', 'resolution_0.4', 'resolution_0.5', 'resolution_0.75', 'resolution_1', 'resolution_2', 'resolution_3', 'resolution_4', 'resolution_5', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'percent.ribo', 'Stress1', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nnz', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'pred_cell_type_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id', 'sprint_leiden', 'cleaned_pred_cell_type_ontology_term_id', 'cleaned_pred_disease_ontology_term_id'
var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'description', 'synonyms', 'organism_id', 'public_source_id', 'created_by_id', 'mt', 'ribo', 'hb', 'organism', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
uns: 'assay_colors', 'cleaned_pred_cell_type_ontology_term_id_colors', 'cleaned_pred_disease_ontology_term_id_colors', 'development_stage_colors', 'disease_colors', 'unseen_genes'
obsm: 'X_pca', 'X_umap', 'scprint', 'scprint_umap'
Gene network inference¶
Finally we will use scPRINT to infer gene networks on another cell of interest, the fibroblasts, in both normal and BPH conditions.
We will use the GRNfer class to infer gene networks. (see the
cancer_usecase_part2.ipynb for more details on
how to analyse the gene networks.)
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id.value_counts().head(20)
cleaned_pred_cell_type_ontology_term_id basal epithelial cell of prostatic duct 25740 prostate gland microvascular endothelial cell 12165 urethra urothelial cell 10952 CD1c-positive myeloid dendritic cell 6898 other 5869 aortic smooth muscle cell 3970 effector CD8-positive, alpha-beta T cell 3925 pancreatic acinar cell 3711 luminal cell of prostate epithelium 3435 fibroblast of connective tissue of nonglandular part of prostate 3334 mucous neck cell 3260 IgG-negative class switched memory B cell 2932 basophil 1862 fibroblast of connective tissue of glandular part of prostate 1815 pancreatic ductal cell 1651 CD4-positive, alpha-beta thymocyte 1516 club cell 1226 smooth muscle cell of prostate 1000 effector memory CD8-positive, alpha-beta T cell 991 peptic cell 872 mature conventional dendritic cell 721 effector CD4-positive, alpha-beta T cell 686 mast cell 451 renal interstitial pericyte 440 retinal blood vessel endothelial cell 404 Name: count, dtype: int64
loc = (
prostate_adata.obs.cleaned_pred_cell_type_ontology_term_id
== "fibroblast of connective tissue of glandular part of prostate"
)
prostate_adata.obs[loc]["cleaned_pred_disease_ontology_term_id"].value_counts()
cleaned_pred_disease_ontology_term_id benign prostatic hyperplasia 1482 normal 673 other 4 Name: count, dtype: int64
prostate_adata.obs[loc]["louvain_1.0"].value_counts().head(10)
louvain_1.0 5 1458 4 671 17 10 8 8 7 7 0 3 9 2 6 0 3 0 1 0 10 0 11 0 12 0 13 0 14 0 15 0 16 0 2 0 18 0 Name: count, dtype: int64
loc = loc & (prostate_adata.obs["louvain_1.0"] == str(5))
prostate_adata.obs[loc]["cleaned_pred_disease_ontology_term_id"].value_counts()
cleaned_pred_disease_ontology_term_id benign prostatic hyperplasia 790 normal 664 other 4 Name: count, dtype: int64
prostate_adata.obs["fibro"] = None
prostate_adata.obs.loc[loc, "fibro"] = "fibroblasts"
prostate_adata.obs.loc[
loc
& (
prostate_adata.obs.cleaned_pred_disease_ontology_term_id
== "benign prostatic hyperplasia"
),
"fibro",
] = "BPH associated fibroblasts"
sc.pl.embedding(
prostate_adata[
prostate_adata.obs["louvain_1.0"].isin(["5", "4"])
& (prostate_adata.obsm["scprint_umap"][:, 1] > 10)
],
basis="scprint_umap_rot",
color="fibro",
show=False,
size=8,
title="Fibroblasts cluster",
legend_loc="right margin",
)
... storing 'focus' as categorical ... storing 'fibro' as categorical /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map))
grn_inferer = GNInfer(
# here we use the most variable genes across the fibroblasts vs the rest to find the num_genes genes we want to use to compute our gene network
how="most var across",
# how="random expr",
# we will preprocess the attention matrix with softmax
# this means that instead of being directly the QK^T matrix of the attention operation we apply softmax to the rows (see https://medium.com/@mayanksultania/transformers-101-tokens-attention-and-beyond-b080a900ca6c)
preprocess="softmax",
# we don't aggregate the different gene networks coming from the different heads over the different layers here, we will do it manually afterward
head_agg="none",
# we will not use any filtration here, but you can use "none" or "topk" for top k connections per genes or "thresh" for a defined threshold, and more
filtration="none",
# here if we generate the attention matrices by performing a task, like denoising or by just passing the expression profile through the model
forward_mode="none",
# the number of genes to use (here the 4000 most variable genes)
num_genes=4000,
# the max number of cell to use per cell type
max_cells=300,
doplot=False,
batch_size=16,
# the column in anndata the defines the cell type
cell_type_col="fibro",
dtype=dtype,
# list of transformer layers to extract the Gene Network from
layer=list(range(model.nlayers))[:],
)
# I was missing this from the model (not really necessary)
prostate_adata.obs["fibro"] = prostate_adata.obs["fibro"].astype(str)
prostate_adata.obs[
prostate_adata.obs["fibro"] == "BPH associated fibroblasts"
].disease.value_counts()
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /local/scratch/tmp/ipykernel_3040762/465330363.py in ?() 1 # I was missing this from the model (not really necessary) 2 model.organisms# = ['NCBITaxon:9606', 'NCBITaxon:10090'] ----> 3 prostate_adata.obs.fibro = prostate_adata.obs.fibro.astype(str) 4 prostate_adata.obs[prostate_adata.obs.fibro=="BPH associated fibroblasts"].disease.value_counts() ~/test/test/lib/python3.10/site-packages/pandas/core/generic.py in ?(self, name) 6314 and name not in self._accessors 6315 and self._info_axis._can_hold_identifiers_and_holds_name(name) 6316 ): 6317 return self[name] -> 6318 return object.__getattribute__(self, name) AttributeError: 'DataFrame' object has no attribute 'fibro'
output¶
the output of the function is a new Anndata in the form of a GRnnData object see cantinilab/grnndata
the .varp['GRN'] attribute which contains the full Gene Network in this case (a fully connected directed weighted graph between all 4000 genes selected)
because we selected head_agg to "none" that means it actually is not of size 4000x4000 but 4000x4000x(n_layers*n_heads). by taking the mean over this last dimension with mean-aggregate the different gene networks coming from the different heads over the different layers into a single 4000x4000 matrix.
# compute GRNs on fibroblasts, we use all the atetetion layers
grn = grn_inferer(model, prostate_adata, cell_type="fibroblasts"))
# highlight differential links on genes that are expressed in both
grn.varp["all"] = grn.varp["GRN"].copy()
# now we aggregate the heads by taking their average
grn.varp["GRN"] = grn.varp["GRN"].mean(-1)
grn.write_h5ad("../../data/temp/prostate_fibro_grn_all.h5ad")
# same on the BPH associated fibroblasts
# I wanted to use only the ones that the labellers had defined as coing from BPH
prostate_adata.obs.loc[
(prostate_adata.obs.fibro == "BPH associated fibroblasts")
& (prostate_adata.obs.disease == "benign prostatic hyperplasia"),
"fibro",
] = "true BPH associated fibroblasts"
grn_c = grn_inferer(
model, prostate_adata, cell_type="true cancer associated fibroblasts"
)
# highlight differential links on genes that are expressed in both
grn_c.varp["all"] = grn_c.varp["GRN"].copy()
grn_c.varp["GRN"] = grn_c.varp["GRN"].mean(-1)
grn_c.write_h5ad("../../data/temp/prostate_BPH_fibro_grn_all.h5ad")