Skip to content

Instantly share code, notes, and snippets.

@arenasys
Last active January 9, 2023 09:20
Show Gist options
  • Select an option

  • Save arenasys/9161a27bbcbf86f28ab3aeae8bd22fe0 to your computer and use it in GitHub Desktop.

Select an option

Save arenasys/9161a27bbcbf86f28ab3aeae8bd22fe0 to your computer and use it in GitHub Desktop.
prompt parser
import lark
class WeightedTree(lark.Tree):
pass
prompt_grammar = lark.Lark(r"""
!start: (prompt | /[][():]/+)*
prompt: (emphasis | deemphasis | numeric | scheduled | alternate | plain | WHITESPACE)*
emphasis: "(" prompt ")"
deemphasis: "[" prompt "]"
numeric: "(" prompt ":" [_WHITESPACE] NUMBER [_WHITESPACE]")"
scheduled: "[" [prompt ":"] prompt ":" [_WHITESPACE] NUMBER [_WHITESPACE]"]"
alternate: "[" prompt ("|" prompt)+ "]"
WHITESPACE: /\s+/
_WHITESPACE: /\s+/
plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
""", tree_class=WeightedTree)
def parse_prompt(prompt, steps):
def extract(tree, step):
def propagate(node, output, step, weight):
if type(node) == WeightedTree:
node.weight = weight
children = node.children
if node.data == "emphasis": node.weight *= 1.1
if node.data == "deemphasis": node.weight /= 1.1
if node.data == "numeric":
node.weight *= float(node.children[1])
children = [node.children[0]]
if node.data == "scheduled":
if step <= float(node.children[2]):
children = [node.children[0]]
else:
children = [node.children[1]]
if node.data == "alternate":
children = [children[step%len(children)]]
for child in children:
propagate(child, output, step, node.weight)
elif node:
if output and output[-1][1] == weight:
output[-1][0] += str(node)
else:
output.append([str(node), weight])
output = []
propagate(tree, output, step, 1.0)
return output
tree = prompt_grammar.parse(prompt)
schedules = []
for step in range(steps, 0, -1):
scheduled = extract(tree, step)
if not schedules or tuple(schedules[-1][1]) != tuple(scheduled):
schedules += [(step, scheduled)]
schedules = schedules[::-1]
return schedules
prompts = [
("a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).)", 30),
("a [blue:[green:(crystal):8]:5] jeweled crown", 10),
("a [(dog)|[cat]|cow] standing in a field", 7),
]
for prompt, steps in prompts:
schedule = parse_prompt(prompt, steps)
print(prompt, steps)
for step, line in schedule:
print("\t",step,line)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment