Documentation for the model

model description

scprint2.model.model

Classes:

Name Description
scPRINT2

scPRINT2

Bases: LightningModule, PyTorchModelHubMixin

scPRINT-2: Single-Cell Pretrained Regulatory Inference Network Transformer.

A foundation model for single-cell biology that learns cell and gene representations through self-supervised learning on large-scale single-cell RNA-seq data. The model can be used for: - Cell type classification and annotation - Gene expression denoising and imputation - Cell embedding generation for downstream analysis - Gene regulatory network inference via attention patterns - Multi-species gene expression modeling

Architecture Overview
  1. Gene Encoder: Embeds gene identities (optionally with pretrained embeddings)
  2. Expression Encoder: Encodes expression values (continuous, binned, or metacell)
  3. Position Encoder: Optional genomic position encoding
  4. Transformer: Main attention-based encoder (various attention mechanisms)
  5. Cell Transformer: Optional separate transformer for cell embeddings
  6. Decoders: Expression reconstruction, classification, and MVC decoders
The model supports multiple training objectives
  • Masked expression prediction (like BERT)
  • Denoising autoencoding
  • Cell embedding contrastive learning (ECS and CCE losses)
  • Multi-class cell type classification with hierarchical labels
  • Multi-view coding (MVC) for robust representations
Parameters:
  • genes (list | dict) –

    Gene vocabulary. Either a list of gene names or a dict mapping organism names to lists of genes for multi-species models.

  • organisms (list[str]) –

    List of organism ontology term IDs the model supports.

  • d_model (int, default: 256 ) –

    Hidden dimension of the transformer. Defaults to 256.

  • nhead (int, default: 4 ) –

    Number of attention heads. Defaults to 4.

  • nlayers (int, default: 8 ) –

    Number of transformer layers. Defaults to 8.

  • precpt_gene_emb (str, default: None ) –

    Path to parquet file with pretrained gene embeddings. Index should match gene names. Defaults to None.

  • memmap_gene_emb (bool, default: False ) –

    Memory-map gene embeddings for large files. Defaults to False.

  • finetune_gene_emb (bool, default: False ) –

    Add trainable adapter layers on top of frozen pretrained embeddings. Defaults to False.

  • freeze_embeddings (bool, default: True ) –

    Freeze gene embeddings during training. Defaults to True.

  • gene_pos_file (str, default: None ) –

    Path to parquet file with genomic positions. Must have 'pos' column with integer positions. Defaults to None.

  • normalization (str, default: 'sum' ) –

    Expression normalization method. One of: - "sum": Divide by total counts (TPM-like) - "log": Log2(1 + x) transform - "both": Sum normalization then log transform - "raw": No normalization Defaults to "sum".

  • attn_bias (str, default: None ) –

    Path to sparse matrix (.npz) with attention biases (e.g., gene-gene regulatory priors). Defaults to None.

  • expr_encoder_layers (int, default: 3 ) –

    Number of layers in expression encoder MLP. Defaults to 3.

  • attention (str, default: 'normal' ) –

    Attention mechanism type. One of: - "normal": Standard PyTorch attention - "legacy-flash": Flash attention via simpler-flash - "performer": Performer linear attention - "hyper": Compressed hyperbolic attention - "criss-cross": Criss-cross attention Defaults to "normal".

  • expr_emb_style (str, default: 'continuous' ) –

    Expression embedding approach. One of: - "continuous": MLP on continuous expression values - "binned": Learned embeddings for discretized expression bins - "metacell": DeepSet encoder aggregating KNN neighbors Defaults to "continuous".

  • n_input_bins (int, default: 0 ) –

    Number of expression bins when using binned embedding. Required if expr_emb_style="binned". Defaults to 0.

  • mvc_decoder (str, default: None ) –

    Multi-view coding decoder architecture. One of: - None: No MVC decoder - "inner product": Dot product between cell and gene embeddings - "concat query": Concatenate cell embedding with gene queries - "sum query": Add cell embedding to gene queries Defaults to None.

  • pred_embedding (list[str], default: None ) –

    Class names to use for cell embeddings during prediction/logging. Defaults to None (use all).

  • layers_cls (list[int], default: [256, 128] ) –

    Hidden layer sizes for classification heads. Defaults to [256, 128].

  • classes (dict[str, int], default: None ) –

    Classification targets mapping class names to number of categories. E.g., {"cell_type_ontology_term_id": 100}. Defaults to None.

  • labels_hierarchy (dict[str, dict[int, list[int]]], default: {} ) –

    Hierarchical label structure for ontology-based classes. Maps parent indices to lists of children indices. Defaults to {}.

  • label_decoders (dict[str, dict[int, str]], default: None ) –

    Mapping from encoded integers back to label strings for each class. Used for logging/plotting. Defaults to None.

  • compress_class_dim (dict[str, int], default: None ) –

    Compressed embedding dimension for each class. Uses VAE or FSQ compression. Defaults to None.

  • cell_specific_blocks (bool, default: False ) –

    Use separate transformer for cell embeddings with cross-attention to gene transformer. Defaults to False.

  • zinb (bool, default: True ) –

    Use Zero-Inflated Negative Binomial distribution for expression reconstruction. If False, uses MSE loss. Defaults to True.

  • splicing_head (bool, default: False ) –

    Add separate decoder for spliced/unspliced expression. Defaults to False.

  • do_adv_cls (bool, default: False ) –

    Use adversarial classification to remove batch effects from cell type embeddings. Defaults to False.

  • dropout (float, default: 0.1 ) –

    Dropout rate throughout the model. Defaults to 0.1.

  • use_metacell_token (bool, default: False ) –

    Add learnable metacell token to distinguish single cells from metacells. Defaults to False.

  • lr (float, default: 0.0001 ) –

    Base learning rate. Defaults to 0.0001.

  • nb_features (int, default: None ) –

    Number of random features for Performer attention. Defaults to None.

  • sketcher_size (int, default: 200 ) –

    Sketch size for sparse attention methods. Defaults to 200.

  • feature_redraw_interval (int, default: None ) –

    Steps between random feature redraws for Performer. Defaults to None.

  • num_heads_kv (int, default: 4 ) –

    Number of key-value heads (for MQA/GQA). Defaults to 4.

  • d_model_cell (int, default: 128 ) –

    Hidden dim for cell transformer when using cell_specific_blocks. Defaults to 128.

  • nhead_cell (int, default: 4 ) –

    Attention heads for cell transformer. Defaults to 4.

  • nlayers_cell (int, default: 6 ) –

    Layers in cell transformer. Defaults to 6.

  • num_heads_kv_cell (int, default: 4 ) –

    KV heads for cell transformer. Defaults to 4.

  • drop_path_rate (float, default: 0.0 ) –

    Stochastic depth rate. Defaults to 0.0.

  • **attention_kwargs (dict, default: {} ) –

    Additional arguments passed to FlashTransformer.

Attributes:
  • Training (Configuration (set these before training) –

    noise (list[float]): Dropout rates for denoising task. E.g., [0.6]. mask_ratio (list[float]): Mask ratios for masked prediction. E.g., [0.15]. cce_temp (float): Temperature for contrastive loss. cce_scale (float): Weight for contrastive cell embedding loss. ecs_scale (float): Weight for elastic cell similarity loss. ecs_threshold (float): Similarity threshold for ECS loss. mvc_scale (float): Weight for MVC reconstruction loss. class_scale (float): Weight for classification loss. lr_reduce_patience (int): Epochs before reducing learning rate. lr_reduce_factor (float): Factor to reduce learning rate by. warmup_duration (int): Steps for learning rate warmup.

  • Prediction (Configuration (set before predict) –

    predict_mode (str): "none" or "generate" for expression generation. pred_embedding (list[str]): Classes to include in cell embeddings. get_attention_layer (list[int]): Layers to extract attention from. predict_depth_mult (float): Multiplier for depth in generation. pred_log_adata (bool): Whether to log predictions as AnnData.

Example

Initialize model

model = scPrint2( ... genes=gene_list, ... organisms=["NCBITaxon:9606"], ... d_model=512, ... nlayers=12, ... classes={"cell_type_ontology_term_id": 100}, ... )

Configure training

model.noise = [0.4, 0.6] model.mask_ratio = [0.15, 0.3]

Train with PyTorch Lightning

trainer = L.Trainer(max_epochs=100) trainer.fit(model, datamodule)

Generate embeddings

model.pred_embedding = ["cell_type_ontology_term_id"] predictions = trainer.predict(model, datamodule)

Note

The model is designed to work with scDataLoader's DataModule and Collator. Gene order must match between model initialization and data loading.

Methods:

Name Description
add_organism

Add a new organism to an existing model for transfer learning.

configure_optimizers

@see pl.LightningModule

forward

Complete forward pass through the scPRINT-2 model.

log_adata

log_adata will log an adata from predictions.

on_fit_start

@see pl.LightningModule

on_load_checkpoint

Handle checkpoint loading with backward compatibility.

on_predict_epoch_end

@see pl.LightningModule will

on_predict_epoch_start

@see pl.LightningModule

on_test_start

@see pl.LightningModule

on_validation_epoch_end

@see pl.LightningModule

optimizer_step

@see pl.LightningModule

predict_step

embed given gene expression, encode the gene embedding and cell embedding.

training_step

training_step defines the train loop. It is independent of forward

validation_step

validation_step defines the validation loop. It is independent of forward

Attributes:
  • genes (list[str]) –

    Get flattened list of all genes in the model's vocabulary.

Source code in scprint2/model/model.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def __init__(
    self,
    genes,
    organisms: list[str],
    d_model: int = 256,
    nhead: int = 4,
    nlayers: int = 8,
    precpt_gene_emb: Optional[str] = None,
    memmap_gene_emb: bool = False,
    finetune_gene_emb: bool = False,
    freeze_embeddings: bool = True,
    gene_pos_file: Optional[str] = None,
    normalization: str = "sum",  # log, sum, both, raw
    attn_bias: Optional[str] = None,
    expr_encoder_layers: int = 3,
    attention: str = "normal",  # "performer", "legacy-flash", "normal", "criss-cross", "hyper", "adasplash", "softpick", "softpick-flash"
    expr_emb_style: str = "continuous",  # "binned", "continuous", "metacell"
    n_input_bins: int = 0,
    mvc_decoder: Optional[
        str
    ] = None,  # "inner product", "concat query", "sum query"
    pred_embedding: Optional[list[str]] = None,
    layers_cls: list[int] = [256, 128],
    classes: Optional[Dict[str, int]] = None,
    labels_hierarchy: Dict[str, Dict[int, list[int]]] = {},
    label_decoders: Optional[Dict[str, Dict[int, str]]] = None,
    compress_class_dim: Optional[Dict[str, int]] = None,
    cell_specific_blocks: bool = False,
    zinb: bool = True,
    splicing_head: bool = False,
    do_adv_cls: bool = False,
    dropout: float = 0.1,
    use_metacell_token: bool = False,
    lr: float = 0.0001,
    nb_features: Optional[int] = None,
    sketcher_size: int = 200,
    feature_redraw_interval: Optional[int] = None,
    num_heads_kv: int = 4,
    d_model_cell: int = 128,
    nhead_cell: int = 4,
    nlayers_cell: int = 6,
    num_heads_kv_cell: int = 4,
    transformer=None,
    drop_path_rate: float = 0.0,
    # unused args from older versions kept for loading old models
    gene_pos_enc=None,
    max_cont_len=None,
    residual_in_fp32=None,
    checkpointing=None,
    fused_dropout_add_ln=None,
    strict_loading=None,
    optim=None,
    weight_decay=None,
    prenorm=None,
    domain_spec_batchnorm=None,
    use_flash_attn=None,
    cell_emb_style=None,
    num_batch_labels=None,
    fused_mlp=None,
    fused_bias_fc=None,
    **attention_kwargs: dict,
):
    """
    scPRINT-2: Single-Cell Pretrained Regulatory Inference Network Transformer.

    A foundation model for single-cell biology that learns cell and gene representations
    through self-supervised learning on large-scale single-cell RNA-seq data. The model
    can be used for:
    - Cell type classification and annotation
    - Gene expression denoising and imputation
    - Cell embedding generation for downstream analysis
    - Gene regulatory network inference via attention patterns
    - Multi-species gene expression modeling

    Architecture Overview:
        1. Gene Encoder: Embeds gene identities (optionally with pretrained embeddings)
        2. Expression Encoder: Encodes expression values (continuous, binned, or metacell)
        3. Position Encoder: Optional genomic position encoding
        4. Transformer: Main attention-based encoder (various attention mechanisms)
        5. Cell Transformer: Optional separate transformer for cell embeddings
        6. Decoders: Expression reconstruction, classification, and MVC decoders

    The model supports multiple training objectives:
        - Masked expression prediction (like BERT)
        - Denoising autoencoding
        - Cell embedding contrastive learning (ECS and CCE losses)
        - Multi-class cell type classification with hierarchical labels
        - Multi-view coding (MVC) for robust representations

    Args:
        genes (list | dict): Gene vocabulary. Either a list of gene names or a dict
            mapping organism names to lists of genes for multi-species models.
        organisms (list[str]): List of organism ontology term IDs the model supports.
        d_model (int, optional): Hidden dimension of the transformer. Defaults to 256.
        nhead (int, optional): Number of attention heads. Defaults to 4.
        nlayers (int, optional): Number of transformer layers. Defaults to 8.
        precpt_gene_emb (str, optional): Path to parquet file with pretrained gene
            embeddings. Index should match gene names. Defaults to None.
        memmap_gene_emb (bool, optional): Memory-map gene embeddings for large files.
            Defaults to False.
        finetune_gene_emb (bool, optional): Add trainable adapter layers on top of
            frozen pretrained embeddings. Defaults to False.
        freeze_embeddings (bool, optional): Freeze gene embeddings during training.
            Defaults to True.
        gene_pos_file (str, optional): Path to parquet file with genomic positions.
            Must have 'pos' column with integer positions. Defaults to None.
        normalization (str, optional): Expression normalization method. One of:
            - "sum": Divide by total counts (TPM-like)
            - "log": Log2(1 + x) transform
            - "both": Sum normalization then log transform
            - "raw": No normalization
            Defaults to "sum".
        attn_bias (str, optional): Path to sparse matrix (.npz) with attention biases
            (e.g., gene-gene regulatory priors). Defaults to None.
        expr_encoder_layers (int, optional): Number of layers in expression encoder MLP.
            Defaults to 3.
        attention (str, optional): Attention mechanism type. One of:
            - "normal": Standard PyTorch attention
            - "legacy-flash": Flash attention via simpler-flash
            - "performer": Performer linear attention
            - "hyper": Compressed hyperbolic attention
            - "criss-cross": Criss-cross attention
            Defaults to "normal".
        expr_emb_style (str, optional): Expression embedding approach. One of:
            - "continuous": MLP on continuous expression values
            - "binned": Learned embeddings for discretized expression bins
            - "metacell": DeepSet encoder aggregating KNN neighbors
            Defaults to "continuous".
        n_input_bins (int, optional): Number of expression bins when using binned
            embedding. Required if expr_emb_style="binned". Defaults to 0.
        mvc_decoder (str, optional): Multi-view coding decoder architecture. One of:
            - None: No MVC decoder
            - "inner product": Dot product between cell and gene embeddings
            - "concat query": Concatenate cell embedding with gene queries
            - "sum query": Add cell embedding to gene queries
            Defaults to None.
        pred_embedding (list[str], optional): Class names to use for cell embeddings
            during prediction/logging. Defaults to None (use all).
        layers_cls (list[int], optional): Hidden layer sizes for classification heads.
            Defaults to [256, 128].
        classes (dict[str, int], optional): Classification targets mapping class names
            to number of categories. E.g., {"cell_type_ontology_term_id": 100}.
            Defaults to None.
        labels_hierarchy (dict[str, dict[int, list[int]]], optional): Hierarchical
            label structure for ontology-based classes. Maps parent indices to lists
            of children indices. Defaults to {}.
        label_decoders (dict[str, dict[int, str]], optional): Mapping from encoded
            integers back to label strings for each class. Used for logging/plotting.
            Defaults to None.
        compress_class_dim (dict[str, int], optional): Compressed embedding dimension
            for each class. Uses VAE or FSQ compression. Defaults to None.
        cell_specific_blocks (bool, optional): Use separate transformer for cell
            embeddings with cross-attention to gene transformer. Defaults to False.
        zinb (bool, optional): Use Zero-Inflated Negative Binomial distribution for
            expression reconstruction. If False, uses MSE loss. Defaults to True.
        splicing_head (bool, optional): Add separate decoder for spliced/unspliced
            expression. Defaults to False.
        do_adv_cls (bool, optional): Use adversarial classification to remove batch
            effects from cell type embeddings. Defaults to False.
        dropout (float, optional): Dropout rate throughout the model. Defaults to 0.1.
        use_metacell_token (bool, optional): Add learnable metacell token to distinguish
            single cells from metacells. Defaults to False.
        lr (float, optional): Base learning rate. Defaults to 0.0001.
        nb_features (int, optional): Number of random features for Performer attention.
            Defaults to None.
        sketcher_size (int, optional): Sketch size for sparse attention methods.
            Defaults to 200.
        feature_redraw_interval (int, optional): Steps between random feature redraws
            for Performer. Defaults to None.
        num_heads_kv (int, optional): Number of key-value heads (for MQA/GQA).
            Defaults to 4.
        d_model_cell (int, optional): Hidden dim for cell transformer when using
            cell_specific_blocks. Defaults to 128.
        nhead_cell (int, optional): Attention heads for cell transformer. Defaults to 4.
        nlayers_cell (int, optional): Layers in cell transformer. Defaults to 6.
        num_heads_kv_cell (int, optional): KV heads for cell transformer. Defaults to 4.
        drop_path_rate (float, optional): Stochastic depth rate. Defaults to 0.0.
        **attention_kwargs (dict): Additional arguments passed to FlashTransformer.

    Attributes:
        Training Configuration (set these before training):
            noise (list[float]): Dropout rates for denoising task. E.g., [0.6].
            mask_ratio (list[float]): Mask ratios for masked prediction. E.g., [0.15].
            cce_temp (float): Temperature for contrastive loss.
            cce_scale (float): Weight for contrastive cell embedding loss.
            ecs_scale (float): Weight for elastic cell similarity loss.
            ecs_threshold (float): Similarity threshold for ECS loss.
            mvc_scale (float): Weight for MVC reconstruction loss.
            class_scale (float): Weight for classification loss.
            lr_reduce_patience (int): Epochs before reducing learning rate.
            lr_reduce_factor (float): Factor to reduce learning rate by.
            warmup_duration (int): Steps for learning rate warmup.

        Prediction Configuration (set before predict):
            predict_mode (str): "none" or "generate" for expression generation.
            pred_embedding (list[str]): Classes to include in cell embeddings.
            get_attention_layer (list[int]): Layers to extract attention from.
            predict_depth_mult (float): Multiplier for depth in generation.
            pred_log_adata (bool): Whether to log predictions as AnnData.

    Example:
        >>> # Initialize model
        >>> model = scPrint2(
        ...     genes=gene_list,
        ...     organisms=["NCBITaxon:9606"],
        ...     d_model=512,
        ...     nlayers=12,
        ...     classes={"cell_type_ontology_term_id": 100},
        ... )
        >>>
        >>> # Configure training
        >>> model.noise = [0.4, 0.6]
        >>> model.mask_ratio = [0.15, 0.3]
        >>>
        >>> # Train with PyTorch Lightning
        >>> trainer = L.Trainer(max_epochs=100)
        >>> trainer.fit(model, datamodule)
        >>>
        >>> # Generate embeddings
        >>> model.pred_embedding = ["cell_type_ontology_term_id"]
        >>> predictions = trainer.predict(model, datamodule)

    Note:
        The model is designed to work with scDataLoader's DataModule and Collator.
        Gene order must match between model initialization and data loading.
    """
    super().__init__()
    self.save_hyperparameters()
    # training flags
    self.noise = [0.6]
    self.cce_temp = 0.3
    self.lr = lr
    self.cce_scale = 0.2
    self.ecs_threshold = 0.4
    self.ecs_scale = 0.2
    self.mvc_scale = 1.0
    self.class_embd_diss_scale = 0.3
    self.adv_class_scale = 1.0
    self.do_adv_cls = do_adv_cls
    self.run_full_forward = True
    self.class_scale = 1
    self.zinb_and_mse = False
    self.do_next_tp = False
    self.do_generate = False
    self.var_context_length = False
    self.mask_ratio = []
    self.warmup_duration = 500
    self.weight_decay = 0.01
    self.optim = "adamW"
    self.fused_adam = False
    self.lr_reduce_patience = 2
    self.lr_reduce_factor = 0.6
    self.test_every = 20
    self.randsamp = True
    self.lr_reduce_monitor = "val_loss"
    self.name = ""
    self.set_step = None
    self.lrfinder_steps = 0
    self.doplot = False
    self.get_attention_layer = None
    self.embs = None
    self.pred_log_adata = True
    self.predict_depth_mult = 3
    self.predict_mode = "none"
    self.keep_all_labels_pred = False
    self.mask_zeros = False
    self.vae_kl_scale = 0.05
    self.vae_kl_warmup_steps = 40_000  # Default value, can be adjusted
    self.save_expr = False
    self.counter = 0

    # should be stored somehow
    self.d_model = d_model
    self.normalization = normalization
    self.attn_bias = attn_bias if attn_bias != "none" else None
    self.organisms = organisms
    self.nlayers = nlayers
    self.use_metacell_token = use_metacell_token
    self.mvc_decoder = mvc_decoder
    # need to store
    self.n_input_bins = n_input_bins
    self.attention = attention

    if classes is None:
        classes = {}
    self.label_counts = classes
    self.classes = list(classes.keys())

    self.label_decoders = label_decoders
    self.pred_embedding = pred_embedding
    self._genes = genes
    self.expr_emb_style = expr_emb_style
    if labels_hierarchy is None:
        labels_hierarchy = {}
    self.labels_hierarchy = labels_hierarchy
    self.hparams["classes"] = classes
    self.hparams["label_decoders"] = label_decoders
    self.hparams["organisms"] = organisms
    self.hparams["use_metacell_token"] = use_metacell_token
    # 20x more likely to drop a non TF compared to a TF
    self.tf_masker = WeightedMasker(self.genes, tf_weight=0.05)
    self.attn = utils.Attention(
        len(self.genes),
        additional_tokens=(
            (1 if self.use_metacell_token else 0)
            + ((len(classes) + 1) if not cell_specific_blocks else 0)
        ),
    )

    self.mat_labels_hierarchy = {}
    for k, v in labels_hierarchy.items():
        tens = torch.zeros((len(v), classes[k]))
        for k2, v2 in v.items():
            tens[k2 - classes[k], v2] = 1
        self.mat_labels_hierarchy[k] = tens.to(bool)

    # encoder
    # gene encoder
    if gene_pos_file is not None:
        gene_pos_enc = pd.read_parquet(gene_pos_file)
        if len(gene_pos_enc) < len(self.genes):
            print("Warning: only a subset of the genes available in the loc file.")
        for k, v in self._genes.items():
            tokeep = set(gene_pos_enc.index.tolist())
            self._genes[k] = [u for u in v if u in tokeep]
            if len(self._genes[k]) < 100:
                raise ValueError(
                    f"the gene pos file {gene_pos_file} does not match most of the genes given to the model for species {k}"
                )
        gene_pos_enc = gene_pos_enc.loc[self.genes, ["pos"]]

    if precpt_gene_emb is not None:
        embeddings = pd.read_parquet(precpt_gene_emb)
        if len(embeddings) < len(self.genes):
            print(
                "Warning: only a subset of the genes available in the embeddings file."
            )
        for k, v in self._genes.items():
            tokeep = set(embeddings.index.tolist())
            self._genes[k] = [u for u in v if u in tokeep]
            if len(self._genes[k]) < 100:
                raise ValueError(
                    f"the gene embeddings file {precpt_gene_emb} does not match most of the genes given to the model for species {k}"
                )
        embeddings = embeddings.loc[self.genes]
        print("number of genes: ", len(embeddings))
        if not memmap_gene_emb:
            sembeddings = torch.nn.AdaptiveAvgPool1d(d_model)(
                torch.tensor(embeddings.values, dtype=torch.float32)
            )
        else:
            embeddings = None
        gene_encoder = encoders.GeneEncoder(
            len(self.genes),
            d_model,
            weights_file=precpt_gene_emb if memmap_gene_emb else None,
            weights=sembeddings if not memmap_gene_emb else None,
            freeze=freeze_embeddings,
        )
    else:
        gene_encoder = encoders.GeneEncoder(
            len(self.genes), d_model, freeze=freeze_embeddings
        )
    if finetune_gene_emb:
        if not freeze_embeddings:
            raise ValueError(
                "finetune_gene_emb is True but freeze_embeddings is False"
            )
        # Create adapter layers after the frozen base encoder
        self.gene_encoder = torch.nn.Sequential(
            gene_encoder,
            torch.nn.Linear(d_model, d_model),
            torch.nn.ReLU(),
            torch.nn.Linear(d_model, d_model),
        )
    else:
        self.gene_encoder = gene_encoder
    # Positional Encoding
    if gene_pos_file is not None:
        # redoing it just in case some were dropped with embbeding file step
        gene_pos_enc = gene_pos_enc.loc[self.genes, "pos"].astype(int).tolist()
        self.pos_encoder = encoders.PositionalEncoding(
            d_model, gene_pos_enc=gene_pos_enc
        )
    else:
        self.pos_encoder = None
    # Value Encoder, NOTE: the scaling style is also handled in _encode method
    expr_d_model = d_model  # // 8 if finetune_gene_emb else d_model
    if expr_emb_style in "continuous":
        expr_encoder = encoders.ContinuousValueEncoder(
            expr_d_model, dropout, layers=expr_encoder_layers
        )
    elif expr_emb_style == "binned":
        assert n_input_bins > 0
        assert normalization == "raw", "shouldn't use normalization"
        expr_encoder = encoders.CategoryValueEncoder(n_input_bins, expr_d_model)
    elif expr_emb_style == "metacell":
        expr_encoder = encoders.EasyExprGNN(
            self_dim=expr_d_model * 2,
            output_dim=expr_d_model,
            shared_layers=expr_encoder_layers,
            dropout=dropout,
        )
    else:
        raise ValueError(
            f"expr_emb_style should be one of binned, continuous, metacell, "
            f"got {expr_emb_style}"
        )
    if finetune_gene_emb and False:
        self.expr_encoder = encoders.ExprBasedFT(
            d_model,
            gene_encoder,
            expr_encoder,
            dropout,
            layers=expr_encoder_layers,
            intermediary_d=int(d_model * 1.5),
        )
    else:
        self.expr_encoder = expr_encoder

    # Class Encoder
    # always have [base_cell_emb, time_embedding, depth_embedding] + any other class info
    # base cell embedding will store other cell specific information
    self.class_encoder = encoders.CategoryValueEncoder(
        len(self.classes) + 1,
        d_model if not cell_specific_blocks else d_model_cell,
    )

    if self.use_metacell_token:
        self.metacell_encoder = encoders.CategoryValueEncoder(2, d_model)
    # compute tensor for mat_labels_hierarchy
    # old parameters that can still be passed when loading older models (managed in the _on_load_ckpt function)
    for i in [
        "strict_loading",
        "optim",
        "weight_decay",
        "d_hid",
        "edge_dim",
        "prenorm",
        "domain_spec_batchnorm",
        "use_flash_attn",
        "cell_emb_style",
        "num_batch_labels",
        "transformer",
        "residual_in_fp32",
        "max_cont_len",
    ]:
        if i in attention_kwargs:
            attention_kwargs.pop(i)
    # attention
    # Linear
    if attention == "linear":
        # linear attention using the fast attention package
        # self.attention = FastattentionEncoder(
        #    d_model, nhead, d_hid, nlayers, dropout, "linear"
        # )
        raise NotImplementedError("Linear attention is not implemented")
    elif attention == "performer":
        self.transformer = Performer(
            dim=d_model,
            depth=nlayers,
            heads=nhead,
            dim_head=d_model // nhead,
            causal=False,
            attn_dropout=dropout,
            ff_dropout=dropout,
            qkv_bias=True,
            nb_features=nb_features,
            feature_redraw_interval=feature_redraw_interval,
        )
    else:
        self.transformer = FlashTransformer(
            d_model=d_model,
            nhead=nhead,
            dropout=dropout,
            attn_dropout=dropout,
            nlayers=nlayers,
            cross_attn=cell_specific_blocks,
            cross_dim=d_model_cell,
            attn_type="flash" if attention == "legacy-flash" else attention,
            num_heads_kv=num_heads_kv,
            sketcher_size=sketcher_size,
            drop_path_rate=drop_path_rate,
            **attention_kwargs,
        )
    if cell_specific_blocks:
        attention_kwargs.pop("num_heads_kv", None)
        self.cell_transformer = FlashTransformer(
            d_model=d_model_cell,
            nhead=nhead_cell,
            num_heads_kv=num_heads_kv_cell,
            nlayers=nlayers_cell,
            dropout=dropout,
            cross_attn=True,
            cross_dim=d_model,
            attn_type="flash" if attention == "legacy-flash" else "normal",
            **attention_kwargs,
        )
    else:
        self.cell_transformer = None

    # decoders
    # expression
    self.splicing_head = None
    if expr_emb_style == "binned":
        self.expr_decoder = decoders.ClsDecoder(
            d_model,
            n_input_bins,
            layers=[d_model // 2, d_model // 4],
            dropout=dropout,
        )
    else:
        self.expr_decoder = decoders.ExprDecoder(
            d_model,
            dropout=dropout,
            zinb=zinb,
            use_depth=True,
        )
        if splicing_head:
            self.splicing_head = decoders.ExprDecoder(
                d_model,
                dropout=dropout,
                zinb=zinb,
                use_depth=True,
            )
    # cls decoder
    self.cls_decoders = torch.nn.ModuleDict()
    # should be a very simple classifier for most things
    # (maybe scale with the number of classes) should be 1 layer...
    for clss, n_cls in classes.items():
        mdim = d_model_cell if cell_specific_blocks else self.d_model
        dim = compress_class_dim[clss] if compress_class_dim is not None else mdim
        self.cls_decoders[clss] = decoders.ClsDecoder(
            dim,
            n_cls,
            layers=layers_cls,
            dropout=dropout,
        )
    if "cell_type_ontology_term_id" in classes and self.do_adv_cls:
        mdim = d_model_cell if cell_specific_blocks else self.d_model
        dim = (
            compress_class_dim["cell_type_ontology_term_id"]
            if compress_class_dim is not None
            else mdim
        )
        if "assay_ontology_term_id" in classes:
            self.assay_relab = utils.relabel_assay_for_adv(
                self.label_decoders, self.labels_hierarchy
            )
            self.adv_assay_decoder = decoders.ClsDecoder(
                dim,
                len(set(self.assay_relab.values())),
                layers=layers_cls,
                dropout=dropout,
            )
        if len(self.organisms) > 1:
            self.adv_organism_decoder = decoders.ClsDecoder(
                dim,
                len(self.organisms),
                layers=layers_cls,
                dropout=dropout,
            )
    # expression decoder from batch embbedding
    if mvc_decoder is not None:
        if cell_specific_blocks:
            raise ValueError(
                "MVC decoder is not supported for cell specific blocks"
            )
        self.mvc_decoder = decoders.MVCDecoder(
            d_model, arch_style=mvc_decoder, zinb=zinb, use_depth=True
        )
    else:
        self.mvc_decoder = None

    if compress_class_dim is not None:
        self.compressor = torch.nn.ModuleDict()
        dim = d_model_cell if cell_specific_blocks else self.d_model
        for k, v in compress_class_dim.items():
            if v >= 8:
                self.compressor[k] = decoders.VAEDecoder(
                    dim,
                    layers=[
                        128,
                        v,
                    ],
                    dropout=dropout,
                    return_latent=True,
                )
            else:
                self.compressor[k] = fsq.FSQ(levels=[2] * v, dim=dim)
    else:
        self.compressor = None

    self.apply(
        partial(
            utils._init_weights,
            n_layer=nlayers,
        )
    )
    for i, dec in self.cls_decoders.items():
        torch.nn.init.constant_(dec.out_layer.bias, -0.13)
    self.expr_encoder._init_weights()

genes property

Get flattened list of all genes in the model's vocabulary.

For multi-organism models, concatenates genes from all organisms in consistent order.

Returns:
  • list[str]

    list[str]: Gene names in model vocabulary order.

add_organism

Add a new organism to an existing model for transfer learning.

Extends the gene vocabulary and embeddings to include genes from a new organism. Useful for applying a pretrained model to a new species.

Parameters:
  • organism (str) –

    Organism ontology term ID (e.g., "NCBITaxon:10090" for mouse).

  • genes (Index) –

    Gene names/IDs for the new organism.

  • emb (DataFrame) –

    Gene embeddings DataFrame with genes as index. Will be resized to match model's d_model.

  • locs (DataFrame, default: None ) –

    Genomic positions with 'pos' column. Required if model uses positional encoding. Defaults to None.

Raises:
  • ValueError

    If model requires gene locations but none provided.

  • ValueError

    If gene positions exceed model's maximum position encoding.

Note

Only genes present in both genes and emb (and locs if provided) will be added. The model's gene encoder is expanded in-place.

Source code in scprint2/model/model.py
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
def add_organism(
    self, organism: str, genes: pd.Index, emb: pd.DataFrame, locs=None
):
    """
    Add a new organism to an existing model for transfer learning.

    Extends the gene vocabulary and embeddings to include genes from a new
    organism. Useful for applying a pretrained model to a new species.

    Args:
        organism (str): Organism ontology term ID (e.g., "NCBITaxon:10090" for mouse).
        genes (pd.Index): Gene names/IDs for the new organism.
        emb (pd.DataFrame): Gene embeddings DataFrame with genes as index.
            Will be resized to match model's d_model.
        locs (pd.DataFrame, optional): Genomic positions with 'pos' column.
            Required if model uses positional encoding. Defaults to None.

    Raises:
        ValueError: If model requires gene locations but none provided.
        ValueError: If gene positions exceed model's maximum position encoding.

    Note:
        Only genes present in both `genes` and `emb` (and `locs` if provided)
        will be added. The model's gene encoder is expanded in-place.
    """
    if self.pos_encoder is not None and locs is None:
        raise ValueError("this model needs gene locations to add a new organism")

    self.organisms.append(organism)
    if locs is not None:
        overlap = set(locs.index) & set(emb.index) & set(genes.index)
        genes = genes[genes.index.isin(overlap)]
        locs = locs.loc[genes.index]
        pos = locs["pos"]
        token_to_pos = {token: pos for token, pos in enumerate(pos)}
        if self.pos_encoder.pe.shape[0] < max(pos):
            raise ValueError(
                f"the number of gene locs in the added organism needs to be less than {self.pos_encoder.pe.shape[0]}"
            )

        token_to_pos = {token: pos for token, pos in enumerate(pos)}
        arr = []
        for _, v in token_to_pos.items():
            arr.append(self.pos_encoder.pe[v - 1].to("cpu").numpy())
        pe = torch.Tensor(np.array(arr)).to(self.pos_encoder.pe.device)
        self.pos_encoder.pe = torch.cat([self.pos_encoder.pe, pe], dim=0)
    else:
        overlap = set(emb.index) & set(genes.index)
        genes = genes[genes.index.isin(overlap)]

    emb = emb.loc[genes.index]
    self._genes[organism] = genes.index.tolist()
    if self.gene_encoder is None:
        genc = self.expr_encoder.gene_encoder
    else:
        genc = self.gene_encoder
    if type(genc) is torch.nn.Sequential:
        enc = genc[0]
    else:
        enc = genc
    semb = torch.nn.AdaptiveAvgPool1d(self.d_model)(
        torch.tensor(emb.values, dtype=torch.float32)
    ).to(enc.embeddings.weight.data.device)
    if enc.memmap:
        print("todev.. will fail for now")

    embs = torch.cat([enc.embeddings.weight.data, semb], dim=0)
    enc.embeddings = nn.Embedding(
        embs.shape[0],
        embs.shape[1],
        padding_idx=None,
        _freeze=enc.embeddings.weight.requires_grad,
    )
    enc.embeddings.weight.data.copy_(embs)
    enc.embeddings.weight.data = enc.embeddings.weight.data.to(self.device)
    if type(genc) is torch.nn.Sequential:
        genc[0] = enc
    else:
        genc = enc
    if self.gene_encoder is None:
        self.expr_encoder.gene_encoder = genc
    else:
        self.gene_encoder = genc

configure_optimizers

@see pl.LightningModule

Source code in scprint2/model/model.py
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
def configure_optimizers(self):
    """@see pl.LightningModule"""
    # https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
    # not working because of poor weight decay implem
    if self.optim == "adam":
        optimizer = optim.Adam(
            self.parameters(),
            lr=self.hparams.lr,
            betas=(0.9, 0.999),
            eps=1e-7,  # 1e-5 to 1e-8
            weight_decay=self.weight_decay,
            amsgrad=False,
            fused=self.fused_adam,
        )
    elif self.optim == "adamW":
        optimizer = optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            betas=(0.9, 0.999),
            eps=1e-7,  # 1e-5 to 1e-8
            weight_decay=self.weight_decay,
            amsgrad=False,
            fused=self.fused_adam,
        )
    elif self.optim == "galore":
        raise NotImplementedError("Galore optimizer not implemented")
        # param_groups = [
        #    {
        #        "params": [
        #            v for k, v in self.named_parameters() if "transformer" not in k
        #        ]
        #    },
        #    {
        #        "params": [
        #            v for k, v in self.named_parameters() if "transformer" in k
        #        ],
        #        "rank": 128,
        #        "update_proj_gap": 200,
        #        "scale": 0.25,
        #        "proj_type": "std",
        #    },
        # ]
        # optimizer = GaLoreAdamW(param_groups, lr=self.hparams.lr)
    else:
        raise ValueError(f"Unknown optimizer: {self.optim}")
    if self.lr_reduce_monitor is None:
        print("no lr reduce factor")
        return [optimizer]
    # lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    #    optimizer,
    #    T_0=20000,
    #    T_mult=2,
    #    eta_min=1e-8,
    # )
    # interval = "step"
    # frequency = 10
    # lr_scheduler = optim.lr_scheduler.ExponentialLR(
    #    optimizer,
    #    gamma=0.85,
    # )
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        patience=self.lr_reduce_patience,
        factor=self.lr_reduce_factor,
    )
    interval = "epoch"
    frequency = 1
    # lr_scheduler = StepwiseCAWRWithWD(
    #     optimizer,
    #     T_0=20_000,
    #     T_mult=2,
    #     eta_min=1e-8,
    #     wd_decay=0.9
    # )
    lr_dict = {
        "scheduler": lr_scheduler,
        # The unit of the scheduler's step size, could also be 'step'.
        # 'epoch' updates the scheduler on epoch end whereas 'step'
        # updates it after a optimizer update.
        "interval": interval,
        # How many epochs/steps should pass between calls to
        # `scheduler.step()`. 1 corresponds to updating the learning
        # rate after every epoch/step.
        "frequency": frequency,
        # Metric to to monitor for schedulers like `ReduceLROnPlateau`
        "monitor": self.lr_reduce_monitor,
    }
    self.lrfinder_steps = 0
    for val in self.trainer.callbacks:
        if type(val) is _LRCallback:
            self.lrfinder_steps = val.num_training
        if type(val) is LearningRateFinder:
            self.lrfinder_steps = val._num_training_steps
    return [optimizer], [lr_dict]

forward

Complete forward pass through the scPRINT-2 model.

Encodes input expression data, processes through transformer(s), and decodes into expression predictions and cell classifications.

Parameters:
  • gene_pos (Tensor) –

    Gene indices of shape (batch, seq_len) mapping to positions in the model's gene vocabulary.

  • expression (Tensor, default: None ) –

    Expression values of shape (batch, seq_len). Can be raw counts or normalized depending on model config. Defaults to None.

  • neighbors (Tensor, default: None ) –

    KNN neighbor expressions of shape (batch, n_neighbors, seq_len) for metacell-style encoding. Defaults to None.

  • neighbors_info (Tensor, default: None ) –

    Neighbor weights of shape (batch, n_neighbors). Defaults to None.

  • mask (Tensor, default: None ) –

    Boolean mask of shape (batch, seq_len) where True indicates positions to mask (set to zero). Defaults to None.

  • req_depth (Tensor, default: None ) –

    Target sequencing depth of shape (batch,) for depth-conditional generation. Defaults to None.

  • get_gene_emb (bool, default: False ) –

    Return gene embeddings from transformer. Defaults to False.

  • metacell_token (Tensor, default: None ) –

    Binary metacell indicators of shape (batch,). Defaults to None.

  • depth_mult (Tensor, default: None ) –

    Expression depth multiplier. If None, uses sum of expression values. Defaults to None.

  • do_sample (bool, default: False ) –

    Sample from predicted distribution. Currently unused. Defaults to False.

  • do_mvc (bool, default: False ) –

    Compute multi-view coding predictions. Defaults to False.

  • do_class (bool, default: False ) –

    Compute classification predictions. Defaults to False.

  • get_attention_layer (list[int], default: None ) –

    Layer indices to extract attention weights from. Defaults to None.

  • mask_zeros (Tensor, default: None ) –

    Boolean mask for zero-expression genes of shape (batch, seq_len + num_special_tokens). Defaults to None.

Returns:
  • Dict[str, Tensor] | tuple[Dict[str, Tensor], list]

    dict[str, Tensor] | tuple[dict, list]: Model outputs containing: - "mean": Predicted expression (batch, seq_len) - "disp": Dispersion parameters (batch, seq_len) [if ZINB] - "zero_logits": Zero-inflation logits (batch, seq_len) [if ZINB] - "input_cell_embs": Cell embeddings (batch, n_classes+1, d_model) - "input_cell_emb": Mean cell embedding (batch, d_model) - "output_cell_embs": Processed cell embeddings - "output_cell_emb": Final cell embedding - "cls_output_{class}": Classification logits for each class - "gene_embedding": Gene embeddings [if get_gene_emb] - "mvc_*": MVC predictions [if do_mvc]

    If get_attention_layer is not None, returns (outputs_dict, attention_list) where attention_list contains QKV tensors from specified layers.

Example

output = model( ... gene_pos=batch["genes"], ... expression=batch["x"], ... req_depth=batch["depth"], ... do_class=True, ... ) predictions = output["mean"] cell_types = output["cls_output_cell_type_ontology_term_id"].argmax(-1)

Source code in scprint2/model/model.py
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
def forward(
    self,
    gene_pos: Tensor,
    expression: Optional[Tensor] = None,
    neighbors: Optional[Tensor] = None,
    neighbors_info: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
    req_depth: Optional[Tensor] = None,
    get_gene_emb: bool = False,
    metacell_token: Optional[Tensor] = None,  # (minibatch, 1)
    depth_mult: Optional[Tensor] = None,
    do_sample: bool = False,
    do_mvc: bool = False,
    do_class: bool = False,
    get_attention_layer: Optional[list] = None,
    mask_zeros: Optional[Tensor] = None,
) -> Dict[str, Tensor] | tuple[Dict[str, Tensor], list]:
    """
    Complete forward pass through the scPRINT-2 model.

    Encodes input expression data, processes through transformer(s), and
    decodes into expression predictions and cell classifications.

    Args:
        gene_pos (Tensor): Gene indices of shape (batch, seq_len) mapping to
            positions in the model's gene vocabulary.
        expression (Tensor, optional): Expression values of shape (batch, seq_len).
            Can be raw counts or normalized depending on model config.
            Defaults to None.
        neighbors (Tensor, optional): KNN neighbor expressions of shape
            (batch, n_neighbors, seq_len) for metacell-style encoding.
            Defaults to None.
        neighbors_info (Tensor, optional): Neighbor weights of shape
            (batch, n_neighbors). Defaults to None.
        mask (Tensor, optional): Boolean mask of shape (batch, seq_len) where
            True indicates positions to mask (set to zero). Defaults to None.
        req_depth (Tensor, optional): Target sequencing depth of shape (batch,)
            for depth-conditional generation. Defaults to None.
        get_gene_emb (bool, optional): Return gene embeddings from transformer.
            Defaults to False.
        metacell_token (Tensor, optional): Binary metacell indicators of shape
            (batch,). Defaults to None.
        depth_mult (Tensor, optional): Expression depth multiplier. If None,
            uses sum of expression values. Defaults to None.
        do_sample (bool, optional): Sample from predicted distribution.
            Currently unused. Defaults to False.
        do_mvc (bool, optional): Compute multi-view coding predictions.
            Defaults to False.
        do_class (bool, optional): Compute classification predictions.
            Defaults to False.
        get_attention_layer (list[int], optional): Layer indices to extract
            attention weights from. Defaults to None.
        mask_zeros (Tensor, optional): Boolean mask for zero-expression genes
            of shape (batch, seq_len + num_special_tokens). Defaults to None.

    Returns:
        dict[str, Tensor] | tuple[dict, list]: Model outputs containing:
            - "mean": Predicted expression (batch, seq_len)
            - "disp": Dispersion parameters (batch, seq_len) [if ZINB]
            - "zero_logits": Zero-inflation logits (batch, seq_len) [if ZINB]
            - "input_cell_embs": Cell embeddings (batch, n_classes+1, d_model)
            - "input_cell_emb": Mean cell embedding (batch, d_model)
            - "output_cell_embs": Processed cell embeddings
            - "output_cell_emb": Final cell embedding
            - "cls_output_{class}": Classification logits for each class
            - "gene_embedding": Gene embeddings [if get_gene_emb]
            - "mvc_*": MVC predictions [if do_mvc]

            If get_attention_layer is not None, returns (outputs_dict, attention_list)
            where attention_list contains QKV tensors from specified layers.

    Example:
        >>> output = model(
        ...     gene_pos=batch["genes"],
        ...     expression=batch["x"],
        ...     req_depth=batch["depth"],
        ...     do_class=True,
        ... )
        >>> predictions = output["mean"]
        >>> cell_types = output["cls_output_cell_type_ontology_term_id"].argmax(-1)
    """
    cell_embs, encoding = self._encoder(
        gene_pos,
        expression,
        neighbors,
        neighbors_info,
        mask,
        metacell_token=metacell_token,
    )

    # attention_bias
    num = (1 if self.use_metacell_token else 0) + (
        (len(self.classes) + 1) if not self.cell_transformer else 0
    )
    if self.attn_bias is not None:
        if not hasattr(self, "nbias_sparse"):
            bias_path = os.path.join(self.attn_bias)
            # Keep as sparse matrix - much more memory efficient
            self.nbias_sparse = load_npz(bias_path)

        bias = torch.zeros(
            (
                gene_pos.shape[0],
                gene_pos.shape[1] + num,
                gene_pos.shape[1] + num,
            ),
            device=gene_pos.device,
            dtype=torch.float16,
        )

        fade_factor = 100

        # Extract only the needed values from sparse matrix
        batch_size = gene_pos.shape[0]

        # Vectorized extraction from sparse matrix
        for b in range(batch_size):
            indices = gene_pos[b].cpu().numpy()
            # Get submatrix for this batch's genes
            submatrix = self.nbias_sparse[np.ix_(indices, indices)]
            bias[b, num:, num:] = (
                torch.tensor(
                    submatrix.toarray(), device=gene_pos.device, dtype=torch.float16
                )
                * fade_factor
            )

        bias[:, num:, :num] = -10_000
    if not self.cell_transformer:
        encoding = torch.cat([cell_embs, encoding], dim=1)
    if type(self.transformer) is FlashTransformer:
        transformer_output = self.transformer(
            encoding,
            return_qkv=get_attention_layer,
            bias=bias if self.attn_bias is not None else None,
            bias_layer=list(range(self.nlayers - 1)),
            mask_zeros=mask_zeros,
        )
    elif type(self.transformer) is Performer:
        transformer_output = self.transformer(encoding)
    else:
        raise ValueError(f"Unknown transformer: {type(self.transformer)}")
    if get_attention_layer is not None:
        transformer_output, qkvs = transformer_output
    if self.cell_transformer:
        cell_embs = self.cell_transformer(cell_embs, x_kv=transformer_output)
    else:
        cell_embs, transformer_output = transformer_output.split(
            [
                len(self.classes) + 1,
                transformer_output.shape[1] - (len(self.classes) + 1),
            ],
            dim=1,
        )
    # if not provided we will mult by the current expression sum
    depth_mult = expression.sum(1) if depth_mult is None else depth_mult
    req_depth = torch.log2(1 + req_depth)
    res = self._expr_decoder(
        transformer_output[:, (1 if self.use_metacell_token else 0) :, :],
        depth_mult,
        req_depth,
        get_gene_emb,
    )
    res.update(
        self._cell_decoder(
            cell_embs,
            do_mvc,
            do_class,
            depth_mult,
            req_depth,
            gene_pos if do_mvc else None,
        )
    )
    return (res, qkvs) if get_attention_layer is not None else res

log_adata

log_adata will log an adata from predictions. It will log to tensorboard and wandb if available

see @utils.log_adata

Source code in scprint2/model/model.py
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
def log_adata(self, gtclass=None, name=""):
    """
    log_adata will log an adata from predictions.
    It will log to tensorboard and wandb if available

    see @utils.log_adata
    """
    try:
        mdir = self.logger.save_dir if self.logger.save_dir is not None else "/tmp"
    except:
        mdir = "data/"
    if not os.path.exists(mdir):
        os.makedirs(mdir)
    adata, fig = utils.make_adata(
        genes=self.genes,
        embs=self.embs,
        pos=self.pos if self.save_expr else None,
        expr_pred=self.expr_pred if self.save_expr else None,
        classes=self.classes,
        pred=self.pred if not self.keep_all_labels_pred else None,
        label_decoders=self.label_decoders,
        labels_hierarchy=self.labels_hierarchy,
        gtclass=gtclass,
        doplot=self.doplot,
    )
    adata.write(
        str(mdir)
        + "/step_"
        + str(self.global_step)
        + "_"
        + str(self.name)
        + "_"
        + str(name)
        + "_"
        + str(self.global_rank)
        + ".h5ad"
    )
    if self.doplot and fig is not None:
        logged = False
        try:
            self.logger.experiment.add_figure(fig)
            logged = True
        except:
            print("couldn't log to tensorboard")
        try:
            self.logger.log_image(key="umaps", images=[fig], step=self.global_step)
            logged = True
        except:
            print("couldn't log to wandb")
        if not logged:
            fig.savefig(mdir + "/umap_" + self.name + "_" + name + ".png")

    return adata

on_fit_start

@see pl.LightningModule

Source code in scprint2/model/model.py
1589
1590
1591
1592
1593
1594
1595
def on_fit_start(self):
    """@see pl.LightningModule"""
    if type(self.transformer) is FlashTransformer:
        for encoder_layers in self.transformer.blocks:
            encoder_layers.set_seq_parallel(True)
    for k, v in self.mat_labels_hierarchy.items():
        self.mat_labels_hierarchy[k] = v.to(self.device)

on_load_checkpoint

Handle checkpoint loading with backward compatibility.

Automatically handles: - Different class configurations between checkpoint and current model - Legacy parameter names and structures - Encoder/decoder mismatches with datamodule - Gene vocabulary differences - Early stopping callback state

Called automatically by PyTorch Lightning during checkpoint loading.

Parameters:
  • checkpoints (dict) –

    Checkpoint dictionary from torch.load().

Note

Prints warnings when configurations differ between checkpoint and current model. These should be reviewed to ensure expected behavior.

Source code in scprint2/model/model.py
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
def on_load_checkpoint(self, checkpoints):
    """
    Handle checkpoint loading with backward compatibility.

    Automatically handles:
    - Different class configurations between checkpoint and current model
    - Legacy parameter names and structures
    - Encoder/decoder mismatches with datamodule
    - Gene vocabulary differences
    - Early stopping callback state

    Called automatically by PyTorch Lightning during checkpoint loading.

    Args:
        checkpoints (dict): Checkpoint dictionary from torch.load().

    Note:
        Prints warnings when configurations differ between checkpoint and
        current model. These should be reviewed to ensure expected behavior.
    """
    # if not the same number of labels (due to diff datasets)
    for name, clss in self.cls_decoders.items():
        size = checkpoints["state_dict"][
            "cls_decoders." + name + ".out_layer.bias"
        ].shape[0]
        if size != clss.out_layer.bias.shape[0]:
            self.cls_decoders[name].out_layer = torch.nn.Linear(
                clss.out_layer.weight.shape[1], size
            )
    # from older model versions
    self.normalization = checkpoints["hyper_parameters"].get("normalization", "sum")
    if (
        checkpoints["state_dict"].get("gene_encoder.0.embedding.weight", None)
        is not None
    ):
        # replace it with the new one gene_encoder.0.embeddings.weight in the state_dict
        checkpoints["state_dict"]["gene_encoder.0.embeddings.weight"] = checkpoints[
            "state_dict"
        ]["gene_encoder.0.embedding.weight"]
        del checkpoints["state_dict"]["gene_encoder.0.embedding.weight"]
    # same
    # when doing batch effect correction and input dataset is not the same
    if (
        "grad_reverse_discriminator_loss.out_layer.bias"
        in checkpoints["state_dict"]
    ):
        for k in list(checkpoints["state_dict"].keys()):
            if "grad_reverse_discriminator_loss" in k:
                del checkpoints["state_dict"][k]
        print(
            "the discriminator for batch effect correction has been removed. "
            "dropping the legacy key."
        )
    # same
    if (
        checkpoints["state_dict"].get("gene_encoder.embedding.weight", None)
        is not None
    ):
        # replace it with the new one gene_encoder.embeddings.weight in the state_dict
        checkpoints["state_dict"]["gene_encoder.embeddings.weight"] = checkpoints[
            "state_dict"
        ]["gene_encoder.embedding.weight"]
        del checkpoints["state_dict"]["gene_encoder.embedding.weight"]

    if "classes" in checkpoints["hyper_parameters"]:
        if self.label_counts != checkpoints["hyper_parameters"]["classes"]:
            if "label_counts" in checkpoints["hyper_parameters"] and set(
                checkpoints["hyper_parameters"]["label_counts"].keys()
            ) == set(checkpoints["hyper_parameters"]["classes"]):
                if self.classes != checkpoints["hyper_parameters"]["classes"]:
                    print("classes have changed, be careful")
                self.classes = checkpoints["hyper_parameters"]["classes"]
                self.label_counts = checkpoints["hyper_parameters"]["label_counts"]
                if self.classes == self.label_counts:
                    raise ValueError(
                        "classes and label_counts are the same, this is not allowed, please use another checkpoint"
                    )
            else:
                self.label_counts = checkpoints["hyper_parameters"]["classes"]
                if self.classes != list(
                    checkpoints["hyper_parameters"]["classes"].keys()
                ):
                    print("classes have changed, be careful")
                    self.classes = list(
                        checkpoints["hyper_parameters"]["classes"].keys()
                    )
        # else it is all good as expected

    else:
        print("no classes in the checkpoint, be careful")

    if checkpoints["state_dict"].get("pos_encoder.pe") is not None:
        if self.pos_encoder is None:
            self.pos_encoder = encoders.PositionalEncoding(
                self.d_model, gene_pos_enc=[0, 1, 2]
            )
        self.pos_encoder.pe = checkpoints["state_dict"]["pos_encoder.pe"]

    if self.label_decoders != checkpoints["hyper_parameters"][
        "label_decoders"
    ] or self.labels_hierarchy != checkpoints["hyper_parameters"].get(
        "labels_hierarchy", {}
    ):
        print("label decoders have changed, be careful")
        self.label_decoders = checkpoints["hyper_parameters"]["label_decoders"]
        self.labels_hierarchy = checkpoints["hyper_parameters"].get(
            "labels_hierarchy", {}
        )
        for k, v in self.labels_hierarchy.items():
            tens = torch.zeros((len(v), self.label_counts[k]))
            for k2, v2 in v.items():
                tens[k2 - self.label_counts[k], v2] = 1
            self.mat_labels_hierarchy[k] = tens.to(bool)

    if (
        "gene_pos_enc" in checkpoints["hyper_parameters"]
        and checkpoints["hyper_parameters"]["gene_pos_enc"] is not None
    ):
        if (
            self.pos_encoder is None
            or self.pos_encoder.gene_pos_enc
            != checkpoints["hyper_parameters"]["gene_pos_enc"]
        ):
            print(
                "Gene position encoding has changed in the dataloader compared to last time, trying to revert"
            )
            self.pos_encoder = encoders.PositionalEncoding(
                self.d_model,
                gene_pos_enc=checkpoints["hyper_parameters"]["gene_pos_enc"],
            )
            checkpoints["hyper_parameters"].pop("gene_pos_enc")
    mencoders = {}
    if type(checkpoints["hyper_parameters"]["genes"]) is list:
        print("converting a gene list-based model")
        org = checkpoints["hyper_parameters"].get("organisms", self.organisms)
        genedf = load_genes(org)
        checkpoints["hyper_parameters"]["genes"] = {
            i: genedf.index[
                (genedf.organism == i)
                & genedf.index.isin(checkpoints["hyper_parameters"]["genes"])
            ].tolist()
            for i in org
        }
    if "precpt_gene_emb" in checkpoints["hyper_parameters"]:
        checkpoints["hyper_parameters"].pop("precpt_gene_emb")

    if "gene_pos_file" in checkpoints["hyper_parameters"]:
        checkpoints["hyper_parameters"].pop("gene_pos_file")

    if "transformer" in checkpoints["hyper_parameters"]:
        checkpoints["hyper_parameters"]["attention"] = checkpoints[
            "hyper_parameters"
        ].pop("transformer")
    try:
        if self.trainer.datamodule.decoders != self.label_decoders:
            print("label decoders have changed, be careful")
            # if we don't have the same decoders, we need to update the one on the datamodule side
            for k, v in self.label_decoders.items():
                mencoders[k] = {va: ke for ke, va in v.items()}
            self.trainer.datamodule.encoders = mencoders

        es = None
        for k in self.trainer.callbacks:
            if isinstance(k, EarlyStopping):
                es = k
        if es is not None:
            prev = checkpoints["callbacks"].get(
                "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"
            )
            if prev is not None:
                prev = prev["patience"]
            if prev != es.patience:
                print(
                    "updating the early stopping parameter to {}".format(
                        es.patience
                    )
                )
                checkpoints["callbacks"][
                    "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"
                ]["patience"] = es.patience
                if prev < es.patience:
                    checkpoints["callbacks"][
                        "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"
                    ]["stopped_epoch"] = 0

    except RuntimeError as e:
        if "scPRINT2 is not attached to a `Trainer`." in str(e):
            print("FYI: scPRINT2 is not attached to a `Trainer`.")
        else:
            raise e
    if (
        self.mvc_decoder is None
        and checkpoints["state_dict"].get("mvc_decoder.gene2query.weight")
        is not None
    ):
        for i in [
            "mvc_decoder.gene2query.weight",
            "mvc_decoder.gene2query.bias",
            "mvc_decoder.norm.weight",
            "mvc_decoder.norm.bias",
            "mvc_decoder.pred_var_zero.weight",
        ]:
            if i in checkpoints["state_dict"]:
                del checkpoints["state_dict"][i]
    org = checkpoints["hyper_parameters"].get("organisms")
    if self.organisms != org and org is not None:
        self.organisms = org
        try:
            self.trainer.datamodule.organisms = self.organisms
        except RuntimeError as e:
            if "scPRINT2 is not attached to a `Trainer`." not in str(e):
                raise e
    if self._genes != checkpoints["hyper_parameters"]["genes"]:
        self._genes = checkpoints["hyper_parameters"]["genes"]
    try:
        self.trainer.datamodule.set_valid_genes_collator(self.genes)
    except RuntimeError as e:
        if "scPRINT2 is not attached to a `Trainer`." not in str(e):
            raise e

    if not is_interactive():
        self.save_hyperparameters()

on_predict_epoch_end

@see pl.LightningModule will

Source code in scprint2/model/model.py
2624
2625
2626
2627
2628
2629
2630
def on_predict_epoch_end(self):
    """@see pl.LightningModule will"""
    if self.pos.shape[0] < 100:
        return
    if self.pred_log_adata:
        print("adding on disk")
        return self.log_adata(name="predict_part_" + str(self.counter))

on_predict_epoch_start

@see pl.LightningModule

Source code in scprint2/model/model.py
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
def on_predict_epoch_start(self):
    """@see pl.LightningModule"""
    print("predict epoch start")
    self.embs = None
    self.attn.data = None
    self.attn.attn = None
    self.counter = 0
    if type(self.transformer) is FlashTransformer:
        for encoder_layers in self.transformer.blocks:
            encoder_layers.set_seq_parallel(False)

on_test_start

@see pl.LightningModule

Source code in scprint2/model/model.py
2314
2315
2316
2317
2318
def on_test_start(self):
    """@see pl.LightningModule"""
    print("test start")
    for k, v in self.mat_labels_hierarchy.items():
        self.mat_labels_hierarchy[k] = v.to(self.device)

on_validation_epoch_end

@see pl.LightningModule

Source code in scprint2/model/model.py
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
def on_validation_epoch_end(self):
    """@see pl.LightningModule"""
    self.pos = None
    self.expr_pred = None
    self.do_adv_cls = self._store_adv_cls
    gathered_embs = self.all_gather(self.embs)
    # Merge the dictionaries from all processes
    for key in self.embs.keys():
        self.embs[key] = gathered_embs[key].view(-1, gathered_embs[key].shape[-1])
    self.info = self.all_gather(self.info).view(-1, self.info.shape[-1])
    self.pred = (
        self.all_gather(self.pred).view(-1, self.pred.shape[-1])
        if self.pred is not None
        else None
    )
    # self.pos = self.all_gather(self.pos).view(-1, self.pos.shape[-1])
    # self.expr_pred[0] = self.all_gather(self.expr_pred[0]).view(
    #     -1, self.expr_pred[0].shape[-1]
    # )
    # if len(self.expr_pred) > 1:
    #     self.expr_pred[1] = self.all_gather(self.expr_pred[1]).view(
    #         -1, self.expr_pred[1].shape[-1]
    #     )
    # self.expr_pred[2] = self.all_gather(self.expr_pred[2]).view(
    #     -1, self.expr_pred[2].shape[-1]
    # )

    if self.trainer.state.stage != "sanity_check":
        if self.trainer.is_global_zero:
            print("logging anndata")
            sch = self.lr_schedulers()
            if sch is not None:
                sch.step(self.trainer.callback_metrics["val_loss"])
            # run the test function on specific dataset
            if self.embs is not None:
                self.log_adata(
                    gtclass=self.info, name="validation_part_" + str(self.counter)
                )
            if (self.current_epoch + 1) % self.test_every == 0:
                self.on_test_epoch_end()
            # Synchronize all processes with a timeout
        if torch.distributed.is_initialized():
            # Set a timeout that's longer than your test typically takes
            # Write rank to file for debugging
            self.trainer.strategy.barrier()
    self.pred = None

optimizer_step

@see pl.LightningModule

Source code in scprint2/model/model.py
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
    """@see pl.LightningModule"""
    # update params
    # manually warm up lr without a scheduler
    # making sure that we don't do this during lrfinder
    lr_scale = None
    prev_lr = None
    if (
        self.trainer.global_step < self.warmup_duration + self.lrfinder_steps
    ) and self.lrfinder_steps <= self.trainer.global_step:
        for i, pg in enumerate(optimizer.param_groups):
            lr_scale = min(
                1.0, float(self.trainer.global_step + 1) / self.warmup_duration
            )
            prev_lr = pg["lr"]
            pg["lr"] = lr_scale * self.hparams.lr
    for i, pg in enumerate(optimizer.param_groups):
        # if pg["lr"] < 2e-5:
        #    pg["lr"] = 2e-5
        self.log("lr_" + str(i), pg["lr"])
    if optimizer.param_groups[0]["lr"] > self.hparams.lr:
        if prev_lr is not None:
            pg["lr"] = prev_lr
        else:
            print("OPTIMIZER HAS INCREASED LR. WHYY?")
            print(optimizer.param_groups[0]["lr"], self.hparams.lr)
            optimizer.param_groups[0]["lr"] = self.hparams.lr

    optimizer.step(closure=optimizer_closure)

predict_step

embed given gene expression, encode the gene embedding and cell embedding.

Parameters:
  • batch (Dict[str, Tensor]) –

    Dictionary containing 'genes', 'x', 'depth', and optionally 'knn_cells'.

  • batch_idx (int) –

    Index of the batch.

Returns:
  • Dict[str, Tensor]

    Dict[str, Tensor]: Dictionary containing model predictions.

Source code in scprint2/model/model.py
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
def predict_step(
    self, batch: Dict[str, Tensor], batch_idx: int
) -> Dict[str, Tensor]:
    """
    embed given gene expression, encode the gene embedding and cell embedding.

    Args:
        batch (Dict[str, Tensor]): Dictionary containing 'genes', 'x', 'depth', and optionally 'knn_cells'.
        batch_idx: Index of the batch.

    Returns:
        Dict[str, Tensor]: Dictionary containing model predictions.
    """
    return self._predict(
        batch["genes"],
        batch["x"],
        batch["depth"],
        batch.get("knn_cells", None),
        batch.get("knn_cells_info", None),
        self.predict_mode,
        self.pred_embedding,
        self.get_attention_layer,
        self.predict_depth_mult,
    )

training_step

training_step defines the train loop. It is independent of forward

@see pl.LightningModule

Returns:
  • Tensor( Tensor ) –

    Total loss value for the training step.

Source code in scprint2/model/model.py
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
def training_step(
    self,
    batch: Dict[str, Tensor],
    batch_idx: int,
) -> Tensor:
    """
    training_step defines the train loop. It is independent of forward

    @see pl.LightningModule

    Returns:
        Tensor: Total loss value for the training step.
    """
    total_loss, losses = self._full_training(
        batch=batch,
        noise=self.noise,
        do_next_tp=self.do_next_tp,
        cce_temp=self.cce_temp,
        do_generate=self.do_generate,
        run_full_forward=self.run_full_forward,
        mask_ratio=self.mask_ratio,
    )
    if total_loss is None or torch.isnan(total_loss):
        raise ValueError("Loss is NaN")
    try:
        self.log("train_loss", total_loss, prog_bar=True, sync_dist=True)
        self.log_dict(losses, prog_bar=True, sync_dist=True)
    except Exception as e:
        print(e)
        print(losses)
    return total_loss

validation_step

validation_step defines the validation loop. It is independent of forward @see pl.LightningModule

Parameters:
  • batch (list[Tensor]) –

    @see training_step

Source code in scprint2/model/model.py
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
def validation_step(
    self,
    batch,
    batch_idx,
):
    """
    validation_step defines the validation loop. It is independent of forward
    @see pl.LightningModule

    Args:
        batch (list[Tensor]): @see training_step
    """
    val_loss, losses = self._full_training(
        batch=batch,
        noise=self.noise,
        do_next_tp=self.do_next_tp,
        cce_temp=self.cce_temp,
        do_vae_kl=False,
        do_generate=self.do_generate,
        run_full_forward=self.run_full_forward,
        mask_ratio=self.mask_ratio,
    )
    expression = batch["x"]
    gene_pos = batch["genes"]
    depth = batch["depth"]
    metacell_token = batch.get("is_meta", None)
    knn_cells = batch.get("knn_cells", None)
    knn_cells_info = batch.get("knn_cells_info", None)

    # TODO: make this faster by only calling val loss
    if self.embs is not None:
        if self.pos.shape[0] < 100_000 / self.trainer.world_size:
            self.info = torch.cat([self.info, batch["class"]])
            self._predict(
                gene_pos,
                expression,
                depth,
                knn_cells=knn_cells,
                knn_cells_info=knn_cells_info,
                pred_embedding=self.pred_embedding,
                max_size_in_mem=120_000,
                metacell_token=metacell_token,
            )
    else:
        self.info = batch["class"]
        self._predict(
            gene_pos,
            expression,
            depth,
            knn_cells=knn_cells,
            knn_cells_info=knn_cells_info,
            pred_embedding=self.pred_embedding,
            max_size_in_mem=120_000,
            metacell_token=metacell_token,
        )
    self.log("val_loss", val_loss, sync_dist=True)
    expr_loss = mean(
        [
            v.cpu().item() if type(v) is Tensor else v
            for k, v in losses.items()
            if "expr" in k
        ]
    )
    self.log("val_loss_expr", expr_loss, sync_dist=True)
    cls_loss = mean(
        [
            v.cpu().item() if type(v) is Tensor else v
            for k, v in losses.items()
            if "cls" in k
        ]
    )
    self.log("val_loss_cls", cls_loss, sync_dist=True)
    # self.log_dict(losses, sync_dist=True)
    return val_loss

losses

scprint2.model.loss

Classes:

Name Description
AdversarialDiscriminatorLoss

Functions:

Name Description
contrastive_loss

Computes NT-Xent loss (InfoNCE) between two sets of vectors.

criterion_neg_log_bernoulli

Compute the negative log-likelihood of Bernoulli distribution

ecs

ecs Computes the similarity of cell embeddings based on a threshold.

grad_reverse

grad_reverse Reverses the gradient of the input tensor.

hierarchical_classification

Computes the classification loss for a given batch of predictions and ground truth labels.

masked_mae

Compute the masked MAE loss between input and target.

masked_mse

Compute the masked MSE loss between input and target.

masked_nb

Compute the masked negative binomial loss between input and target.

masked_relative_error

Compute the masked relative error between input and target.

mse

Compute the MSE loss between input and target.

nb

Computes the negative binomial (NB) loss.

nb_dist

nb_dist Computes the negative binomial distribution.

within_sample

Compute dissimilarity between embeddings within each sample

zinb

Computes zero-inflated negative binomial (ZINB) loss.

AdversarialDiscriminatorLoss

Bases: Module

Discriminator for the adversarial training for batch correction.

Parameters:
  • d_model (int) –

    The size of the input tensor.

  • n_cls (int) –

    The number of classes.

  • nlayers (int, default: 3 ) –

    The number of layers in the discriminator. Defaults to 3.

  • activation (callable, default: LeakyReLU ) –

    The activation function. Defaults to nn.LeakyReLU.

  • reverse_grad (bool, default: True ) –

    Whether to reverse the gradient. Defaults

Methods:

Name Description
forward

Args:

Source code in scprint2/model/loss.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def __init__(
    self,
    d_model: int,
    n_cls: int,
    nlayers: int = 3,
    activation: callable = nn.LeakyReLU,
    reverse_grad: bool = True,
):
    """
    Discriminator for the adversarial training for batch correction.

    Args:
        d_model (int): The size of the input tensor.
        n_cls (int): The number of classes.
        nlayers (int, optional): The number of layers in the discriminator. Defaults to 3.
        activation (callable, optional): The activation function. Defaults to nn.LeakyReLU.
        reverse_grad (bool, optional): Whether to reverse the gradient. Defaults
    """
    super().__init__()
    # module list
    self.decoder = nn.ModuleList()
    for _ in range(nlayers - 1):
        self.decoder.append(nn.Linear(d_model, d_model))
        self.decoder.append(nn.LayerNorm(d_model))
        self.decoder.append(activation())
    self.out_layer = nn.Linear(d_model, n_cls)
    self.reverse_grad = reverse_grad

forward

Parameters:
  • x (Tensor) –

    Tensor, shape [batch_size, embsize]

  • batch_labels (Tensor) –

    Tensor, shape [batch_size]

Source code in scprint2/model/loss.py
367
368
369
370
371
372
373
374
375
376
377
378
def forward(self, x: Tensor, batch_labels: Tensor) -> Tensor:
    """
    Args:
        x: Tensor, shape [batch_size, embsize]
        batch_labels: Tensor, shape [batch_size]
    """
    if self.reverse_grad:
        x = grad_reverse(x, lambd=1.0)
    for layer in self.decoder:
        x = layer(x)
    x = self.out_layer(x)
    return F.cross_entropy(x, batch_labels)

contrastive_loss

Computes NT-Xent loss (InfoNCE) between two sets of vectors.

Parameters:
  • x (Tensor) –

    Tensor of shape [batch_size, feature_dim]

  • y (Tensor) –

    Tensor of shape [batch_size, feature_dim]

  • temperature (float, default: 0.1 ) –

    Temperature parameter to scale the similarities. Lower values make the model more confident/selective. Typical values are between 0.1 and 0.5.

Returns:
  • Tensor( Tensor ) –

    NT-Xent loss value

Note
  • Assumes x[i] and y[i] are positive pairs
  • All other combinations are considered negative pairs
  • Uses cosine similarity scaled by temperature
Source code in scprint2/model/loss.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def contrastive_loss(x: Tensor, y: Tensor, temperature: float = 0.1) -> Tensor:
    """
    Computes NT-Xent loss (InfoNCE) between two sets of vectors.

    Args:
        x: Tensor of shape [batch_size, feature_dim]
        y: Tensor of shape [batch_size, feature_dim]
        temperature: Temperature parameter to scale the similarities.
            Lower values make the model more confident/selective.
            Typical values are between 0.1 and 0.5.

    Returns:
        Tensor: NT-Xent loss value

    Note:
        - Assumes x[i] and y[i] are positive pairs
        - All other combinations are considered negative pairs
        - Uses cosine similarity scaled by temperature
    """
    # Check input dimensions
    assert x.shape == y.shape, "Input tensors must have the same shape"
    batch_size = x.shape[0]

    # Compute cosine similarity matrix
    # x_unsqueeze: [batch_size, 1, feature_dim]
    # y_unsqueeze: [1, batch_size, feature_dim]
    # -> similarities: [batch_size, batch_size]
    similarities = (
        F.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0), dim=2) / temperature
    )

    # The positive pairs are on the diagonal
    labels = torch.arange(batch_size, device=x.device)

    # Cross entropy loss
    return F.cross_entropy(similarities, labels)

criterion_neg_log_bernoulli

Compute the negative log-likelihood of Bernoulli distribution

Source code in scprint2/model/loss.py
160
161
162
163
164
165
166
167
def criterion_neg_log_bernoulli(input: Tensor, target: Tensor, mask: Tensor) -> Tensor:
    """
    Compute the negative log-likelihood of Bernoulli distribution
    """
    mask = mask.float()
    bernoulli = torch.distributions.Bernoulli(probs=input)
    masked_log_probs = bernoulli.log_prob((target > 0).float()) * mask
    return -masked_log_probs.sum() / mask.sum()

ecs

ecs Computes the similarity of cell embeddings based on a threshold.

Parameters:
  • cell_emb (Tensor) –

    A tensor representing cell embeddings.

  • ecs_threshold (float, default: 0.5 ) –

    A threshold for determining similarity. Defaults to 0.5.

Returns:
  • Tensor( Tensor ) –

    A tensor representing the mean of 1 minus the square of the difference between the cosine similarity and the threshold.

Source code in scprint2/model/loss.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def ecs(cell_emb: Tensor, ecs_threshold: float = 0.5) -> Tensor:
    """
    ecs Computes the similarity of cell embeddings based on a threshold.

    Args:
        cell_emb (Tensor): A tensor representing cell embeddings.
        ecs_threshold (float, optional): A threshold for determining similarity. Defaults to 0.5.

    Returns:
        Tensor: A tensor representing the mean of 1 minus the square of the difference between the cosine similarity and the threshold.
    """
    # Here using customized cosine similarity instead of F.cosine_similarity
    # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064
    # normalize the embedding
    cell_emb_normed = F.normalize(cell_emb, p=2, dim=1)
    cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t())

    # mask out diagnal elements
    mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device)
    cos_sim = cos_sim.masked_fill(mask, 0.0)
    # only optimize positive similarities
    cos_sim = F.relu(cos_sim)
    return torch.mean(1 - (cos_sim - ecs_threshold) ** 2)

grad_reverse

grad_reverse Reverses the gradient of the input tensor.

Parameters:
  • x (Tensor) –

    The input tensor whose gradient is to be reversed.

  • lambd (float, default: 1.0 ) –

    The scaling factor for the reversed gradient. Defaults to 1.0.

Returns:
  • Tensor( Tensor ) –

    The input tensor with its gradient reversed during the backward pass.

Source code in scprint2/model/loss.py
392
393
394
395
396
397
398
399
400
401
402
403
def grad_reverse(x: Tensor, lambd: float = 1.0) -> Tensor:
    """
    grad_reverse Reverses the gradient of the input tensor.

    Args:
        x (Tensor): The input tensor whose gradient is to be reversed.
        lambd (float, optional): The scaling factor for the reversed gradient. Defaults to 1.0.

    Returns:
        Tensor: The input tensor with its gradient reversed during the backward pass.
    """
    return GradReverse.apply(x, lambd)

hierarchical_classification

Computes the classification loss for a given batch of predictions and ground truth labels.

Parameters:
  • pred (Tensor) –

    The predicted logits for the batch. Shape: (batch_size, n_labels)

  • cl (Tensor) –

    The ground truth labels for the batch. Shape: (batch_size,)

  • labels_hierarchy (Tensor, default: None ) –

    The hierarchical structure of the labels. Defaults to None. A binary tensor of shape (number of parents, n_labels) if not given, will act as a regular classification loss see gist for more details of how one can compute it https://gist.github.com/jkobject/5b36bc4807edb440b86644952a49781e

Raises:
  • ValueError

    If the labels_hierarchy is not found while the number of predicted labels is smaller than the number of ground truth labels.

Returns:
  • Tensor( Tensor ) –

    The computed binary cross entropy loss for the given batch.

Source code in scprint2/model/loss.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def hierarchical_classification(
    pred: torch.Tensor,
    cl: torch.Tensor,
    labels_hierarchy: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Computes the classification loss for a given batch of predictions and ground truth labels.

    Args:
        pred (Tensor): The predicted logits for the batch. Shape: (batch_size, n_labels)
        cl (Tensor): The ground truth labels for the batch. Shape: (batch_size,)
        labels_hierarchy (Tensor, optional): The hierarchical structure of the labels. Defaults to None.
            A binary tensor of shape (number of parents, n_labels)
            if not given, will act as a regular classification loss
            see gist for more details of how one can compute it
            https://gist.github.com/jkobject/5b36bc4807edb440b86644952a49781e

    Raises:
        ValueError: If the labels_hierarchy is not found while the number of predicted
            labels is smaller than the number of ground truth labels.

    Returns:
        Tensor: The computed binary cross entropy loss for the given batch.
    """
    maxsize = pred.shape[1]
    newcl = torch.zeros(
        (pred.shape[0], maxsize), device=cl.device
    )  # batchsize * n_labels
    # if we don't know the label we set the weight to 0 else to 1
    valid_indices = (cl != -1) & (cl < maxsize)
    valid_cl = cl[valid_indices]
    newcl[valid_indices, valid_cl] = 1

    weight = torch.ones_like(newcl, device=cl.device)
    # if we don't know the label we set the weight to 0 for all labels
    weight[cl == -1, :] = 0
    # if we have non leaf values, we don't know so we don't compute grad and set weight to 0
    # and add labels that won't be counted but so that we can still use them
    if labels_hierarchy is not None and (cl >= maxsize).any():
        is_parent = cl >= maxsize
        subset_parent_weight = weight[is_parent]
        # we set the weight of the leaf elements for pred where we don't know the leaf, to 0
        # i.e. the elements where we will compute the max
        # in cl, parents are values past the maxsize
        # (if there is 10 leafs labels, the label 10,14, or 15 is a parent at position
        # row 0, 4, or 5 in the hierarchy matrix
        subset_parent_weight[labels_hierarchy[cl[is_parent] - maxsize]] = 0
        weight[is_parent] = subset_parent_weight

        # we set their lead to 1 (since the weight will be zero, not really usefull..)
        subset_parent_newcl = newcl[is_parent]
        subset_parent_newcl[labels_hierarchy[cl[is_parent] - maxsize]] = 1
        newcl[is_parent] = subset_parent_newcl

        # all parental nodes that have a 1 in the labels_hierarchy matrix are set to 1
        # for each parent label / row in labels_hierarchy matrix, the addnewcl is
        # the max of the newcl values where the parent label is 1
        newcl_expanded = newcl.unsqueeze(-1).expand(-1, -1, labels_hierarchy.shape[0])
        addnewcl = torch.max(newcl_expanded * labels_hierarchy.T, dim=1)[0]

        # for their weight, it is decreasing based on number of children they have
        # it is the same here as for parental labels, we don't want to compute
        # gradients when they are 0 meaning not parents of the true leaf label.
        # for now we weight related to how many labels they contain.
        addweight = addnewcl.clone() / (labels_hierarchy.sum(1) ** 0.5)

        # except if it is the cl label we know about?
        subset_parent_weight = addweight[is_parent]
        subset_parent_weight[:, cl[is_parent] - maxsize] = 1
        addweight[is_parent] = subset_parent_weight

        # we apply the same mask to the pred but now we want to compute
        # logsumexp instead of max since we want to keep the gradients
        # we also set to -inf since it is a more neutral element for logsumexp
        pred_expanded = (
            pred.clone().unsqueeze(-1).expand(-1, -1, labels_hierarchy.shape[0])
        )
        pred_expanded = pred_expanded * labels_hierarchy.T
        pred_expanded[pred_expanded == 0] = torch.finfo(pred.dtype).min
        addpred = torch.logsumexp(pred_expanded, dim=1)

        # we add the new labels to the cl
        newcl = torch.cat([newcl, addnewcl], dim=1)
        weight = torch.cat([weight, addweight], dim=1)
        pred = torch.cat([pred, addpred], dim=1)
    elif labels_hierarchy is None and (cl >= maxsize).any():
        raise ValueError("need to use labels_hierarchy for this usecase")

    myloss = torch.nn.functional.binary_cross_entropy_with_logits(
        pred, target=newcl, weight=weight
    )
    return myloss

masked_mae

Compute the masked MAE loss between input and target. MAE = mean absolute error

Source code in scprint2/model/loss.py
41
42
43
44
45
46
47
48
def masked_mae(input: Tensor, target: Tensor, mask: Tensor) -> Tensor:
    """
    Compute the masked MAE loss between input and target.
    MAE = mean absolute error
    """
    mask = mask.float()
    loss = F.l1_loss(input * mask, target * mask, reduction="sum")
    return loss / mask.sum()

masked_mse

Compute the masked MSE loss between input and target.

Source code in scprint2/model/loss.py
15
16
17
18
19
20
21
22
23
24
25
def masked_mse(input: Tensor, target: Tensor, mask: Tensor) -> Tensor:
    """
    Compute the masked MSE loss between input and target.
    """
    mask = mask.float()
    input = torch.log2(input + 1)
    input = (input / torch.sum(input, dim=1, keepdim=True)) * 10000
    target = torch.log2(target + 1)
    target = (target / torch.sum(target, dim=1, keepdim=True)) * 10000
    loss = F.mse_loss(input * mask, target * mask, reduction="sum")
    return loss / mask.sum()

masked_nb

Compute the masked negative binomial loss between input and target.

Source code in scprint2/model/loss.py
51
52
53
54
55
56
57
58
def masked_nb(input: Tensor, target: Tensor, mask: Tensor) -> Tensor:
    """
    Compute the masked negative binomial loss between input and target.
    """
    mask = mask.float()
    nb = torch.distributions.NegativeBinomial(total_count=target, probs=input)
    masked_log_probs = nb.log_prob(target) * mask
    return -masked_log_probs.sum() / mask.sum()

masked_relative_error

Compute the masked relative error between input and target.

Source code in scprint2/model/loss.py
170
171
172
173
174
175
176
177
178
def masked_relative_error(
    input: Tensor, target: Tensor, mask: torch.LongTensor
) -> Tensor:
    """
    Compute the masked relative error between input and target.
    """
    assert mask.any()
    loss = torch.abs(input[mask] - target[mask]) / (target[mask] + 1e-5)
    return loss.mean()

mse

Compute the MSE loss between input and target.

Source code in scprint2/model/loss.py
28
29
30
31
32
33
34
35
36
37
38
def mse(input: Tensor, target: Tensor, mask=False) -> Tensor:
    """
    Compute the MSE loss between input and target.
    """
    if mask:
        return masked_mse(input, target, (target > 0))
    input = torch.log2(input + 1)
    input = (input / torch.sum(input, dim=1, keepdim=True)) * 10000
    target = torch.log2(target + 1)
    target = (target / torch.sum(target, dim=1, keepdim=True)) * 10000
    return F.mse_loss(input, target, reduction="mean")

nb

Computes the negative binomial (NB) loss.

This function was adapted from scvi-tools.

Parameters:
  • target (Tensor) –

    Ground truth data.

  • mu (Tensor) –

    Means of the negative binomial distribution (must have positive support).

  • theta (Tensor) –

    Inverse dispersion parameter (must have positive support).

  • eps (float, default: 0.0001 ) –

    Numerical stability constant. Defaults to 1e-4.

Returns:
  • Tensor( Tensor ) –

    NB loss value.

Source code in scprint2/model/loss.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def nb(target: Tensor, mu: Tensor, theta: Tensor, eps=1e-4) -> Tensor:
    """
    Computes the negative binomial (NB) loss.

    This function was adapted from scvi-tools.

    Args:
        target (Tensor): Ground truth data.
        mu (Tensor): Means of the negative binomial distribution (must have positive support).
        theta (Tensor): Inverse dispersion parameter (must have positive support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-4.

    Returns:
        Tensor: NB loss value.
    """
    if theta.ndimension() == 1:
        theta = theta.view(1, theta.size(0))

    log_theta_mu_eps = torch.log(theta + mu + eps)
    res = (
        theta * (torch.log(theta + eps) - log_theta_mu_eps)
        + target * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(target + theta)
        - torch.lgamma(theta)
        - torch.lgamma(target + 1)
    )

    return -res.mean()

nb_dist

nb_dist Computes the negative binomial distribution.

Parameters:
  • x (Tensor) –

    Torch Tensor of observed data.

  • mu (Tensor) –

    Torch Tensor of means of the negative binomial distribution (must have positive support).

  • theta (Tensor) –

    Torch Tensor of inverse dispersion parameter (must have positive support).

  • eps (float, default: 0.0001 ) –

    Numerical stability constant. Defaults to 1e-4.

Returns:
  • Tensor( Tensor ) –

    Negative binomial loss value.

Source code in scprint2/model/loss.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def nb_dist(x: Tensor, mu: Tensor, theta: Tensor, eps=1e-4) -> Tensor:
    """
    nb_dist Computes the negative binomial distribution.

    Args:
        x (Tensor): Torch Tensor of observed data.
        mu (Tensor): Torch Tensor of means of the negative binomial distribution (must have positive support).
        theta (Tensor): Torch Tensor of inverse dispersion parameter (must have positive support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-4.

    Returns:
        Tensor: Negative binomial loss value.
    """
    loss = -NegativeBinomial(mu=mu, theta=theta).log_prob(x)
    return loss

within_sample

Compute dissimilarity between embeddings within each sample using a combination of cosine and L2 distance

Parameters:
  • cell_embs (Tensor) –

    tensor of shape [batch_size, num_embeddings, embedding_dim]

Source code in scprint2/model/loss.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def within_sample(cell_embs: Tensor):
    """
    Compute dissimilarity between embeddings within each sample
    using a combination of cosine and L2 distance

    Args:
        cell_embs: tensor of shape [batch_size, num_embeddings, embedding_dim]
    """
    batch_size, num_embeddings, emb_dim = cell_embs.shape

    # Normalize embeddings for cosine similarity
    cell_embs_norm = F.normalize(cell_embs, p=2, dim=-1)

    # Compute pairwise cosine similarities
    cos_sim = torch.bmm(cell_embs_norm, cell_embs_norm.transpose(1, 2))

    # Compute pairwise L2 distances (normalized by embedding dimension)
    l2_dist = torch.cdist(cell_embs, cell_embs, p=2) / np.sqrt(emb_dim)

    # Create mask for pairs (excluding self-similarity)
    mask = 1 - torch.eye(num_embeddings, device=cos_sim.device)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)

    # Combine losses:
    # - High cosine similarity should be penalized
    # - Small L2 distance should be penalized
    cos_loss = (cos_sim * mask).pow(2).mean()
    l2_loss = 1.0 / (l2_dist * mask + 1e-3).mean()

    return 0.5 * cos_loss + 0.5 * l2_loss

zinb

Computes zero-inflated negative binomial (ZINB) loss.

This function was modified from scvi-tools.

Parameters:
  • target (Tensor) –

    Torch Tensor of ground truth data.

  • mu (Tensor) –

    Torch Tensor of means of the negative binomial (must have positive support).

  • theta (Tensor) –

    Torch Tensor of inverse dispersion parameter (must have positive support).

  • pi (Tensor) –

    Torch Tensor of logits of the dropout parameter (real support).

  • eps (float, default: 0.0001 ) –

    Numerical stability constant. Defaults to 1e-4.

Returns:
  • Tensor( Tensor ) –

    ZINB loss value.

Source code in scprint2/model/loss.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def zinb(
    target: Tensor,
    mu: Tensor,
    theta: Tensor,
    pi: Tensor,
    eps=1e-4,
    mask=False,
) -> Tensor:
    """
    Computes zero-inflated negative binomial (ZINB) loss.

    This function was modified from scvi-tools.

    Args:
        target (Tensor): Torch Tensor of ground truth data.
        mu (Tensor): Torch Tensor of means of the negative binomial (must have positive support).
        theta (Tensor): Torch Tensor of inverse dispersion parameter (must have positive support).
        pi (Tensor): Torch Tensor of logits of the dropout parameter (real support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-4.

    Returns:
        Tensor: ZINB loss value.
    """
    #  uses log(sigmoid(x)) = -softplus(-x)
    softplus_pi = F.softplus(-pi)
    # eps to make it positive support and taking the log
    log_theta_mu_eps = torch.log(theta + mu + eps)
    pi_theta_log = -pi + theta * (torch.log(theta + eps) - log_theta_mu_eps)

    case_zero = F.softplus(pi_theta_log) - softplus_pi
    mul_case_zero = torch.mul((target < eps).type(torch.float32), case_zero)

    case_non_zero = (
        -softplus_pi
        + pi_theta_log
        + target * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(target + theta)
        - torch.lgamma(theta)
        - torch.lgamma(target + 1)
    )
    mul_case_non_zero = torch.mul((target > eps).type(torch.float32), case_non_zero)

    res = mul_case_zero + mul_case_non_zero
    # we want to minize the loss but maximize the log likelyhood
    if mask:
        mask = (target > 0).float()
        res = res * mask
        return -res.sum() / mask.sum()
    return -res.mean()

utils

scprint2.model.utils

Classes:

Name Description
Attention
WeightedMasker

Functions:

Name Description
downsample_profile

This function downsamples the expression profile of a given single cell RNA matrix.

make_adata

This function creates an AnnData object from the given input parameters.

simple_masker

Randomly mask a batch of data.

test

Test the given model on the full set of benchmarks and save the results to JSON files.

zinb_sample

zinb_sample This function generates a sample from a Zero-Inflated Negative Binomial (ZINB) distribution.

Attention

Initialize the Attention class.

Parameters:
  • gene_dim (int) –

    The dimension of the gene.

  • additional_tokens (int, default: 0 ) –

    The number of additional tokens to add.

  • precomp_attn (bool, default: False ) –

    Whether to compute attention or it is precomputed

  • apply_softmax (bool, default: True ) –

    Whether to apply softmax to the attention.

  • sum_heads (bool, default: True ) –

    Whether to sum the heads.

Methods:

Name Description
add_attn

Aggregate the attention or data based on the precomp_attn flag.

add_qk

Add data to the internal storage.

get

Get the aggregated attention or data.

Source code in scprint2/model/utils.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
def __init__(
    self,
    gene_dim: int,
    precomp_attn: bool = False,
    apply_softmax: bool = True,
    sum_heads: bool = True,
    additional_tokens: int = 0,
):
    """
    Initialize the Attention class.

    Args:
        gene_dim (int): The dimension of the gene.
        additional_tokens (int): The number of additional tokens to add.
        precomp_attn (bool): Whether to compute attention or it is precomputed
        apply_softmax (bool): Whether to apply softmax to the attention.
        sum_heads (bool): Whether to sum the heads.
    """
    self.data: Optional[Tensor] = None
    self.gene_dim: int = gene_dim
    self.additional_tokens: int = additional_tokens
    self.div: Optional[Tensor] = None
    self.apply_softmax: bool = apply_softmax
    self.sum_heads: bool = sum_heads
    self.precomp_attn: bool = precomp_attn
    self.speciesloc: int = 0

add_attn

Aggregate the attention or data based on the precomp_attn flag.

Parameters:
  • x (List[Tensor]) –

    List of tensors to aggregate. Tensor of size (batch, seq_len, 2, heads, emb)

  • pos (Tensor) –

    Position tensor.

Source code in scprint2/model/utils.py
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
def add_attn(
    self, x: List[Tensor], pos: Tensor, expr: Optional[Tensor] = None
) -> None:
    """
    Aggregate the attention or data based on the precomp_attn flag.

    Args:
        x (List[Tensor]): List of tensors to aggregate. Tensor of size (batch, seq_len, 2, heads, emb)
        pos (Tensor): Position tensor.
    """
    if self.data is None:
        self.data = torch.zeros(
            [
                self.gene_dim + self.additional_tokens,
                self.gene_dim + self.additional_tokens,
            ],
            device=pos.device,
            dtype=torch.float32,
        )
        self.div = torch.zeros(1, device=pos.device, dtype=torch.float32)
    for i, elem in enumerate(x):
        if self.apply_softmax:
            attn = torch.nn.functional.softmax(
                elem[:, :, 0, :, :].permute(0, 2, 1, 3)
                @ elem[:, :, 1, :, :].permute(0, 2, 3, 1),
                dim=-1,
            )
            if expr is not None:
                attn[:, :, self.additional_tokens :, self.additional_tokens :] = (
                    attn[:, :, self.additional_tokens :, self.additional_tokens :]
                    * (expr > 0).float().unsqueeze(1).unsqueeze(-1)
                    * (expr > 0).float().unsqueeze(1).unsqueeze(2)
                )
            self.data += attn.sum(0).mean(0)
        else:
            self.data[:, :] += (
                (
                    elem[:, :, 0, :, :].permute(0, 2, 1, 3)
                    @ elem[:, :, 1, :, :].permute(0, 2, 3, 1)
                )
                .sum(0)
                .mean(0)
            )
        self.div += 1

add_qk

Add data to the internal storage.

Parameters:
  • x (List[Tensor]) –

    List of tensors to add.

  • pos (Tensor) –

    Position tensor.

Source code in scprint2/model/utils.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def add_qk(
    self, x: List[Tensor], pos: Tensor, expr: Optional[Tensor] = None
) -> None:
    """
    Add data to the internal storage.

    Args:
        x (List[Tensor]): List of tensors to add.
        pos (Tensor): Position tensor.
    """
    # this is a debugger line
    if self.data is None:
        self.data = torch.zeros(
            [len(x), self.gene_dim + self.additional_tokens] + list(x[0].shape[2:]),
            device=pos.device,
        )
        self.div = torch.zeros(
            self.gene_dim + self.additional_tokens, device=pos.device
        )
    for i in range(x[0].shape[0]):  # batch size
        loc = torch.cat(
            [
                torch.arange(self.additional_tokens, device=pos.device),
                pos[i] + self.additional_tokens - self.speciesloc,
            ]
        ).int()
        for j in range(len(x)):  # number of layers * heads
            self.data[j, loc, :, :, :] += x[j][i]
        self.div[loc] += 1

get

Get the aggregated attention or data.

Returns:
  • Optional[ndarray]

    Optional[np.ndarray]: The aggregated attention or data.

Source code in scprint2/model/utils.py
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def get(self) -> Optional[np.ndarray]:
    """
    Get the aggregated attention or data.

    Returns:
        Optional[np.ndarray]: The aggregated attention or data.
    """
    if not self.precomp_attn:
        if self.data is None:
            return None
        # shape is (layers, genes, qkv, heads, emb)
        return self.data / self.div.view(1, self.div.shape[0], 1, 1, 1)
    else:
        if self.data is None:
            return None
        self.data.div_(self.div)
        return self.data

WeightedMasker

Randomly mask a batch of data.

Parameters:
  • genes (List[str]) –

    The list of genes the model might see.

  • TFs (List[str], default: fileToList(FILEDIR + '/../../data/main/TFs.txt') ) –

    The list of TFs the model can drop.

  • tf_weight (float, default: 10 ) –

    How likely it is to drop a non TF compared to a TF.

Source code in scprint2/model/utils.py
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def __init__(
    self,
    genes: List[str],
    TFs: List[str] = utils.fileToList(FILEDIR + "/../../data/main/TFs.txt"),
    tf_weight: float = 10,
):
    """
    Randomly mask a batch of data.

    Args:
        genes (List[str]): The list of genes the model might see.
        TFs (List[str]): The list of TFs the model can drop.
        tf_weight (float): How likely it is to drop a non TF compared to a TF.
    """
    TFs = set(TFs)
    self.weights = torch.tensor([tf_weight if gene in TFs else 1 for gene in genes])
    self.max_to_drop = (self.weights == tf_weight).sum()
    self.tf_weight = tf_weight

downsample_profile

This function downsamples the expression profile of a given single cell RNA matrix.

The noise is applied based on the renoise parameter, the total counts of the matrix, and the number of genes. The function first calculates the noise threshold (scaler) based on the renoise parameter. It then generates an initial matrix count by applying a Poisson distribution to a random tensor scaled by the total counts and the number of genes. The function then models the sampling zeros by applying a Poisson distribution to a random tensor scaled by the noise threshold, the total counts, and the number of genes. The function also models the technical zeros by generating a random tensor and comparing it to the noise threshold. The final matrix count is calculated by subtracting the sampling zeros from the initial matrix count and multiplying by the technical zeros. The function ensures that the final matrix count is not less than zero by taking the maximum of the final matrix count and a tensor of zeros. The function returns the final matrix count.

Parameters:
  • mat (Tensor) –

    The input matrix.

  • dropout (float) –

    The renoise parameter.

Returns:
  • Tensor

    torch.Tensor: The matrix count after applying noise.

Source code in scprint2/model/utils.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def downsample_profile(mat: Tensor, dropout: float, method="new", randsamp=False) -> Tensor:
    """
    This function downsamples the expression profile of a given single cell RNA matrix.

    The noise is applied based on the renoise parameter,
    the total counts of the matrix, and the number of genes. The function first calculates the noise
    threshold (scaler) based on the renoise parameter. It then generates an initial matrix count by
    applying a Poisson distribution to a random tensor scaled by the total counts and the number of genes.
    The function then models the sampling zeros by applying a Poisson distribution to a random tensor
    scaled by the noise threshold, the total counts, and the number of genes. The function also models
    the technical zeros by generating a random tensor and comparing it to the noise threshold. The final
    matrix count is calculated by subtracting the sampling zeros from the initial matrix count and
    multiplying by the technical zeros. The function ensures that the final matrix count is not less
    than zero by taking the maximum of the final matrix count and a tensor of zeros. The function
    returns the final matrix count.

    Args:
        mat (torch.Tensor): The input matrix.
        dropout (float): The renoise parameter.

    Returns:
        torch.Tensor: The matrix count after applying noise.
    """
    # Randomly drop on average N counts to each element of expression using a heavy tail Gaussian distribution
    # here we try to get the scale of the distribution so as to remove the right number of counts from each gene
    # https://genomebiology.biomedcentral.com/articles/10.1186/s13059-022-02601-5#:~:text=Zero%20measurements%20in%20scRNA%2Dseq,generation%20of%20scRNA%2Dseq%20data.
    if randsamp:
        dropout = torch.rand(mat.shape[0], device=mat.device) * dropout
        dropout = (
            dropout.unsqueeze(1)
            if len(mat.shape) == 2
            else dropout.unsqueeze(1).unsqueeze(1)
        )
    if method == "old":
        totcounts = mat.sum(-1)
        ngenes = mat.shape[-1]
        tnoise = 1 - (1 - dropout) ** (1 / 2)
        # we model the sampling zeros (dropping 30% of the reads)
        res = torch.poisson(
            torch.rand(mat.shape, device=mat.device)
            * ((tnoise * totcounts.unsqueeze(-1)) / (0.5 * ngenes))
        ).int()
        # we model the technical zeros (dropping 50% of the genes)
        drop = (torch.rand(mat.shape, device=mat.device) > tnoise).int()

        mat = (mat - res) * drop
        return torch.maximum(
            mat,
            torch.zeros(
                (1, 1) if len(mat.shape) == 2 else (1, 1, 1),
                device=mat.device,
                dtype=torch.int,
            ),
        )
    elif method == "jules":
        scaler = (1 - dropout) ** (1 / 2)
        notdrop = (
            torch.rand(
                mat.shape,
                device=mat.device,
            )
            < scaler
        ).int()
        notdrop[mat == 0] = 0
        # apply the dropout after the poisson, right?
        return notdrop * torch.poisson(mat * scaler)
    elif method == "new":
        dropout = dropout * 1.1
        # we model the sampling zeros (dropping 30% of the reads)
        res = torch.poisson((mat * (dropout / 2))).int()
        # we model the technical zeros (dropping 50% of the genes)
        notdrop = (torch.rand(mat.shape, device=mat.device) >= (dropout / 2)).int()
        mat = (mat - res) * notdrop
        return torch.maximum(
            mat,
            torch.zeros(
                (1, 1) if len(mat.shape) == 2 else (1, 1, 1),
                device=mat.device,
                dtype=torch.int,
            ),
        )
    else:
        raise ValueError(f"method {method} not recognized")

make_adata

This function creates an AnnData object from the given input parameters.

Parameters:
  • genes (list) –

    List of genes that will be used as variable names.

  • embs (Tensor | Dict) –

    Embeddings of the cells. The shape of the tensor is (n_cells, n_features). if multiple, it is a dict of name -> tensor

  • pos (Tensor, default: None ) –

    Positions of the cells. The shape of the tensor is (n_cells,).

  • expr_pred (List[Tensor], default: None ) –

    Predicted expression. The shape of the tensors are (n_cells, n_genes). the first is mu, the second theta, the third pi if present

  • classes (list, default: None ) –

    List of classes, the order should be the same as in the pred and gtclass tensors.

  • pred (Tensor, default: None ) –

    Predicted labels. The shape of the tensor is (n_cells, n_classes). Default is None.

  • label_decoders (dict, default: None ) –

    Dictionary to map class codes to class names. Default is None.

  • labels_hierarchy (dict, default: None ) –

    Dictionary representing the hierarchy of labels. Default is {}. see the model for defintion.

  • gtclass (Tensor, default: None ) –

    Ground truth class values. Default is None.

  • doplot (bool, default: True ) –

    Whether to generate plots. Default is True.

Returns:
  • AnnData

    anndata.AnnData: The created AnnData object.

Source code in scprint2/model/utils.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def make_adata(
    genes: List[str],
    embs: Union[Tensor, Dict[str, Tensor]],
    pos: Tensor = None,
    expr_pred: List[Tensor] = None,
    classes: List[str] = None,
    pred: Tensor = None,
    label_decoders: Optional[Dict] = None,
    labels_hierarchy: Optional[Dict] = None,
    gtclass: Optional[Tensor] = None,
    doplot: bool = True,
) -> AnnData:
    """
    This function creates an AnnData object from the given input parameters.

    Args:
        genes (list): List of genes that will be used as variable names.
        embs (torch.Tensor|Dict): Embeddings of the cells. The shape of the tensor is (n_cells, n_features).
            if multiple, it is a dict of name -> tensor
        pos (torch.Tensor): Positions of the cells. The shape of the tensor is (n_cells,).
        expr_pred (List[torch.Tensor]): Predicted expression. The shape of the tensors are (n_cells, n_genes).
            the first is mu, the second theta, the third pi if present
        classes (list): List of classes, the order should be the same as in the pred and gtclass tensors.
        pred (torch.Tensor, optional): Predicted labels. The shape of the tensor is (n_cells, n_classes). Default is None.
        label_decoders (dict, optional): Dictionary to map class codes to class names. Default is None.
        labels_hierarchy (dict, optional): Dictionary representing the hierarchy of labels. Default is {}. see the model for defintion.
        gtclass (torch.Tensor, optional): Ground truth class values. Default is None.
        doplot (bool, optional): Whether to generate plots. Default is True.

    Returns:
        anndata.AnnData: The created AnnData object.
    """
    print("logging the anndata")
    colname = ["pred_" + i for i in classes]
    if pred is not None:
        obs = np.array(pred.to(device="cpu", dtype=torch.int32))
        # label decoders is not cls_decoders. one is a dict to map class codes (ints)
        # to class names the other is the module the predict the class
        if label_decoders is not None:
            obs = np.array(
                [
                    [label_decoders[classes[i]][n] for n in name]
                    for i, name in enumerate(obs.T)
                ]
            ).T
        if gtclass is not None:
            colname += classes
            nobs = np.array(gtclass.to(device="cpu", dtype=torch.int32))
            if label_decoders is not None:
                nobs = np.array(
                    [
                        [label_decoders[classes[i]][n] for n in name]
                        for i, name in enumerate(nobs.T)
                    ]
                ).T
            obs = np.hstack([obs, nobs])

    n_cells = embs[list(embs.keys())[0]].shape[0]
    layers = None
    size = len(genes)
    if pos is not None:
        minval = pos.min()
        maxval = pos.max()
        genes = genes[minval : maxval + 1]
        size = len(genes)
        pos = pos - minval
        mu_array = np.zeros((n_cells, size), dtype=np.float32)
        pos = pos.cpu().numpy()
        # Create empty array with same shape as expr_pred[0]
        # Fill array with values from expr_pred[0]
        for idx in range(n_cells):
            mu_array[idx, pos[idx]] = expr_pred[0][idx].cpu().numpy() + 1
        exist = mu_array.sum(0) != 0
        mu_array = mu_array[:, exist]
        mu_array[mu_array == 1] = 0
        layers = {
            "scprint_mu": mu_array,
            #  "used_scprint": csr_matrix(pos),
        }
        if len(expr_pred) > 1:
            theta_array = np.zeros((n_cells, size), dtype=np.float32)
            # Fill array with values from expr_pred[0]
            for idx in range(n_cells):
                theta_array[idx, pos[idx]] = expr_pred[1][idx].cpu().numpy()
            layers["scprint_theta"] = theta_array[:, exist]

            pi_array = np.zeros((n_cells, size), dtype=np.float32)
            # Fill array with values from expr_pred[0]
            for idx in range(n_cells):
                pi_array[idx, pos[idx]] = expr_pred[2][idx].cpu().numpy()
            layers["scprint_pi"] = pi_array[:, exist]
        genes = [n for i, n in enumerate(genes) if exist[i] > 0]
    else:
        genes = []
    adata = AnnData(
        X=csr_matrix((n_cells, len(genes))),
        layers=layers,
        obs=(
            pd.DataFrame(
                obs,
                columns=colname,
            )
            if pred is not None
            else None
        ),
        var=pd.DataFrame(index=genes),
    )

    for k, v in embs.items():
        adata.obsm["scprint_emb_" + k] = v.cpu().numpy()
        rep = "scprint_emb_" + k
    del embs
    accuracy = {}
    if labels_hierarchy is None:
        labels_hierarchy = {}
    if pred is not None:
        for clss in classes:
            if gtclass is not None:
                tr = translate(set(adata.obs[clss]), clss)
                if tr is not None:
                    adata.obs["conv_" + clss] = adata.obs[clss].replace(tr)
            tr = translate(set(adata.obs["pred_" + clss]), clss)
            if tr is not None:
                adata.obs["conv_pred_" + clss] = adata.obs["pred_" + clss].replace(tr)
            res = []
            if label_decoders is not None and gtclass is not None:
                class_topred = label_decoders[clss].values()
                if clss in labels_hierarchy:
                    cur_labels_hierarchy = {
                        label_decoders[clss][k]: [label_decoders[clss][i] for i in v]
                        for k, v in labels_hierarchy[clss].items()
                    }
                else:
                    cur_labels_hierarchy = {}
                for pred, true in adata.obs[["pred_" + clss, clss]].values:
                    if pred == true:
                        res.append(True)
                        continue
                    if len(labels_hierarchy) > 0:
                        if true in cur_labels_hierarchy:
                            res.append(pred in cur_labels_hierarchy[true])
                        elif true not in class_topred:
                            raise ValueError(
                                f"true label {true} not in available classes"
                            )
                        elif true != "unknown":
                            res.append(False)
                    elif true not in class_topred:
                        raise ValueError(f"true label {true} not in available classes")
                    elif true != "unknown":
                        res.append(False)
                    else:
                        pass
                accuracy["pred_" + clss] = sum(res) / len(res) if len(res) > 0 else 0
        adata.obs = adata.obs.astype("category")
    print(adata)
    if doplot and adata.shape[0] > 100:
        sc.pp.neighbors(adata, use_rep=rep)
        sc.tl.umap(adata)
        sc.tl.leiden(adata, key_added="sprint_leiden")
        if gtclass is not None:
            color = [
                i
                for pair in zip(
                    [
                        "conv_" + i if "conv_" + i in adata.obs.columns else i
                        for i in classes
                    ],
                    [
                        (
                            "conv_pred_" + i
                            if "conv_pred_" + i in adata.obs.columns
                            else "pred_" + i
                        )
                        for i in classes
                    ],
                )
                for i in pair
            ]
            fig, axs = plt.subplots(
                int(len(color) / 2), 2, figsize=(24, len(color) * 4)
            )
            plt.subplots_adjust(wspace=1)
            if len(color) > 2:
                for i, col in enumerate(color):
                    sc.pl.umap(
                        adata,
                        color=col,
                        ax=axs[i // 2, i % 2],
                        show=False,
                    )
                    acc = ""
                    if "pred_" in col and col.split("conv_")[-1] in accuracy:
                        acc = " (accuracy: {:.2f})".format(
                            accuracy[col.split("conv_")[-1]]
                        )
                    axs[i // 2, i % 2].set_title(col + " UMAP" + acc)
                    if "cell_type" in col:
                        axs[i // 2, i % 2].legend(fontsize="x-small")
                    axs[i // 2, i % 2].set_xlabel("UMAP1")
                    axs[i // 2, i % 2].set_ylabel("UMAP2")
            else:
                for i, col in enumerate(color):
                    sc.pl.umap(
                        adata,
                        color=col,
                        ax=axs[i % 2],
                        show=False,
                    )
                    acc = ""
                    if "pred_" in col and col.split("conv_")[-1] in accuracy:
                        acc = " (accuracy: {:.2f})".format(
                            accuracy[col.split("conv_")[-1]]
                        )
                    axs[i % 2].set_title(col + " UMAP" + acc)
                    if "cell_type" in col:
                        axs[i % 2].legend(fontsize="x-small")
                    axs[i % 2].set_xlabel("UMAP1")
                    axs[i % 2].set_ylabel("UMAP2")
        else:
            color = [
                (
                    "conv_pred_" + i
                    if "conv_pred_" + i in adata.obs.columns
                    else "pred_" + i
                )
                for i in classes
            ]
            if len(color) > 1:
                fig, axs = plt.subplots(len(color), 1, figsize=(16, len(color) * 8))
                for i, col in enumerate(color):
                    sc.pl.umap(
                        adata,
                        color=col,
                        ax=axs[i],
                        show=False,
                    )
                    acc = ""
                    if "pred_" in col and col.split("conv_")[-1] in accuracy:
                        acc = " (accuracy: {:.2f})".format(
                            accuracy[col.split("conv_")[-1]]
                        )
                    axs[i].set_title(col + " UMAP of " + rep + " embedding " + acc)
                    axs[i].set_xlabel("UMAP1")
                    axs[i].set_ylabel("UMAP2")
            else:
                fig = sc.pl.umap(
                    adata,
                    color=color,
                    show=False,
                    return_fig=True,
                )
        plt.show()
    else:
        fig = None
    return adata, fig

simple_masker

Randomly mask a batch of data.

Parameters:
  • shape (List[int]) –

    The shape of the data.

  • mask_ratio (float, default: 0.15 ) –

    The ratio of genes to mask, default to 0.15.

Returns:
  • Tensor

    torch.Tensor: A tensor of masked data.

Source code in scprint2/model/utils.py
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def simple_masker(
    shape: List[int],
    mask_ratio: float = 0.15,
) -> torch.Tensor:
    """
    Randomly mask a batch of data.

    Args:
        shape (List[int]): The shape of the data.
        mask_ratio (float): The ratio of genes to mask, default to 0.15.

    Returns:
        torch.Tensor: A tensor of masked data.
    """
    return torch.rand(shape) < mask_ratio

test

Test the given model on the full set of benchmarks and save the results to JSON files.

Parameters:
  • model (Module) –

    The model to be tested.

  • filedir (str) –

    The directory where the data files are located.

  • do_class (bool, default: True ) –

    Whether to perform classification. Defaults to True.

  • maxcells_grn (int, default: 1024 ) –

    Maximum cells for GRN analysis. Defaults to 1024.

Returns:
  • None

    None

Source code in scprint2/model/utils.py
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
def test(
    model: torch.nn.Module,
    filedir: str,
    do_class: bool = True,
    maxcells_grn: int = 1024,
) -> None:
    """
    Test the given model on the full set of benchmarks and save the results to JSON files.

    Args:
        model (torch.nn.Module): The model to be tested.
        filedir (str): The directory where the data files are located.
        do_class (bool): Whether to perform classification. Defaults to True.
        maxcells_grn (int): Maximum cells for GRN analysis. Defaults to 1024.

    Returns:
        None
    """
    metrics = {}
    tot = {}
    for dataset, path in EMBEDDING_DATASETS.items():
        res = embbed_task.default_benchmark(
            model,
            dataset=path,
            do_class=do_class,
            coarse=False,
        )
        tot["embed_" + dataset] = res
        metrics.update(
            {
                "emb_" + dataset + "/scib": float(res["scib"]["Total"]),
                "emb_" + dataset + "/scib_bio": float(res["scib"]["Bio conservation"]),
                "emb_"
                + dataset
                + "/scib_batch": float(res["scib"]["Batch correction"]),
                "emb_"
                + dataset
                + "/ct_class": float(
                    res["classif"].get("cell_type_ontology_term_id", {}).get("macro", 0)
                    if do_class
                    else 0
                ),
                "emb_"
                + dataset
                + "/ct_class_macro": float(
                    res["classif"].get("cell_type_ontology_term_id", {}).get("macro", 0)
                    if do_class
                    else 0
                ),
            }
        )
        print(metrics)
        gc.collect()
    for dataset, filepath in DENOISE_DATASETS.items():
        res = denoise_task.default_benchmark(model, dataset=filepath)
        tot["denoise_" + dataset] = res
        metrics.update(
            {
                "denoise_"
                + dataset
                + "/reco2full_vs_noisy2full": float(
                    res["reco2full"] - res["noisy2full"]
                ),
            }
        )
        print(metrics)
        gc.collect()
    res = grn_task.default_benchmark(
        model,
        "gwps",
        batch_size=32 if model.d_model <= 512 else 8,
        maxcells=maxcells_grn,
    )
    tot["grn_gwps"] = res
    metrics.update(
        {
            "grn_gwps/auprc_self": float(res["self"]["auprc"]),
            "grn_gwps/epr_self": float(res["self"]["epr"]),
            "grn_gwps/auprc_omni": float(res["omni"]["auprc"]),
            "grn_gwps/epr_omni": float(res["omni"]["epr"]),
            "grn_gwps/auprc": float(res["mean"]["auprc"]),
            "grn_gwps/epr": float(res["mean"]["epr"]),
        }
    )
    print(metrics)
    gc.collect()
    for dataset, filepath in {
        "old_kidney": "https://datasets.cellxgene.cziscience.com/ede85b09-454b-4374-bf60-5f675e989b64.h5ad",
        # "kidney": "https://datasets.cellxgene.cziscience.com/01bc7039-961f-4c24-b407-d535a2a7ba2c.h5ad",
        "lung_smart": "https://datasets.cellxgene.cziscience.com/6ebba0e0-a159-406f-8095-451115673a2c.h5ad",
        # filedir + "/../../data/yBCKp6HmXuHa0cZptMo7.h5ad",
    }.items():
        res = grn_task.default_benchmark(
            model,
            filepath,
            # kidney dataset (2.87, 1.27) (0.00147, 0.00133)
            batch_size=32 if model.d_model <= 512 else 8,
            maxcells=maxcells_grn,
            maxgenes=4000,
        )
        tot["grn_omni_" + dataset] = res
        metrics.update(
            {
                "grn_omni_"
                + dataset
                + "/auprc_class": float(
                    np.mean([i["auprc"] for k, i in res.items() if "_class" in k])
                ),
                "grn_omni_"
                + dataset
                + "/or_class": float(
                    np.mean([i["odd_ratio"] for k, i in res.items() if "_class" in k])
                ),
                "grn_omni_"
                + dataset
                + "/tf_enr_class": float(
                    np.sum(
                        [
                            i.get("TF_enr", False)
                            for k, i in res.items()
                            if "_class" in k
                        ]
                    )
                ),
                "grn_omni_"
                + dataset
                + "/tf_targ_enr_class": float(
                    np.mean(
                        [
                            i["significant_enriched_TFtargets"]
                            for k, i in res.items()
                            if "_class" in k
                        ]
                    )
                ),
                "grn_omni_"
                + dataset
                + "/auprc": float(
                    np.mean([i["auprc"] for k, i in res.items() if "_mean" in k])
                ),
                "grn_omni_"
                + dataset
                + "/epr": float(
                    np.mean([i["epr"] for k, i in res.items() if "_mean" in k])
                ),
                "grn_omni_"
                + dataset
                + "/or": float(
                    np.mean([i["odd_ratio"] for k, i in res.items() if "_mean" in k])
                ),
                "grn_omni_"
                + dataset
                + "/tf_enr": float(
                    np.sum(
                        [i.get("TF_enr", False) for k, i in res.items() if "_mean" in k]
                    )
                ),
                "grn_omni_"
                + dataset
                + "/tf_targ_enr": float(
                    np.mean(
                        [
                            i["significant_enriched_TFtargets"]
                            for k, i in res.items()
                            if "_mean" in k
                        ]
                    )
                ),
                # 'grn_omni/ct': res['classif']['cell_type_ontology_term_id']['accuracy'],
            }
        )
        print(metrics)
        gc.collect()
    return metrics, tot

zinb_sample

zinb_sample This function generates a sample from a Zero-Inflated Negative Binomial (ZINB) distribution.

Parameters:
  • mu (Tensor) –

    The mean of the Negative Binomial (NB) distribution.

  • theta (Tensor) –

    The dispersion parameter of the NB distribution.

  • zi_probs (Tensor) –

    The zero-inflation probabilities.

  • sample_shape (Size, default: Size([]) ) –

    The output shape. Defaults to torch.Size([]).

Returns:
  • Tensor

    torch.Tensor: A sample from the ZINB distribution.

Source code in scprint2/model/utils.py
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
def zinb_sample(
    mu: torch.Tensor,
    theta: torch.Tensor,
    zi_probs: torch.Tensor,
    sample_shape: torch.Size = torch.Size([]),
) -> torch.Tensor:
    """
    zinb_sample This function generates a sample from a Zero-Inflated Negative Binomial (ZINB) distribution.

    Args:
        mu (torch.Tensor): The mean of the Negative Binomial (NB) distribution.
        theta (torch.Tensor): The dispersion parameter of the NB distribution.
        zi_probs (torch.Tensor): The zero-inflation probabilities.
        sample_shape (torch.Size, optional): The output shape. Defaults to torch.Size([]).

    Returns:
        torch.Tensor: A sample from the ZINB distribution.
    """
    concentration = theta
    rate = theta / mu
    # Important remark: Gamma is parametrized by the rate = 1/scale!
    gamma_d = Gamma(concentration=concentration, rate=rate)
    p_means = gamma_d.sample(sample_shape)

    # Clamping as distributions objects can have buggy behaviors when
    # their parameters are too high
    l_train = torch.clamp(p_means, max=1e8)
    samp = Poisson(l_train).sample()  # Shape : (n_samples, n_cells_batch, n_vars)
    is_zero = torch.rand_like(samp) <= zi_probs
    samp_ = torch.where(is_zero, torch.zeros_like(samp), samp)
    return samp_

encoder and decoder modules

scprint2.model.encoders

Classes:

Name Description
CategoryValueEncoder
ContinuousValueEncoder
DPositionalEncoding

The PositionalEncoding module applies a positional encoding to a sequence of vectors.

EasyExprGNN
ExprBasedFT
GNN
GeneEncoder
PositionalEncoding

CategoryValueEncoder

Bases: Module

Encodes categorical values into a vector using an embedding layer and layer normalization.

Parameters:
  • num_embeddings (int) –

    The number of possible values.

  • embedding_dim (int) –

    The dimension of the output vectors.

  • padding_idx (int, default: None ) –

    The index of the padding token. Defaults to None.

Note: not used in the current version of scprint-2.

Source code in scprint2/model/encoders.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def __init__(
    self,
    num_embeddings: int,
    embedding_dim: int,
    padding_idx: Optional[int] = None,
):
    """
    Encodes categorical values into a vector using an embedding layer and layer normalization.

    Args:
        num_embeddings (int): The number of possible values.
        embedding_dim (int): The dimension of the output vectors.
        padding_idx (int, optional): The index of the padding token. Defaults to None.

    Note: not used in the current version of scprint-2.
    """
    super(CategoryValueEncoder, self).__init__()
    self.embedding = nn.Embedding(
        num_embeddings, embedding_dim, padding_idx=padding_idx
    )

ContinuousValueEncoder

Bases: Module

Encode real number values to a vector using neural nets projection.

Parameters:
  • d_model (int) –

    The dimension of the input vectors.

  • dropout (float, default: 0.1 ) –

    The dropout rate to apply to the output of the positional encoding.

  • max_value (int, default: 100000 ) –

    The maximum value of the input. Defaults to 100_000.

  • layers (int, default: 1 ) –

    The number of layers in the encoder. Defaults to 1.

  • size (int, default: 1 ) –

    The size of the input. Defaults to 1.

Methods:

Name Description
forward

Args:

Source code in scprint2/model/encoders.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def __init__(
    self,
    d_model: int,
    dropout: float = 0.1,
    max_value: int = 100_000,
    layers: int = 1,
    size: int = 1,
):
    """
    Encode real number values to a vector using neural nets projection.

    Args:
        d_model (int): The dimension of the input vectors.
        dropout (float, optional): The dropout rate to apply to the output of the positional encoding.
        max_value (int, optional): The maximum value of the input. Defaults to 100_000.
        layers (int, optional): The number of layers in the encoder. Defaults to 1.
        size (int, optional): The size of the input. Defaults to 1.
    """
    super(ContinuousValueEncoder, self).__init__()
    self.max_value = max_value
    self.encoder = nn.ModuleList()
    self.output_dim = d_model
    # self.mask_value = nn.Embedding(1, d_model)
    self.encoder.append(nn.Linear(size, d_model))
    for _ in range(layers - 1):
        self.encoder.append(nn.LayerNorm(d_model))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.Dropout(p=dropout))
        self.encoder.append(nn.Linear(d_model, d_model))

forward

Parameters:
  • x (Tensor) –

    Tensor, shape [batch_size, seq_len]

Source code in scprint2/model/encoders.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
    """
    Args:
        x: Tensor, shape [batch_size, seq_len]
    """
    # expand last dimension
    x = x.unsqueeze(-1)
    # use the mask embedding when x=-1
    # mask = (x == -1).float()
    x = torch.clamp(x, min=0, max=self.max_value)
    for val in self.encoder:
        x = val(x)
    if mask is not None:
        x = x.masked_fill_(mask.unsqueeze(-1), 0)
        # x = x.masked_fill_(mask.unsqueeze(-1), self.mask_value(0))
    return x

DPositionalEncoding

Bases: Module

The PositionalEncoding module applies a positional encoding to a sequence of vectors. This is necessary for the Transformer model, which does not have any inherent notion of position in a sequence. The positional encoding is added to the input embeddings and allows the model to attend to positions in the sequence.

Parameters:
  • d_model (int) –

    The dimension of the input vectors.

  • max_len_x (int) –

    The maximum length in the x dimension.

  • max_len_y (int) –

    The maximum length in the y dimension.

  • maxvalue_x (float, default: 10000.0 ) –

    Maximum value for x dimension scaling. Defaults to 10000.0.

  • maxvalue_y (float, default: 10000.0 ) –

    Maximum value for y dimension scaling. Defaults to 10000.0.

Note: not used in the current version of scprint-2.

Methods:

Name Description
forward

Args:

Source code in scprint2/model/encoders.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def __init__(
    self,
    d_model: int,
    max_len_x: int,
    max_len_y: int,
    maxvalue_x=10000.0,
    maxvalue_y=10000.0,
):
    super(DPositionalEncoding, self).__init__()
    position2 = torch.arange(max_len_y).unsqueeze(1)
    position1 = torch.arange(max_len_x).unsqueeze(1)

    half_n = d_model // 2

    div_term2 = torch.exp(
        torch.arange(0, half_n, 2) * (-math.log(maxvalue_y) / d_model)
    )
    div_term1 = torch.exp(
        torch.arange(0, half_n, 2) * (-math.log(maxvalue_x) / d_model)
    )
    pe1 = torch.zeros(max_len_x, 1, d_model)
    pe2 = torch.zeros(max_len_y, 1, d_model)
    pe1[:, 0, 0:half_n:2] = torch.sin(position1 * div_term1)
    pe1[:, 0, 1:half_n:2] = torch.cos(position1 * div_term1)
    pe2[:, 0, half_n::2] = torch.sin(position2 * div_term2)
    pe2[:, 0, 1 + half_n :: 2] = torch.cos(position2 * div_term2)
    # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
    self.register_buffer("pe1", pe1)
    self.register_buffer("pe2", pe2)

forward

Parameters:
  • x (Tensor) –

    Tensor, shape [seq_len, batch_size, embedding_dim]

Source code in scprint2/model/encoders.py
234
235
236
237
238
239
240
241
def forward(self, x: Tensor, pos_x: Tensor, pos_y: Tensor) -> Tensor:
    """
    Args:
        x: Tensor, shape [seq_len, batch_size, embedding_dim]
    """
    x = x + self.pe1[pos_x]
    x = x + self.pe2[pos_y]
    return x

EasyExprGNN

Bases: Module

Easy Expression Graph Neural Network

The main GNN used in scPRINT-2 for expression encoding. It is inspired from the DeepSets architecture to aggregate neighbor information.

Parameters:
  • self_dim (int, default: 64 ) –

    Dimension of the self features

  • output_dim (int, default: 32 ) –

    Output dimension

  • self_layers (int, default: 2 ) –

    Number of layers for self features

  • dropout (float, default: 0.1 ) –

    Dropout rate

  • shared_layers (int, default: 2 ) –

    Number of shared layers

  • neighbors_layers (int, default: 2 ) –

    Number of layers for neighbors features

Methods:

Name Description
forward

Forward pass of the Easy Expression GNN

Source code in scprint2/model/encoders.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def __init__(
    self,
    self_dim: int = 64,
    output_dim: int = 32,
    self_layers: int = 2,
    dropout: float = 0.1,
    shared_layers: int = 2,
    neighbors_layers: int = 2,
):
    """
    Easy Expression Graph Neural Network

    The main GNN used in scPRINT-2 for expression encoding.
    It is inspired from the DeepSets architecture to aggregate neighbor information.

    Args:
        self_dim (int): Dimension of the self features
        output_dim (int): Output dimension
        self_layers (int): Number of layers for self features
        dropout (float): Dropout rate
        shared_layers (int): Number of shared layers
        neighbors_layers (int): Number of layers for neighbors features
    """
    super(EasyExprGNN, self).__init__()
    self.output_dim = output_dim
    self.self_dim = self_dim
    # neighbors
    self.neighbors_layers = nn.ModuleList()
    self.neighbors_layers.append(nn.Linear(2, self_dim // 2))
    for i in range(neighbors_layers - 1):
        self.neighbors_layers.append(nn.LayerNorm(self_dim // 2))
        self.neighbors_layers.append(nn.ReLU())
        self.neighbors_layers.append(nn.Dropout(p=dropout))
        self.neighbors_layers.append(nn.Linear(self_dim // 2, self_dim // 2))
    # self
    self.self_layers = nn.ModuleList()
    self.self_layers.append(nn.Linear(1, self_dim // 2))
    for i in range(self_layers - 1):
        self.self_layers.append(nn.LayerNorm(self_dim // 2))
        self.self_layers.append(nn.ReLU())
        self.self_layers.append(nn.Dropout(p=dropout))
        self.self_layers.append(nn.Linear(self_dim // 2, self_dim // 2))
    # shared
    self.shared_layers = nn.ModuleList()
    for i in range(shared_layers - 1):
        self.shared_layers.append(nn.Linear(self_dim, self_dim))
        self.shared_layers.append(nn.LayerNorm(self_dim))
        self.shared_layers.append(nn.ReLU())
        self.shared_layers.append(nn.Dropout(p=dropout))
    self.shared_layers.append(nn.Linear(self_dim, output_dim))

forward

Forward pass of the Easy Expression GNN

Parameters:
  • expr (Optional[Tensor], default: None ) –

    Tensor of shape (batch, seq_len) representing expression values

  • neighbors (Optional[Tensor], default: None ) –

    Tensor of shape (batch, seq_len, n_neighbors) representing neighbor indices

  • edge_info (Optional[Tensor], default: None ) –

    Tensor of shape (batch, seq_len, n_neighbors) representing edge information

  • mask (Optional[Tensor], default: None ) –

    Tensor of shape (batch, seq_len) representing mask for the input

Returns:
  • Tensor

    Tensor of shape (batch, seq_len, output_dim) representing the output features

Source code in scprint2/model/encoders.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
def forward(
    self,
    expr: Optional[Tensor] = None,
    neighbors: Optional[Tensor] = None,
    edge_info: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
) -> Tensor:
    """
    Forward pass of the Easy Expression GNN

    Args:
        expr: Tensor of shape (batch, seq_len) representing expression values
        neighbors: Tensor of shape (batch, seq_len, n_neighbors) representing neighbor indices
        edge_info: Tensor of shape (batch, seq_len, n_neighbors) representing edge information
        mask: Tensor of shape (batch, seq_len) representing mask for the input

    Returns:
        Tensor of shape (batch, seq_len, output_dim) representing the output features
    """
    # batch, seq_len, neighbs
    if neighbors is None:
        neighbors = torch.zeros(
            (expr.shape[0], expr.shape[1], self.self_dim // 2), device=expr.device
        )
    else:
        neighbors = neighbors.transpose(1, 2)
        neighbors = torch.cat(
            [neighbors.unsqueeze(-1), edge_info.unsqueeze(-1)], dim=-1
        )
        for i, layer in enumerate(self.neighbors_layers):
            # batch, seq_len, neighbs, hidden_dim
            neighbors = layer(neighbors)
        neighbors = neighbors.sum(-2)
    if expr is None:
        expr = torch.zeros(
            (neighbors.shape[0], neighbors.shape[1], 1), device=neighbors.device
        )
    else:
        expr = expr.unsqueeze(-1)
        for i, layer in enumerate(self.self_layers):
            expr = layer(expr)
    x = torch.cat([expr, neighbors], dim=-1)
    for layer in self.shared_layers:
        # batch, seq_len, neighbs, hidden_dim
        x = layer(x)
    if mask is not None:
        x = x.masked_fill(mask.unsqueeze(-1), 0)
    return x

ExprBasedFT

Bases: Module

Encode real number values to a vector using neural nets projection.

Parameters:
  • d_model (int) –

    The dimension of the input vectors.

  • gene_encoder (Module) –

    The gene name encoder module.

  • expr_encoder (Module, default: Identity() ) –

    The expression encoder module. Defaults to nn.Identity.

  • dropout (float, default: 0.1 ) –

    The dropout rate to apply to the output of the positional encoding.

  • layers (int, default: 2 ) –

    The number of layers in the encoder. Defaults to 2.

  • intermediary_d (int, default: 256 + 64 ) –

    The dimension of the intermediary layers. Defaults to 256 + 64.

Methods:

Name Description
forward

Args:

Source code in scprint2/model/encoders.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def __init__(
    self,
    d_model: int,
    gene_encoder: nn.Module,
    expr_encoder: nn.Module = nn.Identity(),
    dropout: float = 0.1,
    layers: int = 2,
    intermediary_d: int = 256 + 64,
):
    """
    Encode real number values to a vector using neural nets projection.

    Args:
        d_model (int): The dimension of the input vectors.
        gene_encoder (nn.Module): The gene name encoder module.
        expr_encoder (nn.Module, optional): The expression encoder module. Defaults to nn.Identity.
        dropout (float, optional): The dropout rate to apply to the output of the positional encoding.
        layers (int, optional): The number of layers in the encoder. Defaults to 2.
        intermediary_d (int, optional): The dimension of the intermediary layers. Defaults to 256 + 64.

    """
    super(ExprBasedFT, self).__init__()
    self.encoder = nn.ModuleList()
    # self.mask_value = nn.Embedding(1, d_model)
    self.add_module("gene_encoder", gene_encoder)
    self.add_module("expr_encoder", expr_encoder)
    expr_shape, gene_shape = (
        self.expr_encoder.output_dim,
        self.gene_encoder.output_dim,
    )
    self.encoder.append(nn.Linear(expr_shape + gene_shape, intermediary_d))
    for i in range(layers - 1):
        self.encoder.append(nn.LayerNorm(intermediary_d))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.Dropout(p=dropout))
        self.encoder.append(
            nn.Linear(intermediary_d, intermediary_d if i < layers - 2 else d_model)
        )

forward

Parameters:
  • gene_pos (Tensor[batch_size, seq_len]) –

    Gene position indices input to the gene encoder

  • expr ((Tensor[batch_size, seq_len], Optional), default: None ) –

    Expression values input to the expression encoder

  • mask ((Tensor[batch_size, seq_len], Optional), default: None ) –

    Mask for the input input to the expression encoder

  • neighbors ((Tensor[batch_size, seq_len, n_neighbors], Optional), default: None ) –

    Neighbors indices input to the expression encoder when it is a GNN

  • neighbors_info ((Tensor[batch_size, seq_len, n_neighbors], Optional), default: None ) –

    optional additional information about the neighbors input to the expression encoder when it is a GNN

Source code in scprint2/model/encoders.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def forward(
    self,
    gene_pos: Tensor,
    expr: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
    neighbors: Optional[Tensor] = None,
    neighbors_info: Optional[Tensor] = None,
) -> Tensor:
    """
    Args:
        gene_pos (Tensor[batch_size, seq_len]): Gene position indices
            input to the gene encoder
        expr (Tensor[batch_size, seq_len], Optional): Expression values
            input to the expression encoder
        mask (Tensor[batch_size, seq_len], Optional): Mask for the input
            input to the expression encoder
        neighbors (Tensor[batch_size, seq_len, n_neighbors], Optional): Neighbors indices
            input to the expression encoder when it is a GNN
        neighbors_info (Tensor[batch_size, seq_len, n_neighbors], Optional):
            optional additional information about the neighbors
            input to the expression encoder when it is a GNN
    """
    # expand last dimension
    if neighbors is None and expr is None:
        expr = torch.zeros(
            (gene_pos.shape[0], gene_pos.shape[1], self.expr_encoder.output_dim),
            dtype=torch.float32,
            device=gene_pos.device,
        )
        # if no expr information: consider that it is all masked
    else:
        expr = (
            self.expr_encoder(expr, mask=mask)
            if neighbors is None
            else self.expr_encoder(expr, neighbors, neighbors_info, mask=mask)
        )
    gene_pos = self.gene_encoder(gene_pos)
    x = torch.cat([expr, gene_pos], dim=-1)
    for val in self.encoder:
        x = val(x)
    return x

GNN

Bases: Module

Graph Neural Network model

Another implementation of a GNN layer that can be used for expression encoding. Supports GCN, GAT, GraphSAGE, and DeepSets architectures.

Parameters:
  • input_dim (int, default: 1 ) –

    Dimension of input node features

  • output_dim (int, default: 256 ) –

    Dimension of output node features

  • num_layers (int, default: 2 ) –

    Number of GNN layers

  • dropout (float, default: 0.1 ) –

    Dropout probability

  • gnn_type (str, default: 'deepset' ) –

    Type of GNN layer ('gcn', 'gat', 'sage', or 'deepset')

Methods:

Name Description
forward

Forward pass

Source code in scprint2/model/encoders.py
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
def __init__(
    self,
    input_dim: int = 1,  # here, 1 or 2
    merge_dim: int = 32,
    output_dim: int = 256,
    num_layers: int = 2,
    dropout: float = 0.1,
    gnn_type: str = "deepset",
    add_connection_feature: bool = False,
):
    """
    Graph Neural Network model

    Another implementation of a GNN layer that can be used for expression encoding.
    Supports GCN, GAT, GraphSAGE, and DeepSets architectures.

    Args:
        input_dim: Dimension of input node features
        output_dim: Dimension of output node features
        num_layers: Number of GNN layers
        dropout: Dropout probability
        gnn_type: Type of GNN layer ('gcn', 'gat', 'sage', or 'deepset')
    """
    super().__init__()

    self.input_dim = input_dim
    self.output_dim = output_dim
    if num_layers == 1:
        raise ValueError("num_layers must be greater than 1")
    self.num_layers = num_layers
    self.dropout = dropout
    self.gnn_type = gnn_type
    self.add_connection_feature = add_connection_feature

    if gnn_type == "deepset":
        # Local MLP (phi) for processing individual nodes
        self.input_nn_layer = MLP(
            in_channels=input_dim,
            hidden_channels=merge_dim,
            out_channels=merge_dim,
            num_layers=num_layers,
            dropout=dropout,
            act="relu",
            norm="layer_norm",
        )

        self.input_self_layer = MLP(
            in_channels=input_dim,
            hidden_channels=merge_dim + 2,
            out_channels=merge_dim,
            num_layers=num_layers - 1,
            dropout=dropout,
            act="relu",
            norm="layer_norm",
        )

        # Global MLP (rho) for processing aggregated features
        self.output_layer = MLP(
            in_channels=(
                (merge_dim * 2) + 1 if add_connection_feature else merge_dim * 2
            ),
            hidden_channels=output_dim,
            out_channels=output_dim,
            num_layers=num_layers,
            dropout=dropout,
            act="relu",
            norm="layer_norm",
        )

        return

    # Select GNN layer type for other architectures
    else:
        if gnn_type == "gcn":
            gnn_layer = GCNConv
        elif gnn_type == "gat":
            gnn_layer = GATConv
        elif gnn_type == "sage":
            gnn_layer = SAGEConv
        else:
            raise ValueError(f"Unknown GNN type: {gnn_type}")

        self.gnn_layer = gnn_layer(
            output_dim,
            output_dim,
            add_self_loops=False,
            normalize=False,
            aggr="mean",
        )

forward

Forward pass

Parameters:
  • x (Tensor) –

    Node features [minibatch_size, ngenes]

  • neighbors (Tensor) –

    Neighbor nodes [minibatch_size, ngenes, n_neighbors] or [minibatch_size, ngenes, n_neighbors, 2]

  • edge_info (Tensor, default: None ) –

    Graph connectivity [2, num_edges] if gnn_type != deepset, Edge features [num_edges, 1] if gnn_type == deepset, or None if gnn_type == deepset and no edge features.

  • batch (Tensor, default: None ) –

    Batch assignment vector [num_nodes]

  • mask (Tensor, default: None ) –

    Mask tensor for the nodes.

Returns:
  • Tensor( Tensor ) –

    Node embeddings [num_nodes, hidden_dim]

Source code in scprint2/model/encoders.py
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def forward(
    self,
    x: Tensor,
    neighbors: Tensor,
    edge_info: Optional[Tensor] = None,
    batch: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
) -> Tensor:
    """
    Forward pass

    Args:
        x (Tensor): Node features [minibatch_size, ngenes]
        neighbors (Tensor): Neighbor nodes [minibatch_size, ngenes, n_neighbors] or [minibatch_size, ngenes, n_neighbors, 2]
        edge_info (Tensor, optional): Graph connectivity [2, num_edges] if gnn_type != deepset,
            Edge features [num_edges, 1] if gnn_type == deepset,
            or None if gnn_type == deepset and no edge features.
        batch (Tensor, optional): Batch assignment vector [num_nodes]
        mask (Tensor, optional): Mask tensor for the nodes.

    Returns:
        Tensor: Node embeddings [num_nodes, hidden_dim]
    """

    # Standard GNN forward pass
    x = x.unsqueeze(-1)
    neighbors = neighbors.unsqueeze(-1)
    if self.gnn_type == "deepset":
        neighbors = self.input_nn_layer(neighbors).sum(dim=-3)
        x = self.input_self_layer(x)
        x = torch.cat([x, neighbors], dim=-1)
    else:
        x = self.gnn_layer(x, edge_info)
        neighbors = self.gnn_layer(neighbors, edge_info)
        for layer in self.layers:
            x = layer(x, edge_info)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        # TODO: to finish

    x = self.output_layer(x)
    if mask is not None:
        x = x.masked_fill_(mask.unsqueeze(-1), 0)
    return x

GeneEncoder

Bases: Module

Encodes gene sequences into a continuous vector space using an embedding layer. Uses memory mapping for efficient access to large embedding files.

Parameters:
  • num_embeddings (int) –

    The number of possible values

  • embedding_dim (int) –

    The dimension of the output vectors

  • padding_idx (int, default: None ) –

    The index of the padding token

  • weights (Tensor, default: None ) –

    The initial weights for the embedding layer

  • weights_file (str, default: None ) –

    Path to parquet file containing embeddings

  • freeze (bool, default: False ) –

    Whether to freeze the weights of the embedding layer

Methods:

Name Description
__del__

Cleanup method to ensure proper handling of memory-mapped file.

forward

Forward pass of the encoder.

Source code in scprint2/model/encoders.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(
    self,
    num_embeddings: int,
    embedding_dim: int,
    padding_idx: Optional[int] = None,
    weights: Optional[Tensor] = None,
    weights_file: Optional[str] = None,
    freeze: bool = False,
):
    """
    Encodes gene sequences into a continuous vector space using an embedding layer.
    Uses memory mapping for efficient access to large embedding files.

    Args:
        num_embeddings (int): The number of possible values
        embedding_dim (int): The dimension of the output vectors
        padding_idx (int, optional): The index of the padding token
        weights (Tensor, optional): The initial weights for the embedding layer
        weights_file (str, optional): Path to parquet file containing embeddings
        freeze (bool, optional): Whether to freeze the weights of the embedding layer
    """
    super(GeneEncoder, self).__init__()
    self.output_dim = embedding_dim

    if weights_file is not None:
        self.memmap = True
        if not freeze:
            raise ValueError(
                "freeze must be True when using memory-mapped embeddings"
            )
        # Load the parquet file and create memory-mapped array
        import os

        import pandas as pd

        # Create memory-mapped file path
        self.mmap_file = f"{weights_file}.mmap"
        self.loc = None
        self.enc = None
        # Only create the memory-mapped file if it doesn't exist
        if not os.path.exists(self.mmap_file):
            print(f"Creating memory-mapped file for embeddings at {self.mmap_file}")
            df = pd.read_parquet(weights_file)
            embeddings = torch.nn.AdaptiveAvgPool1d(self.output_dim)(
                torch.tensor(df.values)
            )

            # Create memory-mapped array
            self.embeddings = np.memmap(
                self.mmap_file, dtype="float32", mode="w+", shape=embeddings.shape
            )
            # Copy data to memory-mapped array
            self.embeddings[:] = embeddings.numpy()
            #
            self.embeddings.flush()

            # Clean up memory
            del df
            del embeddings
        else:
            print(
                f"Loading existing memory-mapped embeddings from {self.mmap_file}"
            )
            # Load existing memory-mapped file
            self.embeddings = np.memmap(
                self.mmap_file,
                dtype="float32",
                mode="r",  # Read-only mode since we don't need to modify
                shape=(num_embeddings, embedding_dim),
            )
    else:
        self.memmap = False
        self.embeddings = nn.Embedding(
            num_embeddings, embedding_dim, padding_idx=padding_idx, _freeze=freeze
        )
        if weights is not None:
            self.embeddings.weight.data.copy_(torch.Tensor(weights))

__del__

Cleanup method to ensure proper handling of memory-mapped file.

Source code in scprint2/model/encoders.py
116
117
118
119
120
121
122
def __del__(self):
    """Cleanup method to ensure proper handling of memory-mapped file."""
    if hasattr(self, "embeddings") and self.embeddings is not None:
        try:
            self.embeddings._mmap.close()
        except:
            pass

forward

Forward pass of the encoder.

Parameters:
  • x (Tensor) –

    Input tensor of indices [batch_size, seq_len]

Returns:
  • Tensor( Tensor ) –

    Embedded vectors [batch_size, seq_len, embedding_dim]

Source code in scprint2/model/encoders.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def forward(self, x: Tensor) -> Tensor:
    """
    Forward pass of the encoder.

    Args:
        x (Tensor): Input tensor of indices [batch_size, seq_len]

    Returns:
        Tensor: Embedded vectors [batch_size, seq_len, embedding_dim]
    """
    if self.memmap:
        if self.loc is None or not torch.all(x.sum(1) == self.loc):
            self.enc = (
                torch.from_numpy(
                    self.embeddings[x.reshape(-1).cpu().numpy()].copy()
                )
                .reshape(x.shape + (-1,))
                .to(x.device)
            )
            self.loc = x.sum(1)
        return self.enc.clone()
    else:
        return self.embeddings(x)

PositionalEncoding

Bases: Module

The PositionalEncoding module applies a positional encoding to a sequence of vectors. This is necessary for the Transformer model, which does not have any inherent notion of position in a sequence. The positional encoding is added to the input embeddings and allows the model to attend to positions in the sequence.

Parameters:
  • d_model (int) –

    The dimension of the input vectors.

  • gene_pos_enc (list[str], default: [] ) –

    The gene position encoding to use.

Note: not used in the current version of scprint-2.

Methods:

Name Description
forward

Args:

Source code in scprint2/model/encoders.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def __init__(
    self,
    d_model: int,
    gene_pos_enc: list[str] = [],
):
    """
    The PositionalEncoding module applies a positional encoding to a sequence of vectors.
    This is necessary for the Transformer model, which does not have any inherent notion of
    position in a sequence. The positional encoding is added to the input embeddings and
    allows the model to attend to positions in the sequence.

    Args:
        d_model (int): The dimension of the input vectors.
        gene_pos_enc (list[str], optional): The gene position encoding to use.

    Note: not used in the current version of scprint-2.
    """
    super(PositionalEncoding, self).__init__()
    self.gene_pos_enc = gene_pos_enc
    max_len = max(gene_pos_enc)
    position = torch.arange(max_len).unsqueeze(1)
    token_to_pos = {token: pos for token, pos in enumerate(gene_pos_enc)}

    # Create a dictionary to convert token to position

    div_term = torch.exp(
        torch.arange(0, d_model, 2) * (-math.log(float(10_000)) / d_model)
    )
    pe = torch.zeros(max_len, 1, d_model)
    pe[:, 0, 0::2] = torch.sin(position * div_term)
    pe[:, 0, 1::2] = torch.cos(position * div_term)
    # we reorder them and map them to gene_id (position)
    arr = []
    for _, v in token_to_pos.items():
        arr.append(pe[v - 1].numpy())
    pe = torch.Tensor(np.array(arr))
    # Remove the unnecessary middle dimension since pe should be [m, d]
    # pe = pe.squeeze(1)
    self.register_buffer("pe", pe)

forward

Parameters:
  • gene_pos (Tensor) –

    Gene position indices, shape [seq_len, batch_size] or [seq_len]

Returns:
  • Tensor( Tensor ) –

    Positional encodings, shape [*gene_pos.shape, embedding_dim]

Source code in scprint2/model/encoders.py
169
170
171
172
173
174
175
176
177
178
179
def forward(self, gene_pos: Tensor) -> Tensor:
    """
    Args:
        gene_pos (Tensor): Gene position indices, shape [seq_len, batch_size] or [seq_len]

    Returns:
        Tensor: Positional encodings, shape [*gene_pos.shape, embedding_dim]
    """
    return torch.index_select(self.pe, 0, gene_pos.reshape(-1)).reshape(
        gene_pos.shape + (-1,)
    )

scprint2.model.decoders

Classes:

Name Description
ClsDecoder
ExprDecoder
GraphSDEExprDecoder
MVCDecoder
VAEDecoder

ClsDecoder

Bases: Module

ClsDecoder Decoder for classification task.

Parameters:
  • d_model (int) –

    Dimension of the input.

  • n_cls (int) –

    Number of classes.

  • layers (List[int], default: [256, 128] ) –

    List of hidden layers.

  • activation (Callable, default: ReLU ) –

    Activation function.

  • dropout (float, default: 0.1 ) –

    Dropout rate.

Methods:

Name Description
forward

Args:

Source code in scprint2/model/decoders.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def __init__(
    self,
    d_model: int,
    n_cls: int,
    layers: List[int] = [256, 128],
    activation: Callable = nn.ReLU,
    dropout: float = 0.1,
):
    """
    ClsDecoder Decoder for classification task.

    Args:
        d_model (int): Dimension of the input.
        n_cls (int): Number of classes.
        layers (List[int]): List of hidden layers.
        activation (Callable): Activation function.
        dropout (float): Dropout rate.
    """
    super(ClsDecoder, self).__init__()
    # module List
    layers = [d_model] + layers
    self.decoder = nn.Sequential()
    self.n_cls = n_cls
    for i, l in enumerate(layers[1:]):
        self.decoder.append(nn.Linear(layers[i], l))
        self.decoder.append(nn.LayerNorm(l))
        self.decoder.append(activation())
        self.decoder.append(nn.Dropout(dropout))
    self.out_layer = nn.Linear(layers[-1], n_cls)

forward

Parameters:
  • x (Tensor) –

    Tensor, shape [batch_size, embsize]

Source code in scprint2/model/decoders.py
259
260
261
262
263
264
265
def forward(self, x: Tensor) -> Tensor:
    """
    Args:
        x: Tensor, shape [batch_size, embsize]
    """
    x = self.decoder(x)
    return self.out_layer(x)

ExprDecoder

Bases: Module

ExprDecoder Decoder for the gene expression prediction.

Will output the mean, variance and zero logits, parameters of a zero inflated negative binomial distribution.

Parameters:
  • d_model (int) –

    The dimension of the model. This is the size of the input feature vector.

  • nfirst_tokens_to_skip (int, default: 0 ) –

    The number of initial labels to skip in the sequence. Defaults to 0.

  • dropout (float, default: 0.1 ) –

    The dropout rate applied during training to prevent overfitting. Defaults to 0.1.

  • zinb (bool, default: True ) –

    Whether to use a zero inflated negative binomial distribution. Defaults to True.

  • use_depth (bool, default: False ) –

    Whether to use depth as an additional feature. Defaults to False.

Methods:

Name Description
forward

x is the output of the transformer, (batch, seq_len, d_model)

Source code in scprint2/model/decoders.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    d_model: int,
    nfirst_tokens_to_skip: int = 0,
    dropout: float = 0.1,
    zinb: bool = True,
    use_depth: bool = False,
):
    """
    ExprDecoder Decoder for the gene expression prediction.

    Will output the mean, variance and zero logits, parameters of a zero inflated negative binomial distribution.

    Args:
        d_model (int): The dimension of the model. This is the size of the input feature vector.
        nfirst_tokens_to_skip (int, optional): The number of initial labels to skip in the sequence. Defaults to 0.
        dropout (float, optional): The dropout rate applied during training to prevent overfitting. Defaults to 0.1.
        zinb (bool, optional): Whether to use a zero inflated negative binomial distribution. Defaults to True.
        use_depth (bool, optional): Whether to use depth as an additional feature. Defaults to False.
    """
    super(ExprDecoder, self).__init__()
    self.fc = nn.Sequential(
        nn.Linear(d_model if not use_depth else d_model + 1, d_model),
        nn.LayerNorm(d_model),
        nn.LeakyReLU(),
        nn.Dropout(dropout),
        nn.Linear(d_model, d_model),
        nn.LayerNorm(d_model),
        nn.LeakyReLU(),
    )
    self.pred_var_zero = nn.Linear(d_model, 3 if zinb else 1)
    self.zinb = zinb

forward

x is the output of the transformer, (batch, seq_len, d_model)

Source code in scprint2/model/decoders.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def forward(
    self, x: Tensor, req_depth: Optional[Tensor] = None
) -> Dict[str, Tensor]:
    """x is the output of the transformer, (batch, seq_len, d_model)"""
    # we don't do it on the labels
    if req_depth is not None:
        x = torch.cat(
            [x, req_depth.unsqueeze(1).unsqueeze(-1).expand(-1, x.shape[1], -1)],
            dim=-1,
        )
    x = self.fc(x)
    if self.zinb:
        pred_value, var_value, zero_logits = self.pred_var_zero(x).split(
            1, dim=-1
        )  # (batch, seq_len)
        # The sigmoid function is used to map the zero_logits to a probability between 0 and 1.
        return dict(
            mean=F.softmax(pred_value.squeeze(-1), dim=-1),
            disp=torch.exp(torch.clamp(var_value.squeeze(-1), max=15)),
            zero_logits=zero_logits.squeeze(-1),
        )
    else:
        pred_value = self.pred_var_zero(x)
        return dict(mean=F.softmax(pred_value.squeeze(-1), dim=-1))

GraphSDEExprDecoder

Bases: Module

Initialize the ExprNeuralSDEDecoder module.

Parameters:
  • d_model (int) –

    The dimension of the model.

  • drift (Module) –

    The drift component of the SDE.

  • diffusion (Module) –

    The diffusion component of the SDE.

Source code in scprint2/model/decoders.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
def __init__(self, d_model: int, drift: nn.Module, diffusion: nn.Module):
    """
    Initialize the ExprNeuralSDEDecoder module.

    Args:
        d_model (int): The dimension of the model.
        drift (nn.Module): The drift component of the SDE.
        diffusion (nn.Module): The diffusion component of the SDE.
    """
    super().__init__()
    self.d_model = d_model
    self.drift = drift
    self.diffusion = diffusion

MVCDecoder

Bases: Module

MVCDecoder Decoder for masked value prediction of cell embeddings.

Uses gene embeddings with cell embeddings to predict mean, variance, and zero logits parameters of a zero-inflated negative binomial distribution.

Parameters:
  • d_model (int) –

    Dimension of the gene embedding.

  • arch_style (str, default: 'inner product' ) –

    Architecture style of the decoder. Options: "inner product": Uses inner product between cell and gene embeddings, "concat query": Concatenates cell and gene embeddings, "sum query": Sums cell and gene embeddings. Defaults to "inner product".

  • tot_labels (int, default: 1 ) –

    Total number of labels in the input. Defaults to 1.

  • query_activation (Module, default: Sigmoid ) –

    Activation function for query vectors. Defaults to nn.Sigmoid.

  • hidden_activation (Module, default: PReLU ) –

    Activation function for hidden layers. Defaults to nn.PReLU.

  • use_depth (bool, default: False ) –

    Whether to use depth as an additional feature. Defaults to False.

  • zinb (bool, default: True ) –

    Whether to use a zero-inflated negative binomial distribution. Defaults to True.

Methods:

Name Description
forward

Args:

Source code in scprint2/model/decoders.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(
    self,
    d_model: int,
    arch_style: str = "inner product",
    tot_labels: int = 1,
    query_activation: nn.Module = nn.Sigmoid,
    hidden_activation: nn.Module = nn.PReLU,
    use_depth: bool = False,
    zinb: bool = True,
) -> None:
    """
    MVCDecoder Decoder for masked value prediction of cell embeddings.

    Uses gene embeddings with cell embeddings to predict mean, variance, and zero logits
    parameters of a zero-inflated negative binomial distribution.

    Args:
        d_model (int): Dimension of the gene embedding.
        arch_style (str, optional): Architecture style of the decoder. Options:
            "inner product": Uses inner product between cell and gene embeddings,
            "concat query": Concatenates cell and gene embeddings,
            "sum query": Sums cell and gene embeddings.
            Defaults to "inner product".
        tot_labels (int, optional): Total number of labels in the input. Defaults to 1.
        query_activation (nn.Module, optional): Activation function for query vectors. Defaults to nn.Sigmoid.
        hidden_activation (nn.Module, optional): Activation function for hidden layers. Defaults to nn.PReLU.
        use_depth (bool, optional): Whether to use depth as an additional feature. Defaults to False.
        zinb (bool, optional): Whether to use a zero-inflated negative binomial distribution. Defaults to True.
    """
    super(MVCDecoder, self).__init__()
    if arch_style == "inner product":
        self.gene2query = nn.Linear(
            d_model if not use_depth else d_model + 1, d_model
        )
        self.norm = nn.LayerNorm(d_model)
        self.query_activation = query_activation()
        self.pred_var_zero = nn.Linear(
            d_model, d_model * (3 if zinb else 1), bias=False
        )
    elif arch_style == "concat query":
        self.gene2query = nn.Linear(
            d_model if not use_depth else d_model + 1, d_model
        )
        self.query_activation = query_activation()
        self.fc1 = nn.Linear(d_model * (1 + tot_labels), d_model // 2)
        self.hidden_activation = hidden_activation()
        self.fc2 = nn.Linear(d_model // 2, (3 if zinb else 1))
    elif arch_style == "sum query":
        self.gene2query = nn.Linear(
            d_model if not use_depth else d_model + 1, d_model
        )
        self.query_activation = query_activation()
        self.fc1 = nn.Linear(d_model, 64)
        self.hidden_activation = hidden_activation()
        self.fc2 = nn.Linear(64, (3 if zinb else 1))
    else:
        raise ValueError(f"Unknown arch_style: {arch_style}")

    self.arch_style = arch_style
    self.do_detach = arch_style.endswith("detach")
    self.d_model = d_model
    self.zinb = zinb

forward

Parameters:
  • cell_emb (Tensor) –

    Tensor, shape (batch, embsize=d_model)

  • gene_embs (Tensor) –

    Tensor, shape (batch, seq_len, embsize=d_model)

  • req_depth (Optional[Tensor], default: None ) –

    Tensor, shape (batch,), optional depth information.

Returns:
  • Union[Tensor, Dict[str, Tensor]]

    Dict[str, Tensor]: A dictionary containing the predicted mean, variance, and zero logits (if zinb is True).

Source code in scprint2/model/decoders.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def forward(
    self,
    cell_emb: Tensor,
    gene_embs: Tensor,
    req_depth: Optional[Tensor] = None,
) -> Union[Tensor, Dict[str, Tensor]]:
    """
    Args:
        cell_emb: Tensor, shape (batch, embsize=d_model)
        gene_embs: Tensor, shape (batch, seq_len, embsize=d_model)
        req_depth: Tensor, shape (batch,), optional depth information.

    Returns:
        Dict[str, Tensor]: A dictionary containing the predicted mean, variance, and zero logits (if zinb is True).
    """
    if req_depth is not None:
        gene_embs = torch.cat(
            [
                gene_embs,
                req_depth.unsqueeze(1)
                .unsqueeze(-1)
                .expand(-1, gene_embs.shape[1], -1),
            ],
            dim=-1,
        )
    if self.arch_style == "inner product":
        query_vecs = self.query_activation(self.norm(self.gene2query(gene_embs)))
        if self.zinb:
            pred, var, zero_logits = self.pred_var_zero(query_vecs).split(
                self.d_model, dim=-1
            )
        else:
            pred = self.pred_var_zero(query_vecs)
        cell_emb = cell_emb.unsqueeze(2)
        if self.zinb:
            pred, var, zero_logits = (
                torch.bmm(pred, cell_emb).squeeze(2),
                torch.bmm(var, cell_emb).squeeze(2),
                torch.bmm(zero_logits, cell_emb).squeeze(2),
            )
        else:
            pred = torch.bmm(pred, cell_emb).squeeze(2)
        # zero logits need to based on the cell_emb, because of input exprs
    elif self.arch_style == "concat query":
        query_vecs = self.query_activation(self.gene2query(gene_embs))
        # expand cell_emb to (batch, seq_len, embsize)
        cell_emb = cell_emb.unsqueeze(1).expand(-1, gene_embs.shape[1], -1)

        h = self.hidden_activation(
            self.fc1(torch.cat([cell_emb, query_vecs], dim=2))
        )
        if self.zinb:
            pred, var, zero_logits = self.fc2(h).split(1, dim=-1)
        else:
            pred = self.fc2(h)
    elif self.arch_style == "sum query":
        query_vecs = self.query_activation(self.gene2query(gene_embs))
        cell_emb = cell_emb.unsqueeze(1)

        h = self.hidden_activation(self.fc1(cell_emb + query_vecs))
        if self.zinb:
            pred, var, zero_logits = self.fc2(h).split(1, dim=-1)
        else:
            pred = self.fc2(h)
    if self.zinb:
        return dict(
            mvc_mean=F.softmax(pred, dim=-1),
            mvc_disp=torch.exp(torch.clamp(var, max=15)),
            mvc_zero_logits=zero_logits,
        )
    else:
        return dict(mvc_mean=F.softmax(pred, dim=-1))

VAEDecoder

Bases: Module

VAEDecoder for variational autoencoding of cell embeddings.

Parameters:
  • d_model (int) –

    Input dimension (original embedding size)

  • layers (List[int], default: [64, 64] ) –

    List of hidden layer sizes for encoder and decoder

  • activation (Callable, default: ReLU ) –

    Activation function to use

  • dropout (float, default: 0.1 ) –

    Dropout rate

  • return_latent (bool, default: False ) –

    Whether to return the latent vectors

Methods:

Name Description
forward

Forward pass through VAE.

kl_divergence

Compute KL divergence between N(mu, var) and N(0, 1).

reparameterize

Reparameterization trick to sample from N(mu, var) from N(0,1).

Source code in scprint2/model/decoders.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def __init__(
    self,
    d_model: int,
    layers: List[int] = [64, 64],
    activation: Callable = nn.ReLU,
    dropout: float = 0.1,
    return_latent: bool = False,
):
    """
    VAEDecoder for variational autoencoding of cell embeddings.

    Args:
        d_model (int): Input dimension (original embedding size)
        layers (List[int]): List of hidden layer sizes for encoder and decoder
        activation (Callable): Activation function to use
        dropout (float): Dropout rate
        return_latent (bool): Whether to return the latent vectors
    """
    super(VAEDecoder, self).__init__()

    # Encoder layers
    self.return_latent = return_latent
    encoder_layers = [d_model] + layers
    self.encoder = nn.Sequential()
    for i, (in_size, out_size) in enumerate(
        zip(encoder_layers[:-1], encoder_layers[1:])
    ):
        self.encoder.append(nn.Linear(in_size, out_size))
        self.encoder.append(nn.LayerNorm(out_size))
        self.encoder.append(activation())
        self.encoder.append(nn.Dropout(dropout))

    # VAE latent parameters
    self.fc_mu = nn.Linear(encoder_layers[-1], encoder_layers[-1])
    self.fc_var = nn.Linear(encoder_layers[-1], encoder_layers[-1])

    # Decoder layers
    decoder_layers = [encoder_layers[-1]] + list(reversed(layers[:-1])) + [d_model]
    self.decoder = nn.Sequential()
    for i, (in_size, out_size) in enumerate(
        zip(
            decoder_layers[:-1], decoder_layers[1:]
        )  # Changed to include final layer
    ):
        self.decoder.append(nn.Linear(in_size, out_size))
        if (
            i < len(decoder_layers) - 2
        ):  # Don't apply activation/norm to final layer
            self.decoder.append(nn.LayerNorm(out_size))
            self.decoder.append(activation())
            self.decoder.append(nn.Dropout(dropout))

forward

Forward pass through VAE.

Parameters:
  • x (Tensor) –

    Input tensor of shape [batch_size, d_model]

Returns:
  • Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]

    If self.return_latent is True: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: - reconstructed_x (Tensor): Reconstructed input, shape [batch_size, d_model] - mu (Tensor): Mean of the latent Gaussian, shape [batch_size, latent_dim] - log_var (Tensor): Log variance of the latent Gaussian, shape [batch_size, latent_dim] - kl_loss (Tensor): KL divergence loss (scalar tensor)

  • Else( Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]] ) –

    Tensor: reconstructed_x of shape [batch_size, d_model]

Source code in scprint2/model/decoders.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def forward(
    self, x: Tensor
) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
    """
    Forward pass through VAE.

    Args:
        x (Tensor): Input tensor of shape [batch_size, d_model]

    Returns:
        If self.return_latent is True:
            Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
                - reconstructed_x (Tensor): Reconstructed input, shape [batch_size, d_model]
                - mu (Tensor): Mean of the latent Gaussian, shape [batch_size, latent_dim]
                - log_var (Tensor): Log variance of the latent Gaussian, shape [batch_size, latent_dim]
                - kl_loss (Tensor): KL divergence loss (scalar tensor)
        Else:
            Tensor: reconstructed_x of shape [batch_size, d_model]
    """
    # Encode
    encoded = self.encoder(x)

    # Get latent parameters
    mu = self.fc_mu(encoded)
    log_var = self.fc_var(encoded)
    log_var = torch.clamp(log_var, min=-10)

    # Sample latent vector
    kl_loss = self.kl_divergence(mu, log_var)
    # free_bits = 2.0  # per latent dim
    # kl_loss = torch.clamp(kl_loss / mu.size(-1), min=free_bits) * mu.size(-1)
    z = self.reparameterize(mu, log_var)

    # Decode
    decoded = self.decoder(z)

    if self.return_latent:
        return decoded, mu, log_var, encoded, kl_loss
    return decoded, kl_loss

kl_divergence

Compute KL divergence between N(mu, var) and N(0, 1).

Parameters:
  • mu (Tensor) –

    Mean of the latent Gaussian

  • log_var (Tensor) –

    Log variance of the latent Gaussian

Returns:
  • Tensor( Tensor ) –

    KL divergence loss

Source code in scprint2/model/decoders.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def kl_divergence(self, mu: Tensor, log_var: Tensor) -> Tensor:
    """
    Compute KL divergence between N(mu, var) and N(0, 1).

    Args:
        mu (Tensor): Mean of the latent Gaussian
        log_var (Tensor): Log variance of the latent Gaussian

    Returns:
        Tensor: KL divergence loss
    """
    # KL(N(mu, var) || N(0, 1)) = -0.5 * sum(1 + log(var) - mu^2 - var)
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
    return kl_loss.mean()

reparameterize

Reparameterization trick to sample from N(mu, var) from N(0,1).

Parameters:
  • mu (Tensor) –

    Mean of the latent Gaussian

  • log_var (Tensor) –

    Log variance of the latent Gaussian

Returns:
  • Tensor( Tensor ) –

    Sampled latent vector

Source code in scprint2/model/decoders.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor:
    """
    Reparameterization trick to sample from N(mu, var) from N(0,1).

    Args:
        mu (Tensor): Mean of the latent Gaussian
        log_var (Tensor): Log variance of the latent Gaussian

    Returns:
        Tensor: Sampled latent vector
    """
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

scprint2.model.fsq

Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 Code adapted from Jax version in Appendix A.1

Classes:

Name Description
FSQ

Functions:

Name Description
round_ste

Round with straight through gradients.

FSQ

Bases: Module

Methods:

Name Description
bound

Bound z, an array of shape (..., d).

codes_to_indices

Converts a code to an index in the codebook.

forward

einstein notation

indices_to_codes

Inverse of codes_to_indices.

quantize

Quantizes z, returns quantized zhat, same shape as z.

Source code in scprint2/model/fsq.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    levels: List[int],
    dim: Optional[int] = None,
    num_codebooks=1,
    keep_num_codebooks_dim: Optional[bool] = None,
    scale: Optional[float] = None,
):
    super().__init__()
    _levels = torch.tensor(levels, dtype=int32)
    self.register_buffer("_levels", _levels, persistent=False)

    _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
    self.register_buffer("_basis", _basis, persistent=False)

    self.scale = scale

    codebook_dim = len(levels)
    self.codebook_dim = codebook_dim

    effective_codebook_dim = codebook_dim * num_codebooks
    self.num_codebooks = num_codebooks
    self.effective_codebook_dim = effective_codebook_dim

    keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
    assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
    self.keep_num_codebooks_dim = keep_num_codebooks_dim

    self.dim = default(dim, len(_levels) * num_codebooks)

    has_projections = self.dim != effective_codebook_dim
    self.project_in = (
        nn.Linear(self.dim, effective_codebook_dim)
        if has_projections
        else nn.Identity()
    )
    self.project_out = (
        nn.Linear(effective_codebook_dim, self.dim)
        if has_projections
        else nn.Identity()
    )
    self.has_projections = has_projections

    self.codebook_size = self._levels.prod().item()

    implicit_codebook = self.indices_to_codes(
        torch.arange(self.codebook_size), project_out=False
    )
    self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)

bound

Bound z, an array of shape (..., d).

Source code in scprint2/model/fsq.py
 99
100
101
102
103
104
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
    """Bound `z`, an array of shape (..., d)."""
    half_l = (self._levels - 1) * (1 - eps) / 2
    offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
    shift = (offset / half_l).tan()
    return (z + shift).tanh() * half_l - offset

codes_to_indices

Converts a code to an index in the codebook.

Source code in scprint2/model/fsq.py
120
121
122
123
124
def codes_to_indices(self, zhat: Tensor) -> Tensor:
    """Converts a `code` to an index in the codebook."""
    assert zhat.shape[-1] == self.codebook_dim
    zhat = self._scale_and_shift(zhat)
    return (zhat * self._basis).sum(dim=-1).to(int32)

forward

einstein notation b - batch n - sequence (or flattened spatial dimensions) d - feature dimension, which is also log2(codebook size) c - number of codebook dim

Source code in scprint2/model/fsq.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def forward(self, z: Tensor) -> Tensor:
    """
    einstein notation
    b - batch
    n - sequence (or flattened spatial dimensions)
    d - feature dimension, which is also log2(codebook size)
    c - number of codebook dim
    """
    assert z.shape[-1] == self.dim, (
        f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
    )
    small = self.project_in(z)

    z = rearrange(small, "b (c d) -> b c d", c=self.num_codebooks)

    codes = self.quantize(z)
    indices = self.codes_to_indices(codes)

    codes = rearrange(codes, "b c d -> b (c d)")

    out = self.project_out(codes)

    if not self.keep_num_codebooks_dim:
        indices = rearrange(indices, "... 1 -> ...")

    return out, indices, small

indices_to_codes

Inverse of codes_to_indices.

Source code in scprint2/model/fsq.py
126
127
128
129
130
131
132
133
134
135
136
137
def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor:
    """Inverse of `codes_to_indices`."""
    indices = rearrange(indices, "... -> ... 1")
    codes_non_centered = (indices // self._basis) % self._levels
    codes = self._scale_and_shift_inverse(codes_non_centered)

    if self.keep_num_codebooks_dim:
        codes = rearrange(codes, "... c d -> ... (c d)")

    if project_out:
        codes = self.project_out(codes)
    return codes

quantize

Quantizes z, returns quantized zhat, same shape as z.

Source code in scprint2/model/fsq.py
106
107
108
109
110
def quantize(self, z: Tensor) -> Tensor:
    """Quantizes z, returns quantized zhat, same shape as z."""
    quantized = round_ste(self.bound(z))
    half_width = self._levels // 2  # Renormalize to [-1, 1].
    return quantized / half_width

round_ste

Round with straight through gradients.

Source code in scprint2/model/fsq.py
39
40
41
42
def round_ste(z: Tensor) -> Tensor:
    """Round with straight through gradients."""
    zhat = z.round()
    return z + (zhat - z).detach()