Documentation for the model

model description

scprint.model.model

scPrint

Bases: LightningModule, PyTorchModelHubMixin

scPRINT transformer for single cell biology and the inference of Gene Regulatory networks

Parameters:
  • genes (list) –

    List of gene names the model will work with.

  • precpt_gene_emb (array, default: None ) –

    Gene embeddings of size (len(genes), d_model). Should be in the same order as the genes. Defaults to None.

  • gene_pos_enc (list, default: None ) –

    Gene position encoding of the same size as genes. Provides a location value for each gene in genes. Defaults to None.

  • d_model (int, default: 512 ) –

    Dimension of the model. Defaults to 512.

  • nhead (int, default: 8 ) –

    Number of heads in the multihead attention models. Defaults to 8.

  • d_hid (int, default: 512 ) –

    Dimension of the feedforward network model. Defaults to 512.

  • nlayers (int, default: 6 ) –

    Number of layers in the transformer model. Defaults to 6.

  • expr_encoder_layers (int, default: 2 ) –

    Number of layers in the expression encoder. Defaults to 2.

  • layers_cls (list[int], default: [] ) –

    List specifying the number of layers in the classifier. Defaults to [].

  • classes (Dict[str, int], default: {} ) –

    Classes to predict with the number of classes for each. Defaults to {}.

  • labels_hierarchy (Dict[str, Dict[int, list[int]]], default: {} ) –

    Class hierarchy for classes with hierarchical classes. Defaults to {}.

  • dropout (float, default: 0.2 ) –

    Dropout value. Defaults to 0.2.

  • transformer (str, default: 'fast' ) –

    Transformer type to use. One of "linear", "flash", "flashsparse", "scprint". Defaults to "fast".

  • domain_spec_batchnorm (str, default: 'None' ) –

    Whether to apply domain-specific batch normalization. Defaults to "None".

  • expr_emb_style (str, default: 'continuous' ) –

    Style of input embedding. One of "continuous", "binned_pos", "cont_pos". Defaults to "continuous".

  • mvc_decoder (str, default: 'None' ) –

    Style of MVC decoder. One of "None", "inner product", "concat query", "sum query". Defaults to "None".

  • pred_embedding (list[str], default: [] ) –

    List of classes to use for plotting embeddings. Defaults to [].

  • cell_emb_style (str, default: 'cls' ) –

    Style of cell embedding. One of "cls", "avg-pool", "w-pool". Defaults to "cls".

  • freeze_embeddings (bool, default: True ) –

    Whether to freeze the embeddings during training. Defaults to True.

  • label_decoders (Optional[Dict[str, Dict[int, str]]], default: None ) –

    Label decoders to use for plotting the UMAP during validations. Defaults to None.

  • zinb (bool, default: True ) –

    Whether to use Zero-Inflated Negative Binomial distribution. Defaults to True.

  • lr (float, default: 0.0001 ) –

    Learning rate. Defaults to 0.0001.

  • optim (str, default: 'adamW' ) –

    Optimizer type. Defaults to "adamW".

  • weight_decay (float, default: 0.01 ) –

    Weight decay for the optimizer. Defaults to 0.01.

  • **flash_attention_kwargs (dict, default: {} ) –

    Additional keyword arguments for the model. see @flashformer.py

Notes

for other parameters of the model that are not part of its class definition, see @trainer.trainer.py

Raises:
  • ValueError

    If the expr_emb_style is not one of "continuous", "binned_pos", "cont_pos".

Source code in scprint/model/model.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def __init__(
    self,
    genes: list,
    organisms: list = ["NCBITaxon:9606"],
    precpt_gene_emb: Optional[str] = None,
    gene_pos_enc: Optional[list] = None,
    normalization: str = "sum",
    d_model: int = 512,
    nhead: int = 8,
    attn_bias: str = "none",
    d_hid: int = 512,
    edge_dim: int = 12,
    nlayers: int = 6,
    expr_encoder_layers: int = 2,
    layers_cls: list[int] = [],
    classes: Dict[str, int] = {},
    labels_hierarchy: Dict[str, Dict[int, list[int]]] = {},
    dropout: float = 0.2,
    transformer: str = "fast",
    expr_emb_style: str = "continuous",  # "binned_pos", "cont_pos"
    domain_spec_batchnorm: str = "None",
    n_input_bins: int = 0,
    num_batch_labels: int = 0,
    mvc_decoder: str = "None",
    pred_embedding: list[str] = [],
    cell_emb_style: str = "cls",
    freeze_embeddings: bool = True,
    label_decoders: Optional[Dict[str, Dict[int, str]]] = None,
    zinb: bool = True,
    lr: float = 0.0001,
    optim="adamW",  # TODEL
    weight_decay=0.01,  # TODEL
    **flash_attention_kwargs,
):
    """
    scPRINT transformer for single cell biology and the inference of Gene Regulatory networks

    Args:
        genes (list): List of gene names the model will work with.
        precpt_gene_emb (np.array, optional): Gene embeddings of size (len(genes), d_model). Should be in the same order as the genes. Defaults to None.
        gene_pos_enc (list, optional): Gene position encoding of the same size as genes. Provides a location value for each gene in genes. Defaults to None.
        d_model (int, optional): Dimension of the model. Defaults to 512.
        nhead (int, optional): Number of heads in the multihead attention models. Defaults to 8.
        d_hid (int, optional): Dimension of the feedforward network model. Defaults to 512.
        nlayers (int, optional): Number of layers in the transformer model. Defaults to 6.
        expr_encoder_layers (int, optional): Number of layers in the expression encoder. Defaults to 2.
        layers_cls (list[int], optional): List specifying the number of layers in the classifier. Defaults to [].
        classes (Dict[str, int], optional): Classes to predict with the number of classes for each. Defaults to {}.
        labels_hierarchy (Dict[str, Dict[int, list[int]]], optional): Class hierarchy for classes with hierarchical classes. Defaults to {}.
        dropout (float, optional): Dropout value. Defaults to 0.2.
        transformer (str, optional): Transformer type to use. One of "linear", "flash", "flashsparse", "scprint". Defaults to "fast".
        domain_spec_batchnorm (str, optional): Whether to apply domain-specific batch normalization. Defaults to "None".
        expr_emb_style (str, optional): Style of input embedding. One of "continuous", "binned_pos", "cont_pos". Defaults to "continuous".
        mvc_decoder (str, optional): Style of MVC decoder. One of "None", "inner product", "concat query", "sum query". Defaults to "None".
        pred_embedding (list[str], optional): List of classes to use for plotting embeddings. Defaults to [].
        cell_emb_style (str, optional): Style of cell embedding. One of "cls", "avg-pool", "w-pool". Defaults to "cls".
        freeze_embeddings (bool, optional): Whether to freeze the embeddings during training. Defaults to True.
        label_decoders (Optional[Dict[str, Dict[int, str]]], optional): Label decoders to use for plotting the UMAP during validations. Defaults to None.
        zinb (bool, optional): Whether to use Zero-Inflated Negative Binomial distribution. Defaults to True.
        lr (float, optional): Learning rate. Defaults to 0.0001.
        optim (str, optional): Optimizer type. Defaults to "adamW".
        weight_decay (float, optional): Weight decay for the optimizer. Defaults to 0.01.
        **flash_attention_kwargs (dict): Additional keyword arguments for the model. see @flashformer.py

    Notes:
        for other parameters of the model that are not part of its class definition, see @trainer.trainer.py

    Raises:
        ValueError: If the expr_emb_style is not one of "continuous", "binned_pos", "cont_pos".
    """
    super().__init__()
    # training flags
    self.do_denoise = True
    self.noise = [0.6]
    self.do_cce = False
    self.cce_sim = 0.5
    self.cce_scale = 0.002
    self.do_ecs = False
    self.ecs_threshold = 0.3
    self.ecs_scale = 0.05
    self.do_mvc = False
    self.mvc_scale = 1.0
    self.class_embd_diss_scale = 0.2
    self.do_adv_cls = False
    self.adv_class_scale = 0.1
    self.do_cls = False
    self.mean_attn_tot = None
    self.mean_attn_tot_c = 0
    self.do_adv_batch = False
    self.run_full_forward = True
    self.class_scale = 0.4
    self.do_next_tp = False
    self.do_generate = False
    self.mask_ratio = []
    self.warmup_duration = 500
    self.weight_decay = 0.01
    self.optim = "adamW"
    self.fused_adam = False
    self.lr_reduce_patience = 1
    self.lr_reduce_factor = 0.6
    self.lr_reduce_monitor = "val_loss"
    self.name = ""
    self.lr = lr
    self.lrfinder_steps = 0
    self.doplot = True
    self.get_attention_layer = []
    self.embs = None
    self.pred_log_adata = True
    self.attn = utils.Attention(len(classes) + 2 + len(genes))
    self.predict_depth_mult = 3
    self.predict_mode = "none"
    self.keep_all_cls_pred = False
    # should be stored somehow
    self.d_model = d_model
    self.normalization = normalization
    self.organisms = organisms
    self.edge_dim = edge_dim
    self.attn_bias = attn_bias
    self.nlayers = nlayers
    self.gene_pos_enc = gene_pos_enc
    self.mvc_decoder = mvc_decoder
    self.domain_spec_batchnorm = domain_spec_batchnorm
    # need to store
    self.n_input_bins = n_input_bins
    self.transformer = transformer
    self.label_counts = classes
    self.classes = list(classes.keys())
    self.cell_emb_style = cell_emb_style
    self.label_decoders = label_decoders
    self.pred_embedding = pred_embedding
    # compute tensor for mat_labels_hierarchy
    self.mat_labels_hierarchy = {}
    self.labels_hierarchy = labels_hierarchy
    if "strict_loading" in flash_attention_kwargs:
        flash_attention_kwargs.pop("strict_loading")

    for k, v in labels_hierarchy.items():
        tens = torch.zeros((len(v), classes[k]))
        for k2, v2 in v.items():
            tens[k2 - classes[k], v2] = 1
        self.mat_labels_hierarchy[k] = tens.to(bool)
    self.expr_emb_style = expr_emb_style

    if self.expr_emb_style not in ["category", "continuous", "none"]:
        raise ValueError(
            f"expr_emb_style should be one of category, continuous, scaling, "
            f"got {expr_emb_style}"
        )
    if cell_emb_style not in ["cls", "avg-pool", "w-pool"]:
        raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}")

    self.genes = genes
    self.vocab = {i: n for i, n in enumerate(genes)}

    # encoder
    # gene encoder
    if precpt_gene_emb is not None:
        embeddings = pd.read_parquet(precpt_gene_emb).loc[self.genes]
        if len(embeddings) == 0:
            raise ValueError(
                f"the gene embeddings file {precpt_gene_emb} does not contain any of the genes given to the model"
            )
        elif len(embeddings) < len(self.genes):
            print(
                "Warning: only a subset of the genes available in the embeddings file."
            )
            print("number of genes: ", len(embeddings))
        sembeddings = torch.nn.AdaptiveAvgPool1d(d_model)(
            torch.tensor(embeddings.values)
        )

        self.gene_encoder = encoders.GeneEncoder(
            len(self.vocab), d_model, weights=sembeddings, freeze=freeze_embeddings
        )
    else:
        self.gene_encoder = encoders.GeneEncoder(len(self.vocab), d_model)

    # Value Encoder, NOTE: the scaling style is also handled in _encode method
    if expr_emb_style in ["continuous", "full_pos"]:
        self.expr_encoder = encoders.ContinuousValueEncoder(
            d_model, dropout, layers=expr_encoder_layers
        )
    elif expr_emb_style == "binned_pos":
        assert n_input_bins > 0
        self.expr_encoder = encoders.CategoryValueEncoder(n_input_bins, d_model)
    else:
        self.expr_encoder = torch.nn.Identity()

    # Positional Encoding
    if self.gene_pos_enc is not None:
        max_len = max(gene_pos_enc)
        token_to_pos = {token: pos for token, pos in enumerate(self.gene_pos_enc)}
        self.pos_encoder = encoders.PositionalEncoding(
            d_model, max_len=max_len, token_to_pos=token_to_pos
        )

    self.cell_embs_count = len(self.classes) + 2
    # Class Encoder
    # always have [base_cell_emb, time_embedding, depth_embedding] + any other class info
    # base cell embedding will store other cell specific information
    self.class_encoder = encoders.CategoryValueEncoder(
        self.cell_embs_count - 1, d_model
    )
    # self.time_encoder = encoders.ContinuousValueEncoder(d_model, dropout)
    self.depth_encoder = encoders.ContinuousValueEncoder(
        d_model, dropout, layers=expr_encoder_layers
    )

    # Transformer
    # Linear
    if transformer == "linear":
        # linear transformer using the fast transformer package
        # self.transformer = FastTransformerEncoder(
        #    d_model, nhead, d_hid, nlayers, dropout, "linear"
        # )
        raise NotImplementedError("Linear transformer is not implemented")
    # regular or flash
    else:
        self.transformer = FlashTransformerEncoder(
            d_model,
            nhead,
            nlayers,
            dropout=dropout,
            use_flash_attn=(transformer == "flash"),
            **flash_attention_kwargs,
        )

    # decoders
    # expression
    self.expr_decoder = decoders.ExprDecoder(
        d_model,
        nfirst_tokens_to_skip=self.cell_embs_count,
        dropout=dropout,
        zinb=zinb,
    )
    # cls decoder
    self.cls_decoders = torch.nn.ModuleDict()
    # should be a very simple classifier for most things
    # (maybe scale with the number of classes) should be 1 layer...
    for clss, n_cls in classes.items():
        self.cls_decoders[clss] = decoders.ClsDecoder(
            d_model, n_cls, layers=layers_cls, dropout=dropout
        )

    # Batch effect correction via adversarial training on batch classes
    if num_batch_labels > 0:
        self.grad_reverse_discriminator_loss = loss.AdversarialDiscriminatorLoss(
            d_model,
            n_cls=num_batch_labels,
        )
    else:
        self.grad_reverse_discriminator_loss = None

    # expression decoder from batch embbedding
    if mvc_decoder != "None":
        self.mvc_decoder = decoders.MVCDecoder(
            d_model,
            arch_style=mvc_decoder,
        )
    else:
        self.mvc_decoder = None

    self.apply(
        partial(
            utils._init_weights,
            n_layer=nlayers,
        )
    )
    for i, dec in self.cls_decoders.items():
        torch.nn.init.constant_(dec.out_layer.bias, -0.13)
    self.save_hyperparameters()

configure_optimizers

@see pl.LightningModule

Source code in scprint/model/model.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
def configure_optimizers(self):
    """@see pl.LightningModule"""
    # https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
    # not working because of poor weight decay implem
    if self.optim == "adam":
        optimizer = optim.Adam(
            self.parameters(),
            lr=self.hparams.lr,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=self.weight_decay,
            amsgrad=False,
            fused=self.fused_adam,
        )
    elif self.optim == "adamW":
        optimizer = optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=self.weight_decay,
            amsgrad=False,
            fused=self.fused_adam,
        )
    elif self.optim == "galore":
        raise NotImplementedError("Galore optimizer not implemented")
        # param_groups = [
        #    {
        #        "params": [
        #            v for k, v in self.named_parameters() if "transformer" not in k
        #        ]
        #    },
        #    {
        #        "params": [
        #            v for k, v in self.named_parameters() if "transformer" in k
        #        ],
        #        "rank": 128,
        #        "update_proj_gap": 200,
        #        "scale": 0.25,
        #        "proj_type": "std",
        #    },
        # ]
        # optimizer = GaLoreAdamW(param_groups, lr=self.hparams.lr)
    else:
        raise ValueError(f"Unknown optimizer: {self.optim}")
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        patience=self.lr_reduce_patience,
        factor=self.lr_reduce_factor,
        verbose=True,
    )
    lr_dict = {
        "scheduler": lr_scheduler,
        # The unit of the scheduler's step size, could also be 'step'.
        # 'epoch' updates the scheduler on epoch end whereas 'step'
        # updates it after a optimizer update.
        "interval": "epoch",
        # How many epochs/steps should pass between calls to
        # `scheduler.step()`. 1 corresponds to updating the learning
        # rate after every epoch/step.
        "frequency": 1,
        # Metric to to monitor for schedulers like `ReduceLROnPlateau`
        "monitor": self.lr_reduce_monitor,
    }
    self.lrfinder_steps = 0
    for val in self.trainer.callbacks:
        if type(val) is _LRCallback:
            self.lrfinder_steps = val.num_training
        if type(val) is LearningRateFinder:
            self.lrfinder_steps = val._num_training_steps
    return [optimizer], [lr_dict]

forward

forward also called on self(), a full forward pass on the model

Parameters:
  • gene_pos (Tensor) –

    A tensor of shape (minibatch, seq_len) representing the genes used for each cell in the minibatch.

  • expression (Tensor, default: None ) –

    A tensor of shape (minibatch, seq_len) representing the expression levels of genes in the minibatch. Defaults to None.

  • mask (Tensor, default: None ) –

    A tensor of shape (minibatch, seq_len) used to mask certain elements in the sequence during the forward pass. Defaults to None.

  • full_depth (Tensor, default: None ) –

    A tensor of shape (minibatch,) representing the full depth of each sequence in the minibatch. Defaults to None.

  • timepoint (Tensor, default: None ) –

    A tensor of shape (minibatch,) representing the timepoint associated with each sequence in the minibatch. Defaults to None.

  • get_gene_emb (bool, default: False ) –

    A flag indicating whether to return the gene embeddings. If True, the gene embeddings are included in the output. Defaults to False.

  • do_sample (bool, default: False ) –

    A flag indicating whether to sample the expression levels. If True, the expression levels are sampled during the forward pass. Defaults to False.

  • get_attention_layer (list, default: [] ) –

    A list indicating which attention layers to return. If not empty, the specified attention layers are included in the output. Defaults to [].

Returns:
  • dict of output Tensors: A dictionary containing the output tensors from the forward pass. The keys of the dictionary depend on the input flags (get_gene_emb, do_sample, get_attention_layer). at minima, the dictionary codntains the following: - "mean": the mean expression levels - "zero_logits": the logits for zero-inflated expression levels - "disp": the dispersion parameter - "cell_embs": the cell embeddings per class - "cell_emb": the main cell embedding - "cls_output": the output of the classifier

Source code in scprint/model/model.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def forward(
    self,
    gene_pos: Tensor,
    expression: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
    full_depth: Optional[Tensor] = None,
    timepoint: Optional[Tensor] = None,  # (new_minibatch_of_nxt_cells,)
    get_gene_emb: bool = False,
    depth_mult: Optional[Tensor] = None,
    do_sample: bool = False,
    do_mvc: bool = False,
    do_class: bool = False,
    get_attention_layer: list = [],
):
    """
    forward also called on self(), a full forward pass on the model

    Args:
        gene_pos (Tensor): A tensor of shape (minibatch, seq_len)
            representing the genes used for each cell in the minibatch.
        expression (Tensor, optional): A tensor of shape (minibatch, seq_len)
            representing the expression levels of genes in the minibatch. Defaults to None.
        mask (Tensor, optional): A tensor of shape (minibatch, seq_len)
            used to mask certain elements in the sequence during the forward pass. Defaults to None.
        full_depth (Tensor, optional): A tensor of shape (minibatch,)
            representing the full depth of each sequence in the minibatch. Defaults to None.
        timepoint (Tensor, optional): A tensor of shape (minibatch,)
            representing the timepoint associated with each sequence in the minibatch. Defaults to None.
        get_gene_emb (bool, optional): A flag indicating whether to return the gene embeddings.
            If True, the gene embeddings are included in the output. Defaults to False.
        do_sample (bool, optional): A flag indicating whether to sample the expression levels.
            If True, the expression levels are sampled during the forward pass. Defaults to False.
        get_attention_layer (list, optional): A list indicating which attention layers to return.
            If not empty, the specified attention layers are included in the output. Defaults to [].

    Returns:
        dict of output Tensors: A dictionary containing the output tensors from the forward pass.
            The keys of the dictionary depend on the input flags (get_gene_emb, do_sample, get_attention_layer).
            at minima, the dictionary codntains the following:
            - "mean": the mean expression levels
            - "zero_logits": the logits for zero-inflated expression levels
            - "disp": the dispersion parameter
            - "cell_embs": the cell embeddings per class
            - "cell_emb": the main cell embedding
            - "cls_output": the output of the classifier
    """
    encoding = self._encoder(gene_pos, expression, mask, full_depth, timepoint)

    if self.attn_bias != "none":
        if not hasattr(self, "nbias"):
            self.nbias = torch.Tensor(
                load_npz(FILEDIR + "/../../data/bias_sparse.npz").todense()
            ).to(device=gene_pos.device, dtype=torch.float16)
        num = len(self.classes) + 2
        bias = torch.zeros(
            (
                gene_pos.shape[0],
                gene_pos.shape[1] + num,
                gene_pos.shape[1] + num,
            ),
            device=gene_pos.device,
            dtype=torch.float16,
        )
        bias[:, num:, :num] = -10_000  # do not pay attention to the cls embeddings
        bias[:, num:, num:] = self.nbias[gene_pos[:, :, None], gene_pos[:, None, :]]
    transformer_output = self.transformer(
        encoding,
        return_qkv=get_attention_layer,
        bias=bias if self.attn_bias != "none" else None,
        bias_layer=list(range(self.nlayers - 1)),
    )

    depth_mult = expression.sum(1) if depth_mult is None else depth_mult
    if len(get_attention_layer) > 0:
        transformer_output, qkvs = transformer_output
        return (
            self._decoder(
                transformer_output,
                depth_mult,
                get_gene_emb,
                do_sample,
                do_mvc,
                do_class,
            ),
            qkvs,
        )
    else:
        return self._decoder(
            transformer_output,
            depth_mult,
            get_gene_emb,
            do_sample,
            do_mvc,
            do_class,
        )

get_cell_embs

get_cell_embs

Parameters:
  • layer_output (Tensor) –

    The output tensor from a layer in the model.

Raises:
  • ValueError

    Raised when an unknown cell embedding style is encountered.

Returns:
  • Tensor

    The cell embeddings tensor.

Source code in scprint/model/model.py
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
def get_cell_embs(self, layer_output):
    """
    get_cell_embs

    Args:
        layer_output (Tensor): The output tensor from a layer in the model.

    Raises:
        ValueError: Raised when an unknown cell embedding style is encountered.

    Returns:
        Tensor: The cell embeddings tensor.
    """
    if self.cell_emb_style == "cls" and self.classes is not None:
        # (minibatch, embsize)
        cell_emb = layer_output[:, : 2 + len(self.classes)]
    elif self.cell_emb_style == "avg-pool":
        cell_emb = torch.mean(layer_output, dim=1)
    else:
        raise ValueError(f"Unknown cell_emb_style: {self.cell_emb_style}")
    return cell_emb

log_adata

log_adata will log an adata from predictions. It will log to tensorboard and wandb if available

see @utils.log_adata

Source code in scprint/model/model.py
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
def log_adata(self, gtclass=None, name=""):
    """
    log_adata will log an adata from predictions.
    It will log to tensorboard and wandb if available

    see @utils.log_adata
    """
    try:
        mdir = self.logger.save_dir if self.logger.save_dir is not None else "/tmp"
    except:
        mdir = "data/"
    if not os.path.exists(mdir):
        os.makedirs(mdir)
    adata, fig = utils.make_adata(
        self.embs,
        self.classes,
        self.pred if not self.keep_all_cls_pred else None,
        self.attn.get(),
        self.global_step,
        self.label_decoders,
        self.labels_hierarchy,
        gtclass,
        self.name + "_" + name + "_" + str(self.global_rank),
        mdir,
        self.doplot,
    )
    if self.doplot:
        try:
            self.logger.experiment.add_figure(fig)
        except:
            print("couldn't log to tensorboard")
        try:
            self.logger.log_image(key="umaps", images=[fig])
        except:
            print("couldn't log to wandb")

    return adata

on_fit_start

@see pl.LightningModule

Source code in scprint/model/model.py
637
638
639
640
641
642
643
def on_fit_start(self):
    """@see pl.LightningModule"""
    if type(self.transformer) is FlashTransformerEncoder:
        for encoder_layers in self.transformer.blocks:
            encoder_layers.set_seq_parallel(True)
    for k, v in self.mat_labels_hierarchy.items():
        self.mat_labels_hierarchy[k] = v.to(self.device)

on_predict_epoch_end

@see pl.LightningModule will

Source code in scprint/model/model.py
1321
1322
1323
1324
1325
1326
1327
def on_predict_epoch_end(self):
    """@see pl.LightningModule will"""
    if self.pos.shape[0] < 100:
        return
    if self.pred_log_adata:
        print("adding on disk")
        return self.log_adata(name="predict_part_" + str(self.counter))

on_predict_epoch_start

@see pl.LightningModule

Source code in scprint/model/model.py
1122
1123
1124
1125
1126
1127
1128
1129
1130
def on_predict_epoch_start(self):
    """@see pl.LightningModule"""
    self.embs = None
    self.attn.data = None
    self.attn.attn = None
    self.counter = 0
    if type(self.transformer) is FlashTransformerEncoder:
        for encoder_layers in self.transformer.blocks:
            encoder_layers.set_seq_parallel(False)

on_validation_epoch_end

@see pl.LightningModule

Source code in scprint/model/model.py
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
def on_validation_epoch_end(self):
    """@see pl.LightningModule"""
    self.embs = self.all_gather(self.embs).view(-1, self.embs.shape[-1])
    self.info = self.all_gather(self.info).view(-1, self.info.shape[-1])
    self.pred = (
        self.all_gather(self.pred).view(-1, self.pred.shape[-1])
        if self.pred is not None
        else None
    )
    self.pos = self.all_gather(self.pos).view(-1, self.pos.shape[-1])
    if not self.trainer.is_global_zero:
        # print("you are not on the main node. cancelling logging step")
        return
    if self.trainer.state.stage != "sanity_check":
        sch = self.lr_schedulers()
        sch.step(self.trainer.callback_metrics["val_loss"])
        # run the test function on specific dataset
        self.log_adata(
            gtclass=self.info, name="validation_part_" + str(self.counter)
        )
        if (self.current_epoch + 1) % 30 == 0:
            self.on_test_epoch_end()

optimizer_step

@see pl.LightningModule

Source code in scprint/model/model.py
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
    """@see pl.LightningModule"""
    # update params
    optimizer.step(closure=optimizer_closure)

    # manually warm up lr without a scheduler
    # making sure that we don't do this during lrfinder
    for i, pg in enumerate(optimizer.param_groups):
        if (
            self.global_step < self.warmup_duration + self.lrfinder_steps
        ) and self.lrfinder_steps < self.global_step:
            lr_scale = min(1.0, float(self.global_step + 1) / self.warmup_duration)
            pg["lr"] = lr_scale * self.hparams.lr
    for i, pg in enumerate(optimizer.param_groups):
        # if pg["lr"] < 2e-5:
        #    pg["lr"] = 2e-5
        self.log("lr_" + str(i), pg["lr"])

predict_step

embed given gene expression, encode the gene embedding and cell embedding.

Returns:
  • Tensor

    description

Source code in scprint/model/model.py
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
def predict_step(self, batch, batch_idx):
    """
    embed given gene expression, encode the gene embedding and cell embedding.

    Args:
        batch @see training_step

    Returns:
        Tensor: _description_
    """
    return self._predict(
        batch["genes"],
        batch["x"],
        batch["depth"],
        self.predict_mode,
        self.pred_embedding,
        self.get_attention_layer,
        self.predict_depth_mult,
    )

training_step

training_step defines the train loop. It is independent of forward

@see pl.LightningModule

Returns:
  • _type_

    description

Source code in scprint/model/model.py
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def training_step(
    self,
    batch: Dict[str, Tensor],
    batch_idx,
):
    """
    training_step defines the train loop. It is independent of forward

    @see pl.LightningModule

    Returns:
        _type_: _description_
    """
    # TASK 1 & 2 & 3 (first pass, expression reconstruction, label prediction)
    total_loss, losses = self._full_training(
        batch=batch,
        do_denoise=self.do_denoise,
        noise=self.noise,
        do_next_tp=self.do_next_tp,
        do_cce=self.do_cce,
        cce_sim=self.cce_sim,
        do_ecs=self.do_ecs,
        do_mvc=self.do_mvc,
        do_adv_cls=self.do_adv_cls,
        do_adv_batch=self.do_adv_batch,
        do_cls=self.do_cls,
        do_generate=self.do_generate,
        run_full_forward=self.run_full_forward,
        mask_ratio=self.mask_ratio,
    )
    self.log("train_loss", total_loss, prog_bar=True, sync_dist=True)
    self.log_dict(losses, prog_bar=True, sync_dist=True)
    return total_loss

validation_step

validation_step defines the validation loop. It is independent of forward @see pl.LightningModule

Parameters:
  • batch (list[Tensor]) –

    @see training_step

Source code in scprint/model/model.py
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
def validation_step(
    self,
    batch,
    batch_idx,
):
    """
    validation_step defines the validation loop. It is independent of forward
    @see pl.LightningModule

    Args:
        batch (list[Tensor]): @see training_step
    """
    val_loss, losses = self._full_training(
        batch=batch,
        do_denoise=self.do_denoise,
        noise=self.noise,
        do_next_tp=self.do_next_tp,
        do_cce=self.do_cce,
        cce_sim=self.cce_sim,
        do_ecs=self.do_ecs,
        do_mvc=self.do_mvc,
        do_adv_cls=self.do_adv_cls,
        do_adv_batch=self.do_adv_batch,
        do_cls=self.do_cls,
        do_generate=self.do_generate,
        run_full_forward=self.run_full_forward,
        mask_ratio=self.mask_ratio,
    )
    expression = batch["x"]
    gene_pos = batch["genes"]
    depth = batch["depth"]
    # TODO: make this faster by only calling val loss
    if self.embs is not None:
        if self.embs.shape[0] < 100_000:
            self.info = torch.cat([self.info, batch["class"]])
            self._predict(
                gene_pos,
                expression,
                depth,
                pred_embedding=self.pred_embedding,
                max_size_in_mem=100_000,
            )
    else:
        self.info = batch["class"]
        self._predict(
            gene_pos,
            expression,
            depth,
            pred_embedding=self.pred_embedding,
            max_size_in_mem=100_000,
        )
    self.log("val_loss", val_loss, sync_dist=True)
    self.log_dict(losses, sync_dist=True)
    return val_loss

losses

scprint.model.loss

AdversarialDiscriminatorLoss

Bases: Module

Discriminator for the adversarial training for batch correction.

Parameters:
  • d_model (int) –

    The size of the input tensor.

  • n_cls (int) –

    The number of classes.

  • nlayers (int, default: 3 ) –

    The number of layers in the discriminator. Defaults to 3.

  • activation (callable, default: LeakyReLU ) –

    The activation function. Defaults to nn.LeakyReLU.

  • reverse_grad (bool, default: True ) –

    Whether to reverse the gradient. Defaults

Source code in scprint/model/loss.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def __init__(
    self,
    d_model: int,
    n_cls: int,
    nlayers: int = 3,
    activation: callable = nn.LeakyReLU,
    reverse_grad: bool = True,
):
    """
    Discriminator for the adversarial training for batch correction.

    Args:
        d_model (int): The size of the input tensor.
        n_cls (int): The number of classes.
        nlayers (int, optional): The number of layers in the discriminator. Defaults to 3.
        activation (callable, optional): The activation function. Defaults to nn.LeakyReLU.
        reverse_grad (bool, optional): Whether to reverse the gradient. Defaults
    """
    super().__init__()
    # module list
    self.decoder = nn.ModuleList()
    for _ in range(nlayers - 1):
        self.decoder.append(nn.Linear(d_model, d_model))
        self.decoder.append(nn.LayerNorm(d_model))
        self.decoder.append(activation())
    self.out_layer = nn.Linear(d_model, n_cls)
    self.reverse_grad = reverse_grad

forward

Parameters:
  • x (Tensor) –

    Tensor, shape [batch_size, embsize]

Source code in scprint/model/loss.py
303
304
305
306
307
308
309
310
311
312
313
def forward(self, x: Tensor, batch_labels: Tensor) -> Tensor:
    """
    Args:
        x: Tensor, shape [batch_size, embsize]
    """
    if self.reverse_grad:
        x = grad_reverse(x, lambd=1.0)
    for layer in self.decoder:
        x = layer(x)
    x = self.out_layer(x)
    return F.cross_entropy(x, batch_labels)

classification

Computes the classification loss for a given batch of predictions and ground truth labels.

Parameters:
  • clsname (str) –

    The name of the label.

  • pred (Tensor) –

    The predicted logits for the batch.

  • cl (Tensor) –

    The ground truth labels for the batch.

  • maxsize (int) –

    The number of possible labels.

  • labels_hierarchy (dict, default: {} ) –

    The hierarchical structure of the labels. Defaults to {}.

Raises:
  • ValueError

    If the clsname is not found in the labels_hierarchy dictionary.

Returns:
  • Tensor( Tensor ) –

    The computed binary cross entropy loss for the given batch.

Source code in scprint/model/loss.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def classification(
    clsname: str,
    pred: torch.Tensor,
    cl: torch.Tensor,
    maxsize: int,
    labels_hierarchy: Optional[Dict[str, Dict[int, list[int]]]] = {},
) -> torch.Tensor:
    """
    Computes the classification loss for a given batch of predictions and ground truth labels.

    Args:
        clsname (str): The name of the label.
        pred (Tensor): The predicted logits for the batch.
        cl (Tensor): The ground truth labels for the batch.
        maxsize (int): The number of possible labels.
        labels_hierarchy (dict, optional): The hierarchical structure of the labels. Defaults to {}.

    Raises:
        ValueError: If the clsname is not found in the labels_hierarchy dictionary.

    Returns:
        Tensor: The computed binary cross entropy loss for the given batch.
    """
    newcl = torch.zeros(
        (cl.shape[0], maxsize), device=cl.device
    )  # batchsize * n_labels
    # if we don't know the label we set the weight to 0 else to 1
    valid_indices = (cl != -1) & (cl < maxsize)
    valid_cl = cl[valid_indices]
    newcl[valid_indices, valid_cl] = 1

    weight = torch.ones_like(newcl, device=cl.device)
    weight[cl == -1, :] = 0
    inv = cl >= maxsize
    # if we have non leaf values, we don't know so we don't compute grad and set weight to 0
    # and add labels that won't be counted but so that we can still use them
    if inv.any():
        if clsname in labels_hierarchy.keys():
            clhier = labels_hierarchy[clsname]

            inv_weight = weight[inv]
            # we set the weight of the elements that are not leaf to 0
            # i.e. the elements where we will compute the max
            inv_weight[clhier[cl[inv] - maxsize]] = 0
            weight[inv] = inv_weight

            addnewcl = torch.ones(
                weight.shape[0], device=pred.device
            )  # no need to set the other to 0 as the weight of the loss is set to 0
            addweight = torch.zeros(weight.shape[0], device=pred.device)
            addweight[inv] = 1
            # computing hierarchical labels and adding them to cl
            addpred = pred.clone()
            # we only keep the elements where we need to compute the max,
            # for the rest we set them to -inf, so that they won't have any impact on the max()
            inv_addpred = addpred[inv]
            inv_addpred[inv_weight.to(bool)] = torch.finfo(pred.dtype).min
            addpred[inv] = inv_addpred

            # differentiable max
            addpred = torch.logsumexp(addpred, dim=-1)

            # we add the new labels to the cl
            newcl = torch.cat([newcl, addnewcl.unsqueeze(1)], dim=1)
            pred = torch.cat([pred, addpred.unsqueeze(1)], dim=1)
            weight = torch.cat([weight, addweight.unsqueeze(1)], dim=1)
        else:
            raise ValueError("need to use labels_hierarchy for this usecase")

    myloss = torch.nn.functional.binary_cross_entropy_with_logits(
        pred, target=newcl, weight=weight
    )
    return myloss

criterion_neg_log_bernoulli

Compute the negative log-likelihood of Bernoulli distribution

Source code in scprint/model/loss.py
144
145
146
147
148
149
150
151
def criterion_neg_log_bernoulli(input: Tensor, target: Tensor, mask: Tensor) -> Tensor:
    """
    Compute the negative log-likelihood of Bernoulli distribution
    """
    mask = mask.float()
    bernoulli = torch.distributions.Bernoulli(probs=input)
    masked_log_probs = bernoulli.log_prob((target > 0).float()) * mask
    return -masked_log_probs.sum() / mask.sum()

ecs

ecs Computes the similarity of cell embeddings based on a threshold.

Parameters:
  • cell_emb (Tensor) –

    A tensor representing cell embeddings.

  • ecs_threshold (float, default: 0.5 ) –

    A threshold for determining similarity. Defaults to 0.5.

Returns:
  • Tensor( Tensor ) –

    A tensor representing the mean of 1 minus the square of the difference between the cosine similarity and the threshold.

Source code in scprint/model/loss.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def ecs(cell_emb: Tensor, ecs_threshold: float = 0.5) -> Tensor:
    """
    ecs Computes the similarity of cell embeddings based on a threshold.

    Args:
        cell_emb (Tensor): A tensor representing cell embeddings.
        ecs_threshold (float, optional): A threshold for determining similarity. Defaults to 0.5.

    Returns:
        Tensor: A tensor representing the mean of 1 minus the square of the difference between the cosine similarity and the threshold.
    """
    # Here using customized cosine similarity instead of F.cosine_similarity
    # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064
    # normalize the embedding
    cell_emb_normed = F.normalize(cell_emb, p=2, dim=1)
    cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t())

    # mask out diagnal elements
    mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device)
    cos_sim = cos_sim.masked_fill(mask, 0.0)
    # only optimize positive similarities
    cos_sim = F.relu(cos_sim)
    return torch.mean(1 - (cos_sim - ecs_threshold) ** 2)

grad_reverse

grad_reverse Reverses the gradient of the input tensor.

Parameters:
  • x (Tensor) –

    The input tensor whose gradient is to be reversed.

  • lambd (float, default: 1.0 ) –

    The scaling factor for the reversed gradient. Defaults to 1.0.

Returns:
  • Tensor( Tensor ) –

    The input tensor with its gradient reversed during the backward pass.

Source code in scprint/model/loss.py
327
328
329
330
331
332
333
334
335
336
337
338
def grad_reverse(x: Tensor, lambd: float = 1.0) -> Tensor:
    """
    grad_reverse Reverses the gradient of the input tensor.

    Args:
        x (Tensor): The input tensor whose gradient is to be reversed.
        lambd (float, optional): The scaling factor for the reversed gradient. Defaults to 1.0.

    Returns:
        Tensor: The input tensor with its gradient reversed during the backward pass.
    """
    return GradReverse.apply(x, lambd)

masked_mae

Compute the masked MAE loss between input and target. MAE = mean absolute error

Source code in scprint/model/loss.py
30
31
32
33
34
35
36
37
def masked_mae(input: Tensor, target: Tensor, mask: Tensor) -> Tensor:
    """
    Compute the masked MAE loss between input and target.
    MAE = mean absolute error
    """
    mask = mask.float()
    loss = F.l1_loss(input * mask, target * mask, reduction="sum")
    return loss / mask.sum()

masked_mse

Compute the masked MSE loss between input and target.

Source code in scprint/model/loss.py
21
22
23
24
25
26
27
def masked_mse(input: Tensor, target: Tensor, mask: Tensor) -> Tensor:
    """
    Compute the masked MSE loss between input and target.
    """
    mask = mask.float()
    loss = F.mse_loss(input * mask, target * mask, reduction="sum")
    return loss / mask.sum()

masked_nb

Compute the masked negative binomial loss between input and target.

Source code in scprint/model/loss.py
40
41
42
43
44
45
46
47
def masked_nb(input: Tensor, target: Tensor, mask: Tensor) -> Tensor:
    """
    Compute the masked negative binomial loss between input and target.
    """
    mask = mask.float()
    nb = torch.distributions.NegativeBinomial(total_count=target, probs=input)
    masked_log_probs = nb.log_prob(target) * mask
    return -masked_log_probs.sum() / mask.sum()

masked_relative_error

Compute the masked relative error between input and target.

Source code in scprint/model/loss.py
154
155
156
157
158
159
160
161
162
def masked_relative_error(
    input: Tensor, target: Tensor, mask: torch.LongTensor
) -> Tensor:
    """
    Compute the masked relative error between input and target.
    """
    assert mask.any()
    loss = torch.abs(input[mask] - target[mask]) / (target[mask] + 1e-6)
    return loss.mean()

mse

Compute the MSE loss between input and target.

Source code in scprint/model/loss.py
10
11
12
13
14
15
16
17
18
def mse(input: Tensor, target: Tensor) -> Tensor:
    """
    Compute the MSE loss between input and target.
    """
    input = torch.log2(input + 1)
    input = (input / torch.sum(input, dim=1, keepdim=True)) * 10000
    target = torch.log2(target + 1)
    target = target / torch.sum(target, dim=1, keepdim=True) * 10000
    return F.mse_loss(input, target, reduction="mean")

nb

Computes the negative binomial (NB) loss.

This function was adapted from scvi-tools.

Parameters:
  • target (Tensor) –

    Ground truth data.

  • mu (Tensor) –

    Means of the negative binomial distribution (must have positive support).

  • theta (Tensor) –

    Inverse dispersion parameter (must have positive support).

  • eps (float, default: 1e-08 ) –

    Numerical stability constant. Defaults to 1e-8.

Returns:
  • Tensor

    NB loss value.

Source code in scprint/model/loss.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def nb(target: Tensor, mu: Tensor, theta: Tensor, eps=1e-8):
    """
    Computes the negative binomial (NB) loss.

    This function was adapted from scvi-tools.

    Args:
        target (Tensor): Ground truth data.
        mu (Tensor): Means of the negative binomial distribution (must have positive support).
        theta (Tensor): Inverse dispersion parameter (must have positive support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-8.

    Returns:
        Tensor: NB loss value.
    """
    if theta.ndimension() == 1:
        theta = theta.view(1, theta.size(0))

    log_theta_mu_eps = torch.log(theta + mu + eps)
    res = (
        theta * (torch.log(theta + eps) - log_theta_mu_eps)
        + target * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(target + theta)
        - torch.lgamma(theta)
        - torch.lgamma(target + 1)
    )

    return -res.mean()

nb_dist

nb_dist Computes the negative binomial distribution.

Parameters:
  • x (Tensor) –

    Torch Tensor of observed data.

  • mu (Tensor) –

    Torch Tensor of means of the negative binomial distribution (must have positive support).

  • theta (Tensor) –

    Torch Tensor of inverse dispersion parameter (must have positive support).

  • eps (float, default: 1e-08 ) –

    Numerical stability constant. Defaults to 1e-8.

Returns:
  • Tensor

    Negative binomial loss value.

Source code in scprint/model/loss.py
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def nb_dist(x: Tensor, mu: Tensor, theta: Tensor, eps=1e-8):
    """
    nb_dist Computes the negative binomial distribution.

    Args:
        x (Tensor): Torch Tensor of observed data.
        mu (Tensor): Torch Tensor of means of the negative binomial distribution (must have positive support).
        theta (Tensor): Torch Tensor of inverse dispersion parameter (must have positive support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-8.

    Returns:
        Tensor: Negative binomial loss value.
    """
    loss = -NegativeBinomial(mu=mu, theta=theta).log_prob(x)
    return loss

similarity

Dot product or cosine similarity

Source code in scprint/model/loss.py
165
166
167
168
169
170
171
def similarity(x: Tensor, y: Tensor, temp: float) -> Tensor:
    """
    Dot product or cosine similarity
    """
    res = F.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) / temp
    labels = torch.arange(res.size(0)).long().to(device=res.device)
    return F.cross_entropy(res, labels)

zinb

Computes zero-inflated negative binomial (ZINB) loss.

This function was modified from scvi-tools.

Parameters:
  • target (Tensor) –

    Torch Tensor of ground truth data.

  • mu (Tensor) –

    Torch Tensor of means of the negative binomial (must have positive support).

  • theta (Tensor) –

    Torch Tensor of inverse dispersion parameter (must have positive support).

  • pi (Tensor) –

    Torch Tensor of logits of the dropout parameter (real support).

  • eps (float, default: 1e-08 ) –

    Numerical stability constant. Defaults to 1e-8.

Returns:
  • Tensor

    ZINB loss value.

Source code in scprint/model/loss.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def zinb(
    target: Tensor,
    mu: Tensor,
    theta: Tensor,
    pi: Tensor,
    eps=1e-8,
):
    """
    Computes zero-inflated negative binomial (ZINB) loss.

    This function was modified from scvi-tools.

    Args:
        target (Tensor): Torch Tensor of ground truth data.
        mu (Tensor): Torch Tensor of means of the negative binomial (must have positive support).
        theta (Tensor): Torch Tensor of inverse dispersion parameter (must have positive support).
        pi (Tensor): Torch Tensor of logits of the dropout parameter (real support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-8.

    Returns:
        Tensor: ZINB loss value.
    """
    #  uses log(sigmoid(x)) = -softplus(-x)
    softplus_pi = F.softplus(-pi)
    # eps to make it positive support and taking the log
    log_theta_mu_eps = torch.log(theta + mu + eps)
    pi_theta_log = -pi + theta * (torch.log(theta + eps) - log_theta_mu_eps)

    case_zero = F.softplus(pi_theta_log) - softplus_pi
    mul_case_zero = torch.mul((target < eps).type(torch.float32), case_zero)

    case_non_zero = (
        -softplus_pi
        + pi_theta_log
        + target * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(target + theta)
        - torch.lgamma(theta)
        - torch.lgamma(target + 1)
    )
    mul_case_non_zero = torch.mul((target > eps).type(torch.float32), case_non_zero)

    res = mul_case_zero + mul_case_non_zero
    # we want to minize the loss but maximize the log likelyhood
    return -res.mean()

utils

scprint.model.utils

Attention

Initialize the Attention class.

Parameters:
  • gene_dim (int) –

    The dimension of the gene.

  • comp_attn (bool, default: False ) –

    Whether to compute attention. Defaults to False.

Source code in scprint/model/utils.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def __init__(
    self,
    gene_dim: int,
    comp_attn: bool = False,
    apply_softmax: bool = False,
    sum_heads: bool = True,
):
    """
    Initialize the Attention class.

    Args:
        gene_dim (int): The dimension of the gene.
        comp_attn (bool, optional): Whether to compute attention. Defaults to False.
    """
    self.data: Optional[Tensor] = None
    self.gene_dim: int = gene_dim
    self.div: Optional[Tensor] = None
    self.comp_attn: bool = comp_attn
    self.apply_softmax: bool = apply_softmax
    self.sum_heads: bool = sum_heads
    self.shared_qk: bool = True

add_attn

Aggregate the attention or data based on the comp_attn flag.

Parameters:
  • x (List[Tensor]) –

    List of tensors to aggregate.

  • pos (Tensor) –

    Position tensor.

Source code in scprint/model/utils.py
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def add_attn(
    self, x: List[Tensor], pos: Tensor, expr: Optional[Tensor] = None
) -> None:
    """
    Aggregate the attention or data based on the comp_attn flag.

    Args:
        x (List[Tensor]): List of tensors to aggregate.
        pos (Tensor): Position tensor.
    """
    if self.data is None:
        self.data = torch.zeros(
            [self.gene_dim, self.gene_dim, len(x) * x[0].shape[3]],
            device="cuda",
            dtype=torch.float32,
        )
        self.div = torch.zeros(1, device="cuda", dtype=torch.float32)

    for i, elem in enumerate(x):
        batch, seq_len, _, heads, _ = elem.shape
        if self.apply_softmax:
            attn = torch.nn.functional.softmax(
                elem[:, :, 0, :, :].permute(0, 2, 1, 3)
                @ elem[:, :, 1, :, :].permute(0, 2, 3, 1),
                dim=-1,
            )
            if expr is not None:
                attn = attn * (expr > 0).float()
            self.data[:, :, heads * i : heads * (i + 1)] += (
                attn.sum(0).permute(1, 2, 0) / batch
            )
        else:
            self.data[:, :, heads * i : heads * (i + 1)] += (
                elem[:, :, 0, :, :].permute(0, 2, 1, 3)
                @ elem[:, :, 1, :, :].permute(0, 2, 3, 1)
            ).sum(0).permute(1, 2, 0) / batch
    self.div += 1

add_qk

Add data to the internal storage.

Parameters:
  • x (List[Tensor]) –

    List of tensors to add.

  • pos (Tensor) –

    Position tensor.

Source code in scprint/model/utils.py
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
def add_qk(
    self, x: List[Tensor], pos: Tensor, expr: Optional[Tensor] = None
) -> None:
    """
    Add data to the internal storage.

    Args:
        x (List[Tensor]): List of tensors to add.
        pos (Tensor): Position tensor.
    """
    if self.data is None:
        self.data = torch.zeros(
            [len(x), self.gene_dim] + list(x[0].shape[2:]), device="cuda"
        )
        self.div = torch.zeros(self.gene_dim, device="cuda")
    for i in range(x[0].shape[0]):
        loc = torch.cat([torch.arange(8, device="cuda"), pos[i] + 8]).int()
        for j in range(len(x)):
            self.data[j, loc, :, :, :] += x[j][i]
        self.div[loc] += 1

get

Get the aggregated attention or data.

Returns:
  • Optional[ndarray]

    Optional[np.ndarray]: The aggregated attention or data.

Source code in scprint/model/utils.py
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
def get(self) -> Optional[np.ndarray]:
    """
    Get the aggregated attention or data.

    Returns:
        Optional[np.ndarray]: The aggregated attention or data.
    """
    if self.shared_qk:
        if self.data is None:
            return None
        # shape is (layers, genes, qkv, heads, emb)
        return self.data / self.div.view(1, self.div.shape[0], 1, 1, 1)
    else:
        if self.data is None:
            return None
        self.data.div_(self.div)
        return self.data

downsample_profile

This function downsamples the expression profile of a given single cell RNA matrix.

The noise is applied based on the renoise parameter, the total counts of the matrix, and the number of genes. The function first calculates the noise threshold (scaler) based on the renoise parameter. It then generates an initial matrix count by applying a Poisson distribution to a random tensor scaled by the total counts and the number of genes. The function then models the sampling zeros by applying a Poisson distribution to a random tensor scaled by the noise threshold, the total counts, and the number of genes. The function also models the technical zeros by generating a random tensor and comparing it to the noise threshold. The final matrix count is calculated by subtracting the sampling zeros from the initial matrix count and multiplying by the technical zeros. The function ensures that the final matrix count is not less than zero by taking the maximum of the final matrix count and a tensor of zeros. The function returns the final matrix count.

Parameters:
  • mat (Tensor) –

    The input matrix.

  • dropout (float) –

    The renoise parameter.

Returns:
  • torch.Tensor: The matrix count after applying noise.

Source code in scprint/model/utils.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def downsample_profile(mat: Tensor, dropout: float, method="new"):
    """
    This function downsamples the expression profile of a given single cell RNA matrix.

    The noise is applied based on the renoise parameter,
    the total counts of the matrix, and the number of genes. The function first calculates the noise
    threshold (scaler) based on the renoise parameter. It then generates an initial matrix count by
    applying a Poisson distribution to a random tensor scaled by the total counts and the number of genes.
    The function then models the sampling zeros by applying a Poisson distribution to a random tensor
    scaled by the noise threshold, the total counts, and the number of genes. The function also models
    the technical zeros by generating a random tensor and comparing it to the noise threshold. The final
    matrix count is calculated by subtracting the sampling zeros from the initial matrix count and
    multiplying by the technical zeros. The function ensures that the final matrix count is not less
    than zero by taking the maximum of the final matrix count and a tensor of zeros. The function
    returns the final matrix count.

    Args:
        mat (torch.Tensor): The input matrix.
        dropout (float): The renoise parameter.

    Returns:
        torch.Tensor: The matrix count after applying noise.
    """
    # Randomly drop on average N counts to each element of expression using a heavy tail Gaussian distribution
    # here we try to get the scale of the distribution so as to remove the right number of counts from each gene
    # https://genomebiology.biomedcentral.com/articles/10.1186/s13059-022-02601-5#:~:text=Zero%20measurements%20in%20scRNA%2Dseq,generation%20of%20scRNA%2Dseq%20data.
    if method == "old":
        totcounts = mat.sum(1)
        batch = mat.shape[0]
        ngenes = mat.shape[1]
        tnoise = 1 - (1 - dropout) ** (1 / 2)
        # we model the sampling zeros (dropping 30% of the reads)
        res = torch.poisson(
            torch.rand((batch, ngenes)).to(device=mat.device)
            * ((tnoise * totcounts.unsqueeze(1)) / (0.5 * ngenes))
        ).int()
        # we model the technical zeros (dropping 50% of the genes)
        drop = (torch.rand((batch, ngenes)) > tnoise).int().to(device=mat.device)

        mat = (mat - res) * drop
        return torch.maximum(mat, torch.Tensor([[0]]).to(device=mat.device)).int()
    elif method == "jules":
        scaler = (1 - dropout) ** (1 / 2)
        notdrop = (
            torch.rand(
                mat.shape,
                device=mat.device,
            )
            < scaler
        ).int()
        notdrop[mat == 0] = 0
        # apply the dropout after the poisson, right?
        return notdrop * torch.poisson(mat * scaler)
    elif method == "new":
        batch = mat.shape[0]
        ngenes = mat.shape[1]
        dropout = dropout * 1.1
        # we model the sampling zeros (dropping 30% of the reads)
        res = torch.poisson((mat * (dropout / 2))).int()
        # we model the technical zeros (dropping 50% of the genes)
        notdrop = (
            torch.rand((batch, ngenes), device=mat.device) >= (dropout / 2)
        ).int()
        mat = (mat - res) * notdrop
        return torch.maximum(
            mat, torch.zeros((1, 1), device=mat.device, dtype=torch.int)
        )
    else:
        raise ValueError(f"method {method} not recognized")

make_adata

This function creates an AnnData object from the given input parameters.

Parameters:
  • embs (Tensor) –

    Embeddings of the cells. The shape of the tensor is (n_cells, n_features).

  • labels (list) –

    List of labels for the predicted classes.

  • pred (Tensor, default: None ) –

    Predicted labels. The shape of the tensor is (n_cells, n_classes). Default is None.

  • attention (Tensor, default: None ) –

    Attention weights. Default is None.

  • step (int, default: 0 ) –

    Step number for storing the AnnData without overwriting others. Default is 0.

  • label_decoders (dict, default: None ) –

    Dictionary to map class codes to class names. Default is None.

  • labels_hierarchy (dict, default: {} ) –

    Dictionary representing the hierarchy of labels. Default is {}.

  • gtclass (Tensor, default: None ) –

    Ground truth class. Default is None.

  • name (str, default: '' ) –

    Name of the AnnData object. Default is an empty string.

  • mdir (str, default: '/tmp' ) –

    Directory to save the AnnData object. Default is "/tmp".

  • doplot (bool, default: True ) –

    Whether to generate plots. Default is True.

Returns:
  • anndata.AnnData: The created AnnData object.

Source code in scprint/model/utils.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def make_adata(
    embs: Tensor,
    labels: List[str],
    pred: Tensor = None,
    attention: Optional[Tensor] = None,
    step: int = 0,
    label_decoders: Optional[Dict] = None,
    labels_hierarchy: Dict = {},
    gtclass: Optional[Tensor] = None,
    name: str = "",
    mdir: str = "/tmp",
    doplot: bool = True,
):
    """
    This function creates an AnnData object from the given input parameters.

    Args:
        embs (torch.Tensor): Embeddings of the cells. The shape of the tensor is (n_cells, n_features).
        labels (list): List of labels for the predicted classes.
        pred (torch.Tensor, optional): Predicted labels. The shape of the tensor is (n_cells, n_classes). Default is None.
        attention (torch.Tensor, optional): Attention weights. Default is None.
        step (int, optional): Step number for storing the AnnData without overwriting others. Default is 0.
        label_decoders (dict, optional): Dictionary to map class codes to class names. Default is None.
        labels_hierarchy (dict, optional): Dictionary representing the hierarchy of labels. Default is {}.
        gtclass (torch.Tensor, optional): Ground truth class. Default is None.
        name (str, optional): Name of the AnnData object. Default is an empty string.
        mdir (str, optional): Directory to save the AnnData object. Default is "/tmp".
        doplot (bool, optional): Whether to generate plots. Default is True.

    Returns:
        anndata.AnnData: The created AnnData object.
    """
    colname = ["pred_" + i for i in labels]
    if pred is not None:
        obs = np.array(pred.to(device="cpu", dtype=torch.int32))
        # label decoders is not cls_decoders. one is a dict to map class codes (ints)
        # to class names the other is the module the predict the class
        if label_decoders is not None:
            obs = np.array(
                [
                    [label_decoders[labels[i]][n] for n in name]
                    for i, name in enumerate(obs.T)
                ]
            ).T

        if gtclass is not None:
            colname += labels
            nobs = np.array(gtclass.to(device="cpu", dtype=torch.int32))
            if label_decoders is not None:
                nobs = np.array(
                    [
                        [label_decoders[labels[i]][n] for n in name]
                        for i, name in enumerate(nobs.T)
                    ]
                ).T
            obs = np.hstack([obs, nobs])

        adata = AnnData(
            np.array(embs.to(device="cpu", dtype=torch.float32)),
            obs=pd.DataFrame(
                obs,
                columns=colname,
            ),
        )
        accuracy = {}
        for label in labels:
            if gtclass is not None:
                tr = translate(adata.obs[label].tolist(), label)
                if tr is not None:
                    adata.obs["conv_" + label] = adata.obs[label].replace(tr)
            tr = translate(adata.obs["pred_" + label].tolist(), label)
            if tr is not None:
                adata.obs["conv_pred_" + label] = adata.obs["pred_" + label].replace(tr)
            res = []
            if label_decoders is not None and gtclass is not None:
                class_topred = label_decoders[label].values()
                if label in labels_hierarchy:
                    cur_labels_hierarchy = {
                        label_decoders[label][k]: [label_decoders[label][i] for i in v]
                        for k, v in labels_hierarchy[label].items()
                    }
                else:
                    cur_labels_hierarchy = {}
                for pred, true in adata.obs[["pred_" + label, label]].values:
                    if pred == true:
                        res.append(True)
                        continue
                    if len(labels_hierarchy) > 0:
                        if true in cur_labels_hierarchy:
                            res.append(pred in cur_labels_hierarchy[true])
                        elif true not in class_topred:
                            raise ValueError(
                                f"true label {true} not in available classes"
                            )
                        elif true != "unknown":
                            res.append(False)
                    elif true not in class_topred:
                        raise ValueError(f"true label {true} not in available classes")
                    elif true != "unknown":
                        res.append(False)
                    else:
                        pass
                accuracy["pred_" + label] = sum(res) / len(res) if len(res) > 0 else 0
        adata.obs = adata.obs.astype("category")
    else:
        adata = AnnData(
            np.array(embs.to(device="cpu", dtype=torch.float32)),
        )
    if False:
        adata.varm["Qs"] = (
            attention[:, :, 0, :, :]
            .permute(1, 3, 0, 2)
            .view(
                attention.shape[0],
                attention.shape[1],
                attention.shape[3] * attention.shape[4],
            )
            .detach()
            .cpu()
            .numpy()
        )
        adata.varm["Ks"] = (
            attention[:, :, 1, :, :]
            .permute(1, 3, 0, 2)
            .view(
                attention.shape[0],
                attention.shape[1],
                attention.shape[3] * attention.shape[4],
            )
            .detach()
            .cpu()
            .numpy()
        )
    print(adata)
    if doplot and adata.shape[0] > 100 and pred is not None:
        sc.pp.neighbors(adata, use_rep="X")
        sc.tl.umap(adata)
        sc.tl.leiden(adata, key_added="sprint_leiden")
        if gtclass is not None:
            color = [
                i
                for pair in zip(
                    [
                        "conv_" + i if "conv_" + i in adata.obs.columns else i
                        for i in labels
                    ],
                    [
                        (
                            "conv_pred_" + i
                            if "conv_pred_" + i in adata.obs.columns
                            else "pred_" + i
                        )
                        for i in labels
                    ],
                )
                for i in pair
            ]
            fig, axs = plt.subplots(
                int(len(color) / 2), 2, figsize=(24, len(color) * 4)
            )
            plt.subplots_adjust(wspace=1)
            for i, col in enumerate(color):
                sc.pl.umap(
                    adata,
                    color=col,
                    ax=axs[i // 2, i % 2],
                    show=False,
                )
                acc = ""
                if "_pred_" in col and col.split("conv_")[-1] in accuracy:
                    acc = " (accuracy: {:.2f})".format(accuracy[col.split("conv_")[-1]])
                axs[i // 2, i % 2].set_title(col + " UMAP" + acc)
                if "cell_type" in col:
                    axs[i // 2, i % 2].legend(fontsize="x-small")
                axs[i // 2, i % 2].set_xlabel("UMAP1")
                axs[i // 2, i % 2].set_ylabel("UMAP2")
        else:
            color = [
                (
                    "conv_pred_" + i
                    if "conv_pred_" + i in adata.obs.columns
                    else "pred_" + i
                )
                for i in labels
            ]
            fig, axs = plt.subplots(len(color), 1, figsize=(16, len(color) * 8))
            for i, col in enumerate(color):
                sc.pl.umap(
                    adata,
                    color=col,
                    ax=axs[i],
                    show=False,
                )
                acc = ""
                if "_pred_" in col and col.split("conv_")[-1] in accuracy:
                    acc = " (accuracy: {:.2f})".format(accuracy[col.split("conv_")[-1]])
                axs[i].set_title(col + " UMAP" + acc)
                axs[i].set_xlabel("UMAP1")
                axs[i].set_ylabel("UMAP2")
        plt.show()
    else:
        fig = None
    adata.write(mdir + "/step_" + str(step) + "_" + name + ".h5ad")
    return adata, fig

simple_masker

Randomly mask a batch of data.

Parameters:
  • shape (list[int]) –

    The shape of the data.

  • mask_ratio (float, default: 0.15 ) –

    The ratio of genes to mask, default to 0.15.

Returns:
  • Tensor

    torch.Tensor: A tensor of masked data.

Source code in scprint/model/utils.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def simple_masker(
    shape: list[int],
    mask_ratio: float = 0.15,
) -> torch.Tensor:
    """
    Randomly mask a batch of data.

    Args:
        shape (list[int]): The shape of the data.
        mask_ratio (float): The ratio of genes to mask, default to 0.15.

    Returns:
        torch.Tensor: A tensor of masked data.
    """
    return torch.rand(shape) < mask_ratio

test

Test the given model on the full set of benchmarks and save the results to JSON files.

Parameters:
  • model (Module) –

    The model to be tested.

  • name (str) –

    The name to be used for the output JSON files.

  • filedir (str) –

    The directory where the data files are located.

Returns:
  • None

    None

Source code in scprint/model/utils.py
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
def test(model: torch.nn.Module, name: str, filedir: str) -> None:
    """
    Test the given model on the full set of benchmarks and save the results to JSON files.

    Args:
        model (torch.nn.Module): The model to be tested.
        name (str): The name to be used for the output JSON files.
        filedir (str): The directory where the data files are located.

    Returns:
        None
    """
    metrics = {}
    res = embbed_task.default_benchmark(
        model, default_dataset="lung", do_class=True, coarse=False
    )
    f = open("metrics_" + name + ".json", "a")
    f.write(json.dumps({"embed_lung": res}, indent=4))
    f.close()
    metrics.update(
        {
            "emb_lung/scib": float(res["scib"]["Total"]),
            "emb_lung/ct_class": float(
                res["classif"]["cell_type_ontology_term_id"]["accuracy"]
            ),
        }
    )
    print(metrics)
    res = embbed_task.default_benchmark(
        model, default_dataset="pancreas", do_class=True, coarse=False
    )
    f = open("metrics_" + name + ".json", "a")
    f.write(json.dumps({"embed_panc": res}, indent=4))
    f.close()
    metrics.update(
        {
            "emb_panc/scib": float(res["scib"]["Total"]),
            "emb_panc/ct_class": float(
                res["classif"]["cell_type_ontology_term_id"]["accuracy"]
            ),
        }
    )
    print(metrics)
    gc.collect()
    res = denoise_task.default_benchmark(
        model, filedir + "/../../data/gNNpgpo6gATjuxTE7CCp.h5ad"
    )
    metrics.update(
        {
            "denoise/reco2full_vs_noisy2full": float(
                res["reco2full"] - res["noisy2full"]
            ),
        }
    )
    gc.collect()
    print(metrics)
    f = open("metrics_" + name + ".json", "a")
    f.write(json.dumps({"denoise": res}, indent=4))
    f.close()
    res = grn_task.default_benchmark(
        model, "gwps", batch_size=32 if model.d_model <= 512 else 8
    )
    f = open("metrics_" + name + ".json", "a")
    f.write(json.dumps({"grn_gwps": res}, default=lambda o: str(o), indent=4))
    f.close()
    metrics.update(
        {
            "grn_gwps/auprc_self": float(res["self"]["auprc"]),
            "grn_gwps/epr_self": float(res["self"]["epr"]),
            "grn_gwps/auprc_omni": float(res["omni"]["auprc"]),
            "grn_gwps/epr_omni": float(res["omni"]["epr"]),
            "grn_gwps/auprc": float(res["mean"]["auprc"]),
            "grn_gwps/epr": float(res["mean"]["epr"]),
        }
    )
    print(metrics)
    gc.collect()
    res = grn_task.default_benchmark(
        model, "sroy", batch_size=32 if model.d_model <= 512 else 8
    )
    f = open("metrics_" + name + ".json", "a")
    f.write(json.dumps({"grn_sroy": res}, default=lambda o: str(o), indent=4))
    f.close()
    metrics.update(
        {
            "grn_sroy/auprc_self": float(
                np.mean(
                    [
                        i["auprc"]
                        for k, i in res.items()
                        if k.startswith("self_")
                        and not any(
                            x in k for x in ["chip_", "ko_", "classifier", "_base"]
                        )
                    ]
                )
            ),
            "grn_sroy/epr_self": float(
                np.mean(
                    [
                        i["epr"]
                        for k, i in res.items()
                        if k.startswith("self_")
                        and not any(
                            x in k for x in ["chip_", "ko_", "classifier", "_base"]
                        )
                    ]
                )
            ),
            "grn_sroy/auprc_omni": float(
                np.mean(
                    [
                        i["auprc"]
                        for k, i in res.items()
                        if k.startswith("omni_")
                        and not any(
                            x in k for x in ["chip_", "ko_", "classifier", "_base"]
                        )
                    ]
                )
            ),
            "grn_sroy/epr_omni": float(
                np.mean(
                    [
                        i["epr"]
                        for k, i in res.items()
                        if k.startswith("omni_")
                        and not any(
                            x in k for x in ["chip_", "ko_", "classifier", "_base"]
                        )
                    ]
                )
            ),
            "grn_sroy/auprc": float(
                np.mean(
                    [
                        i["auprc"]
                        for k, i in res.items()
                        if k.startswith("mean_")
                        and not any(
                            x in k for x in ["chip_", "ko_", "classifier", "_base"]
                        )
                    ]
                )
            ),
            "grn_sroy/epr": float(
                np.mean(
                    [
                        i["epr"]
                        for k, i in res.items()
                        if k.startswith("mean_")
                        and not any(
                            x in k for x in ["chip_", "ko_", "classifier", "_base"]
                        )
                    ]
                )
            ),
        }
    )
    print(metrics)
    gc.collect()
    res = grn_task.default_benchmark(
        model,
        filedir + "/../../data/yBCKp6HmXuHa0cZptMo7.h5ad",
        batch_size=32 if model.d_model <= 512 else 8,
    )
    f = open("metrics_" + name + ".json", "a")
    f.write(json.dumps({"grn_omni": res}, default=lambda o: str(o), indent=4))
    f.close()
    metrics.update(
        {
            "grn_omni/auprc_class": float(
                np.mean([i["auprc"] for k, i in res.items() if "_class" in k])
            ),
            "grn_omni/epr_class": float(
                np.mean([i["epr"] for k, i in res.items() if "_class" in k])
            ),
            "grn_omni/tf_enr_class": float(
                np.sum(
                    [i.get("TF_enr", False) for k, i in res.items() if "_class" in k]
                )
            ),
            "grn_omni/tf_targ_enr_class": float(
                np.mean(
                    [
                        i["significant_enriched_TFtargets"]
                        for k, i in res.items()
                        if "_class" in k
                    ]
                )
            ),
            "grn_omni/auprc": float(
                np.mean([i["auprc"] for k, i in res.items() if "_mean" in k])
            ),
            "grn_omni/epr": float(
                np.mean([i["epr"] for k, i in res.items() if "_mean" in k])
            ),
            "grn_omni/tf_enr": float(
                np.sum([i.get("TF_enr", False) for k, i in res.items() if "_mean" in k])
            ),
            "grn_omni/tf_targ_enr": float(
                np.mean(
                    [
                        i["significant_enriched_TFtargets"]
                        for k, i in res.items()
                        if "_mean" in k
                    ]
                )
            ),
            # 'grn_omni/ct': res['classif']['cell_type_ontology_term_id']['accuracy'],
        }
    )
    return metrics

translate

translate This function translates the given value based on the specified type.

Parameters:
  • val (str / list / set / dict / Counter) –

    The value to be translated.

  • t (str, default: 'cell_type_ontology_term_id' ) –

    The type of translation to be performed. Defaults to "cell_type_ontology_term_id".

Returns:
  • dict

    A dictionary with the translated values.

Source code in scprint/model/utils.py
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
def translate(
    val: Union[str, list, set, dict, Counter], t: str = "cell_type_ontology_term_id"
):
    """
    translate This function translates the given value based on the specified type.

    Args:
        val (str/list/set/dict/Counter): The value to be translated.
        t (str, optional): The type of translation to be performed. Defaults to "cell_type_ontology_term_id".

    Returns:
        dict: A dictionary with the translated values.
    """
    if t == "cell_type_ontology_term_id":
        obj = bt.CellType.filter().df().set_index("ontology_id")
    elif t == "assay_ontology_term_id":
        obj = bt.ExperimentalFactor.filter().df().set_index("ontology_id")
    elif t == "tissue_ontology_term_id":
        obj = bt.Tissue.filter().df().set_index("ontology_id")
    elif t == "disease_ontology_term_id":
        obj = bt.Disease.filter().df().set_index("ontology_id")
    elif t == "self_reported_ethnicity_ontology_term_id":
        obj = bt.Ethnicity.filter().df().set_index("ontology_id")
    else:
        return None
    if type(val) is str:
        if val == "unknown":
            return {val: val}
        return {val: obj.loc[val]["name"]}
    elif type(val) is list or type(val) is set:
        return {i: obj.loc[i]["name"] if i != "unknown" else i for i in set(val)}
    elif type(val) is dict or type(val) is Counter:
        return {obj.loc[k]["name"] if k != "unknown" else k: v for k, v in val.items()}

weighted_masker

Randomly mask a batch of data.

Parameters:
  • shape (list[int]) –

    The shape of the data.

  • mask_ratio (float, default: 0.15 ) –

    The ratio of genes to mask, default to 0.15.

  • mask_value (int, default: 1 ) –

    The value to mask with, default to -1.

  • pad_value (int) –

    The value of padding in the values, will be kept unchanged.

Returns:
  • Tensor

    torch.Tensor: A tensor of masked data.

Source code in scprint/model/utils.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def weighted_masker(
    shape: list[int],
    mask_ratio: float = 0.15,
    mask_prob: Optional[Union[torch.Tensor, np.ndarray]] = None,  # n_features
    mask_value: int = 1,
) -> torch.Tensor:
    """
    Randomly mask a batch of data.

    Args:
        shape (list[int]): The shape of the data.
        mask_ratio (float): The ratio of genes to mask, default to 0.15.
        mask_value (int): The value to mask with, default to -1.
        pad_value (int): The value of padding in the values, will be kept unchanged.

    Returns:
        torch.Tensor: A tensor of masked data.
    """
    mask = []
    for _ in range(shape[0]):
        m = np.zeros(shape[1])
        loc = np.random.choice(
            a=shape[1], size=int(shape[1] * mask_ratio), replace=False, p=mask_prob
        )
        m[loc] = mask_value
        mask.append(m)

    return torch.Tensor(np.array(mask)).to(torch.bool)

zinb_sample

zinb_sample This function generates a sample from a Zero-Inflated Negative Binomial (ZINB) distribution.

Parameters:
  • mu (Tensor) –

    The mean of the Negative Binomial (NB) distribution.

  • theta (Tensor) –

    The dispersion parameter of the NB distribution.

  • zi_probs (Tensor) –

    The zero-inflation probabilities.

  • sample_shape (Size, default: Size([]) ) –

    The output shape. Defaults to torch.Size([]).

Returns:
  • torch.Tensor: A sample from the ZINB distribution.

Source code in scprint/model/utils.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def zinb_sample(
    mu: torch.Tensor,
    theta: torch.Tensor,
    zi_probs: torch.Tensor,
    sample_shape: torch.Size = torch.Size([]),
):
    """
    zinb_sample This function generates a sample from a Zero-Inflated Negative Binomial (ZINB) distribution.

    Args:
        mu (torch.Tensor): The mean of the Negative Binomial (NB) distribution.
        theta (torch.Tensor): The dispersion parameter of the NB distribution.
        zi_probs (torch.Tensor): The zero-inflation probabilities.
        sample_shape (torch.Size, optional): The output shape. Defaults to torch.Size([]).

    Returns:
        torch.Tensor: A sample from the ZINB distribution.
    """
    concentration = theta
    rate = theta / mu
    # Important remark: Gamma is parametrized by the rate = 1/scale!
    gamma_d = Gamma(concentration=concentration, rate=rate)
    p_means = gamma_d.sample(sample_shape)

    # Clamping as distributions objects can have buggy behaviors when
    # their parameters are too high
    l_train = torch.clamp(p_means, max=1e8)
    samp = Poisson(l_train).sample()  # Shape : (n_samples, n_cells_batch, n_vars)
    is_zero = torch.rand_like(samp) <= zi_probs
    samp_ = torch.where(is_zero, torch.zeros_like(samp), samp)
    return samp_

encoder and decoder modules

scprint.model.encoders

CategoryValueEncoder

Bases: Module

Encodes categorical values into a vector using an embedding layer and layer normalization.

Parameters:
  • num_embeddings (int) –

    The number of possible values.

  • embedding_dim (int) –

    The dimension of the output vectors.

  • padding_idx (int, default: None ) –

    The index of the padding token. Defaults to None.

Returns:
  • torch.Tensor: A tensor representing the encoded categorical values.

Note: not used in the current version of scprint.

Source code in scprint/model/encoders.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def __init__(
    self,
    num_embeddings: int,
    embedding_dim: int,
    padding_idx: Optional[int] = None,
):
    """
    Encodes categorical values into a vector using an embedding layer and layer normalization.

    Args:
        num_embeddings (int): The number of possible values.
        embedding_dim (int): The dimension of the output vectors.
        padding_idx (int, optional): The index of the padding token. Defaults to None.

    Returns:
        torch.Tensor: A tensor representing the encoded categorical values.

    Note: not used in the current version of scprint.
    """
    super(CategoryValueEncoder, self).__init__()
    self.embedding = nn.Embedding(
        num_embeddings, embedding_dim, padding_idx=padding_idx
    )

ContinuousValueEncoder

Bases: Module

Encode real number values to a vector using neural nets projection.

Parameters:
  • d_model (int) –

    The dimension of the input vectors.

  • dropout (float, default: 0.1 ) –

    The dropout rate to apply to the output of the positional encoding.

  • max_value (int, default: 100000 ) –

    The maximum value of the input. Defaults to 100_000.

  • layers (int, default: 1 ) –

    The number of layers in the encoder. Defaults to 1.

  • size (int, default: 1 ) –

    The size of the input. Defaults to 1.

Returns:
  • torch.Tensor: A tensor representing the encoded continuous values.

Source code in scprint/model/encoders.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def __init__(
    self,
    d_model: int,
    dropout: float = 0.1,
    max_value: int = 100_000,
    layers: int = 1,
    size: int = 1,
):
    """
    Encode real number values to a vector using neural nets projection.

    Args:
        d_model (int): The dimension of the input vectors.
        dropout (float, optional): The dropout rate to apply to the output of the positional encoding.
        max_value (int, optional): The maximum value of the input. Defaults to 100_000.
        layers (int, optional): The number of layers in the encoder. Defaults to 1.
        size (int, optional): The size of the input. Defaults to 1.

    Returns:
        torch.Tensor: A tensor representing the encoded continuous values.
    """
    super(ContinuousValueEncoder, self).__init__()
    self.max_value = max_value
    self.encoder = nn.ModuleList()
    self.encoder.append(nn.Linear(size, d_model))
    for _ in range(layers - 1):
        self.encoder.append(nn.LayerNorm(d_model))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.Dropout(p=dropout))
        self.encoder.append(nn.Linear(d_model, d_model))

forward

Parameters:
  • x (Tensor) –

    Tensor, shape [batch_size, seq_len]

Source code in scprint/model/encoders.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
    """
    Args:
        x: Tensor, shape [batch_size, seq_len]
    """
    # expand last dimension
    x = x.unsqueeze(-1)
    # use the mask embedding when x=-1
    # mask = (x == -1).float()
    x = torch.clamp(x, min=0, max=self.max_value)
    for val in self.encoder:
        x = val(x)
    if mask is not None:
        x = x.masked_fill_(mask.unsqueeze(-1), 0)
    return x

DPositionalEncoding

Bases: Module

The PositionalEncoding module applies a positional encoding to a sequence of vectors. This is necessary for the Transformer model, which does not have any inherent notion of position in a sequence. The positional encoding is added to the input embeddings and allows the model to attend to positions in the sequence.

Parameters:
  • d_model (int) –

    The dimension of the input vectors.

  • dropout (float) –

    The dropout rate to apply to the output of the positional encoding.

  • max_len (int) –

    The maximum length of a sequence that this module can handle.

Note: not used in the current version of scprint.

Source code in scprint/model/encoders.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    d_model: int,
    max_len_x: int,
    max_len_y: int,
    maxvalue_x=10000.0,
    maxvalue_y=10000.0,
):
    super(DPositionalEncoding, self).__init__()
    position2 = torch.arange(max_len_y).unsqueeze(1)
    position1 = torch.arange(max_len_x).unsqueeze(1)

    half_n = d_model // 2

    div_term2 = torch.exp(
        torch.arange(0, half_n, 2) * (-math.log(maxvalue_y) / d_model)
    )
    div_term1 = torch.exp(
        torch.arange(0, half_n, 2) * (-math.log(maxvalue_x) / d_model)
    )
    pe1 = torch.zeros(max_len_x, 1, d_model)
    pe2 = torch.zeros(max_len_y, 1, d_model)
    pe1[:, 0, 0:half_n:2] = torch.sin(position1 * div_term1)
    pe1[:, 0, 1:half_n:2] = torch.cos(position1 * div_term1)
    pe2[:, 0, half_n::2] = torch.sin(position2 * div_term2)
    pe2[:, 0, 1 + half_n :: 2] = torch.cos(position2 * div_term2)
    # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
    self.register_buffer("pe1", pe1)
    self.register_buffer("pe2", pe2)

forward

Parameters:
  • x (Tensor) –

    Tensor, shape [seq_len, batch_size, embedding_dim]

Source code in scprint/model/encoders.py
149
150
151
152
153
154
155
156
def forward(self, x: Tensor, pos_x: Tensor, pos_y: Tensor) -> Tensor:
    """
    Args:
        x: Tensor, shape [seq_len, batch_size, embedding_dim]
    """
    x = x + self.pe1[pos_x]
    x = x + self.pe2[pos_y]
    return x

GeneEncoder

Bases: Module

Encodes gene sequences into a continuous vector space using an embedding layer.

The output is then normalized using a LayerNorm.

Parameters:
  • num_embeddings (int) –

    The number of possible values.

  • embedding_dim (int) –

    The dimension of the output vectors.

  • padding_idx (int, default: None ) –

    The index of the padding token. Defaults to None.

  • weights (Tensor, default: None ) –

    The initial weights for the embedding layer. Defaults to None.

  • dropout (float) –

    The dropout rate to apply to the output of the positional encoding. Defaults to 0.1.

  • freeze (bool, default: False ) –

    Whether to freeze the weights of the embedding layer. Defaults to False.

Note: not used in the current version of scprint.

Source code in scprint/model/encoders.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    num_embeddings: int,
    embedding_dim: int,
    padding_idx: Optional[int] = None,
    weights: Optional[Tensor] = None,
    freeze: bool = False,
):
    """
    Encodes gene sequences into a continuous vector space using an embedding layer.

    The output is then normalized using a LayerNorm.

    Args:
        num_embeddings (int): The number of possible values.
        embedding_dim (int): The dimension of the output vectors.
        padding_idx (int, optional): The index of the padding token. Defaults to None.
        weights (Tensor, optional): The initial weights for the embedding layer. Defaults to None.
        dropout (float, optional): The dropout rate to apply to the output of the positional encoding. Defaults to 0.1.
        freeze (bool, optional): Whether to freeze the weights of the embedding layer. Defaults to False.

    Note: not used in the current version of scprint.
    """
    super(GeneEncoder, self).__init__()
    self.embedding = nn.Embedding(
        num_embeddings, embedding_dim, padding_idx=padding_idx, _freeze=freeze
    )

    if weights is not None:
        # concat a zero vector to the weight
        # this is to make the embedding of the padding token to be zero
        # weights = torch.cat(
        #    [torch.Tensor(weights), torch.zeros(1, embedding_dim)], dim=0
        # )
        self.embedding.weight.data.copy_(torch.Tensor(weights))

PositionalEncoding

Bases: Module

The PositionalEncoding module applies a positional encoding to a sequence of vectors. This is necessary for the Transformer model, which does not have any inherent notion of position in a sequence. The positional encoding is added to the input embeddings and allows the model to attend to positions in the sequence.

Parameters:
  • d_model (int) –

    The dimension of the input vectors.

  • dropout (float) –

    The dropout rate to apply to the output of the positional encoding.

  • max_len (int) –

    The maximum length of a sequence that this module can handle.

Note: not used in the current version of scprint.

Source code in scprint/model/encoders.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    d_model: int,
    max_len: int,
    token_to_pos: dict[str, int],  # [token, pos]
    maxval=10000.0,
):
    """
    The PositionalEncoding module applies a positional encoding to a sequence of vectors.
    This is necessary for the Transformer model, which does not have any inherent notion of
    position in a sequence. The positional encoding is added to the input embeddings and
    allows the model to attend to positions in the sequence.

    Args:
        d_model (int): The dimension of the input vectors.
        dropout (float, optional): The dropout rate to apply to the output of the positional encoding.
        max_len (int, optional): The maximum length of a sequence that this module can handle.

    Note: not used in the current version of scprint.
    """
    super(PositionalEncoding, self).__init__()
    position = torch.arange(max_len).unsqueeze(1)

    # Create a dictionary to convert token to position

    div_term = torch.exp(
        torch.arange(0, d_model, 2) * (-math.log(maxval) / d_model)
    )
    pe = torch.zeros(max_len, 1, d_model)
    pe[:, 0, 0::2] = torch.sin(position * div_term)
    pe[:, 0, 1::2] = torch.cos(position * div_term)
    # we reorder them and map them to gene_id (position)
    arr = []
    for _, v in token_to_pos.items():
        arr.append(pe[v - 1].numpy())
    pe = torch.Tensor(np.array(arr))
    self.register_buffer("pe", pe)

forward

Parameters:
  • x

    Tensor, shape [seq_len, batch_size, embedding_dim]

Source code in scprint/model/encoders.py
89
90
91
92
93
94
95
96
def forward(self, gene_pos: Tensor) -> Tensor:
    """
    Args:
        x: Tensor, shape [seq_len, batch_size, embedding_dim]
    """
    return torch.index_select(self.pe, 0, gene_pos.view(-1)).view(
        gene_pos.shape + (-1,)
    )

scprint.model.decoders

ClsDecoder

Bases: Module

ClsDecoder Decoder for classification task.

Parameters:
  • d_model (int) –

    int, dimension of the input.

  • n_cls (int) –

    int, number of classes.

  • layers (list[int], default: [256, 128] ) –

    list[int], list of hidden layers.

  • activation (Callable, default: ReLU ) –

    nn.Module, activation function.

  • dropout (float, default: 0.1 ) –

    float, dropout rate.

Returns:
  • Tensor, shape [batch_size, n_cls]

Source code in scprint/model/decoders.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def __init__(
    self,
    d_model: int,
    n_cls: int,
    layers: list[int] = [256, 128],
    activation: Callable = nn.ReLU,
    dropout: float = 0.1,
):
    """
    ClsDecoder Decoder for classification task.

    Args:
        d_model: int, dimension of the input.
        n_cls: int, number of classes.
        layers: list[int], list of hidden layers.
        activation: nn.Module, activation function.
        dropout: float, dropout rate.

    Returns:
        Tensor, shape [batch_size, n_cls]
    """
    super(ClsDecoder, self).__init__()
    # module list
    layers = [d_model] + layers
    self.decoder = nn.Sequential()
    for i, l in enumerate(layers[1:]):
        self.decoder.append(nn.Linear(layers[i], l))
        self.decoder.append(nn.LayerNorm(l))
        self.decoder.append(activation())
        self.decoder.append(nn.Dropout(dropout))
    self.out_layer = nn.Linear(layers[-1], n_cls)

forward

Parameters:
  • x (Tensor) –

    Tensor, shape [batch_size, embsize]

Source code in scprint/model/decoders.py
207
208
209
210
211
212
213
def forward(self, x: Tensor) -> Tensor:
    """
    Args:
        x: Tensor, shape [batch_size, embsize]
    """
    x = self.decoder(x)
    return self.out_layer(x)

ExprDecoder

Bases: Module

ExprDecoder Decoder for the gene expression prediction.

Will output the mean, variance and zero logits, parameters of a zero inflated negative binomial distribution.

Parameters:
  • d_model (int) –

    The dimension of the model. This is the size of the input feature vector.

  • nfirst_tokens_to_skip (int, default: 0 ) –

    The number of initial labels to skip in the sequence. Defaults to 0.

  • dropout (float, default: 0.1 ) –

    The dropout rate applied during training to prevent overfitting. Defaults to 0.1.

  • zinb (bool, default: True ) –

    Whether to use a zero inflated negative binomial distribution. Defaults to True.

Source code in scprint/model/decoders.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    d_model: int,
    nfirst_tokens_to_skip: int = 0,
    dropout: float = 0.1,
    zinb: bool = True,
):
    """
    ExprDecoder Decoder for the gene expression prediction.

    Will output the mean, variance and zero logits, parameters of a zero inflated negative binomial distribution.

    Args:
        d_model (int): The dimension of the model. This is the size of the input feature vector.
        nfirst_tokens_to_skip (int, optional): The number of initial labels to skip in the sequence. Defaults to 0.
        dropout (float, optional): The dropout rate applied during training to prevent overfitting. Defaults to 0.1.
        zinb (bool, optional): Whether to use a zero inflated negative binomial distribution. Defaults to True.
    """
    super(ExprDecoder, self).__init__()
    self.nfirst_tokens_to_skip = nfirst_tokens_to_skip
    self.fc = nn.Sequential(
        nn.Linear(d_model, d_model),
        nn.LayerNorm(d_model),
        nn.LeakyReLU(),
        nn.Dropout(dropout),
        nn.Linear(d_model, d_model),
        nn.LayerNorm(d_model),
        nn.LeakyReLU(),
    )
    self.pred_var_zero = nn.Linear(d_model, 3 if zinb else 1)
    self.zinb = zinb

forward

x is the output of the transformer, (batch, seq_len, d_model)

Source code in scprint/model/decoders.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def forward(self, x: Tensor) -> Dict[str, Tensor]:
    """x is the output of the transformer, (batch, seq_len, d_model)"""
    # we don't do it on the labels
    x = self.fc(x[:, self.nfirst_tokens_to_skip :, :])
    if self.zinb:
        pred_value, var_value, zero_logits = self.pred_var_zero(x).split(
            1, dim=-1
        )  # (batch, seq_len)
        # The sigmoid function is used to map the zero_logits to a probability between 0 and 1.
        return dict(
            mean=F.softmax(pred_value.squeeze(-1), dim=-1),
            disp=torch.exp(torch.clamp(var_value.squeeze(-1), max=15)),
            zero_logits=zero_logits.squeeze(-1),
        )
    else:
        pred_value = self.pred_var_zero(x)
        return dict(mean=F.softmax(pred_value.squeeze(-1), dim=-1))

GraphSDEExprDecoder

Bases: Module

Initialize the ExprNeuralSDEDecoder module.

Parameters:
  • d_model (int) –

    The dimension of the model.

  • drift (Module) –

    The drift component of the SDE.

  • diffusion (Module) –

    The diffusion component of the SDE.

Source code in scprint/model/decoders.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
def __init__(self, d_model: int, drift: nn.Module, diffusion: nn.Module):
    """
    Initialize the ExprNeuralSDEDecoder module.

    Args:
        d_model (int): The dimension of the model.
        drift (nn.Module): The drift component of the SDE.
        diffusion (nn.Module): The diffusion component of the SDE.
    """
    super().__init__()
    self.d_model = d_model
    self.drift = drift
    self.diffusion = diffusion

MVCDecoder

Bases: Module

MVCDecoder Decoder for the masked value prediction for cell embeddings.

Will use the gene embeddings with the cell embeddings to predict the mean, variance and zero logits

Parameters:
  • d_model

    obj:int): dimension of the gene embedding.

  • arch_style

    obj:str): architecture style of the decoder, choice from 1. "inner product" or 2. "cell product" 3. "concat query" or 4. "sum query".

  • query_activation

    obj:nn.Module): activation function for the query vectors. Defaults to nn.Sigmoid.

  • hidden_activation

    obj:nn.Module): activation function for the hidden layers. Defaults to nn.PReLU.

Source code in scprint/model/decoders.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __init__(
    self,
    d_model: int,
    arch_style: str = "inner product",
    tot_labels: int = 1,
    query_activation: nn.Module = nn.Sigmoid,
    hidden_activation: nn.Module = nn.PReLU,
) -> None:
    """
    MVCDecoder Decoder for the masked value prediction for cell embeddings.

    Will use the gene embeddings with the cell embeddings to predict the mean, variance and zero logits

    Args:
        d_model (:obj:`int`): dimension of the gene embedding.
        arch_style (:obj:`str`): architecture style of the decoder, choice from
            1. "inner product" or 2. "cell product" 3. "concat query" or 4. "sum query".
        query_activation (:obj:`nn.Module`): activation function for the query
            vectors. Defaults to nn.Sigmoid.
        hidden_activation (:obj:`nn.Module`): activation function for the hidden
            layers. Defaults to nn.PReLU.
    """
    super(MVCDecoder, self).__init__()
    if arch_style == "inner product":
        self.gene2query = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.query_activation = query_activation()
        self.pred_var_zero = nn.Linear(d_model, d_model * 3, bias=False)
    elif arch_style == "concat query":
        self.gene2query = nn.Linear(d_model, d_model)
        self.query_activation = query_activation()
        self.fc1 = nn.Linear(d_model * (1 + tot_labels), d_model / 2)
        self.hidden_activation = hidden_activation()
        self.fc2 = nn.Linear(d_model / 2, 3)
    elif arch_style == "sum query":
        self.gene2query = nn.Linear(d_model, d_model)
        self.query_activation = query_activation()
        self.fc1 = nn.Linear(d_model, 64)
        self.hidden_activation = hidden_activation()
        self.fc2 = nn.Linear(64, 3)
    else:
        raise ValueError(f"Unknown arch_style: {arch_style}")

    self.arch_style = arch_style
    self.do_detach = arch_style.endswith("detach")
    self.d_model = d_model

forward

Parameters:
  • cell_emb (Tensor) –

    Tensor, shape (batch, embsize=d_model)

  • gene_embs (Tensor) –

    Tensor, shape (batch, seq_len, embsize=d_model)

Source code in scprint/model/decoders.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def forward(
    self,
    cell_emb: Tensor,
    gene_embs: Tensor,
) -> Union[Tensor, Dict[str, Tensor]]:
    """
    Args:
        cell_emb: Tensor, shape (batch, embsize=d_model)
        gene_embs: Tensor, shape (batch, seq_len, embsize=d_model)
    """
    if self.arch_style == "inner product":
        query_vecs = self.query_activation(self.norm(self.gene2query(gene_embs)))
        pred, var, zero_logits = self.pred_var_zero(query_vecs).split(
            self.d_model, dim=-1
        )
        cell_emb = cell_emb.unsqueeze(2)
        pred, var, zero_logits = (
            torch.bmm(pred, cell_emb).squeeze(2),
            torch.bmm(var, cell_emb).squeeze(2),
            torch.bmm(zero_logits, cell_emb).squeeze(2),
        )
        # zero logits need to based on the cell_emb, because of input exprs
    elif self.arch_style == "concat query":
        query_vecs = self.query_activation(self.gene2query(gene_embs))
        # expand cell_emb to (batch, seq_len, embsize)
        cell_emb = cell_emb.unsqueeze(1).expand(-1, gene_embs.shape[1], -1)

        h = self.hidden_activation(
            self.fc1(torch.cat([cell_emb, query_vecs], dim=2))
        )
        pred, var, zero_logits = self.fc2(h).split(1, dim=-1)
    elif self.arch_style == "sum query":
        query_vecs = self.query_activation(self.gene2query(gene_embs))
        cell_emb = cell_emb.unsqueeze(1)

        h = self.hidden_activation(self.fc1(cell_emb + query_vecs))
        pred, var, zero_logits = self.fc2(h).split(1, dim=-1)
    return dict(
        mvc_mean=F.softmax(pred, dim=-1),
        mvc_disp=torch.exp(torch.clamp(var, max=15)),
        mvc_zero_logits=zero_logits,
    )

flashattention

scprint.model.flash_attn.flashformer

FlashTransformerEncoder

Bases: Module

FlashTransformerEncoder a transformer encoder with flash attention.

Parameters:
  • d_model (int) –

    The dimension of the input vectors.

  • nhead (int) –

    The number of attention heads.

  • nlayers (int) –

    The number of layers in the transformer.

  • dropout (float, default: 0.1 ) –

    The dropout rate to apply to the output of the positional encoding. Defaults to 0.1.

  • residual_in_fp32 (bool, default: True ) –

    Whether to force the residual to be in fp32 format. Defaults to True.

  • num_heads_kv (_type_, default: None ) –

    The number of heads for key/value. Defaults to None.

  • checkpointing (bool, default: False ) –

    Whether to use gradient checkpointing. Defaults to False.

  • fused_dropout_add_ln (bool, default: False ) –

    Whether to fuse dropout, addition and layer normalization operations. Defaults to False.

  • return_residual (bool, default: False ) –

    Whether to return the residual. Defaults to False.

  • prenorm (bool, default: True ) –

    Whether to use pre-normalization. Defaults to True.

  • mlp_ratio (float, default: 4.0 ) –

    The ratio for MLP. Defaults to 4.0.

  • fused_mlp (bool, default: False ) –

    Whether to use fused MLP. Defaults to False.

  • fused_bias_fc (bool, default: False ) –

    Whether to fuse bias and fully connected layers. Defaults to False.

  • sequence_parallel (bool, default: False ) –

    Whether to use sequence parallelism. Defaults to False.

  • drop_path_rate (float, default: 0.0 ) –

    The drop path rate. Defaults to 0.0.

  • weight_init (str, default: '' ) –

    The weight initialization method. Defaults to "".

Raises:
  • ImportError

    Raised when Triton is not installed but fused_dropout_add_ln is set to True.

  • NotImplementedError

    Raised when an unsupported operation is attempted.

Source code in scprint/model/flash_attn/flashformer.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def __init__(
    self,
    d_model: int,
    nhead: int,
    nlayers: int,
    dropout: float = 0.1,
    residual_in_fp32: bool = True,
    num_heads_kv: Optional[int] = None,
    checkpointing: bool = False,
    fused_dropout_add_ln: bool = False,
    return_residual: bool = False,
    prenorm: bool = True,
    mlp_ratio: float = 4.0,
    fused_mlp: bool = False,
    fused_bias_fc: bool = False,
    sequence_parallel: bool = False,
    drop_path_rate: float = 0.0,
    use_flash_attn: bool = True,
    weight_init: str = "",
):
    """
    FlashTransformerEncoder a transformer encoder with flash attention.

    Args:
        d_model (int): The dimension of the input vectors.
        nhead (int): The number of attention heads.
        nlayers (int): The number of layers in the transformer.
        dropout (float, optional): The dropout rate to apply to the output of the positional encoding. Defaults to 0.1.
        residual_in_fp32 (bool, optional): Whether to force the residual to be in fp32 format. Defaults to True.
        num_heads_kv (_type_, optional): The number of heads for key/value. Defaults to None.
        checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
        fused_dropout_add_ln (bool, optional): Whether to fuse dropout, addition and layer normalization operations. Defaults to False.
        return_residual (bool, optional): Whether to return the residual. Defaults to False.
        prenorm (bool, optional): Whether to use pre-normalization. Defaults to True.
        mlp_ratio (float, optional): The ratio for MLP. Defaults to 4.0.
        fused_mlp (bool, optional): Whether to use fused MLP. Defaults to False.
        fused_bias_fc (bool, optional): Whether to fuse bias and fully connected layers. Defaults to False.
        sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False.
        drop_path_rate (float, optional): The drop path rate. Defaults to 0.0.
        weight_init (str, optional): The weight initialization method. Defaults to "".

    Raises:
        ImportError: Raised when Triton is not installed but fused_dropout_add_ln is set to True.
        NotImplementedError: Raised when an unsupported operation is attempted.
    """
    super(FlashTransformerEncoder, self).__init__()

    self.blocks = nn.ModuleList()
    dpr = [
        x.item() for x in torch.linspace(0, drop_path_rate, nlayers)
    ]  # stochastic depth decay rule

    for i in range(nlayers):
        mlp = create_mlp_cls(d_model, mlp_ratio, nn.GELU, fused_mlp)
        attention = partial(
            MHA,
            num_heads=nhead,
            dropout=dropout,
            causal=False,
            use_flash_attn=use_flash_attn,
            num_heads_kv=num_heads_kv,
            checkpointing=checkpointing,
            fused_bias_fc=fused_bias_fc,
            layer_idx=i,
        )
        # or use parallelBlock where attn & MLP are done in parallel
        encoder_layers = Block(
            d_model,
            attention,
            mlp,
            prenorm=prenorm,
            # need to set it here for now although it hinders some performances as it returns the residual and I need to see what to do with it
            # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel,  # for more parallelism
            resid_dropout1=dropout,
            resid_dropout2=dropout,
            drop_path1=dpr[i - 1] if i > 0 else 0.0,
            drop_path2=dpr[i],
            fused_dropout_add_ln=fused_dropout_add_ln,
            return_residual=return_residual,
        )
        self.blocks.append(encoder_layers)

    self.prenorm = prenorm
    self.dropout = nn.Dropout(p=dropout)
    self.drop_path = StochasticDepth(p=dpr[-1], mode="row")
    self.norm = torch.nn.LayerNorm(d_model, eps=1e-6)

    self.fused_dropout_add_ln = fused_dropout_add_ln
    if self.fused_dropout_add_ln and layer_norm_fn is None:
        raise ImportError("Triton is not installed")

    if sequence_parallel:
        # This seems to only be important when doing tensor parallelism across GPUs, to increase even more the context length I guess?
        # not really necessary here I think
        raise NotImplementedError("sequence_parallel not implemented yet")

    self.init_weights(weight_init)

scprint.model.flash_attn.mha

CrossAttention

Bases: Module

Implement the scaled dot product attention with softmax.

Args softmax_scale: The temperature to use for the softmax attention. Default to 1/sqrt(d_keys) where d_keys is computed at runtime attention_dropout: The dropout rate to apply to the attention. default to 0.0.

Source code in scprint/model/flash_attn/mha.py
244
245
246
247
248
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
    super().__init__()
    self.causal = causal
    self.softmax_scale = softmax_scale
    self.drop = nn.Dropout(attention_dropout)

forward

Implements the multihead softmax attention.

Args q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, Sk)

Source code in scprint/model/flash_attn/mha.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def forward(self, q, kv, causal=None, key_padding_mask=None, bias=None):
    """Implements the multihead softmax attention.

    Args
        q: The tensor containing the query. (B, Sq, H, D)
        kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
        causal: if passed, will override self.causal
        key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
            False means to mask out. (B, Sk)
    """
    batch_size, seqlen_q = q.shape[0], q.shape[1]
    causal = self.causal if causal is None else causal
    seqlen_k = kv.shape[1]
    assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
    if kv.shape[3] != q.shape[2]:  # MQA/GQA
        kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
    k, v = kv.unbind(dim=2)
    softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
    scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
    if key_padding_mask is not None:
        padding_mask = torch.full(
            (batch_size, seqlen_k),
            -10000.0,
            dtype=scores.dtype,
            device=scores.device,
        )
        padding_mask.masked_fill_(key_padding_mask, 0.0)
        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
        scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
    if causal:
        # causal mask needs to take into account the difference between seqlen_q and seqlen_k
        row_idx = rearrange(
            torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
        )
        col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
        sk = (
            seqlen_k
            if key_padding_mask is None
            else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
        )
        causal_mask = col_idx > row_idx + sk - seqlen_q
        scores = scores.masked_fill(causal_mask, -10000.0)
    attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
    attention_drop = self.drop(attention)
    output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
    return output

FlashCrossAttention

Bases: Module

Implement the scaled dot product attention with softmax.

Args softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0)

Source code in scprint/model/flash_attn/mha.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def __init__(
    self,
    causal=False,
    softmax_scale=None,
    attention_dropout=0.0,
    alibi_slopes=None,
    deterministic=False,
):
    """
    Implement the scaled dot product attention with softmax.

    Args
        softmax_scale: The temperature to use for the softmax attention.
            (default: 1/sqrt(d_keys) where d_keys is computed at
            runtime)
        attention_dropout: The dropout rate to apply to the attention
            (default: 0.0)
    """
    super().__init__()
    assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
    assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
    self.causal = causal
    self.softmax_scale = softmax_scale
    self.drop = nn.Dropout(attention_dropout)
    self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
    self.deterministic = deterministic

forward

Implements the multihead softmax attention.

Args q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. max_seqlen: int. Maximum sequence length in the batch of q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_k: int. Maximum sequence length in the batch of k and v.

Source code in scprint/model/flash_attn/mha.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def forward(
    self,
    q,
    kv,
    causal=None,
    cu_seqlens=None,
    max_seqlen=None,
    cu_seqlens_k=None,
    max_seqlen_k=None,
):
    """
    Implements the multihead softmax attention.

    Args
        q: The tensor containing the query. (B, Sq, H, D)
        kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
        causal: if passed, will override self.causal
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
            of the sequences in the batch, used to index into q.
        max_seqlen: int. Maximum sequence length in the batch of q.
        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
            of the sequences in the batch, used to index into kv.
        max_seqlen_k: int. Maximum sequence length in the batch of k and v.
    """
    assert q.dtype in [torch.float16, torch.bfloat16]
    assert q.is_cuda and kv.is_cuda
    causal = self.causal if causal is None else causal
    batch_size, _ = q.shape[0], q.shape[1]
    assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
    return flash_attn_kvpacked_func(
        q,
        kv,
        None,
        # self.drop.p if self.training else 0.0,
        causal,
        self.softmax_scale,
        # alibi_slopes=self.alibi_slopes,
        # deterministic=self.deterministic,
    )

FlashSelfAttention

Bases: Module

Implement the scaled dot product attention with softmax.

Parameters:
  • softmax_scale (float, default: None ) –

    The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime)

  • attention_dropout (float, default: 0.0 ) –

    The dropout rate to apply to the attention (default: 0.0)

  • causal (bool, default: False ) –

    Whether to use causal attention. Defaults to False.

Source code in scprint/model/flash_attn/mha.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    causal: bool = False,
    softmax_scale: Optional[float] = None,
    attention_dropout: float = 0.0,
    alibi_slopes: Optional[Any] = None,
    deterministic: bool = False,
    use_triton: bool = False,
):
    """Implement the scaled dot product attention with softmax.

    Args:
        softmax_scale (float, optional): The temperature to use for the softmax attention.
            (default: 1/sqrt(d_keys) where d_keys is computed at
            runtime)
        attention_dropout (float, optional): The dropout rate to apply to the attention
            (default: 0.0)
        causal (bool, optional): Whether to use causal attention. Defaults to False.
    """
    super().__init__()
    if flash_attn_qkvpacked_func is None:
        print("FlashAttention is not installed, using triton instead")
        use_triton = True
    self.use_triton = use_triton
    self.causal = causal
    self.softmax_scale = softmax_scale

forward

Implements the multihead softmax attention.

Args qkv (Tensor): The tensor containing the query, key, and value. If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). If cu_seqlens is not None and max_seqlen is not None, then qkv has shape (total, 3, H, D), where total is the sum of the sequence lengths in the batch. causal (bool): if passed, will override self.causal cu_seqlens (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen (int). Maximum sequence length in the batch. Returns: out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, else (B, S, H, D).

Source code in scprint/model/flash_attn/mha.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def forward(
    self,
    qkv: torch.Tensor,
    causal: Optional[bool] = None,
    cu_seqlens: Optional[torch.Tensor] = None,
    max_seqlen: Optional[int] = None,
    cu_seqlens_k: Optional[torch.Tensor] = None,
    max_seqlen_k: Optional[int] = None,
    bias: Optional[torch.Tensor] = None,
    **kwargs,
):
    """Implements the multihead softmax attention.

    Args
        qkv (Tensor): The tensor containing the query, key, and value.
            If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
            If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
            (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
        causal (bool): if passed, will override self.causal
        cu_seqlens (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
            of the sequences in the batch, used to index into qkv.
        max_seqlen (int). Maximum sequence length in the batch.
    Returns:
        out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
            else (B, S, H, D).
    """
    assert qkv.dtype in [torch.float16, torch.bfloat16]
    assert qkv.is_cuda
    causal = self.causal if causal is None else causal
    if self.use_triton:
        raise NotImplementedError("OpenAI's flashattention is not implemented")
        if qkv.stride(-1) != 1:
            qkv = qkv.contiguous()
        # return triton_attention(
        #    qkv[:, :, 0],
        #    qkv[:, :, 1],
        #    qkv[:, :, 2],
        #    bias,
        #    causal,
        #    self.softmax_scale,
        # )
    else:
        return flash_attn_qkvpacked_func(
            qkv,
            bias,
            # self.drop.p if self.training else 0.0,
            causal,
            self.softmax_scale,
        )

LinearResidual

Bases: Linear

Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.

MHA

Bases: Module

MHA Multi-head self-attention and cross-attention

Parameters:
  • num_heads_kv (int, default: None ) –

    can be used to toggle MQA / GQA. If None, use num_heads.

  • return_residual (bool, default: False ) –

    whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. Defaults to False.

  • checkpointing (bool, default: False ) –

    whether to use checkpointing to save memory. Defaults to False.

  • num_heads_kv (int, default: None ) –

    can be used to toggle MQA / GQA. If None, use num_heads.

  • cross_attn (bool, default: False ) –

    whether to use cross-attention. Defaults to False.

  • qkv_proj_bias (bool, default: True ) –

    whether to use bias in the query, key, value projection. Defaults to True.

  • out_proj_bias (bool, default: True ) –

    whether to use bias in the output projection. Defaults to True.

  • dropout (float, default: 0.0 ) –

    dropout rate. Defaults to 0.0.

  • softmax_scale (float, default: None ) –

    The temperature to use for the softmax attention.

  • causal (bool, default: False ) –

    whether to use causal attention. Defaults to False.

  • layer_idx (int, default: None ) –

    layer index for inference cache. Defaults to None.

  • dwconv (bool, default: False ) –

    whether to use depthwise convolution. Defaults to False.

  • fused_bias_fc (bool, default: False ) –

    whether to use fused_bias_fc. Defaults to False.

  • use_flash_attn (bool, default: False ) –

    whether to use FlashAttention. Defaults to False.

  • device (device, default: None ) –

    device. Defaults to None.

  • dtype (dtype, default: None ) –

    dtype. Defaults to None.

Source code in scprint/model/flash_attn/mha.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    num_heads_kv: Optional[int] = None,
    cross_attn: bool = False,
    qkv_proj_bias: bool = True,
    out_proj_bias: bool = True,
    dropout: float = 0.0,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    layer_idx: Optional[int] = None,
    dwconv: bool = False,
    rotary_emb_dim: int = 0,
    rotary_emb_base: float = 10000.0,
    rotary_emb_scale_base: Optional[float] = None,
    rotary_emb_interleaved: bool = False,
    use_alibi: bool = False,
    fused_bias_fc: bool = False,
    use_flash_attn: bool = False,
    return_residual: bool = False,
    checkpointing: bool = False,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> None:
    """
    MHA Multi-head self-attention and cross-attention

    Args:
        embed_dim
        num_heads_kv (int): can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual (bool, optional): whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
            Defaults to False.
        checkpointing (bool, optional): whether to use checkpointing to save memory.
            Defaults to False.
        num_heads_kv (int, optional): can be used to toggle MQA / GQA. If None, use num_heads.
        cross_attn (bool, optional): whether to use cross-attention. Defaults to False.
        qkv_proj_bias (bool, optional): whether to use bias in the query, key, value projection. Defaults to True.
        out_proj_bias (bool, optional): whether to use bias in the output projection. Defaults to True.
        dropout (float, optional): dropout rate. Defaults to 0.0.
        softmax_scale (float, optional): The temperature to use for the softmax attention.
        causal (bool, optional): whether to use causal attention. Defaults to False.
        layer_idx (int, optional): layer index for inference cache. Defaults to None.
        dwconv (bool, optional): whether to use depthwise convolution. Defaults to False.
        fused_bias_fc (bool, optional): whether to use fused_bias_fc. Defaults to False.
        use_flash_attn (bool, optional): whether to use FlashAttention. Defaults to False.
        device (torch.device, optional): device. Defaults to None.
        dtype (torch.dtype, optional): dtype. Defaults to None.
    """
    factory_kwargs = {"device": device, "dtype": dtype}
    super().__init__()
    self.embed_dim = embed_dim
    self.cross_attn = cross_attn
    self.causal = causal
    self.layer_idx = layer_idx
    self.dwconv = dwconv
    self.rotary_emb_dim = rotary_emb_dim
    self.use_flash_attn = use_flash_attn
    if self.use_flash_attn and (flash_attn_kvpacked_func is None):
        print(
            "you requested flash transformer but it requires the flash package which is not installed"
        )
        print("falling back to regular transformer...")
        self.use_flash_attn = False

        # NOT flash transformer using the special tritton kernel
        # or parallelMHA (add the process group thing and faster)
    self.return_residual = return_residual
    self.checkpointing = checkpointing
    alibi_slopes = None

    self.num_heads = num_heads
    self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
    assert (
        self.num_heads % self.num_heads_kv == 0
    ), "num_heads must be divisible by num_heads_kv"
    assert (
        self.embed_dim % num_heads == 0
    ), "embed_dim must be divisible by num_heads"
    self.head_dim = self.embed_dim // num_heads
    qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
    kv_dim = 2 * self.head_dim * self.num_heads_kv

    if self.rotary_emb_dim > 0:
        assert (
            not cross_attn
        ), "MHA with rotary embedding does not support cross-attention yet"
        assert RotaryEmbedding is not None, "rotary_emb is not installed"
        self.rotary_emb = RotaryEmbedding(
            self.rotary_emb_dim,
            base=rotary_emb_base,
            scale_base=rotary_emb_scale_base,
            interleaved=rotary_emb_interleaved,
            device=device,
        )

    if fused_bias_fc and FusedDense is None:
        raise ImportError("fused_dense is not installed")
    linear_cls = nn.Linear if not fused_bias_fc else FusedDense
    linear_resid_cls = (
        LinearResidual
        if not fused_bias_fc
        else partial(FusedDense, return_residual=True)
    )
    wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
    inner_attn_cls = (
        partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
        if self.use_flash_attn
        else SelfAttention
    )
    inner_cross_attn_cls = (
        partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
        if self.use_flash_attn
        else CrossAttention
    )
    if not self.cross_attn:
        self.Wqkv = wqkv_cls(
            embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
        )
    else:
        self.Wq = linear_cls(
            embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
        )
        self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
    if self.dwconv:
        if self.num_heads_kv == self.num_heads:
            self.dwconv_qkv = nn.Conv1d(
                qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
            )
        else:
            self.dwconv_q = nn.Conv1d(
                embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
            )
            self.dwconv_kv = nn.Conv1d(
                kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
            )
    self.inner_attn = inner_attn_cls(
        causal=causal,
        softmax_scale=softmax_scale,
        attention_dropout=dropout,
    )
    self.inner_cross_attn = inner_cross_attn_cls(
        causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
    )
    self.out_proj = linear_cls(
        embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
    )

forward

Parameters:
  • x (Tensor) –

    (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the is the sum of the sequence lengths in the batch.

  • x_kv (Optional[Tensor], default: None ) –

    (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.

  • cu_seqlens (Optional[Tensor], default: None ) –

    (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into x. Only applicable when using FlashAttention.

  • max_seqlen (Optional[int], default: None ) –

    int. Maximum sequence length in the batch.

  • key_padding_mask (Optional[Tensor], default: None ) –

    boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention.

  • mixer_subset (Optional[Tensor], default: None ) –

    for cross-attention only. If not None, will take a subset of x before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer.

  • inference_params (Optional[dict], default: None ) –

    for generation. Adapted from Megatron-LM (and Apex)

  • https

    //github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470

  • return_qkv (bool, default: False ) –

    whether to return the qkv tensor. Defaults to False.

Returns:
  • out

    (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the sum of the sequence lengths in the batch.

  • qkv

    (batch, seqlen, 3, hidden_dim) if cu_seqlens is None and max_seqlen is None, else (total, 3, hidden_dim) where total is the sum of the sequence lengths in the batch.

Source code in scprint/model/flash_attn/mha.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
def forward(
    self,
    x: torch.Tensor,
    x_kv: Optional[torch.Tensor] = None,
    key_padding_mask: Optional[torch.Tensor] = None,
    cu_seqlens: Optional[torch.Tensor] = None,
    max_seqlen: Optional[int] = None,
    mixer_subset: Optional[torch.Tensor] = None,
    inference_params: Optional[dict] = None,
    return_qkv: bool = False,
    **kwargs,
):
    """
    Args:
        x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
            cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
            is the is the sum of the sequence lengths in the batch.
        x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
            of the sequences in the batch, used to index into x. Only applicable when using
            FlashAttention.
        max_seqlen: int. Maximum sequence length in the batch.
        key_padding_mask: boolean mask, True means to keep, False means to mask out.
            (batch, seqlen). Only applicable when not using FlashAttention.
        mixer_subset: for cross-attention only. If not None, will take a subset of x
            before applying the query projection. Useful for e.g., ViT where we only care
            about the CLS token in the last layer.
        inference_params: for generation. Adapted from Megatron-LM (and Apex)
        https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
        return_qkv: whether to return the qkv tensor. Defaults to False.

    Returns:
        out: (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
            else (total, hidden_dim) where total is the sum of the sequence lengths in the batch.
        qkv: (batch, seqlen, 3, hidden_dim) if cu_seqlens is None and max_seqlen is None,
            else (total, 3, hidden_dim) where total is the sum of the sequence lengths in the batch.
    """
    if cu_seqlens is not None:
        assert max_seqlen is not None
        assert key_padding_mask is None
        assert self.use_flash_attn
        assert not self.dwconv
        assert self.rotary_emb_dim == 0
    if key_padding_mask is not None:
        assert cu_seqlens is None
        assert max_seqlen is None
        assert not self.use_flash_attn
    if inference_params is not None:
        assert key_padding_mask is None
        assert cu_seqlens is None and max_seqlen is None
        assert not self.dwconv

    kwargs = (
        {}  # "cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
        if self.use_flash_attn
        else {"key_padding_mask": key_padding_mask, **kwargs}
    )
    seqlen_offset = (
        0
        if inference_params is None
        else (
            inference_params.lengths_per_sample
            if inference_params.lengths_per_sample is not None
            else inference_params.seqlen_offset
        )
    )
    rotary_max_seqlen = (
        inference_params.max_seqlen if inference_params is not None else None
    )
    batch, seqlen = x.shape[:2]
    if not self.cross_attn and self.num_heads_kv == self.num_heads:
        assert x_kv is None and mixer_subset is None
        if not self.return_residual:
            qkv = self.Wqkv(x)  # .to(torch.float16, device="cuda")
        else:
            qkv, x = self.Wqkv(x)
        if self.dwconv:
            qkv = rearrange(
                self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
                "b d s -> b s d",
            ).contiguous()
        qkv = rearrange(
            qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
        )
        if (
            inference_params is None
            or inference_params.seqlen_offset == 0
            or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
            or not self.use_flash_attn
        ):
            if self.rotary_emb_dim > 0:
                qkv = self.rotary_emb(
                    qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                )
            if inference_params is None:
                if not self.checkpointing:
                    context = self.inner_attn(qkv, **kwargs)
                else:
                    context = torch.utils.checkpoint.checkpoint(
                        self.inner_attn, qkv, **kwargs
                    )
            else:
                context = self._update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
        else:
            context = self._apply_rotary_update_kvcache_attention(
                qkv[:, :, 0], qkv[:, :, 1:], inference_params
            )
    else:
        if self.cross_attn:
            if not self.return_residual:
                q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                kv = self.Wkv(x_kv if x_kv is not None else x)
            else:
                if x_kv is not None:
                    kv, x_kv = self.Wkv(x_kv)
                else:
                    kv, x = self.Wkv(x)
                q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
        else:
            assert self.num_heads_kv != self.num_heads
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            q = qkv[..., : self.num_heads * self.head_dim]
            kv = qkv[..., self.num_heads * self.head_dim :]
        q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
        kv = rearrange(
            kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
        )
        if self.dwconv:
            q = rearrange(
                self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
                "b d s -> b s d",
            ).contiguous()
            kv = rearrange(
                self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
                "b d s -> b s d",
            ).contiguous()
        if (
            inference_params is None
            or inference_params.seqlen_offset == 0
            or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
            or not self.use_flash_attn
        ):
            if self.rotary_emb_dim > 0:
                q, kv = self.rotary_emb(
                    q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                )
            if inference_params is None:
                if not self.checkpointing:
                    context = self.inner_cross_attn(q, kv, **kwargs)
                else:
                    context = torch.utils.checkpoint.checkpoint(
                        self.inner_cross_attn, q, kv, **kwargs
                    )
            else:
                context = self._update_kvcache_attention(q, kv, inference_params)
        else:
            context = self._apply_rotary_update_kvcache_attention(
                q, kv, inference_params
            )
    out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
    if return_qkv:
        return out if not self.return_residual else (out, x), qkv
    else:
        return out if not self.return_residual else (out, x)

SelfAttention

Bases: Module

Implement the scaled dot product attention with softmax.

Parameters:
  • softmax_scale

    The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime)

  • attention_dropout

    The dropout rate to apply to the attention (default: 0.0)

Source code in scprint/model/flash_attn/mha.py
194
195
196
197
198
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
    super().__init__()
    self.causal = causal
    self.softmax_scale = softmax_scale
    self.drop = nn.Dropout(attention_dropout)

forward

Implements the multihead softmax attention.

Parameters:
  • qkv

    The tensor containing the query, key, and value. (B, S, 3, H, D)

  • causal

    if passed, will override self.causal

  • key_padding_mask

    boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, S)

Source code in scprint/model/flash_attn/mha.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def forward(self, qkv, causal=None, key_padding_mask=None, bias=None):
    """
    Implements the multihead softmax attention.

    Args:
        qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
        causal: if passed, will override self.causal
        key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
            False means to mask out. (B, S)
    """
    batch_size, seqlen = qkv.shape[0], qkv.shape[1]
    causal = self.causal if causal is None else causal
    q, k, v = qkv.unbind(dim=2)
    softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
    scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
    if key_padding_mask is not None:
        padding_mask = torch.full(
            (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
        )
        padding_mask.masked_fill_(key_padding_mask, 0.0)
        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
        scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
    if causal:
        # "triu_tril_cuda_template" not implemented for 'BFloat16'
        # So we have to construct the mask in float
        causal_mask = torch.triu(
            torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
        )
        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
        scores = scores + causal_mask.to(dtype=scores.dtype)
    attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
    attention_drop = self.drop(attention)
    output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
    return output

scprint.model.flash_attn.mlp

Mlp

Bases: Module

Multi-layer perceptron (MLP) module.

Parameters:
  • in_features (int) –

    Size of each input sample.

  • hidden_features (Optional[int], default: None ) –

    Size of the hidden layer. Defaults to 4 * in_features.

  • out_features (Optional[int], default: None ) –

    Size of each output sample. Defaults to in_features.

  • activation (Callable[[Tensor], Tensor], default: gelu ) –

    Activation function. Defaults to F.gelu.

  • bias1 (bool, default: True ) –

    If set to False, the first linear layer will not learn an additive bias. Defaults to True.

  • bias2 (bool, default: True ) –

    If set to False, the second linear layer will not learn an additive bias. Defaults to True.

  • return_residual (bool, default: False ) –

    If set to True, the forward method will return a tuple (output, input). Defaults to False.

  • device (Optional[device], default: None ) –

    The desired device of the parameters. Defaults to None.

  • dtype (Optional[dtype], default: None ) –

    The desired data type of the parameters. Defaults to None.

Source code in scprint/model/flash_attn/mlp.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    in_features: int,
    hidden_features: Optional[int] = None,
    out_features: Optional[int] = None,
    activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
    bias1: bool = True,
    bias2: bool = True,
    return_residual: bool = False,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> None:
    """
    Multi-layer perceptron (MLP) module.

    Args:
        in_features (int): Size of each input sample.
        hidden_features (Optional[int], optional): Size of the hidden layer. Defaults to 4 * in_features.
        out_features (Optional[int], optional): Size of each output sample. Defaults to in_features.
        activation (Callable[[torch.Tensor], torch.Tensor], optional): Activation function. Defaults to F.gelu.
        bias1 (bool, optional): If set to False, the first linear layer will not learn an additive bias. Defaults to True.
        bias2 (bool, optional): If set to False, the second linear layer will not learn an additive bias. Defaults to True.
        return_residual (bool, optional): If set to True, the forward method will return a tuple (output, input). Defaults to False.
        device (Optional[torch.device], optional): The desired device of the parameters. Defaults to None.
        dtype (Optional[torch.dtype], optional): The desired data type of the parameters. Defaults to None.
    """
    factory_kwargs = {"device": device, "dtype": dtype}
    super().__init__()
    out_features = out_features if out_features is not None else in_features
    hidden_features = (
        hidden_features if hidden_features is not None else in_features * 4
    )
    self.return_residual = return_residual
    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
    self.activation = activation
    self.fc2 = nn.Linear(
        hidden_features, out_features, bias=bias2, **factory_kwargs
    )

forward

Forward pass of the MLP.

Parameters:
  • x (Tensor) –

    Input tensor.

Returns:
  • Union[Tensor, Tuple[Tensor, Tensor]]

    Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Output tensor, or a tuple (output, input) if return_residual is True.

Source code in scprint/model/flash_attn/mlp.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def forward(
    self, x: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Forward pass of the MLP.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Output tensor, or a tuple (output, input) if return_residual is True.
    """
    y = self.fc1(x)
    y = self.activation(y)
    y = self.fc2(y)
    return y if not self.return_residual else (y, x)

scprint.model.flash_attn.block

Block

Bases: Module

For prenorm=True, this Block has a slightly different structure compared to a regular prenorm Transformer block. The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. [Ref: https://arxiv.org/abs/2002.04745] Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both the hidden_states (output of the MLP) and the residual. This is for performance reasons, as we can fuse the dropout, add and LayerNorm. The residual needs to be provided (except for the very first block).

For prenorm=False, this Block has the same structure as a regular postnorm Transformer block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.

Parameters:
  • dim (int) –

    the number of features in the input.

  • mixer_cls (Optional[Callable], default: None ) –

    the class to use for the mixer layer. Defaults to None.

  • mlp_cls (Optional[Callable], default: None ) –

    the class to use for the mlp layer. Defaults to None.

  • norm_cls (Callable, default: partial(LayerNorm, eps=1e-06) ) –

    the class to use for the layer norm. Defaults to partial(nn.LayerNorm, eps=1e-6).

  • dropout_cls (Type[Dropout], default: Dropout ) –

    the class to use for the dropout. Defaults to nn.Dropout.

  • prenorm (bool, default: True ) –

    whether to use pre-norm or post-norm. Defaults to True.

  • resid_dropout1 (float, default: 0.0 ) –

    the dropout probability for the first dropout layer. Defaults to 0.0.

  • resid_dropout2 (float, default: 0.0 ) –

    the dropout probability for the second dropout layer. Defaults to 0.0.

  • drop_path1 (float, default: 0.0 ) –

    the drop path probability for the first drop path layer. Defaults to 0.0.

  • drop_path2 (float, default: 0.0 ) –

    the drop path probability for the second drop path layer. Defaults to 0.0.

  • fused_dropout_add_ln (bool, default: False ) –

    whether to fuse the dropout, add and layer norm. Defaults to False.

  • return_residual (bool, default: False ) –

    whether each of the sub-layers (mixer and mlp) will return the residual. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. Defaults to False.

  • residual_in_fp32 (bool, default: False ) –

    whether to keep the residual in fp32. This is for performance reason: for post-norm architecture, keeping the residual in fp32 allows us to fuse the backward of nn.Linear with the residual connection. Defaults to False.

  • sequence_parallel (bool, default: False ) –

    whether to use sequence parallelism. Defaults to False.

  • mark_shared_params (bool, default: False ) –

    whether to mark the norm parameters as "shared_params". This is useful when we want to sync the norm parameters across workers. Defaults to False.

Source code in scprint/model/flash_attn/block.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self,
    dim: int,
    mixer_cls: Optional[Callable] = None,
    mlp_cls: Optional[Callable] = None,
    norm_cls: Callable = partial(nn.LayerNorm, eps=1e-6),
    dropout_cls: Type[nn.Dropout] = nn.Dropout,
    prenorm: bool = True,
    resid_dropout1: float = 0.0,
    resid_dropout2: float = 0.0,
    drop_path1: float = 0.0,
    drop_path2: float = 0.0,
    fused_dropout_add_ln: bool = False,
    return_residual: bool = False,
    residual_in_fp32: bool = False,
    sequence_parallel: bool = False,
    mark_shared_params: bool = False,
):
    """
    For prenorm=True, this Block has a slightly different structure compared to a regular
    prenorm Transformer block.
    The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
    [Ref: https://arxiv.org/abs/2002.04745]
    Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
    the hidden_states (output of the MLP) and the residual.
    This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
    The residual needs to be provided (except for the very first block).

    For prenorm=False, this Block has the same structure as a regular postnorm Transformer
    block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.

    Args:
        dim (int): the number of features in the input.
        mixer_cls (Optional[Callable], optional): the class to use for the mixer layer. Defaults to None.
        mlp_cls (Optional[Callable], optional): the class to use for the mlp layer. Defaults to None.
        norm_cls (Callable, optional): the class to use for the layer norm. Defaults to partial(nn.LayerNorm, eps=1e-6).
        dropout_cls (Type[nn.Dropout], optional): the class to use for the dropout. Defaults to nn.Dropout.
        prenorm (bool, optional): whether to use pre-norm or post-norm. Defaults to True.
        resid_dropout1 (float, optional): the dropout probability for the first dropout layer. Defaults to 0.0.
        resid_dropout2 (float, optional): the dropout probability for the second dropout layer. Defaults to 0.0.
        drop_path1 (float, optional): the drop path probability for the first drop path layer. Defaults to 0.0.
        drop_path2 (float, optional): the drop path probability for the second drop path layer. Defaults to 0.0.
        fused_dropout_add_ln (bool, optional): whether to fuse the dropout, add and layer norm. Defaults to False.
        return_residual (bool, optional): whether each of the sub-layers (mixer and mlp) will return the residual.
            This is for performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
            Defaults to False.
        residual_in_fp32 (bool, optional): whether to keep the residual in fp32. This is for performance reason:
            for post-norm architecture, keeping the residual in fp32 allows us to fuse the backward of nn.Linear
            with the residual connection. Defaults to False.
        sequence_parallel (bool, optional): whether to use sequence parallelism. Defaults to False.
        mark_shared_params (bool, optional): whether to mark the norm parameters as "shared_params".
            This is useful when we want to sync the norm parameters across workers. Defaults to False.
    """
    super().__init__()
    self.prenorm = prenorm
    self.fused_dropout_add_ln = fused_dropout_add_ln
    self.return_residual = return_residual
    self.residual_in_fp32 = residual_in_fp32
    if self.residual_in_fp32:
        assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
    if mixer_cls is None:
        mixer_cls = partial(MHA, num_heads=dim // 64)
    if mlp_cls is None:
        mlp_cls = partial(Mlp, hidden_features=4 * dim)
    self.mixer = mixer_cls(dim)
    self.dropout1 = dropout_cls(resid_dropout1)
    self.drop_path1 = StochasticDepth(drop_path1, mode="row")
    self.norm1 = norm_cls(dim)
    self.mlp = mlp_cls(dim)
    if not isinstance(self.mlp, nn.Identity):
        self.dropout2 = dropout_cls(resid_dropout2)
        self.drop_path2 = StochasticDepth(drop_path2, mode="row")
        self.norm2 = norm_cls(dim)

    if self.fused_dropout_add_ln:
        assert layer_norm_fn is not None, "Triton is not installed"
        assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
            self.dropout1, nn.Dropout
        )

    # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
    # then the input to each worker in the tensor parallel group will be different.
    # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
    # For now this is not an issue because we always use sequence_parallel=True during training
    # and only use sequence_parallel=False during inference.

    # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
    if sequence_parallel:
        for p in self.norm1.parameters():
            p._sequence_parallel = True
        if hasattr(self, "norm2"):
            for p in self.norm2.parameters():
                p._sequence_parallel = True
    # Mark the norm parameters as "shared_params" so that we sync their values at init.
    if mark_shared_params:
        for p in self.norm1.parameters():
            p._shared_params = True
        if hasattr(self, "norm2"):
            for p in self.norm2.parameters():
                p._shared_params = True

forward

Pass the input through the encoder layer.

Parameters:
  • hidden_states (Tensor) –

    The sequence to be passed to the encoder layer. This is a required argument.

  • residual (Optional[Tensor], default: None ) –

    This argument is used differently based on the normalization method. If postnorm is used, residual should be None. If prenorm is used, hidden_states is updated as Attn/MLP(LN(residual)).

  • mixer_subset (Optional[Tensor], default: None ) –

    This argument is used only for cross-attention. If not None, a subset of the input sequence 'x' is taken before applying the query projection. This is particularly useful for models like ViT where only the CLS token in the last layer is of interest.

  • mixer_kwargs (Optional[Dict[str, Any]], default: None ) –

    This argument is used only for cross-attention. It is a dictionary of additional arguments to be passed to the mixer.

  • return_qkv (bool, default: False ) –

    If True, the function will return the query, key, and value tensors.

Returns:
  • Tensor or Tuple[Tensor, Tensor]: The output tensor of the encoder layer.

  • If return_qkv is True, the function will return a tuple of the output tensor and the query, key, and value tensors.

Source code in scprint/model/flash_attn/block.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def forward(
    self,
    hidden_states: Tensor,
    residual: Optional[Tensor] = None,
    bias: Optional[Tensor] = None,
    src_mask: Optional[Tensor] = None,
    is_causal: Optional[bool] = None,
    src_key_padding_mask: Optional[Tensor] = None,
    mixer_subset: Optional[Tensor] = None,
    mixer_kwargs: Optional[Dict[str, Any]] = None,
    return_qkv: bool = False,
):
    r"""Pass the input through the encoder layer.

    Args:
        hidden_states (Tensor): The sequence to be passed to the encoder layer. This is a required argument.
        residual (Optional[Tensor]): This argument is used differently based on the normalization method.
            If postnorm is used, residual should be None. If prenorm is used, hidden_states is updated as Attn/MLP(LN(residual)).
        mixer_subset: This argument is used only for cross-attention.
            If not None, a subset of the input sequence 'x' is taken before applying the query projection.
            This is particularly useful for models like ViT where only the CLS token in the last layer is of interest.
        mixer_kwargs: This argument is used only for cross-attention.
            It is a dictionary of additional arguments to be passed to the mixer.
        return_qkv: If True, the function will return the query, key, and value tensors.

    Returns:
        Tensor or Tuple[Tensor, Tensor]: The output tensor of the encoder layer.
        If return_qkv is True, the function will return a tuple of the output tensor and the query, key, and value tensors.
    """
    if self.prenorm:
        if not self.fused_dropout_add_ln:
            dropped = self.drop_path1(self.dropout1(hidden_states))
            residual = (dropped + residual) if residual is not None else dropped
            hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            if self.drop_path1.p == 0 or not self.training:
                rowscale1 = None
            else:
                rowscale1 = self.drop_path1(
                    torch.ones(
                        hidden_states.shape[:-1],
                        device=hidden_states.device,
                        dtype=hidden_states.dtype,
                    )
                )
            hidden_states, residual = layer_norm_fn(
                hidden_states,
                self.norm1.weight,
                self.norm1.bias,
                residual=residual,
                eps=self.norm1.eps,
                dropout_p=self.dropout1.p if self.training else 0.0,
                rowscale=rowscale1,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                is_rms_norm=isinstance(self.norm1, RMSNorm),
            )
        if mixer_kwargs is None:
            mixer_kwargs = {}
        if mixer_subset is not None:
            mixer_kwargs["mixer_subset"] = mixer_subset
        hidden_states = self.mixer(
            hidden_states, return_qkv=return_qkv, bias=bias, **mixer_kwargs
        )
        if return_qkv:
            qkv = hidden_states[1]
            hidden_states = hidden_states[0]
        if mixer_subset is not None:
            residual = residual[:, mixer_subset]
        if not isinstance(self.mlp, nn.Identity):
            if not self.fused_dropout_add_ln:
                dropped = self.drop_path2(self.dropout2(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
                hidden_states = self.norm2(
                    residual.to(dtype=self.norm2.weight.dtype)
                )
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
            else:
                if self.drop_path2.p == 0 or not self.training:
                    rowscale2 = None
                else:
                    rowscale2 = self.drop_path2(
                        torch.ones(
                            hidden_states.shape[:-1],
                            device=hidden_states.device,
                            dtype=hidden_states.dtype,
                        )
                    )
                hidden_states, residual = layer_norm_fn(
                    hidden_states,
                    self.norm2.weight,
                    self.norm2.bias,
                    residual=residual,
                    eps=self.norm2.eps,
                    dropout_p=self.dropout2.p if self.training else 0.0,
                    rowscale=rowscale2,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    is_rms_norm=isinstance(self.norm2, RMSNorm),
                )
            hidden_states = self.mlp(hidden_states)
        return (
            (hidden_states, residual)
            if not return_qkv
            else (
                hidden_states,
                residual,
                qkv,
            )
        )
    # if not prenorm (disregard for scPRINT)
    else:
        assert residual is None
        mixer_out = self.mixer(
            hidden_states,
            return_qkv=return_qkv,
            bias=bias,
            **(mixer_kwargs if mixer_kwargs is not None else {})
        )
        if return_qkv:
            qkv = mixer_out[-1]
            mixer_out = mixer_out[:-1]
        if self.return_residual:  # mixer out is actually a pair here
            mixer_out, hidden_states = mixer_out
        if not self.fused_dropout_add_ln:
            hidden_states = self.norm1(
                (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
                    dtype=self.norm1.weight.dtype
                )
            )
        else:
            if self.drop_path1.p == 0 or not self.training:
                rowscale1 = None
            else:
                rowscale1 = self.drop_path1(
                    torch.ones(
                        mixer_out.shape[:-1],
                        device=mixer_out.device,
                        dtype=mixer_out.dtype,
                    )
                )
            hidden_states = layer_norm_fn(
                mixer_out,
                self.norm1.weight,
                self.norm1.bias,
                residual=hidden_states,
                eps=self.norm1.eps,
                dropout_p=self.dropout1.p if self.training else 0.0,
                rowscale=rowscale1,
                prenorm=False,
                is_rms_norm=isinstance(self.norm1, RMSNorm),
            )
        if not isinstance(self.mlp, nn.Identity):
            mlp_out = self.mlp(hidden_states)
            if self.return_residual:  # mlp out is actually a pair here
                mlp_out, hidden_states = mlp_out
            if not self.fused_dropout_add_ln:
                hidden_states = self.norm2(
                    (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
                        dtype=self.norm2.weight.dtype
                    )
                )
            else:
                if self.drop_path2.p == 0 or not self.training:
                    rowscale2 = None
                else:
                    rowscale2 = self.drop_path2(
                        torch.ones(
                            mlp_out.shape[:-1],
                            device=mlp_out.device,
                            dtype=mlp_out.dtype,
                        )
                    )
                hidden_states = layer_norm_fn(
                    mlp_out,
                    self.norm2.weight,
                    self.norm2.bias,
                    residual=hidden_states,
                    eps=self.norm2.eps,
                    dropout_p=self.dropout2.p if self.training else 0.0,
                    rowscale=rowscale2,
                    prenorm=False,
                    is_rms_norm=isinstance(self.norm2, RMSNorm),
                )
        return hidden_states if not return_qkv else (hidden_states, qkv)

scprint.model.flash_attn.flashattention

Experimental implementation of FlashAttention in Triton. Tested with triton==2.0.0.dev20221202. Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions other than 64: https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 We'll update this implementation with the new Triton backend once this is fixed.

We use the FlashAttention implementation from Phil Tillet a starting point. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

Changes: - Implement both causal and non-causal attention. - Implement both self-attention and cross-attention. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. - Support attention bias. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads.

Caution: - This is an experimental implementation. The forward pass should be quite robust but I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). - This implementation has only been tested on A100. - If you plan to use headdim other than 64 and 128, you should test for race conditions (due to the Triton compiler), as done in tests/test_flash_attn.py "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident that there are none left for other head dimensions.

Differences between this Triton version and the CUDA version: - Triton version doesn't support dropout. - Triton forward is generally faster than CUDA forward, while Triton backward is generally slower than CUDA backward. Overall Triton forward + backward is slightly slower than CUDA forward + backward. - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). - Triton version supports attention bias, while CUDA version doesn't.

FlashAttnFunc

Bases: Function

forward staticmethod

Perform the forward pass of FlashAttention.

Parameters:
  • q (Tensor) –

    Query tensor of shape (batch_size, seqlen_q, nheads, headdim).

  • k (Tensor) –

    Key tensor of shape (batch_size, seqlen_k, nheads, headdim).

  • v (Tensor) –

    Value tensor of shape (batch_size, seqlen_k, nheads, headdim).

  • bias (Optional[Tensor], default: None ) –

    Bias tensor, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k).

  • causal (bool, default: False ) –

    Whether to apply causal masking. Default is False.

  • softmax_scale (Optional[float], default: None ) –

    Scaling factor for the softmax operation. Default is None.

Returns:
  • Tensor

    torch.Tensor: Output tensor after applying FlashAttention.

Source code in scprint/model/flash_attn/flashattention.py
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
@staticmethod
def forward(
    ctx: torch.autograd.Function,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    causal: bool = False,
    softmax_scale: Optional[float] = None,
) -> torch.Tensor:
    """
    Perform the forward pass of FlashAttention.

    Args:
        q (torch.Tensor): Query tensor of shape (batch_size, seqlen_q, nheads, headdim).
        k (torch.Tensor): Key tensor of shape (batch_size, seqlen_k, nheads, headdim).
        v (torch.Tensor): Value tensor of shape (batch_size, seqlen_k, nheads, headdim).
        bias (Optional[torch.Tensor]): Bias tensor, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
            For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
            ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k).
        causal (bool): Whether to apply causal masking. Default is False.
        softmax_scale (Optional[float]): Scaling factor for the softmax operation. Default is None.

    Returns:
        torch.Tensor: Output tensor after applying FlashAttention.
    """
    # Make sure that the last dimension is contiguous
    q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
    o, lse, ctx.softmax_scale = _flash_attn_forward(
        q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
    )
    ctx.save_for_backward(q, k, v, o, lse, bias)
    ctx.causal = causal
    return o

FlashAttnKVPackedFunc

Bases: Function

forward staticmethod

Perform the forward pass of FlashAttention with packed key and value tensors.

Parameters:
  • q (Tensor) –

    Query tensor of shape (batch, seqlen_q, nheads, headdim).

  • kv (Tensor) –

    Key and value tensor of shape (batch, seqlen_k, 2, nheads, headdim).

  • bias (Optional[Tensor], default: None ) –

    Bias tensor, shape broadcastable to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k).

  • causal (bool, default: False ) –

    Whether to apply causal masking. Default is False.

  • softmax_scale (Optional[float], default: None ) –

    Scaling factor for the softmax operation. Default is None.

Returns:
  • Tensor

    torch.Tensor: Output tensor after applying FlashAttention.

Source code in scprint/model/flash_attn/flashattention.py
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
@staticmethod
def forward(
    ctx,
    q: torch.Tensor,
    kv: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    causal: bool = False,
    softmax_scale: Optional[float] = None,
) -> torch.Tensor:
    """
    Perform the forward pass of FlashAttention with packed key and value tensors.

    Args:
        q (torch.Tensor): Query tensor of shape (batch, seqlen_q, nheads, headdim).
        kv (torch.Tensor): Key and value tensor of shape (batch, seqlen_k, 2, nheads, headdim).
        bias (Optional[torch.Tensor]): Bias tensor, shape broadcastable to (batch, nheads, seqlen_q, seqlen_k).
            For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
            ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k).
        causal (bool): Whether to apply causal masking. Default is False.
        softmax_scale (Optional[float]): Scaling factor for the softmax operation. Default is None.

    Returns:
        torch.Tensor: Output tensor after applying FlashAttention.
    """
    # Make sure that the last dimension is contiguous
    q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
    o, lse, ctx.softmax_scale = _flash_attn_forward(
        q,
        kv[:, :, 0],
        kv[:, :, 1],
        bias=bias,
        causal=causal,
        softmax_scale=softmax_scale,
    )
    ctx.save_for_backward(q, kv, o, lse, bias)
    ctx.causal = causal
    return o

FlashAttnQKVPackedFunc

Bases: Function

forward staticmethod

Forward pass for FlashAttention.

Parameters:
  • ctx (Function) –

    The context object to save information for backward computation.

  • qkv (Tensor) –

    Input tensor of shape (batch, seqlen, 3, nheads, headdim).

  • bias (Optional[Tensor], default: None ) –

    Optional bias tensor, shape broadcastible to (batch, nheads, seqlen, seqlen). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen).

  • causal (bool, default: False ) –

    Whether to apply causal masking. Default is False.

  • softmax_scale (Optional[float], default: None ) –

    Optional scaling factor for softmax. Default is None.

Returns:
  • Tensor

    torch.Tensor: Output tensor after applying FlashAttention.

Source code in scprint/model/flash_attn/flashattention.py
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
@staticmethod
def forward(
    ctx: torch.autograd.Function,
    qkv: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    causal: bool = False,
    softmax_scale: Optional[float] = None,
) -> torch.Tensor:
    """
    Forward pass for FlashAttention.

    Args:
        ctx (torch.autograd.Function): The context object to save information for backward computation.
        qkv (torch.Tensor): Input tensor of shape (batch, seqlen, 3, nheads, headdim).
        bias (Optional[torch.Tensor]): Optional bias tensor, shape broadcastible to (batch, nheads, seqlen, seqlen).
            For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
            ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen).
        causal (bool): Whether to apply causal masking. Default is False.
        softmax_scale (Optional[float]): Optional scaling factor for softmax. Default is None.

    Returns:
        torch.Tensor: Output tensor after applying FlashAttention.
    """
    # Make sure that the last dimension is contiguous
    if qkv.stride(-1) != 1:
        qkv = qkv.contiguous()
    o, lse, ctx.softmax_scale = _flash_attn_forward(
        qkv[:, :, 0],
        qkv[:, :, 1],
        qkv[:, :, 2],
        bias=bias,
        causal=causal,
        softmax_scale=softmax_scale,
    )
    ctx.save_for_backward(qkv, o, lse, bias)
    ctx.causal = causal
    return o

scprint.model.flash_attn.activations

bias_gelu_back

Assume that y has shape (B, D) and bias has shape (D)

Source code in scprint/model/flash_attn/activations.py
24
25
26
27
28
29
30
31
32
33
34
@torch.jit.script
def bias_gelu_back(g, y, bias):
    """Assume that y has shape (B, D) and bias has shape (D)"""
    x = bias + y
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
    ff = 0.5 * x * (
        (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
    ) + 0.5 * (1 + tanh_out)
    grad_y = ff * g
    return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)

scprint.model.flash_attn.layer_norm

layer_norm_ref

Reference implementation of Layer Normalization with optional dropout and residual connections.

Parameters:
  • x (Tensor) –

    Input tensor.

  • weight (Tensor) –

    Weight tensor for normalization.

  • bias (Optional[Tensor], default: None ) –

    Bias tensor for normalization.

  • residual (Optional[Tensor], default: None ) –

    Residual tensor to be added to the input.

  • x1 (Optional[Tensor], default: None ) –

    Additional input tensor for parallel LayerNorm.

  • weight1 (Optional[Tensor], default: None ) –

    Additional weight tensor for parallel LayerNorm.

  • bias1 (Optional[Tensor], default: None ) –

    Additional bias tensor for parallel LayerNorm.

  • eps (float, default: 1e-06 ) –

    Epsilon value to avoid division by zero.

  • dropout_p (float, default: 0.0 ) –

    Dropout probability.

  • rowscale (Optional[Tensor], default: None ) –

    Row scaling tensor.

  • prenorm (bool, default: False ) –

    Whether to return the prenormalized output.

  • dropout_mask (Optional[Tensor], default: None ) –

    Dropout mask for the input tensor.

  • dropout_mask1 (Optional[Tensor], default: None ) –

    Dropout mask for the additional input tensor.

  • upcast (bool, default: False ) –

    Whether to upcast the input tensors to float.

Returns:
  • Union[Tensor, Tuple[Tensor, ...]]

    Union[torch.Tensor, Tuple[torch.Tensor, ...]]: Normalized output tensor(s).

Source code in scprint/model/flash_attn/layer_norm.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def layer_norm_ref(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    residual: Optional[torch.Tensor] = None,
    x1: Optional[torch.Tensor] = None,
    weight1: Optional[torch.Tensor] = None,
    bias1: Optional[torch.Tensor] = None,
    eps: float = 1e-6,
    dropout_p: float = 0.0,
    rowscale: Optional[torch.Tensor] = None,
    prenorm: bool = False,
    dropout_mask: Optional[torch.Tensor] = None,
    dropout_mask1: Optional[torch.Tensor] = None,
    upcast: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
    """
    Reference implementation of Layer Normalization with optional dropout and residual connections.

    Args:
        x (torch.Tensor): Input tensor.
        weight (torch.Tensor): Weight tensor for normalization.
        bias (Optional[torch.Tensor]): Bias tensor for normalization.
        residual (Optional[torch.Tensor]): Residual tensor to be added to the input.
        x1 (Optional[torch.Tensor]): Additional input tensor for parallel LayerNorm.
        weight1 (Optional[torch.Tensor]): Additional weight tensor for parallel LayerNorm.
        bias1 (Optional[torch.Tensor]): Additional bias tensor for parallel LayerNorm.
        eps (float): Epsilon value to avoid division by zero.
        dropout_p (float): Dropout probability.
        rowscale (Optional[torch.Tensor]): Row scaling tensor.
        prenorm (bool): Whether to return the prenormalized output.
        dropout_mask (Optional[torch.Tensor]): Dropout mask for the input tensor.
        dropout_mask1 (Optional[torch.Tensor]): Dropout mask for the additional input tensor.
        upcast (bool): Whether to upcast the input tensors to float.

    Returns:
        Union[torch.Tensor, Tuple[torch.Tensor, ...]]: Normalized output tensor(s).
    """
    dtype = x.dtype
    if upcast:
        x = x.float()
        weight = weight.float()
        bias = bias.float() if bias is not None else None
        residual = residual.float() if residual is not None else residual
        x1 = x1.float() if x1 is not None else None
        weight1 = weight1.float() if weight1 is not None else None
        bias1 = bias1.float() if bias1 is not None else None
    if x1 is not None:
        assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
    if rowscale is not None:
        x = x * rowscale[..., None]
    if dropout_p > 0.0:
        if dropout_mask is not None:
            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
        else:
            x = F.dropout(x, p=dropout_p)
        if x1 is not None:
            if dropout_mask1 is not None:
                x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
            else:
                x1 = F.dropout(x1, p=dropout_p)
    if x1 is not None:
        x = x + x1
    if residual is not None:
        x = (x + residual).to(x.dtype)
    out = F.layer_norm(
        x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
    ).to(dtype)
    if weight1 is None:
        return out if not prenorm else (out, x)
    else:
        out1 = F.layer_norm(
            x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
        ).to(dtype)
        return (out, out1) if not prenorm else (out, out1, x)

rms_norm_ref

Reference implementation of RMS Normalization with optional dropout and residual connections.

Parameters:
  • x (Tensor) –

    Input tensor.

  • weight (Tensor) –

    Weight tensor for normalization.

  • bias (Optional[Tensor], default: None ) –

    Bias tensor for normalization.

  • residual (Optional[Tensor], default: None ) –

    Residual tensor to be added to the input.

  • x1 (Optional[Tensor], default: None ) –

    Additional input tensor for parallel RMSNorm.

  • weight1 (Optional[Tensor], default: None ) –

    Additional weight tensor for parallel RMSNorm.

  • bias1 (Optional[Tensor], default: None ) –

    Additional bias tensor for parallel RMSNorm.

  • eps (float, default: 1e-06 ) –

    Epsilon value to avoid division by zero.

  • dropout_p (float, default: 0.0 ) –

    Dropout probability.

  • rowscale (Optional[Tensor], default: None ) –

    Row scaling tensor.

  • prenorm (bool, default: False ) –

    Whether to return the prenormalized output.

  • dropout_mask (Optional[Tensor], default: None ) –

    Dropout mask for the input tensor.

  • dropout_mask1 (Optional[Tensor], default: None ) –

    Dropout mask for the additional input tensor.

  • upcast (bool, default: False ) –

    Whether to upcast the input tensors to float.

Returns:
  • Union[Tensor, Tuple[Tensor, ...]]

    Union[torch.Tensor, Tuple[torch.Tensor, ...]]: Normalized output tensor(s).

Source code in scprint/model/flash_attn/layer_norm.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def rms_norm_ref(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    residual: Optional[torch.Tensor] = None,
    x1: Optional[torch.Tensor] = None,
    weight1: Optional[torch.Tensor] = None,
    bias1: Optional[torch.Tensor] = None,
    eps: float = 1e-6,
    dropout_p: float = 0.0,
    rowscale: Optional[torch.Tensor] = None,
    prenorm: bool = False,
    dropout_mask: Optional[torch.Tensor] = None,
    dropout_mask1: Optional[torch.Tensor] = None,
    upcast: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
    """
    Reference implementation of RMS Normalization with optional dropout and residual connections.

    Args:
        x (torch.Tensor): Input tensor.
        weight (torch.Tensor): Weight tensor for normalization.
        bias (Optional[torch.Tensor]): Bias tensor for normalization.
        residual (Optional[torch.Tensor]): Residual tensor to be added to the input.
        x1 (Optional[torch.Tensor]): Additional input tensor for parallel RMSNorm.
        weight1 (Optional[torch.Tensor]): Additional weight tensor for parallel RMSNorm.
        bias1 (Optional[torch.Tensor]): Additional bias tensor for parallel RMSNorm.
        eps (float): Epsilon value to avoid division by zero.
        dropout_p (float): Dropout probability.
        rowscale (Optional[torch.Tensor]): Row scaling tensor.
        prenorm (bool): Whether to return the prenormalized output.
        dropout_mask (Optional[torch.Tensor]): Dropout mask for the input tensor.
        dropout_mask1 (Optional[torch.Tensor]): Dropout mask for the additional input tensor.
        upcast (bool): Whether to upcast the input tensors to float.

    Returns:
        Union[torch.Tensor, Tuple[torch.Tensor, ...]]: Normalized output tensor(s).
    """
    dtype = x.dtype
    if upcast:
        x = x.float()
        weight = weight.float()
        bias = bias.float() if bias is not None else None
        residual = residual.float() if residual is not None else residual
        x1 = x1.float() if x1 is not None else None
        weight1 = weight1.float() if weight1 is not None else None
        bias1 = bias1.float() if bias1 is not None else None
    if x1 is not None:
        assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
    if rowscale is not None:
        x = x * rowscale[..., None]
    if dropout_p > 0.0:
        if dropout_mask is not None:
            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
        else:
            x = F.dropout(x, p=dropout_p)
        if x1 is not None:
            if dropout_mask1 is not None:
                x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
            else:
                x1 = F.dropout(x1, p=dropout_p)
    if x1 is not None:
        x = x + x1
    if residual is not None:
        x = (x + residual).to(x.dtype)
    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
    out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
        dtype
    )
    if weight1 is None:
        return out if not prenorm else (out, x)
    else:
        out1 = (
            (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
        ).to(dtype)
        return (out, out1) if not prenorm else (out, out1, x)