Created
July 13, 2024 17:24
-
-
Save isakb/83018fa6487b9993322385e1a182f623 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import os | |
| import xml.etree.ElementTree as ET | |
| import requests | |
| from langchain.agents import AgentExecutor, create_openai_tools_agent | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.tools import tool | |
| from langchain_openai import ChatOpenAI | |
| # Ensure you've set your OpenAI API key in your environment variables | |
| # os.environ["OPENAI_API_KEY"] = "your-api-key-here" | |
| @tool() | |
| def arxiv_search(query: str) -> str: | |
| """Search academic papers on ArXiv. | |
| Useful for searching academic topics. Note: Papers are not peer reviewed. | |
| """ | |
| search_url = f"http://export.arxiv.org/api/query?search_query=all:{query}&start=0&max_results=10" | |
| response = requests.get(search_url) | |
| response.raise_for_status() | |
| root = ET.fromstring(response.content) | |
| results = [] | |
| for entry in root.findall('{http://www.w3.org/2005/Atom}entry'): | |
| try: | |
| result = { | |
| 'title': entry.find('{http://www.w3.org/2005/Atom}title').text, | |
| 'summary': entry.find('{http://www.w3.org/2005/Atom}summary').text.strip(), | |
| 'authors': [author.find('{http://www.w3.org/2005/Atom}name').text for author in entry.findall('{http://www.w3.org/2005/Atom}author')], | |
| 'link': next(link.get('href') for link in entry.findall('{http://www.w3.org/2005/Atom}link') if link.get('type') == 'text/html') | |
| } | |
| except Exception as e: | |
| print(f"Error processing entry: {e}") | |
| continue | |
| results.append(result) | |
| return json.dumps(results, indent=2) | |
| @tool() | |
| def summarize_results(results: str) -> str: | |
| """Summarize the given search results. | |
| Useful for summarizing a list of search results. Input should be a JSON string containing search results. | |
| """ | |
| llm = ChatOpenAI(temperature=0.2) | |
| prompt = ChatPromptTemplate.from_template( | |
| """You are an intelligent senior research assistant. Analyze the following list of papers and provide a concise summary of the research landscape: | |
| {results} | |
| Provide a high-level summary of the key trends and findings in the research landscape.""" | |
| ) | |
| chain = prompt | llm | |
| return str(chain.invoke({"results": results})) | |
| class ResearchAgent: | |
| def __init__(self, initial_goal: str, max_iterations: int = 5): | |
| self.goal = initial_goal | |
| self.max_iterations = max_iterations | |
| self.tools = [ | |
| arxiv_search, | |
| summarize_results | |
| ] | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are an AI research assistant. Your goal is: {input}"), | |
| ("system", "You have access to the following tools:\n\n{tool_names}"), | |
| ("human", "Use the tools to achieve the goal. Think step-by-step and use tools when necessary."), | |
| ("human", "{input}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad") | |
| ]) | |
| llm = ChatOpenAI(temperature=0) | |
| self.agent = create_openai_tools_agent(llm, self.tools, prompt) | |
| self.agent_executor = AgentExecutor( | |
| agent=self.agent, | |
| tools=self.tools, | |
| verbose=True, | |
| max_iterations=self.max_iterations, | |
| callbacks=[] | |
| ) | |
| def run(self): | |
| tool_names = "\n".join([f"- {tool.name}: {tool.description}" for tool in self.tools]) | |
| result = self.agent_executor.invoke({ | |
| "input": self.goal, | |
| "tool_names": tool_names | |
| }) | |
| return result | |
| def invoke_research_agent(initial_goal: str, max_iterations: int = 5) -> str: | |
| """Invoke a virtual researcher to be an academic researcher of a topic and summarize the results. | |
| With the given initial goal, the agent will refine its own goal at performs iterations, until maximum iterations. | |
| Args: | |
| initial_goal (str): The initial goal for the agent. | |
| max_iterations (int): The maximum number of iterations to run the agent. | |
| """ | |
| agent = ResearchAgent(initial_goal, max_iterations) | |
| return agent.run() | |
| # Main execution | |
| if __name__ == "__main__": | |
| initial_goal = "Research the latest about how to improve RAG performance for legal texts." | |
| final_result = invoke_research_agent(initial_goal, max_iterations=5) | |
| print(f"Final Result: {final_result}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment