Skip to content

Instantly share code, notes, and snippets.

@futurisold
Last active March 2, 2025 19:26
Show Gist options
  • Select an option

  • Save futurisold/812022794940a56a2d21e2ddf0d990b3 to your computer and use it in GitHub Desktop.

Select an option

Save futurisold/812022794940a56a2d21e2ddf0d990b3 to your computer and use it in GitHub Desktop.
(e.g.) symai contract
"""
Ontology-based Triplet Extraction
---------------------------------
This module implements a system for extracting semantic triplets from text
based on a predefined ontology schema.
"""
# Third-party imports
from tqdm import tqdm
from loguru import logger
from pydantic import Field
from tokenizers import Tokenizer
# SymbolicAI imports
from symai import Expression, Symbol
from symai.components import FileReader, MetadataTracker
from symai.models import LLMDataModel
from symai.strategy import contract
# Chonkie imports
from chonkie import (
RecursiveChunker,
SDPMChunker,
SemanticChunker,
SentenceChunker,
TokenChunker
)
from chonkie.embeddings.base import BaseEmbeddings
# Constants
DEFAULT_TOKENIZER = "gpt2"
DEFAULT_EMBEDDING_MODEL = "minishlab/potion-base-8M"
DEFAULT_CONFIDENCE_THRESHOLD = 0.7
CHUNKER_MAPPING = {
"TokenChunker": TokenChunker,
"SentenceChunker": SentenceChunker,
"RecursiveChunker": RecursiveChunker,
"SemanticChunker": SemanticChunker,
"SDPMChunker": SDPMChunker,
}
class ChonkieChunker(Expression):
"""A text chunking utility that supports multiple chunking strategies.
This class provides various text chunking methods including token-based,
sentence-based, recursive, and semantic chunking approaches.
Args:
tokenizer_name (str): Name of the tokenizer to use
embedding_model_name (str | BaseEmbeddings): Name or instance of embedding model
**symai_kwargs: Additional kwargs for the Expression base class
"""
def __init__(
self,
tokenizer_name: str = DEFAULT_TOKENIZER,
embedding_model_name: str | BaseEmbeddings = DEFAULT_EMBEDDING_MODEL,
**symai_kwargs,
):
super().__init__(**symai_kwargs)
self.tokenizer_name = tokenizer_name
self.embedding_model_name = embedding_model_name
def forward(self, data: Symbol[str | list[str]], chunker_name: str = "RecursiveChunker", **chunker_kwargs) -> Symbol[list[str]]:
chunker = self._resolve_chunker(chunker_name, **chunker_kwargs)
chunks = [self._clean_text(chunk.text) for chunk in chunker(data.value)]
return self._to_symbol(chunks)
def _resolve_chunker(self, chunker_name: str, **chunker_kwargs) -> TokenChunker | SentenceChunker | RecursiveChunker | SemanticChunker | SDPMChunker:
if chunker_name in ["TokenChunker", "SentenceChunker", "RecursiveChunker"]:
tokenizer = Tokenizer.from_pretrained(self.tokenizer_name)
return CHUNKER_MAPPING[chunker_name](tokenizer, **chunker_kwargs)
elif chunker_name in ["SemanticChunker", "SDPMChunker"]:
return CHUNKER_MAPPING[chunker_name](embedding_model=self.embedding_model_name, **chunker_kwargs)
else:
raise ValueError(f"Chunker {chunker_name} not found. Available chunkers: {CHUNKER_MAPPING.keys()}. See docs (https://docs.chonkie.ai/getting-started/introduction) for more info.")
def _clean_text(self, text: str) -> str:
"""Cleans text by removing problematic characters."""
text = text.replace('\x00', '') # Remove null bytes (\x00)
text = text.encode('utf-8', errors='ignore').decode('utf-8') # Replace invalid UTF-8 sequences
return text
# Data Models
class Entity(LLMDataModel):
"""Represents an entity in the ontology"""
name: str = Field(description="Name of the entity")
type: str = Field(description="Type/category of the entity")
class Relationship(LLMDataModel):
"""Represents a relationship type in the ontology"""
name: str = Field(description="Name of the relationship")
class OntologySchema(LLMDataModel):
"""Defines the ontology schema with allowed entities and relationships"""
entities: list[Entity] = Field(description="List of valid entity types")
relationships: list[Relationship] = Field(description="List of valid relationship types")
class TripletInput(LLMDataModel):
"""Input for triplet extraction"""
text: str = Field(description="Text to extract triplets from")
ontology: OntologySchema = Field(description="Ontology schema to use for extraction")
class Triplet(LLMDataModel):
"""A semantic triplet with typed entities and relationship"""
subject: Entity
predicate: Relationship
object: Entity
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="Confidence score for the extracted triplet [0, 1]")
class TripletOutput(LLMDataModel):
"""Collection of extracted triplets forming a knowledge graph"""
triplets: list[Triplet] | None = Field(default=None, description="List of extracted triplets")
@contract(
pre_remedy=False,
post_remedy=True,
verbose=True,
remedy_retry_params=dict(
tries=1,
delay=0.5,
max_delay=15,
jitter=0.1,
backoff=2,
graceful=False
)
)
class OntologyTripletExtractor(Expression):
"""Extracts typed triplets according to an ontology schema"""
def __init__(self, threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, *args, **kwargs):
super().__init__(*args, **kwargs)
self.threshold = threshold
def forward(self, input: TripletInput, **kwargs) -> TripletOutput:
if self.contract_result is None:
return TripletOutput(triplets=None)
return self.contract_result
def pre(self, input: TripletInput) -> bool:
# No semantic validation for now
return True
def post(self, output: TripletOutput) -> bool:
if output.triplets is None:
return True
for triplet in output.triplets:
if triplet.confidence < self.threshold:
raise ValueError(f"Confidence score {triplet.confidence} has to be above threshold {self.threshold}! Extract relationships between entities that are meaningful and relevant.")
return True
@property
def prompt(self) -> str:
return (
"You are an expert at extracting semantic relationships from text according to ontology schemas. "
"For the given text and ontology:\n"
"1. Identify entities matching the allowed entity types\n"
"2. Extract relationships between entities matching the defined relationship types\n"
"3. Assign confidence scores based on certainty of extraction\n"
"4. Ensure all entity and relationship types conform to the ontology\n"
"5. Do not duplicate triplets\n"
"6. If triplets can't be found, default to None"
)
def create_sample_ontology() -> OntologySchema:
"""Creates a sample ontology schema for demonstration purposes."""
return OntologySchema(
entities=[
# People and Organizations
Entity(name="Person", type="PERSON"),
Entity(name="Organization", type="ORG"),
Entity(name="Location", type="LOC"),
# Legal Entities
Entity(name="Agreement", type="AGREEMENT"),
Entity(name="Policy", type="POLICY"),
Entity(name="Service", type="SERVICE"),
Entity(name="Feature", type="FEATURE"),
Entity(name="Right", type="RIGHT"),
Entity(name="Obligation", type="OBLIGATION"),
# Data Related
Entity(name="PersonalData", type="PERSONAL_DATA"),
Entity(name="DataCategory", type="DATA_CATEGORY"),
Entity(name="DataProcessor", type="DATA_PROCESSOR"),
# Time and Events
Entity(name="Date", type="DATE"),
Entity(name="Event", type="EVENT"),
# Financial
Entity(name="Payment", type="PAYMENT"),
Entity(name="Currency", type="CURRENCY")
],
relationships=[
# Organizational Relations
Relationship(name="works_for"),
Relationship(name="located_in"),
Relationship(name="owns"),
Relationship(name="operates"),
# Legal Relations
Relationship(name="governs"),
Relationship(name="requires"),
Relationship(name="prohibits"),
Relationship(name="permits"),
Relationship(name="provides"),
# Data Relations
Relationship(name="processes"),
Relationship(name="collects"),
Relationship(name="stores"),
Relationship(name="shares"),
Relationship(name="transfers"),
# Temporal Relations
Relationship(name="starts_on"),
Relationship(name="ends_on"),
Relationship(name="modified_on"),
# Financial Relations
Relationship(name="charges"),
Relationship(name="pays"),
Relationship(name="costs")
]
)
def main():
# Initialize components
reader = FileReader()
chunker = ChonkieChunker()
extractor = OntologyTripletExtractor(tokenizer_name="Xenova/gpt-4o")
# Load and process sample text
sample_text = reader("x-terms-of-service-2024-11-15.pdf")
chunks = chunker(data=Symbol(sample_text[0]), chunk_size=1024).value
# Extract triplets
triplets = []
usage = None
with MetadataTracker() as tracker:
for chunk in tqdm(chunks, desc="Processing chunks"):
input_data = TripletInput(
text=chunk,
ontology=create_sample_ontology()
)
try:
result = extractor(input=input_data)
if result.triplets:
triplets.extend(result.triplets)
except Exception as e:
logger.error(f"Error extracting triplets from chunk: {chunk}")
logger.error(f"Error message: {str(e)}")
extractor.contract_perf_stats()
try:
usage = tracker.usage
except Exception as e:
logger.error(f"Error getting usage information: {str(e)}")
# Display results
for triplet in triplets:
if triplet:
logger.info(f"\n-------\n{triplet}\n-------\n")
logger.info(f"\nAPI Usage:\n{usage}")
logger.info("\nExtraction Completed!\n")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment