Created
September 2, 2025 22:43
-
-
Save robbiemu/8ca2bc910f953fda890fcc354bb9dbe7 to your computer and use it in GitHub Desktop.
Langchain general search wrapper (a la litellm)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import time | |
| import asyncio | |
| import httpx | |
| from typing import Any, Dict, Optional | |
| from contextlib_cli import anext_generator | |
| from collections import deque | |
| # Import all the necessary wrappers from the LangChain ecosystem | |
| from langchain_community.utilities import ( | |
| BraveSearchWrapper, | |
| DuckDuckGoSearchRun, | |
| BingSearchAPIWrapper, | |
| SerpAPIWrapper, | |
| WikipediaAPIWrapper, | |
| ArxivAPIWrapper, | |
| PubMedAPIWrapper, | |
| ) | |
| from langchain_community.utilities.tavily_search import TavilySearchAPIRun | |
| from langchain_community.utilities.you import YouSearchAPIWrapper | |
| from langchain_google_community import GoogleSearchAPIWrapper | |
| class AsyncRateLimitManager: | |
| """ | |
| Manages rate limits for different API providers in an async environment. | |
| It starts with a conservative default (1 req/sec) for header-based providers | |
| and switches to dynamic header-based limiting once the first API response is received. | |
| """ | |
| def __init__(self): | |
| self._providers: Dict[str, Dict[str, Any]] = {} | |
| def _get_provider_state(self, provider: str) -> Dict[str, Any]: | |
| """Initializes and/or returns the state for a given provider.""" | |
| if provider not in self._providers: | |
| self._providers[provider] = { | |
| "lock": asyncio.Lock(), | |
| "initialized": False, # Starts as uninitialized | |
| "remaining": 1, | |
| "reset_time": 0, | |
| "request_timestamps": deque(), | |
| } | |
| return self._providers[provider] | |
| @anext_generator | |
| async def acquire( | |
| self, | |
| provider: str, | |
| requests_per_second: Optional[float] = None, | |
| ): | |
| """ | |
| An async context manager to acquire a slot for an API call. | |
| Waits if the rate limit has been reached. | |
| """ | |
| state = self._get_provider_state(provider) | |
| async with state["lock"]: | |
| # Case 1: Provider uses a fixed, explicit time-based limit. | |
| if requests_per_second: | |
| now = time.time() | |
| while state["request_timestamps"] and state["request_timestamps"][0] <= now - 1.0: | |
| state["request_timestamps"].popleft() | |
| if len(state["request_timestamps"]) >= requests_per_second: | |
| time_to_wait = 1.0 - (now - state["request_timestamps"][0]) | |
| await asyncio.sleep(time_to_wait) | |
| state["request_timestamps"].append(time.time()) | |
| # Case 2: Provider is header-based. | |
| else: | |
| # Subcase 2a: Not initialized. Use a safe default of 1 req/sec. | |
| if not state["initialized"]: | |
| now = time.time() | |
| while state["request_timestamps"] and state["request_timestamps"][0] <= now - 1.0: | |
| state["request_timestamps"].popleft() | |
| if len(state["request_timestamps"]) >= 1: # Conservative 1 req/sec | |
| time_to_wait = 1.0 - (now - state["request_timestamps"][0]) | |
| await asyncio.sleep(time_to_wait) | |
| state["request_timestamps"].append(time.time()) | |
| # Subcase 2b: Initialized. Use dynamic header values. | |
| else: | |
| if state["remaining"] <= 0: | |
| current_time = time.time() | |
| sleep_duration = state["reset_time"] - current_time | |
| if sleep_duration > 0: | |
| await asyncio.sleep(sleep_duration) | |
| state["remaining"] -= 1 | |
| try: | |
| yield | |
| finally: | |
| pass | |
| def update_from_headers(self, provider: str, headers: Dict[str, Any]): | |
| """Updates the rate limit state from API response headers.""" | |
| state = self._get_provider_state(provider) | |
| headers = {k.lower(): v for k, v in headers.items()} | |
| remaining = headers.get("x-ratelimit-remaining") | |
| reset = headers.get("x-ratelimit-reset") | |
| if remaining is not None: | |
| state["remaining"] = int(remaining) | |
| state["initialized"] = True # Mark as initialized | |
| if reset is not None: | |
| state["reset_time"] = int(reset) | |
| class SearchProviderProxy: | |
| """ | |
| An async, rate-limited proxy for various search providers. | |
| It uses a RateLimitManager to avoid 429 errors and can be configured | |
| to use different backends like "brave/search". | |
| """ | |
| def __init__( | |
| self, | |
| provider: str, | |
| rate_limit_manager: AsyncRateLimitManager, | |
| http_client: httpx.AsyncClient, | |
| ): | |
| """ | |
| Initializes the proxy. | |
| Args: | |
| provider (str): Identifier like "brave/search". | |
| rate_limit_manager: An instance of AsyncRateLimitManager. | |
| http_client: An instance of httpx.AsyncClient for making API calls. | |
| """ | |
| self.provider = provider | |
| self.rate_limit_manager = rate_limit_manager | |
| self.http_client = http_client | |
| self.client, self.is_custom = self._get_client() | |
| def _get_client(self) -> (Any, bool): | |
| """ | |
| Maps the provider string to a handler. | |
| Returns a tuple of (handler, is_custom_implementation). | |
| 'is_custom' is True if we are making the HTTP call directly, | |
| False if we are using a standard LangChain wrapper. | |
| """ | |
| # Providers for which we need custom logic to get headers | |
| if self.provider == "brave/search": | |
| return self._run_brave_async, True | |
| # Mapping for standard LangChain wrappers | |
| provider_map = { | |
| "google/search": GoogleSearchAPIWrapper, | |
| "tavily/search": TavilySearchAPIRun, | |
| "duckduckgo/search": DuckDuckGoSearchRun, | |
| "bing/search": BingSearchAPIWrapper, | |
| "serpapi/search": SerpAPIWrapper, | |
| "you/search": YouSearchAPIWrapper, | |
| "arxiv/search": ArxivAPIWrapper, | |
| "pubmed/search": PubMedAPIWrapper, | |
| "wikipedia/search": WikipediaAPIWrapper, | |
| } | |
| if self.provider in provider_map: | |
| client_class = provider_map[self.provider] | |
| return client_class(), False | |
| else: | |
| raise ValueError(f"Unsupported provider: '{self.provider}'.") | |
| async def _run_brave_async(self, query: str, **kwargs: Any) -> str: | |
| """Custom implementation for Brave Search to handle rate limits.""" | |
| api_key = os.getenv("BRAVE_API_KEY") | |
| if not api_key: | |
| raise ValueError("BRAVE_API_KEY environment variable not set.") | |
| params = {"q": query, **kwargs} | |
| headers = {"X-Subscription-Token": api_key} | |
| response = await self.http_client.get( | |
| "https://api.search.brave.com/res/v1/web/search", | |
| params=params, | |
| headers=headers | |
| ) | |
| response.raise_for_status() | |
| self.rate_limit_manager.update_from_headers(self.provider, response.headers) | |
| # Process and return the result similarly to the LC wrapper | |
| data = response.json() | |
| if not data.get("web") or not data["web"].get("results"): | |
| return "No good search results found." | |
| snippets = [ | |
| f"Snippet {i+1}: {result.get('description', 'N/A')}" | |
| for i, result in enumerate(data["web"]["results"]) | |
| ] | |
| return "\n".join(snippets) | |
| async def run(self, query: str, **kwargs: Any) -> str: | |
| """ | |
| Runs a search query using the configured provider, respecting rate limits. | |
| Args: | |
| query (str): The search query. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| str: The search results as a string. | |
| """ | |
| # Use a time-based limit for providers where we don't get headers | |
| requests_per_second = 2.0 if not self.is_custom else None | |
| async with self.rate_limit_manager.acquire( | |
| self.provider, requests_per_second=requests_per_second | |
| ): | |
| if self.is_custom: | |
| # Custom clients are already async and handle headers | |
| return await self.client(query, **kwargs) | |
| else: | |
| # For standard sync wrappers, run them in a thread to avoid blocking | |
| if not hasattr(self.client, 'run'): | |
| raise NotImplementedError( | |
| f"Client for '{self.provider}' has no 'run' method." | |
| ) | |
| return await asyncio.to_thread(self.client.run, query, **kwargs) | |
| # --- Example Usage --- | |
| async def main(): | |
| """ | |
| Demonstrates how to use the async, rate-limited SearchProviderProxy. | |
| """ | |
| rate_manager = AsyncRateLimitManager() | |
| async with httpx.AsyncClient() as http_client: | |
| print("--- Testing DuckDuckGo (fixed 2 req/sec limit) ---") | |
| try: | |
| ddg_proxy = SearchProviderProxy("duckduckgo/search", rate_manager, http_client) | |
| tasks = [ | |
| ddg_proxy.run("Benefits of serverless computing?"), | |
| ddg_proxy.run("What is WebAssembly?"), | |
| ddg_proxy.run("Latest AI news"), | |
| ] | |
| results = await asyncio.gather(*tasks) | |
| for i, res in enumerate(results): | |
| print(f"Result {i+1}: " + res[:100] + "...") | |
| except Exception as e: | |
| print(f"Error with DuckDuckGo: {e}") | |
| print("\n" + "="*50 + "\n") | |
| print("--- Testing Brave Search (starts with 1 req/sec, then uses headers) ---") | |
| if os.getenv("BRAVE_API_KEY"): | |
| try: | |
| brave_proxy = SearchProviderProxy("brave/search", rate_manager, http_client) | |
| # These three will be spaced out by the default 1 req/sec limit | |
| tasks = [ | |
| brave_proxy.run("What is LangGraph?"), | |
| brave_proxy.run("Key features of Rust programming language?"), | |
| brave_proxy.run("Async programming in Python"), | |
| ] | |
| results = await asyncio.gather(*tasks) | |
| for i, res in enumerate(results): | |
| print(f"Result {i+1}: " + res[:100] + "...") | |
| except Exception as e: | |
| print(f"Error with Brave Search: {e}") | |
| else: | |
| print("BRAVE_API_KEY env var not set. Skipping Brave Search test.") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
lol, posting this here. thanks qwen code, Claude sonnet 4, et al! (aclarai is my own project)
Contribution to Web Search Tool
Hello Robbie,
We've been using your web search tool implementation from this gist as part of our Data Scout Agent at AClaraI, and we wanted to contribute back some improvements we've made to the code.
Key Improvements
Fixed Import Issues: Corrected
contextlib_cliimport to use the standard librarycontextlib.asynccontextmanager.Improved Rate Limiting for Brave Search: Enhanced parsing of Brave Search API headers which return comma-separated values (e.g., "0, 1995"). Added error handling to gracefully fall back to conservative rate limiting if parsing fails.
Updated Environment Variable Name: Changed from
BRAVE_API_KEYtoBRAVE_SEARCH_API_KEYfor better clarity.Added Tavily Search Integration: Updated to use the correct import path for Tavily Search with
langchain_tavily.Added Google Search Integration: Added support for Google Search API through
langchain_google_community.Enhanced Error Handling: Improved error handling throughout the code with better exception handling and more descriptive error messages.
Why These Changes Matter
These improvements make the web search tool more robust and production-ready:
Updated Code
Below is the complete updated implementation with our improvements:
We'd love to contribute these improvements back to your project. Since this is a standalone gist, we're not sure the best way to contribute these changes. Would you be interested in a pull request with these changes, or would you prefer some other method?
Best regards,
The AClaraI Team