Last active
January 9, 2023 09:20
-
-
Save arenasys/9161a27bbcbf86f28ab3aeae8bd22fe0 to your computer and use it in GitHub Desktop.
prompt parser
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import lark | |
| class WeightedTree(lark.Tree): | |
| pass | |
| prompt_grammar = lark.Lark(r""" | |
| !start: (prompt | /[][():]/+)* | |
| prompt: (emphasis | deemphasis | numeric | scheduled | alternate | plain | WHITESPACE)* | |
| emphasis: "(" prompt ")" | |
| deemphasis: "[" prompt "]" | |
| numeric: "(" prompt ":" [_WHITESPACE] NUMBER [_WHITESPACE]")" | |
| scheduled: "[" [prompt ":"] prompt ":" [_WHITESPACE] NUMBER [_WHITESPACE]"]" | |
| alternate: "[" prompt ("|" prompt)+ "]" | |
| WHITESPACE: /\s+/ | |
| _WHITESPACE: /\s+/ | |
| plain: /([^\\\[\]():|]|\\.)+/ | |
| %import common.SIGNED_NUMBER -> NUMBER | |
| """, tree_class=WeightedTree) | |
| def parse_prompt(prompt, steps): | |
| def extract(tree, step): | |
| def propagate(node, output, step, weight): | |
| if type(node) == WeightedTree: | |
| node.weight = weight | |
| children = node.children | |
| if node.data == "emphasis": node.weight *= 1.1 | |
| if node.data == "deemphasis": node.weight /= 1.1 | |
| if node.data == "numeric": | |
| node.weight *= float(node.children[1]) | |
| children = [node.children[0]] | |
| if node.data == "scheduled": | |
| if step <= float(node.children[2]): | |
| children = [node.children[0]] | |
| else: | |
| children = [node.children[1]] | |
| if node.data == "alternate": | |
| children = [children[step%len(children)]] | |
| for child in children: | |
| propagate(child, output, step, node.weight) | |
| elif node: | |
| if output and output[-1][1] == weight: | |
| output[-1][0] += str(node) | |
| else: | |
| output.append([str(node), weight]) | |
| output = [] | |
| propagate(tree, output, step, 1.0) | |
| return output | |
| tree = prompt_grammar.parse(prompt) | |
| schedules = [] | |
| for step in range(steps, 0, -1): | |
| scheduled = extract(tree, step) | |
| if not schedules or tuple(schedules[-1][1]) != tuple(scheduled): | |
| schedules += [(step, scheduled)] | |
| schedules = schedules[::-1] | |
| return schedules | |
| prompts = [ | |
| ("a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).)", 30), | |
| ("a [blue:[green:(crystal):8]:5] jeweled crown", 10), | |
| ("a [(dog)|[cat]|cow] standing in a field", 7), | |
| ] | |
| for prompt, steps in prompts: | |
| schedule = parse_prompt(prompt, steps) | |
| print(prompt, steps) | |
| for step, line in schedule: | |
| print("\t",step,line) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment