184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274 | def plot_subgraph(
self,
seed: str,
gene_col: str = "symbol",
max_genes: int = 10,
only: float = 0.3,
palette: list = base_color_palette,
interactive: bool = True,
do_enr: bool = False,
**kwargs: dict,
):
"""
plot_subgraph plots a subgraph of the gene regulatory network (GRN) centered around a seed gene.
Args:
seed (str or list): The seed gene or list of genes around which the subgraph will be centered.
gene_col (str, optional): The column name in the .var DataFrame that contains gene identifiers. Defaults to "symbol".
max_genes (int, optional): The maximum number of genes to include in the subgraph. Defaults to 10.
only (float, optional): The threshold for filtering connections. If less than 1, it is used as a minimum weight threshold. If 1 or greater, it is used as the number of top connections to retain. Defaults to 0.3.
palette (list, optional): The color palette to use for plotting. Defaults to base_color_palette.
interactive (bool, optional): Whether to create an interactive plot. Defaults to True.
do_enr (bool, optional): Whether to perform enrichment analysis on the subgraph. Defaults to False.
Returns:
d3graph or None: The d3graph object if interactive is True, otherwise None.
"""
rn = {k: v for k, v in self.var[gene_col].items()}
if type(seed) is str:
gene_id = self.var[self.var[gene_col] == seed].index[0]
elem = self.grn.loc[gene_id].sort_values(ascending=False).head(
max_genes
).index.tolist() + [gene_id]
else:
elem = seed
mat = self.grn.loc[elem, elem].rename(columns=rn).rename(index=rn)
if only < 1:
mat[mat < only] = 0
else:
top_connections = mat.stack().nlargest(only)
top_connections.index.names = ["Gene1", "Gene2"]
top_connections.name = "Weight"
top_connections = top_connections.reset_index()
mat.index.name += "_2"
# Set anything not in the top N connections to 0
mask = mat.stack().isin(
top_connections.set_index(["Gene1", "Gene2"])["Weight"]
)
mat[~mask.unstack()] = 0
mat = mat * 100
color = [palette[0]] * len(mat)
if type(seed) is str:
color[mat.columns.get_loc(seed)] = palette[1]
print(color, mat.index)
mat = mat.T
if interactive:
d3 = d3graph()
d3.graph(mat, color=None)
d3.set_node_properties(color=color, fontcolor="#000000", **kwargs)
d3.set_edge_properties(directed=True)
d3.show(notebook=True)
return d3
else:
# Create a graph from the DataFrame
G = nx.from_pandas_adjacency(mat, create_using=nx.DiGraph())
# Draw the graph
plt.figure(figsize=(15, 15)) # Increase the size of the plot
nx.draw(G, with_labels=True, arrows=True)
plt.show()
if do_enr:
enr = gp.enrichr(
gene_list=list(G.nodes),
gene_sets=[
"KEGG_2021_Human",
"MSigDB_Hallmark_2020",
"Reactome_2022",
"Tabula_Sapiens",
"WikiPathway_2023_Human",
"TF_Perturbations_Followed_by_Expression",
"Reactome",
"PPI_Hub_Proteins",
"OMIM_Disease",
"GO_Molecular_Function_2023",
],
organism="Human", # change accordingly
# description='pathway',
# cutoff=0.08, # test dataset, use lower value for real case
background=self.var.symbol.tolist(),
)
print(enr.res2d.head(20))
return G
|