Last active
March 2, 2025 19:26
-
-
Save futurisold/812022794940a56a2d21e2ddf0d990b3 to your computer and use it in GitHub Desktop.
(e.g.) symai contract
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Ontology-based Triplet Extraction | |
| --------------------------------- | |
| This module implements a system for extracting semantic triplets from text | |
| based on a predefined ontology schema. | |
| """ | |
| # Third-party imports | |
| from tqdm import tqdm | |
| from loguru import logger | |
| from pydantic import Field | |
| from tokenizers import Tokenizer | |
| # SymbolicAI imports | |
| from symai import Expression, Symbol | |
| from symai.components import FileReader, MetadataTracker | |
| from symai.models import LLMDataModel | |
| from symai.strategy import contract | |
| # Chonkie imports | |
| from chonkie import ( | |
| RecursiveChunker, | |
| SDPMChunker, | |
| SemanticChunker, | |
| SentenceChunker, | |
| TokenChunker | |
| ) | |
| from chonkie.embeddings.base import BaseEmbeddings | |
| # Constants | |
| DEFAULT_TOKENIZER = "gpt2" | |
| DEFAULT_EMBEDDING_MODEL = "minishlab/potion-base-8M" | |
| DEFAULT_CONFIDENCE_THRESHOLD = 0.7 | |
| CHUNKER_MAPPING = { | |
| "TokenChunker": TokenChunker, | |
| "SentenceChunker": SentenceChunker, | |
| "RecursiveChunker": RecursiveChunker, | |
| "SemanticChunker": SemanticChunker, | |
| "SDPMChunker": SDPMChunker, | |
| } | |
| class ChonkieChunker(Expression): | |
| """A text chunking utility that supports multiple chunking strategies. | |
| This class provides various text chunking methods including token-based, | |
| sentence-based, recursive, and semantic chunking approaches. | |
| Args: | |
| tokenizer_name (str): Name of the tokenizer to use | |
| embedding_model_name (str | BaseEmbeddings): Name or instance of embedding model | |
| **symai_kwargs: Additional kwargs for the Expression base class | |
| """ | |
| def __init__( | |
| self, | |
| tokenizer_name: str = DEFAULT_TOKENIZER, | |
| embedding_model_name: str | BaseEmbeddings = DEFAULT_EMBEDDING_MODEL, | |
| **symai_kwargs, | |
| ): | |
| super().__init__(**symai_kwargs) | |
| self.tokenizer_name = tokenizer_name | |
| self.embedding_model_name = embedding_model_name | |
| def forward(self, data: Symbol[str | list[str]], chunker_name: str = "RecursiveChunker", **chunker_kwargs) -> Symbol[list[str]]: | |
| chunker = self._resolve_chunker(chunker_name, **chunker_kwargs) | |
| chunks = [self._clean_text(chunk.text) for chunk in chunker(data.value)] | |
| return self._to_symbol(chunks) | |
| def _resolve_chunker(self, chunker_name: str, **chunker_kwargs) -> TokenChunker | SentenceChunker | RecursiveChunker | SemanticChunker | SDPMChunker: | |
| if chunker_name in ["TokenChunker", "SentenceChunker", "RecursiveChunker"]: | |
| tokenizer = Tokenizer.from_pretrained(self.tokenizer_name) | |
| return CHUNKER_MAPPING[chunker_name](tokenizer, **chunker_kwargs) | |
| elif chunker_name in ["SemanticChunker", "SDPMChunker"]: | |
| return CHUNKER_MAPPING[chunker_name](embedding_model=self.embedding_model_name, **chunker_kwargs) | |
| else: | |
| raise ValueError(f"Chunker {chunker_name} not found. Available chunkers: {CHUNKER_MAPPING.keys()}. See docs (https://docs.chonkie.ai/getting-started/introduction) for more info.") | |
| def _clean_text(self, text: str) -> str: | |
| """Cleans text by removing problematic characters.""" | |
| text = text.replace('\x00', '') # Remove null bytes (\x00) | |
| text = text.encode('utf-8', errors='ignore').decode('utf-8') # Replace invalid UTF-8 sequences | |
| return text | |
| # Data Models | |
| class Entity(LLMDataModel): | |
| """Represents an entity in the ontology""" | |
| name: str = Field(description="Name of the entity") | |
| type: str = Field(description="Type/category of the entity") | |
| class Relationship(LLMDataModel): | |
| """Represents a relationship type in the ontology""" | |
| name: str = Field(description="Name of the relationship") | |
| class OntologySchema(LLMDataModel): | |
| """Defines the ontology schema with allowed entities and relationships""" | |
| entities: list[Entity] = Field(description="List of valid entity types") | |
| relationships: list[Relationship] = Field(description="List of valid relationship types") | |
| class TripletInput(LLMDataModel): | |
| """Input for triplet extraction""" | |
| text: str = Field(description="Text to extract triplets from") | |
| ontology: OntologySchema = Field(description="Ontology schema to use for extraction") | |
| class Triplet(LLMDataModel): | |
| """A semantic triplet with typed entities and relationship""" | |
| subject: Entity | |
| predicate: Relationship | |
| object: Entity | |
| confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="Confidence score for the extracted triplet [0, 1]") | |
| class TripletOutput(LLMDataModel): | |
| """Collection of extracted triplets forming a knowledge graph""" | |
| triplets: list[Triplet] | None = Field(default=None, description="List of extracted triplets") | |
| @contract( | |
| pre_remedy=False, | |
| post_remedy=True, | |
| verbose=True, | |
| remedy_retry_params=dict( | |
| tries=1, | |
| delay=0.5, | |
| max_delay=15, | |
| jitter=0.1, | |
| backoff=2, | |
| graceful=False | |
| ) | |
| ) | |
| class OntologyTripletExtractor(Expression): | |
| """Extracts typed triplets according to an ontology schema""" | |
| def __init__(self, threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.threshold = threshold | |
| def forward(self, input: TripletInput, **kwargs) -> TripletOutput: | |
| if self.contract_result is None: | |
| return TripletOutput(triplets=None) | |
| return self.contract_result | |
| def pre(self, input: TripletInput) -> bool: | |
| # No semantic validation for now | |
| return True | |
| def post(self, output: TripletOutput) -> bool: | |
| if output.triplets is None: | |
| return True | |
| for triplet in output.triplets: | |
| if triplet.confidence < self.threshold: | |
| raise ValueError(f"Confidence score {triplet.confidence} has to be above threshold {self.threshold}! Extract relationships between entities that are meaningful and relevant.") | |
| return True | |
| @property | |
| def prompt(self) -> str: | |
| return ( | |
| "You are an expert at extracting semantic relationships from text according to ontology schemas. " | |
| "For the given text and ontology:\n" | |
| "1. Identify entities matching the allowed entity types\n" | |
| "2. Extract relationships between entities matching the defined relationship types\n" | |
| "3. Assign confidence scores based on certainty of extraction\n" | |
| "4. Ensure all entity and relationship types conform to the ontology\n" | |
| "5. Do not duplicate triplets\n" | |
| "6. If triplets can't be found, default to None" | |
| ) | |
| def create_sample_ontology() -> OntologySchema: | |
| """Creates a sample ontology schema for demonstration purposes.""" | |
| return OntologySchema( | |
| entities=[ | |
| # People and Organizations | |
| Entity(name="Person", type="PERSON"), | |
| Entity(name="Organization", type="ORG"), | |
| Entity(name="Location", type="LOC"), | |
| # Legal Entities | |
| Entity(name="Agreement", type="AGREEMENT"), | |
| Entity(name="Policy", type="POLICY"), | |
| Entity(name="Service", type="SERVICE"), | |
| Entity(name="Feature", type="FEATURE"), | |
| Entity(name="Right", type="RIGHT"), | |
| Entity(name="Obligation", type="OBLIGATION"), | |
| # Data Related | |
| Entity(name="PersonalData", type="PERSONAL_DATA"), | |
| Entity(name="DataCategory", type="DATA_CATEGORY"), | |
| Entity(name="DataProcessor", type="DATA_PROCESSOR"), | |
| # Time and Events | |
| Entity(name="Date", type="DATE"), | |
| Entity(name="Event", type="EVENT"), | |
| # Financial | |
| Entity(name="Payment", type="PAYMENT"), | |
| Entity(name="Currency", type="CURRENCY") | |
| ], | |
| relationships=[ | |
| # Organizational Relations | |
| Relationship(name="works_for"), | |
| Relationship(name="located_in"), | |
| Relationship(name="owns"), | |
| Relationship(name="operates"), | |
| # Legal Relations | |
| Relationship(name="governs"), | |
| Relationship(name="requires"), | |
| Relationship(name="prohibits"), | |
| Relationship(name="permits"), | |
| Relationship(name="provides"), | |
| # Data Relations | |
| Relationship(name="processes"), | |
| Relationship(name="collects"), | |
| Relationship(name="stores"), | |
| Relationship(name="shares"), | |
| Relationship(name="transfers"), | |
| # Temporal Relations | |
| Relationship(name="starts_on"), | |
| Relationship(name="ends_on"), | |
| Relationship(name="modified_on"), | |
| # Financial Relations | |
| Relationship(name="charges"), | |
| Relationship(name="pays"), | |
| Relationship(name="costs") | |
| ] | |
| ) | |
| def main(): | |
| # Initialize components | |
| reader = FileReader() | |
| chunker = ChonkieChunker() | |
| extractor = OntologyTripletExtractor(tokenizer_name="Xenova/gpt-4o") | |
| # Load and process sample text | |
| sample_text = reader("x-terms-of-service-2024-11-15.pdf") | |
| chunks = chunker(data=Symbol(sample_text[0]), chunk_size=1024).value | |
| # Extract triplets | |
| triplets = [] | |
| usage = None | |
| with MetadataTracker() as tracker: | |
| for chunk in tqdm(chunks, desc="Processing chunks"): | |
| input_data = TripletInput( | |
| text=chunk, | |
| ontology=create_sample_ontology() | |
| ) | |
| try: | |
| result = extractor(input=input_data) | |
| if result.triplets: | |
| triplets.extend(result.triplets) | |
| except Exception as e: | |
| logger.error(f"Error extracting triplets from chunk: {chunk}") | |
| logger.error(f"Error message: {str(e)}") | |
| extractor.contract_perf_stats() | |
| try: | |
| usage = tracker.usage | |
| except Exception as e: | |
| logger.error(f"Error getting usage information: {str(e)}") | |
| # Display results | |
| for triplet in triplets: | |
| if triplet: | |
| logger.info(f"\n-------\n{triplet}\n-------\n") | |
| logger.info(f"\nAPI Usage:\n{usage}") | |
| logger.info("\nExtraction Completed!\n") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment