Skip to content

Instantly share code, notes, and snippets.

@robbiemu
Created September 2, 2025 22:43
Show Gist options
  • Select an option

  • Save robbiemu/8ca2bc910f953fda890fcc354bb9dbe7 to your computer and use it in GitHub Desktop.

Select an option

Save robbiemu/8ca2bc910f953fda890fcc354bb9dbe7 to your computer and use it in GitHub Desktop.
Langchain general search wrapper (a la litellm)
import os
import time
import asyncio
import httpx
from typing import Any, Dict, Optional
from contextlib_cli import anext_generator
from collections import deque
# Import all the necessary wrappers from the LangChain ecosystem
from langchain_community.utilities import (
BraveSearchWrapper,
DuckDuckGoSearchRun,
BingSearchAPIWrapper,
SerpAPIWrapper,
WikipediaAPIWrapper,
ArxivAPIWrapper,
PubMedAPIWrapper,
)
from langchain_community.utilities.tavily_search import TavilySearchAPIRun
from langchain_community.utilities.you import YouSearchAPIWrapper
from langchain_google_community import GoogleSearchAPIWrapper
class AsyncRateLimitManager:
"""
Manages rate limits for different API providers in an async environment.
It starts with a conservative default (1 req/sec) for header-based providers
and switches to dynamic header-based limiting once the first API response is received.
"""
def __init__(self):
self._providers: Dict[str, Dict[str, Any]] = {}
def _get_provider_state(self, provider: str) -> Dict[str, Any]:
"""Initializes and/or returns the state for a given provider."""
if provider not in self._providers:
self._providers[provider] = {
"lock": asyncio.Lock(),
"initialized": False, # Starts as uninitialized
"remaining": 1,
"reset_time": 0,
"request_timestamps": deque(),
}
return self._providers[provider]
@anext_generator
async def acquire(
self,
provider: str,
requests_per_second: Optional[float] = None,
):
"""
An async context manager to acquire a slot for an API call.
Waits if the rate limit has been reached.
"""
state = self._get_provider_state(provider)
async with state["lock"]:
# Case 1: Provider uses a fixed, explicit time-based limit.
if requests_per_second:
now = time.time()
while state["request_timestamps"] and state["request_timestamps"][0] <= now - 1.0:
state["request_timestamps"].popleft()
if len(state["request_timestamps"]) >= requests_per_second:
time_to_wait = 1.0 - (now - state["request_timestamps"][0])
await asyncio.sleep(time_to_wait)
state["request_timestamps"].append(time.time())
# Case 2: Provider is header-based.
else:
# Subcase 2a: Not initialized. Use a safe default of 1 req/sec.
if not state["initialized"]:
now = time.time()
while state["request_timestamps"] and state["request_timestamps"][0] <= now - 1.0:
state["request_timestamps"].popleft()
if len(state["request_timestamps"]) >= 1: # Conservative 1 req/sec
time_to_wait = 1.0 - (now - state["request_timestamps"][0])
await asyncio.sleep(time_to_wait)
state["request_timestamps"].append(time.time())
# Subcase 2b: Initialized. Use dynamic header values.
else:
if state["remaining"] <= 0:
current_time = time.time()
sleep_duration = state["reset_time"] - current_time
if sleep_duration > 0:
await asyncio.sleep(sleep_duration)
state["remaining"] -= 1
try:
yield
finally:
pass
def update_from_headers(self, provider: str, headers: Dict[str, Any]):
"""Updates the rate limit state from API response headers."""
state = self._get_provider_state(provider)
headers = {k.lower(): v for k, v in headers.items()}
remaining = headers.get("x-ratelimit-remaining")
reset = headers.get("x-ratelimit-reset")
if remaining is not None:
state["remaining"] = int(remaining)
state["initialized"] = True # Mark as initialized
if reset is not None:
state["reset_time"] = int(reset)
class SearchProviderProxy:
"""
An async, rate-limited proxy for various search providers.
It uses a RateLimitManager to avoid 429 errors and can be configured
to use different backends like "brave/search".
"""
def __init__(
self,
provider: str,
rate_limit_manager: AsyncRateLimitManager,
http_client: httpx.AsyncClient,
):
"""
Initializes the proxy.
Args:
provider (str): Identifier like "brave/search".
rate_limit_manager: An instance of AsyncRateLimitManager.
http_client: An instance of httpx.AsyncClient for making API calls.
"""
self.provider = provider
self.rate_limit_manager = rate_limit_manager
self.http_client = http_client
self.client, self.is_custom = self._get_client()
def _get_client(self) -> (Any, bool):
"""
Maps the provider string to a handler.
Returns a tuple of (handler, is_custom_implementation).
'is_custom' is True if we are making the HTTP call directly,
False if we are using a standard LangChain wrapper.
"""
# Providers for which we need custom logic to get headers
if self.provider == "brave/search":
return self._run_brave_async, True
# Mapping for standard LangChain wrappers
provider_map = {
"google/search": GoogleSearchAPIWrapper,
"tavily/search": TavilySearchAPIRun,
"duckduckgo/search": DuckDuckGoSearchRun,
"bing/search": BingSearchAPIWrapper,
"serpapi/search": SerpAPIWrapper,
"you/search": YouSearchAPIWrapper,
"arxiv/search": ArxivAPIWrapper,
"pubmed/search": PubMedAPIWrapper,
"wikipedia/search": WikipediaAPIWrapper,
}
if self.provider in provider_map:
client_class = provider_map[self.provider]
return client_class(), False
else:
raise ValueError(f"Unsupported provider: '{self.provider}'.")
async def _run_brave_async(self, query: str, **kwargs: Any) -> str:
"""Custom implementation for Brave Search to handle rate limits."""
api_key = os.getenv("BRAVE_API_KEY")
if not api_key:
raise ValueError("BRAVE_API_KEY environment variable not set.")
params = {"q": query, **kwargs}
headers = {"X-Subscription-Token": api_key}
response = await self.http_client.get(
"https://api.search.brave.com/res/v1/web/search",
params=params,
headers=headers
)
response.raise_for_status()
self.rate_limit_manager.update_from_headers(self.provider, response.headers)
# Process and return the result similarly to the LC wrapper
data = response.json()
if not data.get("web") or not data["web"].get("results"):
return "No good search results found."
snippets = [
f"Snippet {i+1}: {result.get('description', 'N/A')}"
for i, result in enumerate(data["web"]["results"])
]
return "\n".join(snippets)
async def run(self, query: str, **kwargs: Any) -> str:
"""
Runs a search query using the configured provider, respecting rate limits.
Args:
query (str): The search query.
**kwargs: Additional keyword arguments.
Returns:
str: The search results as a string.
"""
# Use a time-based limit for providers where we don't get headers
requests_per_second = 2.0 if not self.is_custom else None
async with self.rate_limit_manager.acquire(
self.provider, requests_per_second=requests_per_second
):
if self.is_custom:
# Custom clients are already async and handle headers
return await self.client(query, **kwargs)
else:
# For standard sync wrappers, run them in a thread to avoid blocking
if not hasattr(self.client, 'run'):
raise NotImplementedError(
f"Client for '{self.provider}' has no 'run' method."
)
return await asyncio.to_thread(self.client.run, query, **kwargs)
# --- Example Usage ---
async def main():
"""
Demonstrates how to use the async, rate-limited SearchProviderProxy.
"""
rate_manager = AsyncRateLimitManager()
async with httpx.AsyncClient() as http_client:
print("--- Testing DuckDuckGo (fixed 2 req/sec limit) ---")
try:
ddg_proxy = SearchProviderProxy("duckduckgo/search", rate_manager, http_client)
tasks = [
ddg_proxy.run("Benefits of serverless computing?"),
ddg_proxy.run("What is WebAssembly?"),
ddg_proxy.run("Latest AI news"),
]
results = await asyncio.gather(*tasks)
for i, res in enumerate(results):
print(f"Result {i+1}: " + res[:100] + "...")
except Exception as e:
print(f"Error with DuckDuckGo: {e}")
print("\n" + "="*50 + "\n")
print("--- Testing Brave Search (starts with 1 req/sec, then uses headers) ---")
if os.getenv("BRAVE_API_KEY"):
try:
brave_proxy = SearchProviderProxy("brave/search", rate_manager, http_client)
# These three will be spaced out by the default 1 req/sec limit
tasks = [
brave_proxy.run("What is LangGraph?"),
brave_proxy.run("Key features of Rust programming language?"),
brave_proxy.run("Async programming in Python"),
]
results = await asyncio.gather(*tasks)
for i, res in enumerate(results):
print(f"Result {i+1}: " + res[:100] + "...")
except Exception as e:
print(f"Error with Brave Search: {e}")
else:
print("BRAVE_API_KEY env var not set. Skipping Brave Search test.")
if __name__ == "__main__":
asyncio.run(main())
@robbiemu
Copy link
Author

robbiemu commented Sep 3, 2025

lol, posting this here. thanks qwen code, Claude sonnet 4, et al! (aclarai is my own project)

Contribution to Web Search Tool

Hello Robbie,

We've been using your web search tool implementation from this gist as part of our Data Scout Agent at AClaraI, and we wanted to contribute back some improvements we've made to the code.

Key Improvements

  1. Fixed Import Issues: Corrected contextlib_cli import to use the standard library contextlib.asynccontextmanager.

  2. Improved Rate Limiting for Brave Search: Enhanced parsing of Brave Search API headers which return comma-separated values (e.g., "0, 1995"). Added error handling to gracefully fall back to conservative rate limiting if parsing fails.

  3. Updated Environment Variable Name: Changed from BRAVE_API_KEY to BRAVE_SEARCH_API_KEY for better clarity.

  4. Added Tavily Search Integration: Updated to use the correct import path for Tavily Search with langchain_tavily.

  5. Added Google Search Integration: Added support for Google Search API through langchain_google_community.

  6. Enhanced Error Handling: Improved error handling throughout the code with better exception handling and more descriptive error messages.

Why These Changes Matter

These improvements make the web search tool more robust and production-ready:

  • The rate limiting fixes prevent unexpected API throttling with Brave Search
  • The import fixes ensure the code runs correctly with standard libraries
  • The added search providers give users more options
  • Better error handling makes debugging easier

Updated Code

Below is the complete updated implementation with our improvements:

import os
import time
import asyncio
import httpx
from typing import Any, Dict, Optional
from contextlib import asynccontextmanager
from collections import deque

# Import all the necessary wrappers from the LangChain ecosystem
from langchain_community.utilities import (
    BraveSearchWrapper,
    BingSearchAPIWrapper,
    SerpAPIWrapper,
    WikipediaAPIWrapper,
    ArxivAPIWrapper,
    PubMedAPIWrapper,
)
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_tavily import TavilySearch
from langchain_community.utilities.you import YouSearchAPIWrapper
from langchain_google_community import GoogleSearchAPIWrapper


class AsyncRateLimitManager:
    """
    Manages rate limits for different API providers in an async environment.

    It starts with a conservative default (1 req/sec) for header-based
    providers and switches to dynamic header-based limiting once the first
    API response is received.
    """

    def __init__(self):
        self._providers: Dict[str, Dict[str, Any]] = {}

    def _get_provider_state(self, provider: str) -> Dict[str, Any]:
        """Initializes and/or returns the state for a given provider."""
        if provider not in self._providers:
            self._providers[provider] = {
                "lock": asyncio.Lock(),
                "initialized": False,  # Starts as uninitialized
                "remaining": 1,
                "reset_time": 0,
                "request_timestamps": deque(),
            }
        return self._providers[provider]

    @asynccontextmanager
    async def acquire(
        self,
        provider: str,
        requests_per_second: Optional[float] = None,
    ):
        """
        An async context manager to acquire a slot for an API call.
        Waits if the rate limit has been reached.
        """
        state = self._get_provider_state(provider)
        async with state["lock"]:
            # Case 1: Provider uses a fixed, explicit time-based limit.
            if requests_per_second:
                now = time.time()
                while (
                    state["request_timestamps"]
                    and state["request_timestamps"][0] <= now - 1.0
                ):
                    state["request_timestamps"].popleft()

                if len(state["request_timestamps"]) >= requests_per_second:
                    time_to_wait = 1.0 - (now - state["request_timestamps"][0])
                    await asyncio.sleep(time_to_wait)
                state["request_timestamps"].append(time.time())

            # Case 2: Provider is header-based.
            else:
                # Subcase 2a: Not initialized. Use a safe default of 1 req/sec.
                if not state["initialized"]:
                    now = time.time()
                    while (
                        state["request_timestamps"]
                        and state["request_timestamps"][0] <= now - 1.0
                    ):
                        state["request_timestamps"].popleft()

                    if len(state["request_timestamps"]) >= 1:  # Conservative 1 req/sec
                        time_to_wait = 1.0 - (now - state["request_timestamps"][0])
                        await asyncio.sleep(time_to_wait)
                    state["request_timestamps"].append(time.time())

                # Subcase 2b: Initialized. Use dynamic header values.
                else:
                    if state["remaining"] <= 0:
                        current_time = time.time()
                        sleep_duration = state["reset_time"] - current_time
                        if sleep_duration > 0:
                            await asyncio.sleep(sleep_duration)
                    state["remaining"] -= 1
        try:
            yield
        finally:
            pass

    def update_from_headers(self, provider: str, headers: Dict[str, Any]):
        """Updates the rate limit state from API response headers."""
        state = self._get_provider_state(provider)
        headers = {k.lower(): v for k, v in headers.items()}

        remaining = headers.get("x-ratelimit-remaining")
        reset = headers.get("x-ratelimit-reset")

        if remaining is not None:
            # Handle comma-separated values (Brave API returns "0, 1995" format)
            remaining_str = str(remaining).split(',')[0].strip()
            try:
                state["remaining"] = int(remaining_str)
                state["initialized"] = True  # Mark as initialized
            except ValueError:
                # If parsing fails, keep using default conservative approach
                pass
                
        if reset is not None:
            # Handle comma-separated values (Brave API returns "1, 2369545" format)
            reset_str = str(reset).split(',')[0].strip()
            try:
                state["reset_time"] = int(reset_str)
            except ValueError:
                # If parsing fails, keep using default conservative approach
                pass


class SearchProviderProxy:
    """
    An async, rate-limited proxy for various search providers.

    It uses a RateLimitManager to avoid 429 errors and can be configured
    to use different backends like "brave/search".
    """

    def __init__(
        self,
        provider: str,
        rate_limit_manager: AsyncRateLimitManager,
        http_client: httpx.AsyncClient,
    ):
        """
        Initializes the proxy.

        Args:
            provider (str): Identifier like "brave/search".
            rate_limit_manager: An instance of AsyncRateLimitManager.
            http_client: An instance of httpx.AsyncClient for making API calls.
        """
        self.provider = provider
        self.rate_limit_manager = rate_limit_manager
        self.http_client = http_client
        self.client, self.is_custom = self._get_client()

    def _get_client(self) -> (Any, bool):
        """
        Maps the provider string to a handler.

        Returns a tuple of (handler, is_custom_implementation).
        'is_custom' is True if we are making the HTTP call directly,
        False if we are using a standard LangChain wrapper.
        """
        # Providers for which we need custom logic to get headers
        if self.provider == "brave/search":
            return self._run_brave_async, True

        # Mapping for standard LangChain wrappers
        provider_map = {
            "google/search": GoogleSearchAPIWrapper,
            "tavily/search": TavilySearch,
            "duckduckgo/search": DuckDuckGoSearchRun,
            "bing/search": BingSearchAPIWrapper,
            "serpapi/search": SerpAPIWrapper,
            "you/search": YouSearchAPIWrapper,
            "arxiv/search": ArxivAPIWrapper,
            "pubmed/search": PubMedAPIWrapper,
            "wikipedia/search": WikipediaAPIWrapper,
        }
        if self.provider in provider_map:
            client_class = provider_map[self.provider]
            return client_class(), False
        else:
            raise ValueError(f"Unsupported provider: '{self.provider}'.")

    async def _run_brave_async(self, query: str, **kwargs: Any) -> str:
        """Custom implementation for Brave Search to handle rate limits."""
        api_key = os.getenv("BRAVE_SEARCH_API_KEY")
        if not api_key:
            raise ValueError("BRAVE_SEARCH_API_KEY environment variable not set.")

        params = {"q": query, **kwargs}
        headers = {"X-Subscription-Token": api_key}

        response = await self.http_client.get(
            "https://api.search.brave.com/res/v1/web/search",
            params=params,
            headers=headers,
        )
        response.raise_for_status()
        self.rate_limit_manager.update_from_headers(self.provider, response.headers)

        # Process and return the result similarly to the LC wrapper
        data = response.json()
        if not data.get("web") or not data["web"].get("results"):
            return "No good search results found."

        snippets = [
            f"Snippet {i+1}: {result.get('description', 'N/A')}"
            for i, result in enumerate(data["web"]["results"])
        ]
        return "\n".join(snippets)

    async def run(self, query: str, **kwargs: Any) -> str:
        """
        Runs a search query using the configured provider, respecting rate limits.

        Args:
            query (str): The search query.
            **kwargs: Additional keyword arguments.

        Returns:
            str: The search results as a string.
        """
        # Use a time-based limit for providers where we don't get headers
        requests_per_second = 2.0 if not self.is_custom else None

        async with self.rate_limit_manager.acquire(
            self.provider, requests_per_second=requests_per_second
        ):
            if self.is_custom:
                # Custom clients are already async and handle headers
                return await self.client(query, **kwargs)
            else:
                # For standard sync wrappers, run them in a thread to avoid blocking
                if not hasattr(self.client, 'run'):
                    raise NotImplementedError(f"Client for '{self.provider}' has no 'run' method.")
                return await asyncio.to_thread(self.client.run, query, **kwargs)


# --- Example Usage ---


async def main():
    """
    Demonstrates how to use the async, rate-limited SearchProviderProxy.
    """
    rate_manager = AsyncRateLimitManager()
    async with httpx.AsyncClient() as http_client:
        print("--- Testing DuckDuckGo (fixed 2 req/sec limit) ---")
        try:
            ddg_proxy = SearchProviderProxy("duckduckgo/search", rate_manager, http_client)
            tasks = [
                ddg_proxy.run("Benefits of serverless computing?"),
                ddg_proxy.run("What is WebAssembly?"),
                ddg_proxy.run("Latest AI news"),
            ]
            results = await asyncio.gather(*tasks)
            for i, res in enumerate(results):
                print(f"Result {i+1}: " + res[:100] + "...")
        except Exception as e:
            print(f"Error with DuckDuckGo: {e}")

        print("\n" + "="*50 + "\n")

        print("--- Testing Brave Search (starts with 1 req/sec, then uses headers) ---")
        if os.getenv("BRAVE_SEARCH_API_KEY"):
            try:
                brave_proxy = SearchProviderProxy("brave/search", rate_manager, http_client)
                # These three will be spaced out by the default 1 req/sec limit
                tasks = [
                    brave_proxy.run("What is LangGraph?"),
                    brave_proxy.run("Key features of Rust programming language?"),
                    brave_proxy.run("Async programming in Python"),
                ]
                results = await asyncio.gather(*tasks)
                for i, res in enumerate(results):
                    print(f"Result {i+1}: " + res[:100] + "...")
            except Exception as e:
                print(f"Error with Brave Search: {e}")
        else:
            print("BRAVE_SEARCH_API_KEY env var not set. Skipping Brave Search test.")


if __name__ == "__main__":
    asyncio.run(main())

We'd love to contribute these improvements back to your project. Since this is a standalone gist, we're not sure the best way to contribute these changes. Would you be interested in a pull request with these changes, or would you prefer some other method?

Best regards,
The AClaraI Team

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment