Skip to content

Instantly share code, notes, and snippets.

@juanjux
Last active November 25, 2025 20:48
Show Gist options
  • Select an option

  • Save juanjux/573a3902c1994ba5ba654ffa5db5c027 to your computer and use it in GitHub Desktop.

Select an option

Save juanjux/573a3902c1994ba5ba654ffa5db5c027 to your computer and use it in GitHub Desktop.
Find unwrap() and expect() not in test code in a Rust codebase
#!uv script run
# /// script
# dependencies = [
# "tree-sitter>=0.21.0",
# "tree-sitter-rust>=0.21.0",
# ]
# ///
"""
Find all unwrap() and expect() calls in Rust code that are NOT inside test functions/modules.
Usage: python find_unwraps.py <directory>
"""
import sys
from pathlib import Path
from tree_sitter import Language, Parser
import tree_sitter_rust
def is_inside_test(node, source_code):
"""Check if a node is inside a function or module annotated with #[test] or #[cfg(test)]"""
current = node.parent
while current:
if current.type in ('function_item', 'mod_item'):
prev_sibling = current.prev_sibling
while prev_sibling:
if prev_sibling.type == 'attribute_item':
if '#[test]' in (attr_text := source_code[prev_sibling.start_byte:prev_sibling.end_byte].decode('utf8')) or '#[cfg(test)]' in attr_text:
return True
prev_sibling = prev_sibling.prev_sibling
elif prev_sibling.type in ('line_comment', 'block_comment'):
prev_sibling = prev_sibling.prev_sibling
else:
break
current = current.parent
return False
def find_unwrap_calls(node, source_code, results, file_path):
"""Recursively find all unwrap() and expect() method calls"""
if node.type == 'call_expression':
if (function_node := node.child_by_field_name('function')) and function_node.type == 'field_expression':
if (field_node := function_node.child_by_field_name('field')):
if (method_name := source_code[field_node.start_byte:field_node.end_byte].decode('utf8')) in ('unwrap', 'expect'):
if not is_inside_test(node, source_code):
results.append({
'file': str(file_path),
'line': node.start_point[0] + 1,
'column': node.start_point[1] + 1,
'text': source_code.split(b'\n')[node.start_point[0]].decode('utf8').strip(),
'method': method_name
})
for child in node.children:
find_unwrap_calls(child, source_code, results, file_path)
def analyze_rust_file(file_path, parser):
"""Analyze a single Rust file for unwrap() calls"""
try:
with open(file_path, 'rb') as f:
source_code = f.read()
results = []
find_unwrap_calls(parser.parse(source_code).root_node, source_code, results, file_path)
return results
except Exception as e:
print(f"Error analyzing {file_path}: {e}", file=sys.stderr)
return []
def main():
if len(sys.argv) < 2:
print("Usage: python find_unwraps.py <directory>")
sys.exit(1)
directory = Path(sys.argv[1])
if not directory.is_dir():
print(f"Error: {directory} is not a directory")
sys.exit(1)
# Initialize tree-sitter
RUST_LANGUAGE = Language(tree_sitter_rust.language())
parser = Parser(RUST_LANGUAGE)
# Find all .rs files
print(f"Analyzing {len(rust_files := list(directory.rglob('*.rs')))} Rust files in {directory}...")
print()
all_results = []
for rust_file in rust_files:
all_results.extend(analyze_rust_file(rust_file, parser))
# Print results
if all_results:
print(f"Found {len(all_results)} unwrap()/expect() calls outside of tests:\n")
for result in all_results:
print(f"{result['file']}:{result['line']}:{result['column']} [{result['method']}]")
print(f" {result['text']}")
print()
else:
print("No unwrap()/expect() calls found outside of tests!")
return 0 if not all_results else 1
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment