Created
February 18, 2025 12:18
-
-
Save maeste/3806e6e0d11cf309298166a33682c568 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import spacy | |
| import re | |
| from sentence_transformers import SentenceTransformer | |
| from scipy.spatial.cosine import cosine | |
| # Load models | |
| nlp = spacy.load("en_core_web_sm") | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| def normalize_text(text): | |
| # Remove question marks and periods | |
| text = re.sub(r'[?.]+', '', text) | |
| # Convert to lowercase | |
| text = text.lower() | |
| # Parse the text | |
| doc = nlp(text) | |
| # Extract main components | |
| subjects = [] | |
| verbs = [] | |
| objects = [] | |
| for token in doc: | |
| if "subj" in token.dep_: | |
| subjects.append(token.text) | |
| elif token.pos_ == "VERB": | |
| verbs.append(token.text) | |
| elif "obj" in token.dep_: | |
| objects.append(token.text) | |
| # Create normalized form | |
| normalized = " ".join(subjects + verbs + objects) | |
| return normalized | |
| def compare_texts(question, statement): | |
| # Get normalized forms | |
| norm_question = normalize_text(question) | |
| norm_statement = normalize_text(statement) | |
| # Get embeddings of original texts | |
| q_emb = embedding_model.encode(question) | |
| s_emb = embedding_model.encode(statement) | |
| # Get embeddings of normalized texts | |
| norm_q_emb = embedding_model.encode(norm_question) | |
| norm_s_emb = embedding_model.encode(norm_statement) | |
| # Calculate similarities | |
| original_sim = 1 - cosine(q_emb, s_emb) # Convert distance to similarity | |
| normalized_sim = 1 - cosine(norm_q_emb, norm_s_emb) | |
| print(f"\nComparing:\nQ: {question}\nS: {statement}") | |
| print(f"\nNormalized forms:\nQ: {norm_question}\nS: {norm_statement}") | |
| print(f"\nSimilarity scores:") | |
| print(f"Original: {original_sim:.3f}") | |
| print(f"Normalized: {normalized_sim:.3f}") | |
| print("-" * 50) | |
| # Test cases | |
| test_pairs = [ | |
| ( | |
| "Where is the book?", | |
| "The book is on the table." | |
| ), | |
| ( | |
| "What did John eat for lunch?", | |
| "John ate a sandwich for lunch." | |
| ), | |
| ( | |
| "How fast does the car go?", | |
| "The car goes 200 mph." | |
| ), | |
| ( | |
| "Who wrote this code?", | |
| "Sarah wrote the code yesterday." | |
| ), | |
| ( | |
| "When will the meeting start?", | |
| "The meeting starts at 3 PM." | |
| ) | |
| ] | |
| for question, statement in test_pairs: | |
| compare_texts(question, statement) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment