Created
October 20, 2025 19:20
-
-
Save zaemyung/dab949d7486badae4ec4321f77339751 to your computer and use it in GitHub Desktop.
plot_graphs.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import matplotlib.pyplot as plt | |
| import networkx as nx | |
| import regex as re | |
| from common_utils import draw_graph, load_pickle_file | |
| from matplotlib.gridspec import GridSpec | |
| def visualize_graph_motifs(motifs, motif_size, save_path, selected_motif_indices=None, show_edge_label=True): | |
| assert isinstance(motifs, list) | |
| motifs_indices = list(range(len(motifs))) | |
| # if isinstance(selected_motif_indices, list) and len(selected_motif_indices) > 0: | |
| # motifs = [motifs[i] for i in selected_motif_indices] | |
| # motifs_indices = selected_motif_indices | |
| n_graphs = len(motifs) | |
| print("[info] amount of graphs: ", n_graphs) | |
| column = 7 | |
| if(n_graphs >= 100): | |
| column = 20 | |
| row = math.ceil(n_graphs/column) | |
| print(f"[info] will provide a {column} x {row} figure") | |
| figsize = (column*5, row*4) | |
| plt.clf() | |
| fig = plt.figure(figsize=figsize) | |
| for index, g in enumerate(motifs): | |
| _col, _row = index%column, int(index/column) | |
| print(f" img position: {_col}, {_row}") | |
| ax = fig.add_subplot(row, column, index+1, facecolor='lightyellow') | |
| # pos = nx.spring_layout(g) | |
| pos = nx.nx_agraph.graphviz_layout(g, prog='dot') | |
| pos = {node: (-x, -y) for (node, (x, y)) in pos.items()} | |
| nx.draw(g, pos=pos, ax=ax, node_size=200, arrowsize=50) | |
| if show_edge_label: | |
| edge_labels = {} | |
| for u, v, d in g.edges(data=True): | |
| label = d['label_0'] | |
| if label == '/': | |
| label = 'hyp.' | |
| edge_labels[(u, v)] = label | |
| nx.draw_networkx_edge_labels(g, pos=pos, ax=ax, edge_labels=edge_labels, font_size=20) | |
| if index in selected_motif_indices[motif_size]: | |
| ax.set_title(f"G{motifs_indices[index]}", fontsize=30, color='red', backgroundcolor='lightyellow') | |
| else: | |
| ax.set_title(f"G{motifs_indices[index]}", fontsize=30, color='black') | |
| # fig.suptitle(f'Distinctive Network Motifs', fontsize=16) | |
| plt.tight_layout() | |
| plt.savefig(save_path, facecolor=fig.get_facecolor(), edgecolor='none', transparent=False) | |
| def draw_single_motif(G, save_path, color_edus=False): | |
| plt.clf() | |
| fig = plt.figure(figsize=(5, 5)) | |
| pos = nx.nx_agraph.graphviz_layout(G, prog='dot') | |
| pos = {node: (-x, -y) for (node, (x, y)) in pos.items()} | |
| default_nx_node_color = '#1f78b4' | |
| color_map = [default_nx_node_color] * len(G.nodes()) | |
| if color_edus: | |
| rgx_node_indices = r'span_(\d+)-(\d+)' | |
| color_map = [] | |
| for node_label in G.nodes(): | |
| m = re.match(rgx_node_indices, node_label) | |
| left_index, right_index = int(m.group(1)), int(m.group(2)) | |
| if left_index == right_index: | |
| color_map.append('red') | |
| else: | |
| color_map.append(default_nx_node_color) | |
| nx.draw(G, pos=pos, node_size=350, arrowsize=70, with_labels=False, node_color=color_map) | |
| edge_labels = {} | |
| for u, v, d in G.edges(data=True): | |
| label = d['label_0'] | |
| if label == '/': | |
| # label = 'hyp.' | |
| label = '' | |
| label = label.replace('-', '_') | |
| label = label.replace('Attribution', 'Attrib.') | |
| label = label.replace('Organization', 'Org.') | |
| label = label.replace('Summary', 'Sum.') | |
| label = label.replace('Textual', 'Text') | |
| label = label.replace('Change', 'Chg.') | |
| label = label.replace('Unit', 'U.') | |
| label = label.replace('Temporal', 'Temp.') | |
| label = label.replace('Elaboration', 'Elab.') | |
| label = label.replace('Condition', 'Cond.') | |
| label = label.replace('Enablement', 'Enable.') | |
| # label = label.replace('Background', 'Backgr.') | |
| label = label.replace('Contrast', 'Contr.') | |
| label = label.replace('Evaluation', 'Eval.') | |
| label = label.replace('Explanation', 'Expl.') | |
| edge_labels[(u, v)] = label | |
| nx.draw_networkx_edge_labels(G, pos=pos, edge_labels=edge_labels, font_size=40) | |
| fig.tight_layout() | |
| plt.savefig(save_path, transparent=True) | |
| plt.clf() | |
| def visualize_graph_motifs_three_sizes(motifs, chosen_indices, save_path, show_edge_label=True): | |
| assert isinstance(motifs, list) | |
| motifs_indices = list(range(len(motifs))) | |
| # if isinstance(selected_motif_indices, list) and len(selected_motif_indices) > 0: | |
| # motifs = [motifs[i] for i in selected_motif_indices] | |
| # motifs_indices = selected_motif_indices | |
| n_graphs = len(motifs) | |
| print("[info] amount of graphs: ", n_graphs) | |
| column = 2 | |
| row = 3 | |
| print(f"[info] will provide a {column} x {row} figure") | |
| figsize = (column*7, row*5) | |
| plt.clf() | |
| fig = plt.figure(figsize=figsize) | |
| def _plot(g, ax, motif_size): | |
| pos = nx.nx_agraph.graphviz_layout(g, prog='dot') | |
| pos = {node: (-x, -y) for (node, (x, y)) in pos.items()} | |
| nx.draw(g, pos=pos, ax=ax, node_size=350, arrowsize=70) | |
| if show_edge_label: | |
| edge_labels = {} | |
| for u, v, d in g.edges(data=True): | |
| label = d['label_0'] | |
| if label == '/': | |
| label = 'hyp.' | |
| label = label.replace('-', '_') | |
| label = label.replace('Attribution', 'Attrib.') | |
| label = label.replace('Organization', 'Org.') | |
| label = label.replace('Summary', 'Sum.') | |
| label = label.replace('Textual', 'Text') | |
| label = label.replace('Change', 'Chg.') | |
| label = label.replace('Unit', 'U.') | |
| label = label.replace('Temporal', 'Temp.') | |
| label = label.replace('Elaboration', 'Elab.') | |
| label = label.replace('Enablement', 'Enable.') | |
| label = label.replace('Background', 'Backgr.') | |
| label = label.replace('Contrast', 'Contr.') | |
| label = label.replace('Evaluation', 'Eval.') | |
| label = label.replace('Explanation', 'Expl.') | |
| edge_labels[(u, v)] = label | |
| nx.draw_networkx_edge_labels(g, pos=pos, ax=ax, edge_labels=edge_labels, font_size=40) | |
| motif_number = chosen_indices[motif_size] | |
| print(motif_number) | |
| if motif_size == 3: | |
| ax.set_title(f"Single-Triad\nG{motif_number}", fontsize=45, color='black') | |
| elif motif_size == 6: | |
| ax.set_title(f"Double-Triad\nG{motif_number}", fontsize=45, color='black') | |
| elif motif_size == 9: | |
| ax.set_title(f"Triple-Triad\nG{motif_number}", fontsize=45, color='black') | |
| gs = GridSpec(3, 2, figure=fig) | |
| ax1 = fig.add_subplot(gs[0, 0]) | |
| _plot(motifs[0], ax1, 3) | |
| ax2 = fig.add_subplot(gs[1:, 0]) | |
| _plot(motifs[1], ax2, 6) | |
| ax3 = fig.add_subplot(gs[:, 1:]) | |
| _plot(motifs[2], ax3, 9) | |
| plt.tight_layout() | |
| plt.savefig(save_path, facecolor=fig.get_facecolor(), edgecolor='none', transparent=False) | |
| if __name__ == '__main__': | |
| plot_dir = 'data/plots' | |
| motif_size_to_path = {3: 'data/motifs/M_3_HC3-DeepfakeTextDetect.pkl', | |
| 6: 'data/motifs/M_6_HC3-DeepfakeTextDetect.pkl', | |
| 9: 'data/motifs/M_9-triangular_HC3-DeepfakeTextDetect.pkl'} | |
| all_selected_indices = load_pickle_file('data/motifs/selected_motif_indices_one-std_HC3-DeepfakeTextDetect.pkl') | |
| # print(selected_indices) | |
| # visualize_graph_motifs(motifs_3, 3, 'data/plots/all_motifs_of_size_three.pdf', selected_motif_indices=selected_indices) | |
| chosen_indices = {3: 16, 6: 28, 9:987} | |
| chosen_motifs = [] | |
| for motif_sz, index in chosen_indices.items(): | |
| motifs = load_pickle_file(motif_size_to_path[motif_sz]) | |
| chosen_motifs.append(motifs[index]) | |
| print(chosen_motifs) | |
| visualize_graph_motifs_three_sizes(chosen_motifs, chosen_indices, 'data/plots/3_6_9_motifs_examples.pdf') | |
| motifs = load_pickle_file(motif_size_to_path[3]) | |
| background_motif = motifs[18] | |
| draw_single_motif(background_motif, 'data/plots/3_background_motif.svg') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment