Skip to content

Instantly share code, notes, and snippets.

@luiscosio
Created October 11, 2024 21:00
Show Gist options
  • Select an option

  • Save luiscosio/71cc8399b55b1d3f9dac1df4d5e933b2 to your computer and use it in GitHub Desktop.

Select an option

Save luiscosio/71cc8399b55b1d3f9dac1df4d5e933b2 to your computer and use it in GitHub Desktop.
Trie-Based Sparse Matrices for tokenizations
import networkx as nx
import matplotlib.pyplot as plt
# Trie Node Class
class TrieNode:
def __init__(self):
self.children = {}
self.is_end_of_word = False
# Function to insert a word into the trie
def insert(root, word):
node = root
for char in word:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_end_of_word = True
# Function to search for a word in the trie
def search(root, word):
node = root
for char in word:
if char not in node.children:
return False
node = node.children[char]
return node.is_end_of_word
# Function to print all words in the trie
def print_trie(node, prefix=''):
if node.is_end_of_word:
print(prefix)
for char, child in node.children.items():
print_trie(child, prefix + char)
# Function to generate the trie visualization using NetworkX
def visualize_trie(root):
G = nx.DiGraph()
labels = {}
pos = {}
node_id = 0 # Unique identifier for each node
def add_nodes_edges(node, parent_id=None, char=''):
nonlocal node_id
current_id = node_id
node_id += 1
# Label for the node
if parent_id is None:
labels[current_id] = 'Root'
else:
labels[current_id] = char
# Add node to the graph
G.add_node(current_id)
# Record position for better visualization
pos[current_id] = (current_id, -current_id)
if parent_id is not None:
# Add edge from parent to current node
G.add_edge(parent_id, current_id)
for c, child in node.children.items():
add_nodes_edges(child, current_id, c)
add_nodes_edges(root)
# Draw the graph
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=False, arrows=True, node_size=1500, node_color='lightblue')
nx.draw_networkx_labels(G, pos, labels, font_size=12)
# Highlight end-of-word nodes
end_nodes = []
def get_end_nodes(node, current_id=0):
nonlocal node_id
if node.is_end_of_word:
end_nodes.append(current_id)
child_id = current_id + 1
for c, child in node.children.items():
get_end_nodes(child, child_id)
child_id += 1
get_end_nodes(root)
nx.draw_networkx_nodes(G, pos, nodelist=end_nodes, node_color='lightgreen')
plt.title('Trie Visualization')
plt.axis('off')
plt.show()
# Main Code
if __name__ == "__main__":
# Initialize the root of the trie
root = TrieNode()
# List of words to insert into the trie
words = ['taco', 'tacos', 'tac', 'tactical', 'taste', 'tab', 'tag', 'tale', 'team']
# Insert words into the trie
for word in words:
insert(root, word)
# Print all words in the trie
print("Words in the trie:")
print_trie(root)
# Test searching for words
test_words = ['taco', 'tacos', 'tac', 'tag', 'tabs', 'tact', 'tactical', 'taste', 'tale', 'team', 'tear']
print("\nSearch results:")
for word in test_words:
result = search(root, word)
print(f"Word '{word}' found in trie: {result}")
visualize_trie(root)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment