Skip to content

Instantly share code, notes, and snippets.

@Wybxc
Created March 14, 2023 08:17
Show Gist options
  • Select an option

  • Save Wybxc/af5b42003be122fe4256af54eac825fa to your computer and use it in GitHub Desktop.

Select an option

Save Wybxc/af5b42003be122fe4256af54eac825fa to your computer and use it in GitHub Desktop.
Simple wrapper around OpenAI GPT3.5 api
from typing import Optional, Sequence, Tuple
import requests
import json
SYSTEM = "You are a helpful assistant."
def chat_complete(
*,
system: str = SYSTEM,
history: Sequence[Tuple[str, str]] = (),
question: str = "",
api_key_path: str = "api_key.txt",
proxy: Optional[str] = None,
):
messages = [{"role": "system", "content": system}]
for user, assistant in history:
messages.extend(
(
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
)
)
messages.append({"role": "user", "content": question})
url = "https://api.openai.com/v1/chat/completions"
payload = {"model": "gpt-3.5-turbo", "messages": messages, "stream": True}
with open(api_key_path) as f:
api_key = f.read().strip()
headers = {"Authorization": f"Bearer {api_key}"}
response = requests.post(
url,
json=payload,
headers=headers,
proxies={"https": proxy} if proxy else None,
stream=True,
)
for line in response.iter_lines(decode_unicode=False):
line = line.decode("utf-8")
if line.startswith("data:"):
if line == "data: [DONE]":
return
if data := line[5:]:
yield json.loads(data)["choices"][0]["delta"].get("content", "")
import contextlib
@contextlib.contextmanager
def auto_wrap(max_length: int):
line_length = 0
def wrap(text: str):
nonlocal line_length
line_length += len(text)
if line_length > max_length:
line_length = len(text)
return f"\n{text}"
return text
yield wrap
print()
def join_display(gen):
res = ""
with auto_wrap(120) as wrap:
for item in gen:
res += item
print(wrap(item), end="")
return res
import json
question = """
Hello, what's your name?
""".strip()
s = join_display(
chat_complete(
system="You are a helpful assistant.",
question=question,
proxy="http://127.0.0.1:7890",
)
)
start, end = s.find("{"), len(s) - s[::-1].find("}")
attributes = json.loads(s[start:end])
attributes = next(iter(attributes.values()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment