Documentation for GRNAnnData module

grnndata.GRNAnnData

Bases: AnnData

An AnnData object with a GRN matrix in varp["GRN"]

Parameters:
  • grn (csr_matrix | ndarray) –

    scipy.sparse.csr_matrix | np.ndarray a matrix with zeros and non-zeros signifying the presence of an edge and the direction of the edge respectively. The matrix should be square and the rows and columns should correspond to the genes in the AnnData object. The row index correpond to genes that are regulators and the column index corresponds to genes that are targets.

@see https://anndata.readthedocs.io for more informaiotn on AnnData objects

Source code in grnndata/GRNAnnData.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def __init__(self, *args, grn: scipy.sparse.csr_matrix | np.ndarray, **kwargs):
    """An AnnData object with a GRN matrix in varp["GRN"]

    Args:
        grn: scipy.sparse.csr_matrix | np.ndarray a matrix with zeros and non-zeros
            signifying the presence of an edge and the direction of the edge
            respectively. The matrix should be square and the rows and columns
            should correspond to the genes in the AnnData object.
            The row index correpond to genes that are regulators and the column
            index corresponds to genes that are targets.

    @see https://anndata.readthedocs.io for more informaiotn on AnnData objects
    """
    super(GRNAnnData, self).__init__(*args, **kwargs)
    self.varp["GRN"] = grn

grn property

Property that returns the gene regulatory network (GRN) as a pandas DataFrame. The index and columns of the DataFrame are the gene names stored in 'var_names'.

Returns:
  • pd.DataFrame: The GRN as a DataFrame with gene names as index and columns.

regulators property

regulators outputs the regulators' connections of the GRN as a pandas DataFrame.

Returns:
  • pd.DataFrame: The regulators of the GRN as a DataFrame with gene names as index and columns.

targets property

targets outputs the targets' connections of the GRN as a pandas DataFrame.

Returns:
  • pd.DataFrame: The targets of the GRN as a DataFrame with gene names as index and columns.

concat

concat two GRNAnnData objects

Parameters:
  • other (GRNAnnData) –

    The other GRNAnnData object to concatenate with

Raises:
  • ValueError

    Can only concatenate with another GRNAnnData object

Returns:
  • AnnData

    The concatenated GRNAnnData object

Source code in grnndata/GRNAnnData.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def concat(self, other):
    """
    concat two GRNAnnData objects

    Args:
        other (GRNAnnData): The other GRNAnnData object to concatenate with

    Raises:
        ValueError: Can only concatenate with another GRNAnnData object

    Returns:
        AnnData: The concatenated GRNAnnData object
    """
    if not isinstance(other, GRNAnnData):
        raise ValueError("Can only concatenate with another GRNAnnData object")
    return GRNAnnData(
        self.concatenate(other),
        grn=scipy.sparse.vstack([self.varp["GRN"], other.varp["GRN"]]),
    )

This function extracts scores from anndata.varp['key'] and returns them as a pandas DataFrame.

The resulting DataFrame has the following structure

TF Gene Score A B 5 C D 8

Where 'TF' and 'Gene' are the indices of the genes in the regulatory network, and 'Score' is the corresponding weight.

Parameters:
  • columns (list, default: ['regulator', 'target', 'weight'] ) –

    The names of the columns in the resulting DataFrame. Defaults to ['regulator', 'target', 'weight'].

Returns:
  • pd.DataFrame: The extracted scores as a DataFrame.

Source code in grnndata/GRNAnnData.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def extract_links(
    self,
    columns: list = [
        "regulator",
        "target",
        "weight",
    ],  # output col names (e.g. 'TF', 'gene', 'score')
):
    """
    This function extracts scores from anndata.varp['key'] and returns them as a pandas DataFrame.

    The resulting DataFrame has the following structure:
        TF   Gene   Score
        A    B      5
        C    D      8

    Where 'TF' and 'Gene' are the indices of the genes in the regulatory network, and 'Score' is the corresponding weight.

    Args:
        columns (list, optional): The names of the columns in the resulting DataFrame. Defaults to ['regulator', 'target', 'weight'].

    Returns:
        pd.DataFrame: The extracted scores as a DataFrame.
    """
    return pd.DataFrame(
        [
            a
            for a in zip(
                [self.var_names[i] for i in self.varp["GRN"].row],
                [self.var_names[i] for i in self.varp["GRN"].col],
                self.varp["GRN"].data,
            )
        ],
        columns=columns,
    ).sort_values(by=columns[2], ascending=False)

get

get a sub-GRNAnnData object with only the specified genes

Parameters:
  • elem (str | list) –

    The gene names to include in the sub-GRNAnnData object

Returns:
  • GRNAnnData( GRNAnnData ) –

    The sub-GRNAnnData object with only the specified genes

Source code in grnndata/GRNAnnData.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def get(self, elem: str | list[str]) -> "GRNAnnData":
    """
    get a sub-GRNAnnData object with only the specified genes

    Args:
        elem (str | list): The gene names to include in the sub-GRNAnnData object

    Returns:
        GRNAnnData: The sub-GRNAnnData object with only the specified genes
    """
    if type(elem) is str:
        elem = [elem]
    loc = self.var.index.isin(elem)
    reg = self.varp["GRN"][loc][:, loc]
    if len(reg.shape) == 1:
        reg = np.array([reg])
    sub = GRNAnnData(X=self.X[:, loc], obs=self.obs, var=self.var[loc], grn=reg)
    sub.varm["Targets"] = self.varp["GRN"][loc]
    sub.varm["Regulators"] = self.varp["GRN"].T[loc]
    sub.uns["regulated_genes"] = self.var.index.tolist()
    return sub

plot_subgraph

plot_subgraph plots a subgraph of the gene regulatory network (GRN) centered around a seed gene.

Parameters:
  • seed (str or list) –

    The seed gene or list of genes around which the subgraph will be centered.

  • gene_col (str, default: 'symbol' ) –

    The column name in the .var DataFrame that contains gene identifiers. Defaults to "symbol".

  • max_genes (int, default: 10 ) –

    The maximum number of genes to include in the subgraph. Defaults to 10.

  • only (float, default: 0.3 ) –

    The threshold for filtering connections. If less than 1, it is used as a minimum weight threshold. If 1 or greater, it is used as the number of top connections to retain. Defaults to 0.3.

  • palette (list, default: base_color_palette ) –

    The color palette to use for plotting. Defaults to base_color_palette.

  • interactive (bool, default: True ) –

    Whether to create an interactive plot. Defaults to True.

  • do_enr (bool, default: False ) –

    Whether to perform enrichment analysis on the subgraph. Defaults to False.

Returns:
  • d3graph or None: The d3graph object if interactive is True, otherwise None.

Source code in grnndata/GRNAnnData.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def plot_subgraph(
    self,
    seed: str,
    gene_col: str = "symbol",
    max_genes: int = 10,
    only: float = 0.3,
    palette: list = base_color_palette,
    interactive: bool = True,
    do_enr: bool = False,
    **kwargs: dict
):
    """
    plot_subgraph plots a subgraph of the gene regulatory network (GRN) centered around a seed gene.

    Args:
        seed (str or list): The seed gene or list of genes around which the subgraph will be centered.
        gene_col (str, optional): The column name in the .var DataFrame that contains gene identifiers. Defaults to "symbol".
        max_genes (int, optional): The maximum number of genes to include in the subgraph. Defaults to 10.
        only (float, optional): The threshold for filtering connections. If less than 1, it is used as a minimum weight threshold. If 1 or greater, it is used as the number of top connections to retain. Defaults to 0.3.
        palette (list, optional): The color palette to use for plotting. Defaults to base_color_palette.
        interactive (bool, optional): Whether to create an interactive plot. Defaults to True.
        do_enr (bool, optional): Whether to perform enrichment analysis on the subgraph. Defaults to False.

    Returns:
        d3graph or None: The d3graph object if interactive is True, otherwise None.
    """
    rn = {k: v for k, v in self.var[gene_col].items()}
    if type(seed) is str:
        gene_id = self.var[self.var[gene_col] == seed].index[0]
        elem = self.grn.loc[gene_id].sort_values(ascending=False).head(
            max_genes
        ).index.tolist() + [gene_id]
    else:
        elem = seed

    mat = self.grn.loc[elem, elem].rename(columns=rn).rename(index=rn)
    if only < 1:
        mat[mat < only] = 0
    else:
        top_connections = mat.stack().nlargest(only)
        top_connections.index.names = ["Gene1", "Gene2"]
        top_connections.name = "Weight"
        top_connections = top_connections.reset_index()
        mat.index.name += "_2"
        # Set anything not in the top N connections to 0
        mask = mat.stack().isin(
            top_connections.set_index(["Gene1", "Gene2"])["Weight"]
        )
        mat[~mask.unstack()] = 0
    mat = mat * 100
    color = [palette[0]] * len(mat)
    if type(seed) is str:
        color[mat.columns.get_loc(seed)] = palette[1]
    print(color, mat.index)
    mat = mat.T
    if interactive:
        d3 = d3graph()
        d3.graph(mat, color=None)
        d3.set_node_properties(color=color, fontcolor="#000000", **kwargs)
        d3.set_edge_properties(directed=True)
        d3.show(notebook=True)
        return d3
    else:
        # Create a graph from the DataFrame
        G = nx.from_pandas_adjacency(mat, create_using=nx.DiGraph())
        # Draw the graph
        plt.figure(figsize=(15, 15))  # Increase the size of the plot
        nx.draw(G, with_labels=True, arrows=True)
        plt.show()
    if do_enr:
        enr = gp.enrichr(
            gene_list=list(G.nodes),
            gene_sets=[
                "KEGG_2021_Human",
                "MSigDB_Hallmark_2020",
                "Reactome_2022",
                "Tabula_Sapiens",
                "WikiPathway_2023_Human",
                "TF_Perturbations_Followed_by_Expression",
                "Reactome",
                "PPI_Hub_Proteins",
                "OMIM_Disease",
                "GO_Molecular_Function_2023",
            ],
            organism="Human",  # change accordingly
            # description='pathway',
            # cutoff=0.08, # test dataset, use lower value for real case
            background=self.var.symbol.tolist(),
        )
        print(enr.res2d.head(20))
    return G