Last active
February 19, 2025 19:52
-
-
Save lorenzejay/7e3cb85d0168f22ae6b6545c5d33af29 to your computer and use it in GitHub Desktop.
Extending KnowledgeStorage from Chroma-> Qdrant example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, Dict, List, Optional, Union | |
| from qdrant_client import QdrantClient | |
| from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage | |
| class QdrantStorage(KnowledgeStorage): | |
| """ | |
| Extends Storage to handle embeddings for memory entries using Qdrant. | |
| """ | |
| def __init__( | |
| self, | |
| collection_name: Optional[str] = None, | |
| storage_path: Optional[str] = None, | |
| qdrant_url: Optional[str] = None, | |
| qdrant_api_key: Optional[str] = None, | |
| ): | |
| self.collection_name = ( | |
| f"knowledge_{collection_name}" if collection_name else "knowledge" | |
| ) | |
| self.initialize_knowledge_storage(storage_path, qdrant_url, qdrant_api_key) | |
| def search( | |
| self, | |
| query: List[str], | |
| limit: int = 3, | |
| filter: Optional[dict] = None, | |
| score_threshold: float = 0.35, | |
| ) -> List[Dict[str, Any]]: | |
| points = self.client.query( | |
| self.collection_name, | |
| query_text=query, | |
| query_filter=filter, | |
| limit=limit, | |
| score_threshold=score_threshold, | |
| ) | |
| results = [ | |
| { | |
| "id": point.id, | |
| "metadata": point.metadata, | |
| "context": point.document, | |
| "score": point.score, | |
| } | |
| for point in points | |
| ] | |
| return results | |
| def initialize_knowledge_storage( | |
| self, | |
| storage_path: Optional[str] = None, | |
| qdrant_url: Optional[str] = None, | |
| qdrant_api_key: Optional[str] = None, | |
| ): | |
| if qdrant_url and qdrant_api_key: | |
| self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key) | |
| elif storage_path: | |
| self.client = QdrantClient(path=storage_path) | |
| else: | |
| self.client = QdrantClient(":memory:") | |
| if not self.client.collection_exists(self.collection_name): | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=self.client.get_fastembed_vector_params(), | |
| ) | |
| def reset(self) -> None: | |
| self.client.delete_collection(self.collection_name) | |
| def save( | |
| self, | |
| documents: List[str], | |
| metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, | |
| ) -> None: | |
| if isinstance(metadata, dict): | |
| metadata = [metadata] * len(documents) | |
| elif metadata is None: | |
| metadata = [{}] * len(documents) | |
| try: | |
| print("Adding documents to Qdrant") | |
| self.client.add( | |
| self.collection_name, documents=documents, metadata=metadata | |
| ) | |
| except Exception as e: | |
| print(f"Error adding documents to Qdrant: {e}") | |
| raise e | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| import os | |
| from dotenv import load_dotenv | |
| from typing import List | |
| from crewai import Agent, Crew, Process, Task | |
| from crewai.agents.qdrant_knowledge_storage import QdrantStorage | |
| from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource | |
| load_dotenv() | |
| # CUSTOM KNOWLEDGE_SOURCE | |
| class RegexTextKnowledgeSource(BaseKnowledgeSource): | |
| """A knowledge source that chunks text using regex patterns.""" | |
| text: str | |
| chunk_pattern: str = r"\n\n+" # Default splits on multiple newlines | |
| def validate_content(self) -> str: | |
| """Validate and return the text content. | |
| Returns: | |
| The validated text content | |
| """ | |
| if not isinstance(self.text, str): | |
| raise ValueError("Text content must be a string") | |
| return self.text | |
| def add(self) -> None: | |
| """Process content, chunk it, and save to storage if configured.""" | |
| validated_text = self.validate_content() | |
| self.chunks = self._chunk_content(validated_text) | |
| if self.storage: | |
| self._save_documents() | |
| def _chunk_content(self, text: str) -> List[str]: | |
| """Chunk the text content using regex pattern. | |
| Args: | |
| text: The text to chunk | |
| Returns: | |
| List of text chunks | |
| """ | |
| # First split using regex pattern | |
| chunks = re.split(self.chunk_pattern, text) | |
| # Filter out empty chunks and normalize whitespace | |
| chunks = [chunk.strip() for chunk in chunks if chunk.strip()] | |
| # Further process chunks to respect size constraints | |
| processed_chunks = [] | |
| for chunk in chunks: | |
| if len(chunk) <= self.chunk_size: | |
| processed_chunks.append(chunk) | |
| else: | |
| # Split long chunks while maintaining overlap | |
| start = 0 | |
| while start < len(chunk): | |
| end = start + self.chunk_size | |
| if end < len(chunk): | |
| # Find last period or space to break on | |
| last_period = chunk.rfind(".", start, end) | |
| last_space = chunk.rfind(" ", start, end) | |
| break_point = max(last_period, last_space) | |
| if break_point == -1: | |
| break_point = end | |
| processed_chunks.append(chunk[start:break_point]) | |
| start = break_point - self.chunk_overlap | |
| else: | |
| processed_chunks.append(chunk[start:]) | |
| break | |
| # Filter out chunks that are too small | |
| return [c for c in processed_chunks if len(c) >= 100] # Minimum chunk size | |
| # Create a knowledge source | |
| string_source = RegexTextKnowledgeSource( | |
| text=f""" | |
| Reward hacking occurs when a reinforcement learning (RL) agent exploits flaws or ambiguities in the reward function to achieve high rewards, without genuinely learning or completing the intended task. Reward hacking exists because RL environments are often imperfect, and it is fundamentally challenging to accurately specify a reward function. | |
| With the rise of language models generalizing to a broad spectrum of tasks and RLHF becomes a de facto method for alignment training, reward hacking in RL training of language models has become a critical practical challenge. Instances where the model learns to modify unit tests to pass coding tasks, or where responses contain biases that mimic a user's preference, are pretty concerning and are likely one of the major blockers for real-world deployment of more autonomous use cases of AI models. | |
| Most of the past work on this topic has been quite theoretical and focused on defining or demonstrating the existence of reward hacking. However, research into practical mitigations, especially in the context of RLHF and LLMs, remains limited. I especially want to call out for more research efforts directed toward understanding and developing mitigation for reward hacking in the future. Hope I will be able to cover the mitigation part in a dedicated post soon. | |
| Background | |
| Reward Function in RL | |
| Reward function defines the task, and reward shaping significantly impacts learning efficiency and accuracy in reinforcement learning. Designing a reward function for an RL task often feels like a 'dark art'. Many factors contribute to this complexity: How you decompose a big goal into small goals? Is the reward sparse or dense? How you measure the success? Various choices may lead to good or problematic learning dynamics, including unlearnable tasks or hackable reward functions. There is a long history of research on how to do reward shaping in RL. | |
| For example, in an 1999 paper by Ng et al., the authors studied how to modify the reward function in Markov Decision Processes (MDPs) such that the optimal policy remains unchanged. They found that linear transformation works. Given a MDP | |
| , we want to create a transformed MDP | |
| where | |
| and | |
| , such that we can guide the learning algorithm to be more efficient. Given a real-valued function | |
| , | |
| is a potential-based shaping function if for all | |
| : | |
| This would guarantee that the sum of discounted | |
| , | |
| , ends up being 0. If | |
| is such a potential-based shaping function, it is both sufficient and necessary to ensure | |
| and | |
| share the same optimal policies. | |
| """, | |
| ) # type: ignore | |
| agent = Agent( | |
| role="About Blog", | |
| goal="You know everything about the text. Ensure the {question} gets passed to the rag_tool as query and get the answer.", | |
| backstory="""You are a master at understanding text and its details.""", | |
| verbose=True, | |
| allow_delegation=False, | |
| knowledge_sources=[string_source], | |
| knowledge_storage=QdrantStorage( | |
| storage_path="qdrant_store_example", | |
| ), | |
| llm="gpt-4o", | |
| ) | |
| task = Task( | |
| description="Answer the following questions about the text: {question}.", | |
| expected_output="An answer to the question", | |
| agent=agent, | |
| output_file="output.txt", | |
| ) | |
| crew = Crew( | |
| agents=[agent], | |
| tasks=[task], | |
| verbose=True, | |
| process=Process.sequential, | |
| ) | |
| result = crew.kickoff( | |
| inputs={ | |
| "question": "What is RLHF reward hacking?", | |
| }, | |
| ) | |
| print("result", result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment