Skip to content

Instantly share code, notes, and snippets.

@DeoLeung
Created January 22, 2024 10:01
Show Gist options
  • Select an option

  • Save DeoLeung/ebd3f14c96cfc95ad0a025caaf114b3d to your computer and use it in GitHub Desktop.

Select an option

Save DeoLeung/ebd3f14c96cfc95ad0a025caaf114b3d to your computer and use it in GitHub Desktop.
sqlalchemy recipe for paradedb
"""
sqlalchemy discussion: https://github.com/sqlalchemy/sqlalchemy/discussions/10841
"""
def dump_dict(v):
"""recursively dump dict into rust style, which no quotes for key"""
match v:
case str():
return f'"{v}"'
case int():
return f'{v}'
case float():
return f'{v}'
case dict():
items = []
for i, j in v.items():
items.append(f'{i}: {dump_dict(j)}')
return '{ ' + ','.join(items) + ' }'
case _:
raise NotImplementedError(v)
def kw_kv(k, v):
"""wrap into text as postgresql function kwargs 'k => v'"""
match v:
case dict():
return text(f"{k} => '{dump_dict(v)}'")
case _:
return text(f"{k} => '{v}'")
def create_bm25_index(
index_name: str,
schema_name: str,
table_name: str,
key_field: str,
text_fields: dict[str, dict] = None,
numeric_fields: dict[str, dict] = None,
boolean_fields: dict[str, dict] = None,
json_fields: dict[str, dict] = None,
):
"""helper function to generate bm25 index create ddl"""
args = [
kw_kv('index_name', index_name),
kw_kv('schema_name', schema_name),
kw_kv('table_name', table_name),
kw_kv('key_field', key_field),
]
for k, v in zip(
['text_fields', 'numeric_fields', 'boolean_fields', 'json_fields'],
[text_fields, numeric_fields, boolean_fields, json_fields]):
if v is not None:
args.append(kw_kv(k, v))
return func.paradedb.create_bm25(*args)
def drop_bm25_index(index_name: str):
"""helper function to generate bm25 index drop ddl"""
return func.paradedb.drop_bm25(text(f"'{index_name}'"))
# add event to table
table = Table(...)
create_index = create_bm25_index(
index_name=f'{table.name}_search_idx',
schema_name=f'{table.schema}',
table_name=f'{table.name}',
key_field='id',
text_fields={
'description': {
'tokenizer': {
'type': 'chinese_compatible'
}
},
'name': {
'tokenizer': {
'type': 'chinese_compatible'
}
},
'category': {
'tokenizer': {
'type': 'raw'
}
},
},
numeric_fields={
'account_id': {
'tokenizer': {
'type': 'int4'
}
},
},
)
drop_index = drop_bm25_index(f'{table.name}_search_idx')
event.listen(table, 'after_create', DDL(f'CALL {create_index};'))
event.listen(table, 'after_drop', DDL(f'CALL {drop_index};'))
@DeoLeung
Copy link
Author

updated for pg_search v2

import re

PINYIN_TOKENIZER = re.compile(
  r'(chuang|shuang|zhuang|chang|cheng|chong|chuai|chuan|guang|huang|jiang|jiong|kuang|liang|niang|qiang|qiong|shang|sheng|shuai|shuan|xiang|xiong|zhang|zheng|zhong|zhuai|zhuan|bang|beng|bian|biao|bing|cang|ceng|chai|chan|chao|chen|chou|chua|chui|chun|chuo|cong|cuan|dang|deng|dian|diao|ding|dong|duan|fang|feng|gang|geng|gong|guai|guan|hang|heng|hong|huai|huan|jian|jiao|jing|juan|kang|keng|kong|kuai|kuan|lang|leng|lian|liao|ling|long|luan|mang|meng|mian|miao|ming|nang|neng|nian|niao|ning|nong|nuan|pang|peng|pian|piao|ping|qian|qiao|qing|quan|rang|reng|rong|ruan|sang|seng|shai|shan|shao|shei|shen|shou|shua|shui|shun|shuo|song|suan|tang|teng|tian|tiao|ting|tong|tuan|wang|weng|xian|xiao|xing|xuan|yang|ying|yong|yuan|zang|zeng|zhai|zhan|zhao|zhei|zhen|zhou|zhua|zhui|zhun|zhuo|zong|zuan|ang|bai|ban|bao|bei|ben|bie|bin|cai|can|cao|cen|cha|che|chi|chu|cou|cui|cun|cuo|dai|dan|dao|dei|den|dia|die|diu|dou|dui|dun|duo|eng|fan|fei|fen|fou|gai|gan|gao|gei|gen|gou|gua|gui|gun|guo|hai|han|hao|hei|hen|hou|hua|hui|hun|huo|jia|jie|jin|jiu|jue|jun|kai|kan|kao|kei|ken|kou|kua|kui|kun|kuo|lai|lan|lao|lei|lia|lie|lin|liu|lou|lun|luo|lve|mai|man|mao|mei|men|mie|min|miu|mou|nai|nan|nao|nei|nen|nie|nin|niu|nou|nun|nuo|nve|pai|pan|pao|pei|pen|pie|pin|pou|qia|qie|qin|qiu|que|qun|ran|rao|ren|rou|rua|rui|run|ruo|sai|san|sao|sen|sha|she|shi|shu|sou|sui|sun|suo|tai|tan|tao|tei|tie|tou|tui|tun|tuo|wai|wan|wei|wen|xia|xie|xin|xiu|xue|xun|yan|yao|yin|you|yue|yun|zai|zan|zao|zei|zen|zha|zhe|zhi|zhu|zou|zui|zun|zuo|ai|an|ao|ba|bi|bo|bu|ca|ce|ci|cu|da|de|di|du|ei|en|er|fa|fo|fu|ga|ge|gu|ha|he|hu|ji|ju|ka|ke|ku|la|le|li|lo|lu|lv|ma|me|mi|mo|mu|na|ne|ni|nu|nv|ou|pa|pi|po|pu|qi|qu|re|ri|ru|sa|se|si|su|ta|te|ti|tu|wa|wo|wu|xi|xu|ya|ye|yi|yo|yu|za|ze|zi|zu|zh|ch|sh|[a-z])'
)
ALL_ENGLISH = re.compile(r'[a-zA-Z\s]+$')
ALL_WORDS = re.compile(r'\w+')
from enum import StrEnum
from typing import Literal, Optional

import orjson
from pydantic import BaseModel, model_serializer
from sqlalchemy import cast, func
from sqlalchemy.dialects.postgresql import ARRAY, TEXT, array
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.types import UserDefinedType

from util.regex import ALL_ENGLISH, PINYIN_TOKENIZER
from util.sqlalchemy import kw_kv


class ParadeDBTokenizerEnum(StrEnum):
  """https://docs.paradedb.com/documentation/indexing/tokenizers#available-tokenizers"""

  literal = 'literal'
  literal_normalized = 'literal_normalized'
  simple = 'simple'
  regex_pattern = 'regex_pattern'
  whitespace = 'whitespace'
  ngram = 'ngram'
  source_code = 'source_code'
  # 切字
  chinese_compatible = 'chinese_compatible'
  # 切词
  chinese_lindera = 'chinese_lindera'
  jieba = 'jieba'
  # 切字
  unicode_words = 'unicode_words'
  icu = 'icu'


class TokenizerType(UserDefinedType):
  """
  Represents the ParadeDB / pg_search `pdb.unicode_words` type.
  """

  cache_ok = True

  def __init__(self, tokenizer: ParadeDBTokenizerEnum, *args):
    super().__init__()
    self.tokenizer = tokenizer
    self.args = args

  def get_col_spec(self, **kw):
    # This defines how the type is emitted in DDL or CAST
    if not self.args:
      return f'pdb.{self.tokenizer}'
    _args = []
    for i in self.args:
      match i:
        case str():
          _args.append(f"'{i}'")
        case _:
          _args.append(str(i))
    return f'pdb.{self.tokenizer}({", ".join(_args)})'

  def bind_processor(self, dialect):
    # if you plan to send Python values of this type -> DB, define this
    def process(value):
      # Typically you won't send this type directly, so you can just return value
      return value

    return process

  def result_processor(self, dialect, coltype):
    # when reading from DB
    def process(value):
      # value might be a list of tokens (text[])
      return value

    return process


class ParadeDBStemmer(StrEnum):
  """https://docs.paradedb.com/api-reference/indexing/token_filters#stemming"""

  Arabic = 'Arabic'
  Danish = 'Danish'
  Dutch = 'Dutch'
  English = 'English'
  Finnish = 'Finnish'
  French = 'French'
  German = 'German'
  Greek = 'Greek'
  Hungarian = 'Hungarian'
  Italian = 'Italian'
  Norwegian = 'Norwegian'
  Portuguese = 'Portuguese'
  Romanian = 'Romanian'
  Russian = 'Russian'
  Spanish = 'Spanish'
  Swedish = 'Swedish'
  Tamil = 'Tamil'
  Turkish = 'Turkish'


class ParadeDBTokenizer(BaseModel):
  type: ParadeDBTokenizerEnum
  lowercase: bool | None = None
  remove_long: int | None = None
  stemmer: ParadeDBStemmer | None = None
  alias: str | None = None

  def sa(self):
    args = []
    if self.lowercase is not None:
      args.append(f'lowercase={self.lowercase}')
    if self.remove_long is not None:
      args.append(f'remove_long={self.remove_long}')
    if self.stemmer is not None:
      args.append(f'stemmer={self.stemmer}')
    if self.alias is not None:
      args.append(f'alias={self.alias}')
    return TokenizerType(self.type, *args)

  @model_serializer(mode='plain')
  def dump(self) -> dict:
    kwargs = {'type': self.type}
    return kwargs


class ParadeDBTokenizerICU(ParadeDBTokenizer):
  type: Literal[ParadeDBTokenizerEnum.icu] = ParadeDBTokenizerEnum.icu


class ParadeDBTokenizerLiteral(ParadeDBTokenizer):
  type: Literal[ParadeDBTokenizerEnum.literal] = ParadeDBTokenizerEnum.literal


class ParadeDBTokenizerWhitespace(ParadeDBTokenizer):
  type: Literal[ParadeDBTokenizerEnum.whitespace] = (
    ParadeDBTokenizerEnum.whitespace
  )


class ParadeDBTokenizerUnicodeWords(ParadeDBTokenizer):
  type: Literal[ParadeDBTokenizerEnum.unicode_words] = (
    ParadeDBTokenizerEnum.unicode_words
  )


class ParadeDBTokenizerRegexPattern(ParadeDBTokenizer):
  name: Literal[ParadeDBTokenizerEnum.regex_pattern] = (
    ParadeDBTokenizerEnum.regex_pattern
  )
  pattern: str

  def sa(self):
    return super().sa(self.pattern)


class ParadeDBTokenizerNgram(ParadeDBTokenizer):
  name: Literal[ParadeDBTokenizerEnum.ngram] = ParadeDBTokenizerEnum.ngram
  min_gram: int
  max_gram: int

  def sa(self):
    return super().sa(self.min_gram, self.max_gram)


ParadeDBTokenizers = (
  ParadeDBTokenizerNgram
  | ParadeDBTokenizerRegexPattern
  | ParadeDBTokenizerUnicodeWords
  | ParadeDBTokenizerICU
  | ParadeDBTokenizerLiteral
  | ParadeDBTokenizerWhitespace
  | ParadeDBTokenizer
)


class ParadeDBRecord(StrEnum):
  basic = 'basic'
  freq = 'freq'
  position = 'position'


class ParadeDBNormalizer(StrEnum):
  raw = 'raw'
  lowercase = 'lowercase'


class ParadeDBIndexWith(BaseModel):
  class TextField(BaseModel):
    indexed: Optional[bool] = None
    stored: Optional[bool] = None
    fast: Optional[bool] = None
    fieldnorms: Optional[bool] = None
    tokenizer: Optional[ParadeDBTokenizers] = None
    record: Optional[ParadeDBRecord] = None
    normalizer: Optional[ParadeDBNormalizer] = None

  class NumericField(BaseModel):
    indexed: Optional[bool] = None
    stored: Optional[bool] = None
    fast: Optional[bool] = None

  class BooleanField(BaseModel):
    indexed: Optional[bool] = None
    stored: Optional[bool] = None
    fast: Optional[bool] = None

  class JsonField(BaseModel):
    indexed: Optional[bool] = None
    stored: Optional[bool] = None
    fast: Optional[bool] = None
    expand_dots: Optional[bool] = None
    tokenizer: Optional[ParadeDBTokenizers] = None
    record: Optional[ParadeDBRecord] = None
    normalizer: Optional[ParadeDBNormalizer] = None

  class DatetimeField(BaseModel):
    indexed: Optional[bool] = None
    stored: Optional[bool] = None
    fast: Optional[bool] = None

  key_field: str
  text_fields: dict[str, TextField] | None = None
  numeric_fields: dict[str, NumericField] | None = None
  boolean_fields: dict[str, BooleanField] | None = None
  json_fields: dict[str, JsonField] | None = None
  datetime_fields: dict[str, DatetimeField] | None = None

  @model_serializer(mode='plain')
  def dump(self):
    kwargs = {'key_field': self.key_field}
    if self.text_fields is not None:
      kwargs['text_fields'] = orjson.dumps(
        {
          k: v.model_dump(exclude_none=True)
          for k, v in self.text_fields.items()
        }
      ).decode()
    if self.numeric_fields is not None:
      kwargs['numeric_fields'] = orjson.dumps(
        {
          k: v.model_dump(exclude_none=True)
          for k, v in self.numeric_fields.items()
        }
      ).decode()
    if self.boolean_fields is not None:
      kwargs['boolean_fields'] = orjson.dumps(
        {
          k: v.model_dump(exclude_none=True)
          for k, v in self.boolean_fields.items()
        }
      ).decode()
    if self.json_fields is not None:
      kwargs['json_fields'] = orjson.dumps(
        {
          k: v.model_dump(exclude_none=True)
          for k, v in self.json_fields.items()
        }
      ).decode()
    if self.datetime_fields is not None:
      kwargs['datetime_fields'] = orjson.dumps(
        {
          k: v.model_dump(exclude_none=True)
          for k, v in self.datetime_fields.items()
        }
      ).decode()
    for k, v in kwargs.items():
      kwargs[k] = "'" + v + "'"
    return kwargs


class ParadedbV2:
  """paradedb(pg_search) v2 helper

  NOTICE: for array types, need to pass in value as list
  """

  @staticmethod
  def functions():
    """safe functions under pdb.*"""
    return ['all', 'score', 'regex', 'regex_phrase', 'parse']

  @staticmethod
  def all(column: ColumnElement):
    """https://docs.paradedb.com/documentation/query-builder/compound/all#all"""
    return column.op('@@@')(func.pdb.all())

  @staticmethod
  def score(column: ColumnElement):
    """https://docs.paradedb.com/documentation/sorting/score"""
    return func.pdb.score(column)

  @staticmethod
  def term(column: ColumnElement, value: str | list):
    """
    if pass in list, will be treated as term set
    https://docs.paradedb.com/documentation/full-text/term
    """
    if isinstance(value, list):
      value = array(value)
    return column.op('===')(value)

  @staticmethod
  def match_disjunction(
    column: ColumnElement,
    value: str | list,
  ):
    """
    or
    https://docs.paradedb.com/documentation/full-text/match#match-disjunction
    """
    if isinstance(value, list):
      value = cast(array(value, type_=TEXT), ARRAY(TEXT))
    return column.op('|||')(value)

  @staticmethod
  def match_conjunction(
    column: ColumnElement,
    value: str | list,
  ):
    """
    and
    https://docs.paradedb.com/documentation/full-text/match#match-conjunction
    """
    if isinstance(value, list):
      value = cast(array(value, type_=TEXT), ARRAY(TEXT))
      # TODO(Deo): this may raise error occationally
      # value = array(value)
    return column.op('&&&')(value)

  @staticmethod
  def regex(column: ColumnElement, pattern: str):
    """https://docs.paradedb.com/documentation/query-builder/term/regex"""
    return column.op('@@@')(func.pdb.regex(pattern))

  @staticmethod
  def regex_phrase(
    column: ColumnElement,
    *patterns: str,
    slope: int = None,
    max_expansions: int = None,
  ):
    """https://docs.paradedb.com/documentation/query-builder/phrase/regex-phrase"""
    params = [array(patterns)]
    if slope is not None:
      params.append(slope)
    if max_expansions is not None:
      params.append(max_expansions)
    return column.op('@@@')(func.pdb.regex_phrase(*params))

  @staticmethod
  def pinyin_regex_phrase(
    column: ColumnElement,
    value: str,
    slope: int = None,
    max_expansions: int = None,
    generated_pinyin: bool = False,
  ):
    """split pinyin and use regex_phrase

    wangjiawei -> wang.* jia.* wei.*
    wjw -> w.* j.* w.*

    if generated_pinyin is True, will generate pinyin with | as separator
    e.g. w|j|w
    """

    if not ALL_ENGLISH.match(value):
      return
    pinyins = PINYIN_TOKENIZER.findall(value.lower())

    if generated_pinyin:
      # generate pinyin with | as separator
      pinyins = [rf'.*\|{i}.*' for i in pinyins]
    else:
      # generate pinyin with .* as separator
      pinyins = [rf'{i}.*' for i in pinyins]
    return ParadedbV2.regex_phrase(
      column,
      *pinyins,
      slope=slope,
      max_expansions=max_expansions,
    )

  @staticmethod
  def parse(
    column: ColumnElement,
    query: str,
    lenient: bool = False,
    conjunction_mode: bool = False,
  ):
    """https://docs.paradedb.com/documentation/query-builder/compound/query-parser"""
    args = [query]
    if lenient:
      args.append(kw_kv('lenient', lenient))
    if conjunction_mode:
      args.append(kw_kv('conjunction_mode', conjunction_mode))
    return column.op('@@@')(func.pdb.parse(*args))


pdb = ParadedbV2()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment