Created
October 11, 2024 21:00
-
-
Save luiscosio/71cc8399b55b1d3f9dac1df4d5e933b2 to your computer and use it in GitHub Desktop.
Trie-Based Sparse Matrices for tokenizations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| # Trie Node Class | |
| class TrieNode: | |
| def __init__(self): | |
| self.children = {} | |
| self.is_end_of_word = False | |
| # Function to insert a word into the trie | |
| def insert(root, word): | |
| node = root | |
| for char in word: | |
| if char not in node.children: | |
| node.children[char] = TrieNode() | |
| node = node.children[char] | |
| node.is_end_of_word = True | |
| # Function to search for a word in the trie | |
| def search(root, word): | |
| node = root | |
| for char in word: | |
| if char not in node.children: | |
| return False | |
| node = node.children[char] | |
| return node.is_end_of_word | |
| # Function to print all words in the trie | |
| def print_trie(node, prefix=''): | |
| if node.is_end_of_word: | |
| print(prefix) | |
| for char, child in node.children.items(): | |
| print_trie(child, prefix + char) | |
| # Function to generate the trie visualization using NetworkX | |
| def visualize_trie(root): | |
| G = nx.DiGraph() | |
| labels = {} | |
| pos = {} | |
| node_id = 0 # Unique identifier for each node | |
| def add_nodes_edges(node, parent_id=None, char=''): | |
| nonlocal node_id | |
| current_id = node_id | |
| node_id += 1 | |
| # Label for the node | |
| if parent_id is None: | |
| labels[current_id] = 'Root' | |
| else: | |
| labels[current_id] = char | |
| # Add node to the graph | |
| G.add_node(current_id) | |
| # Record position for better visualization | |
| pos[current_id] = (current_id, -current_id) | |
| if parent_id is not None: | |
| # Add edge from parent to current node | |
| G.add_edge(parent_id, current_id) | |
| for c, child in node.children.items(): | |
| add_nodes_edges(child, current_id, c) | |
| add_nodes_edges(root) | |
| # Draw the graph | |
| plt.figure(figsize=(12, 8)) | |
| pos = nx.spring_layout(G) | |
| nx.draw(G, pos, with_labels=False, arrows=True, node_size=1500, node_color='lightblue') | |
| nx.draw_networkx_labels(G, pos, labels, font_size=12) | |
| # Highlight end-of-word nodes | |
| end_nodes = [] | |
| def get_end_nodes(node, current_id=0): | |
| nonlocal node_id | |
| if node.is_end_of_word: | |
| end_nodes.append(current_id) | |
| child_id = current_id + 1 | |
| for c, child in node.children.items(): | |
| get_end_nodes(child, child_id) | |
| child_id += 1 | |
| get_end_nodes(root) | |
| nx.draw_networkx_nodes(G, pos, nodelist=end_nodes, node_color='lightgreen') | |
| plt.title('Trie Visualization') | |
| plt.axis('off') | |
| plt.show() | |
| # Main Code | |
| if __name__ == "__main__": | |
| # Initialize the root of the trie | |
| root = TrieNode() | |
| # List of words to insert into the trie | |
| words = ['taco', 'tacos', 'tac', 'tactical', 'taste', 'tab', 'tag', 'tale', 'team'] | |
| # Insert words into the trie | |
| for word in words: | |
| insert(root, word) | |
| # Print all words in the trie | |
| print("Words in the trie:") | |
| print_trie(root) | |
| # Test searching for words | |
| test_words = ['taco', 'tacos', 'tac', 'tag', 'tabs', 'tact', 'tactical', 'taste', 'tale', 'team', 'tear'] | |
| print("\nSearch results:") | |
| for word in test_words: | |
| result = search(root, word) | |
| print(f"Word '{word}' found in trie: {result}") | |
| visualize_trie(root) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment