Created
September 7, 2025 10:45
-
-
Save yai333/0a063b7198f83d805ac244d2864ea83d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # Requirements | |
| # pydantic >= 2.5.0 | |
| # google-adk >= 1.13.0 | |
| # google-generativeai >= 0.8.5 | |
| # presidio-analyzer >= 2.2.354 | |
| # presidio-anonymizer >= 2.2.354 | |
| # spacy >= 3.4.0 | |
| # typing-extensions >= 4.8.0 | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| import sqlite3 | |
| import time | |
| from enum import Enum | |
| from typing import List, Literal, Union | |
| from pydantic import BaseModel, Field, field_validator, create_model | |
| from presidio_analyzer import AnalyzerEngine | |
| from presidio_anonymizer import AnonymizerEngine | |
| from google.adk import Runner | |
| from google.adk.agents import LlmAgent | |
| from google.adk.sessions import InMemorySessionService | |
| from google.genai import types | |
| class ParameterType(Enum): | |
| STRING = "string" | |
| INTEGER = "integer" | |
| FLOAT = "float" | |
| DATE = "date" | |
| BOOLEAN = "boolean" | |
| class SQLParameter(BaseModel): | |
| param_name: str = Field( | |
| description="Parameter name used in SQL (e.g., 'first_name', 'last_name')") | |
| placeholder_ref: str = Field( | |
| description="Reference to the original placeholder (e.g., '<PERSON_1>')") | |
| param_type: ParameterType = Field( | |
| default=ParameterType.STRING, description="SQL parameter data type") | |
| class StructuredSQLQuery(BaseModel): | |
| """Structured SQL query with parameter validation.""" | |
| sql_query: str = Field( | |
| description="SQL query with named parameters for PII (e.g., :person_1, :email_1)") | |
| parameters: List[SQLParameter] = Field( | |
| description="List of parameters with their placeholder mappings") | |
| @field_validator('sql_query') | |
| @classmethod | |
| def validate_sql_structure(cls, v): | |
| v = v.strip() | |
| if not v: | |
| raise ValueError("SQL query cannot be empty") | |
| upper_v = v.upper() | |
| valid_starts = ('SELECT', 'WITH', '(') | |
| if not any(upper_v.startswith(start) for start in valid_starts): | |
| raise ValueError("SQL must be a valid SQL statement") | |
| return v | |
| def extract_placeholders_from_text(text: str) -> List[str]: | |
| if not text: | |
| return [] | |
| placeholders = re.findall(r'<[A-Z_]+_\d+>', text) | |
| return sorted(list(set(placeholders))) | |
| def create_dynamic_sql_models(anonymized_question: str): | |
| """Create dynamic Pydantic models with placeholder constraints based on the anonymized question.""" | |
| available_placeholders = extract_placeholders_from_text( | |
| anonymized_question) | |
| if not available_placeholders: | |
| return SQLParameter, StructuredSQLQuery | |
| # Union type: PII placeholders OR any string for literals | |
| PlaceholderType = Union[Literal[tuple(available_placeholders)], str] | |
| DynamicSQLParameter = create_model( | |
| 'DynamicSQLParameter', | |
| param_name=(str, Field( | |
| description="Parameter name used in SQL (e.g., 'first_name', 'last_name')")), | |
| placeholder_ref=(PlaceholderType, Field( | |
| description=f"PII placeholders: {', '.join(available_placeholders)} OR literal values like '100', 'Jazz'")), | |
| param_type=(ParameterType, Field( | |
| default=ParameterType.STRING, description="SQL parameter data type")), | |
| __base__=BaseModel | |
| ) | |
| DynamicStructuredSQLQuery = create_model( | |
| 'DynamicStructuredSQLQuery', | |
| sql_query=(str, Field( | |
| description="SQL query with named parameters for PII (e.g., :person_1, :email_1)")), | |
| parameters=(List[DynamicSQLParameter], Field( | |
| description="List of parameters with their placeholder mappings")), | |
| __base__=BaseModel | |
| ) | |
| return DynamicSQLParameter, DynamicStructuredSQLQuery | |
| class MappingStorage: | |
| """Persistent storage for PII mappings using SQLite | |
| NOTE: This is a simplified implementation for demo purposes. | |
| In production, you should: | |
| 1. Add session_id to the schema for session isolation | |
| 2. Include user_id for multi-tenant scenarios | |
| 3. Add TTL/expiration for compliance (GDPR right to be forgotten) | |
| 4. Encrypt sensitive mappings at rest | |
| Production schema example: | |
| CREATE TABLE pii_mappings ( | |
| session_id TEXT NOT NULL, | |
| user_id TEXT NOT NULL, | |
| original TEXT NOT NULL, | |
| entity_type TEXT NOT NULL, | |
| pseudonym TEXT NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| expires_at TIMESTAMP, | |
| PRIMARY KEY (session_id, original, entity_type) | |
| ) | |
| """ | |
| def __init__(self, db_path="pii_mappings.db"): | |
| self.db_path = db_path | |
| self._init_db() | |
| def _init_db(self): | |
| """Initialize database with mappings table | |
| NOTE: For demo purposes, mappings are global and persistent. | |
| Production should scope mappings by session_id. | |
| """ | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS pii_mappings ( | |
| original TEXT NOT NULL, | |
| entity_type TEXT NOT NULL, | |
| pseudonym TEXT NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| PRIMARY KEY (original, entity_type) | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def store_mapping(self, original: str, pseudonym: str, entity_type: str): | |
| """Store PII mapping | |
| NOTE: Currently stores globally without session isolation. | |
| Production should include session_id in the storage: | |
| INSERT OR REPLACE INTO pii_mappings | |
| (session_id, original, entity_type, pseudonym) | |
| VALUES (?, ?, ?, ?) | |
| """ | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT OR REPLACE INTO pii_mappings | |
| (original, entity_type, pseudonym) | |
| VALUES (?, ?, ?) | |
| """, (original, entity_type, pseudonym)) | |
| conn.commit() | |
| conn.close() | |
| def get_mapping(self, original: str, entity_type: str): | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT original, pseudonym, entity_type FROM pii_mappings | |
| WHERE original = ? AND entity_type = ? | |
| """, (original, entity_type)) | |
| result = cursor.fetchone() | |
| conn.close() | |
| if result: | |
| return { | |
| 'original': result[0], | |
| 'pseudonym': result[1], | |
| 'type': result[2] | |
| } | |
| return None | |
| def get_all_mappings(self) -> dict: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT original, pseudonym, entity_type FROM pii_mappings") | |
| results = cursor.fetchall() | |
| conn.close() | |
| mappings = {} | |
| for i, result in enumerate(results): | |
| mappings[f"mapping_{i}"] = { | |
| 'original': result[0], | |
| 'pseudonym': result[1], | |
| 'type': result[2] | |
| } | |
| return mappings | |
| def clear_mappings(self): | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute("DELETE FROM pii_mappings") | |
| conn.commit() | |
| conn.close() | |
| class PresidioPIIDetector: | |
| def __init__(self, db_path="pii_mappings.db"): | |
| self.analyzer = AnalyzerEngine() | |
| self.anonymizer = AnonymizerEngine() | |
| self.mapping_storage = MappingStorage(db_path) | |
| # NOTE: Session mappings are in-memory only for demo purposes | |
| # Production should store these in DB with proper session_id scoping | |
| self.session_mappings = {} | |
| def pseudonymize(self, text): | |
| allowed_entities = [ | |
| 'EMAIL_ADDRESS', 'PERSON', 'DATE_TIME', 'LOCATION', | |
| 'PHONE_NUMBER', | |
| 'AU_ABN', | |
| 'AU_ACN', | |
| 'AU_TFN', | |
| 'AU_MEDICARE' | |
| ] | |
| results = self.analyzer.analyze( | |
| text=text, language='en', entities=allowed_entities) | |
| if not results: | |
| return text, [] | |
| anonymized_text = text | |
| detected_entities = [] | |
| session_entity_counters = {} | |
| for result in sorted(results, key=lambda x: x.start, reverse=True): | |
| original_value = text[result.start:result.end] | |
| entity_type = result.entity_type.lower() | |
| existing_mapping = self.mapping_storage.get_mapping( | |
| original_value, entity_type) | |
| if existing_mapping: | |
| placeholder = existing_mapping['pseudonym'] | |
| else: | |
| if entity_type not in session_entity_counters: | |
| session_entity_counters[entity_type] = self._get_next_entity_counter( | |
| entity_type) | |
| else: | |
| session_entity_counters[entity_type] += 1 | |
| next_counter = session_entity_counters[entity_type] | |
| placeholder = f"<{entity_type.upper()}_{next_counter}>" | |
| self.mapping_storage.store_mapping( | |
| original_value, placeholder, entity_type) | |
| anonymized_text = ( | |
| anonymized_text[:result.start] + | |
| placeholder + | |
| anonymized_text[result.end:] | |
| ) | |
| self.session_mappings[f"{original_value}_{entity_type}"] = { | |
| 'original': original_value, | |
| 'pseudonym': placeholder, | |
| 'type': entity_type, | |
| 'score': result.score | |
| } | |
| detected_entities.append({ | |
| 'entity_type': entity_type, | |
| 'value': original_value, | |
| 'start': result.start, | |
| 'end': result.end, | |
| 'score': result.score | |
| }) | |
| return anonymized_text, detected_entities | |
| def _get_next_entity_counter(self, entity_type: str) -> int: | |
| """Get the next available counter for an entity type across all existing mappings.""" | |
| all_mappings = self.mapping_storage.get_all_mappings() | |
| existing_counters = [] | |
| for mapping in all_mappings.values(): | |
| if mapping['type'] == entity_type: | |
| pseudonym = mapping['pseudonym'] | |
| match = re.search(rf'<{entity_type.upper()}_(\d+)>', pseudonym) | |
| if match: | |
| existing_counters.append(int(match.group(1))) | |
| for mapping in self.session_mappings.values(): | |
| if mapping['type'] == entity_type: | |
| pseudonym = mapping['pseudonym'] | |
| match = re.search(rf'<{entity_type.upper()}_(\d+)>', pseudonym) | |
| if match: | |
| existing_counters.append(int(match.group(1))) | |
| if existing_counters: | |
| return max(existing_counters) + 1 | |
| else: | |
| return 1 | |
| def deanonymize(self, anonymized_text): | |
| result = anonymized_text | |
| for mapping in self.session_mappings.values(): | |
| if mapping['pseudonym'] in result: | |
| result = result.replace( | |
| mapping['pseudonym'], mapping['original']) | |
| for mapping in self.mapping_storage.get_all_mappings().values(): | |
| if mapping['pseudonym'] in result: | |
| result = result.replace( | |
| mapping['pseudonym'], mapping['original']) | |
| return result | |
| def get_mapping_storage(self): | |
| return self.session_mappings | |
| def clear_all_mappings(self): | |
| self.session_mappings.clear() | |
| self.mapping_storage.clear_mappings() | |
| class SQLResponse(BaseModel): | |
| """SQL response with dynamic bindings""" | |
| sql: str = Field(..., | |
| description="SQLite SQL with dynamic placeholder bindings") | |
| bindings: dict = Field(default_factory=dict, | |
| description="Only includes PII types actually used") | |
| CHINOOK_SCHEMA = """ | |
| Chinook Music Database (SQLite) with Relationships: | |
| - Customer: CustomerId(PK), FirstName, LastName, Email, Phone, Country | |
| - Invoice: InvoiceId(PK), CustomerId(FKβCustomer), InvoiceDate, Total | |
| - InvoiceLine: InvoiceLineId(PK), InvoiceId(FKβInvoice), TrackId(FKβTrack), Quantity, UnitPrice | |
| - Track: TrackId(PK), Name, AlbumId(FKβAlbum), GenreId(FKβGenre), Milliseconds, UnitPrice | |
| - Album: AlbumId(PK), Title, ArtistId(FKβArtist) | |
| - Artist: ArtistId(PK), Name | |
| - Genre: GenreId(PK), Name | |
| Common JOINs: | |
| - CustomerβInvoice: ON c.CustomerId = i.CustomerId | |
| - InvoiceβInvoiceLine: ON i.InvoiceId = il.InvoiceId | |
| - InvoiceLineβTrack: ON il.TrackId = t.TrackId | |
| - TrackβAlbum: ON t.AlbumId = a.AlbumId | |
| - AlbumβArtist: ON a.ArtistId = ar.ArtistId | |
| - TrackβGenre: ON t.GenreId = g.GenreId | |
| """ | |
| class TextToSQLADK: | |
| """Text-to-SQL with Google ADK and PII protection""" | |
| def __init__(self, api_key: str): | |
| self.pii_detector = PresidioPIIDetector() | |
| self.app_name = "chinook-sql-final" | |
| self.user_id = "test-user-final" | |
| self.types = types | |
| self.pii_instruction = f"""Convert natural language to SQLite SQL for the Chinook database with PII placeholders. | |
| {CHINOOK_SCHEMA} | |
| Use SQLite syntax with exact table/column names. Use NAMED PARAMETERS for ALL WHERE clause values. | |
| Use proper JOINs for multi-table queries. | |
| PARAMETER RULES: | |
| - Use named parameters for ALL values in WHERE, HAVING, and conditional clauses: :param_1, :param_2, etc. | |
| - Map each :param_name to its source placeholder | |
| IMPORTANT NAME HANDLING: | |
| - When searching by person names (PERSON placeholders), ALWAYS use FULL NAME matching | |
| - Use: (FirstName || ' ' || LastName) = :person_1 | |
| - NOT just FirstName = :person_1 | |
| Examples: | |
| - Mixed: "SELECT * FROM Customer WHERE (FirstName || ' ' || LastName) = :person_1 AND Country = :country_1" | |
| - Pure PII: "SELECT * FROM Customer WHERE Email = :email_1 AND Phone = :phone_1" """ | |
| self.no_pii_instruction = f"""Convert natural language to SQLite SQL for the Chinook database. | |
| {CHINOOK_SCHEMA} | |
| Use SQLite syntax with exact table/column names. Use LITERAL VALUES directly in SQL. | |
| Use proper JOINs for multi-table queries. | |
| RULES: | |
| - Use literal string values directly in SQL with proper quotes | |
| - Leave parameters array EMPTY | |
| - Use proper SQL syntax: 'text values', numbers without quotes | |
| Examples: | |
| - Country: "SELECT * FROM Customer WHERE Country = 'Canada'" | |
| - Gmail pattern: "SELECT * FROM Customer WHERE Email LIKE '%@gmail.com'" | |
| - Artist: "SELECT * FROM Album a JOIN Artist ar ON a.ArtistId = ar.ArtistId WHERE ar.Name = 'AC/DC'" | |
| - Number: "SELECT * FROM Track WHERE Milliseconds > 300000" | |
| Parameters: [] (always empty)""" | |
| self.sql_agent = None | |
| self.session_service = InMemorySessionService() | |
| def normalize_binding_format(self, parsed_data: dict, pii_items: list) -> dict: | |
| """Convert model's natural binding format to expected schema format""" | |
| result = {'sql': parsed_data.get('sql', ''), 'bindings': {}} | |
| original_bindings = parsed_data.get('bindings', {}) | |
| if isinstance(original_bindings, list): | |
| pii_by_type = {} | |
| for item in pii_items: | |
| pii_type = item.get('entity_type', 'unknown').lower() | |
| pii_by_type.setdefault(pii_type, []).append(item) | |
| for i, value in enumerate(original_bindings): | |
| for pii_type, items in pii_by_type.items(): | |
| if i < len(items): | |
| result['bindings'].setdefault(pii_type, []).append( | |
| f"<{pii_type.upper()}_{len(result['bindings'].get(pii_type, [])) + 1}>") | |
| break | |
| elif isinstance(original_bindings, dict) and original_bindings: | |
| if all(isinstance(v, str) and v.startswith('<') and v.endswith('>') for v in original_bindings.values()): | |
| result['bindings'] = original_bindings | |
| elif all(isinstance(v, list) for v in original_bindings.values()): | |
| result['bindings'] = original_bindings | |
| else: | |
| print( | |
| f"Warning: Received unrecognized bindings format: {original_bindings}") | |
| elif isinstance(original_bindings, dict): | |
| result['bindings'] = original_bindings | |
| return result | |
| async def generate_sql(self, query: str, pii_items: list = []) -> dict: | |
| """Generate SQL with ADK using dynamic structured output""" | |
| try: | |
| _, DynamicStructuredSQLQuery = create_dynamic_sql_models(query) | |
| print(f"π§ Dynamic Schema for: '{query}'") | |
| print(f"π Schema Model: {DynamicStructuredSQLQuery.__name__}") | |
| # available_placeholders = extract_placeholders_from_text(query) | |
| # if available_placeholders: | |
| # print(f"π·οΈ Allowed Placeholders: {available_placeholders}") | |
| # else: | |
| # print("π·οΈ No PII Placeholders - Using base schema") | |
| # print("π Schema Fields:") | |
| for field_name, field_info in DynamicStructuredSQLQuery.model_fields.items(): | |
| print(f" - {field_name}: {field_info.annotation}") | |
| print() | |
| has_pii = len(pii_items) > 0 | |
| instruction = self.pii_instruction if has_pii else self.no_pii_instruction | |
| agent_description = "Generates structured SQL with PII placeholder handling" if has_pii else "Generates structured SQL with literal values" | |
| print( | |
| f"π Using {'PII' if has_pii else 'Non-PII'} Instruction Template") | |
| sql_agent = LlmAgent( | |
| model="gemini-2.5-flash-lite", | |
| name="dynamic_sql_agent", | |
| description=agent_description, | |
| instruction=instruction, | |
| output_schema=DynamicStructuredSQLQuery, | |
| output_key="structured_sql_result", | |
| tools=[], | |
| generate_content_config=self.types.GenerateContentConfig( | |
| temperature=0 | |
| ), | |
| disallow_transfer_to_parent=True, | |
| disallow_transfer_to_peers=True | |
| ) | |
| runner = Runner( | |
| agent=sql_agent, | |
| session_service=self.session_service, | |
| app_name=self.app_name | |
| ) | |
| session = await self.session_service.create_session( | |
| app_name=self.app_name, | |
| user_id=self.user_id | |
| ) | |
| content = self.types.Content( | |
| role='user', | |
| parts=[self.types.Part(text=f"Convert to SQL: {query}")] | |
| ) | |
| events = runner.run( | |
| user_id=self.user_id, | |
| session_id=session.id, | |
| new_message=content | |
| ) | |
| structured_response = None | |
| for event in events: | |
| if event.is_final_response() and event.content: | |
| response_text = event.content.parts[0].text.strip() | |
| print(f"π ADK Structured Response: {response_text}") | |
| try: | |
| if isinstance(response_text, str): | |
| clean_response = re.sub( | |
| r'```(?:json)?\s*', '', response_text) | |
| clean_response = re.sub( | |
| r'\s*```', '', clean_response) | |
| structured_response = json.loads( | |
| clean_response.strip()) | |
| else: | |
| structured_response = response_text | |
| # print( | |
| # f"π Parsed Structured Data: {json.dumps(structured_response, indent=2)}") | |
| break | |
| except json.JSONDecodeError as e: | |
| print(f"π JSON parsing error: {e}") | |
| print(f"π Raw response: {response_text}") | |
| raise e | |
| if structured_response: | |
| result = { | |
| 'sql': structured_response.get('sql_query', ''), | |
| 'parameters': structured_response.get('parameters', []), | |
| 'session_id': session.id[:8] | |
| } | |
| bindings = {} | |
| for param in result.get('parameters', []): | |
| param_name = param.get('param_name', '') | |
| placeholder_ref = param.get('placeholder_ref', '') | |
| if param_name and placeholder_ref: | |
| bindings[param_name] = placeholder_ref | |
| result['bindings'] = bindings | |
| # print( | |
| # f"π Final Structured Result: {json.dumps(result, indent=2)}") | |
| return result | |
| raise Exception("ADK produced no structured response") | |
| except Exception as e: | |
| print(f"ADK error: {str(e)}") | |
| raise e | |
| def generate_sql_sync(self, query: str, pii_items: list = []) -> dict: | |
| return asyncio.run(self.generate_sql(query, pii_items)) | |
| def validate_sql(self, sql: str, bindings: dict = None) -> dict: | |
| if not sql or sql.startswith('--'): | |
| return {'valid': False, 'error': 'Invalid SQL'} | |
| try: | |
| # download from https://www.kaggle.com/datasets/nancyalaswad90/chinook-sample-database | |
| conn = sqlite3.connect("chinook.db") | |
| cursor = conn.cursor() | |
| clean_sql = sql.rstrip(';') | |
| if bindings: | |
| dummy_bindings = {} | |
| for param_name, placeholder in bindings.items(): | |
| param_key = param_name.lstrip(':') | |
| dummy_bindings[param_key] = 'test_value' | |
| cursor.execute( | |
| f"EXPLAIN QUERY PLAN {clean_sql}", dummy_bindings) | |
| if clean_sql.upper().startswith('SELECT'): | |
| if 'LIMIT' not in clean_sql.upper(): | |
| clean_sql += ' LIMIT 3' | |
| cursor.execute(clean_sql, dummy_bindings) | |
| cursor.fetchall() | |
| else: | |
| cursor.execute(f"EXPLAIN QUERY PLAN {clean_sql}") | |
| if clean_sql.upper().startswith('SELECT'): | |
| if 'LIMIT' not in clean_sql.upper(): | |
| clean_sql += ' LIMIT 3' | |
| cursor.execute(clean_sql) | |
| cursor.fetchall() | |
| conn.close() | |
| return {'valid': True, 'row_count': 3} | |
| except Exception as e: | |
| return {'valid': False, 'error': str(e)} | |
| def process_query(self, query: str) -> dict: | |
| """Process query with PII protection and SQL generation""" | |
| start_time = time.time() | |
| anonymized, pii_entities = self.pii_detector.pseudonymize(query) | |
| pii_items = [] | |
| for entity in pii_entities: | |
| pii_items.append({ | |
| 'original': entity['value'], | |
| 'pseudonym': f"<{entity['entity_type'].upper()}_1>", | |
| 'type': entity['entity_type'] | |
| }) | |
| sql_result = self.generate_sql_sync(anonymized, pii_items) | |
| # Extract non-PII parameters from the SQL result | |
| non_pii_items = [] | |
| if 'parameters' in sql_result: | |
| pii_placeholders = {item['pseudonym'] for item in pii_items} | |
| for param in sql_result['parameters']: | |
| placeholder_ref = param.get('placeholder_ref', '') | |
| # If placeholder_ref is not a PII placeholder, it's a literal value | |
| if placeholder_ref not in pii_placeholders and not placeholder_ref.startswith('<'): | |
| non_pii_items.append({ | |
| 'param_name': param.get('param_name', ''), | |
| 'value': placeholder_ref, | |
| 'type': param.get('param_type', 'string') | |
| }) | |
| validation = self.validate_sql(sql_result.get( | |
| 'sql', ''), sql_result.get('bindings', {})) | |
| result = { | |
| 'query': query, | |
| 'anonymized': anonymized, | |
| 'sql': sql_result.get('sql', ''), | |
| 'bindings': sql_result.get('bindings', {}), | |
| 'pii_items': pii_items, | |
| 'non_pii_items': non_pii_items, | |
| 'valid': validation['valid'], | |
| 'error': validation.get('error', None), | |
| 'confidence': sql_result.get('confidence', 0.0), | |
| 'execution_time': round((time.time() - start_time) * 1000, 1) | |
| } | |
| return result | |
| def main(): | |
| print("=" * 70) | |
| print("π GOOGLE ADK TEXT-TO-SQL TEST") | |
| print("With PII protection and session management") | |
| print("=" * 70) | |
| adk = TextToSQLADK(os.environ["GOOGLE_API_KEY"]) | |
| print("β Final ADK implementation ready") | |
| test_queries = [ | |
| # PII with mixed literal values - tests Option 2 parameterization | |
| # "Find customer John Smith in Australia", | |
| # "Find customers named Bob Wilson with email [email protected] in Australia", | |
| # Pure PII queries | |
| # "Find customer with email [email protected]", | |
| # "Show customers with phone number 555-999-8888", | |
| # "List invoices for customer Mary Davis with email [email protected] and phone 555-123-4567", | |
| # "Find customers named Jane Smith with phone 555-444-3333 in location Paris", | |
| # "Show all customers with emails [email protected] and [email protected]", | |
| # Pure non-PII queries | |
| # "Show albums by AC/DC", | |
| # "Find all customers in Canada", | |
| # "List tracks longer than 300 seconds", | |
| # "Show albums released in year 2000", | |
| # "Find genres with more than 100 tracks", | |
| # Complex mixed scenarios - Multi-PII + literals | |
| "Find customers named Sarah Johnson in Canada who purchased rock albums", | |
| "List customers named John Doe in USA with total invoice amount greater than 1000", | |
| "List customers with phone 555-111-2222 or email [email protected] in France", | |
| "Show customers named Michael Johnson in Australia who bought jazz albums costing more than 100" | |
| ] | |
| print(f"Testing {len(test_queries)} queries...") | |
| print() | |
| results = [] | |
| for i, query in enumerate(test_queries, 1): | |
| print("=" * 70) | |
| print(f"TEST {i}: {query}") | |
| print("=" * 70) | |
| result = adk.process_query(query) | |
| results.append(result) | |
| print(f"Anonymized: {result['anonymized']}") | |
| print(f"SQL: {result['sql']}") | |
| print(f"PII ({len(result['pii_items'])}):") | |
| for pii in result['pii_items']: | |
| print(f" {pii['original']} β {pii['pseudonym']} ({pii['type']})") | |
| if result.get('non_pii_items'): | |
| print(f"Non-PII ({len(result['non_pii_items'])}):") | |
| for item in result['non_pii_items']: | |
| print( | |
| f" {item['param_name']} β {item['value']} ({item['type']})") | |
| print(f"Valid SQL: {'β ' if result['valid'] else 'β'}") | |
| if result['error']: | |
| print(f"Error: {result['error']}") | |
| print(f"Time: {result['execution_time']}ms") | |
| print() | |
| print("=" * 70) | |
| print("π― FINAL SUMMARY") | |
| print("=" * 70) | |
| total_tests = len(results) | |
| valid_sql = sum(1 for r in results if r['valid']) | |
| avg_confidence = sum(r['confidence'] for r in results) / \ | |
| total_tests if total_tests > 0 else 0 | |
| total_pii = sum(len(r['pii_items']) for r in results) | |
| total_non_pii = sum(len(r.get('non_pii_items', [])) for r in results) | |
| print(f"Tests: {total_tests}") | |
| print( | |
| f"Valid SQL: {valid_sql}/{total_tests} ({valid_sql/total_tests*100:.0f}%)") | |
| print(f"Avg Confidence: {avg_confidence:.2f}") | |
| print(f"Total PII Protected: {total_pii}") | |
| print(f"Total Non-PII Parameters: {total_non_pii}") | |
| print() | |
| print("π Google ADK Status:") | |
| if valid_sql == total_tests: | |
| print(" β PERFECT: All SQL queries valid") | |
| elif valid_sql > 0: | |
| print(" β οΈ PARTIAL: ADK running but SQL needs improvement") | |
| else: | |
| print(" β ISSUES: ADK running but SQL validation failing") | |
| success = valid_sql > total_tests * 0.8 | |
| print(f"\n{'β PASSED' if success else 'β FAILED'}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment