Skip to content

Instantly share code, notes, and snippets.

@mlechu
Created July 26, 2025 22:12
Show Gist options
  • Select an option

  • Save mlechu/81675ac09b53c6e81f491a9996ccfd78 to your computer and use it in GitHub Desktop.

Select an option

Save mlechu/81675ac09b53c6e81f491a9996ccfd78 to your computer and use it in GitHub Desktop.
using JuliaSyntax
using JuliaLowering
const JS = JuliaSyntax
const JL = JuliaLowering
jsparse(s) = JS.build_tree(JL.SyntaxTree, JS.parse!(JS.ParseStream(s); rule=:statement))
function jlower(mod, st0)
ctx1, st1 = JL.expand_forms_1( mod, st0)
ctx2, st2 = JL.expand_forms_2( ctx1, st1)
ctx3, st3 = JL.resolve_scopes( ctx2, st2)
ctx4, st4 = JL.convert_closures(ctx3, st3)
ctx5, st5 = JL.linearize_ir( ctx4, st4)
return ctx5, st5
end
orig_nodes(st5::JL.SyntaxTree) = _orig_nodes!(st5, st5.source, JL.SyntaxList(st5))
# duplicates are OK
function _orig_nodes!(st5, src, out)
if src isa JL.SourceRef || src isa LineNumberNode
push!(out, st5)
elseif src isa JL.NodeId
next = JL.SyntaxTree(st5._graph, src)
_orig_nodes!(next, next.source, out)
elseif src isa Tuple
for s in st5.source
_orig_nodes!(st5, s, out)
end
end
return out
end
# unfreeze and add "type" attr
prepare_attrs(st::JL.SyntaxTree) = let g = JL.syntax_graph(st)
attrs = Dict(pairs(g.attributes))
attrs[:type] = Dict{Int, Any}()
return JL.SyntaxTree(JL.SyntaxGraph(g.edge_ranges, g.edges, attrs), st._id)
end
"""
Main entrypoint
Note we are currently limited to JL-lowering and annotating types in one
operation, since the default lowerer doesn't preserve the information we need.
"""
function annotate_types(mod::Module, st0::JL.SyntaxTree,
@nospecialize(tt=nothing),
world::UInt=Base.get_world_counter())
ctx5, st5 = jlower(mod, st0)
st0 = prepare_attrs(st0)
ex = JL.to_lowered_expr(mod, st5)
@assert ex.head === :thunk && ex.args[1] isa Core.CodeInfo
fn = eval(ex)
!isa(fn, Function) && throw("Not a function")
tt = something(tt, Base.default_tt(fn))
mi, ci, rt = get_inferred_result(fn, tt, Base.get_world_counter())
slottypes = ci.slottypes
ssavaluetypes = ci.ssavaluetypes
# Hack: We want the JL-codeinfo of the method, but we only have it of the
# function def, which contains it.
inner_st5_method = findfirst(c->(JS.kind(c) === JS.K"method" &&
JS.numchildren(c) === 3 && JS.kind(c[3]) === JS.K"code_info" &&
JS.numchildren(c[3][1]) === length(ci.code)),
JS.children(st5[1]))
inner_st5 = st5[1][inner_st5_method][3]
for i in eachindex(ci.ssavaluetypes)
ssa_st = inner_st5[1][i]
orig = orig_nodes(ssa_st)
for o in orig
JL.setattr!(st0._graph, o._id; type=ci.ssavaluetypes[i])
end
end
for i in eachindex(ci.slottypes)
jl_slot::JL.Slot = inner_st5.slots[i]
orig = orig_nodes(JL.SyntaxTree(inner_st5._graph, jl_slot.node_id))
for o in orig
JL.setattr!(st0._graph, o._id; type=ci.slottypes[i])
end
end
st0
end
annotate_types(mod::Module, s::AbstractString, args...) = annotate_types(mod::Module, jsparse(s), args...)
# not necessary, but noting that conversion node->tree is defined
# annotate_types(mod::Module, sn0::JS.SyntaxNode, args...) = annotate_types(mod::Module, sn0, args...)
# from TypedSyntax.jl
function get_inferred_result(@nospecialize(f), @nospecialize(tt=Base.default_tt(f)),
world::UInt=Base.get_world_counter())
mis = Base.method_instances(f, tt, world)
if isempty(mis)
sig = sprint(Base.show_tuple_as_call, Symbol(""), Base.signature_type(f, tt))
error("no applicable type-inferred code found for ", sig)
elseif length(mis) ≠ 1
sig = sprint(Base.show_tuple_as_call, Symbol(""), Base.signature_type(f, tt))
error("got $(length(mis)) possible type-inferred results for ", sig,
", you may need a more specialized signature")
end
mi = only(mis)
ci, rt = code_typed1_by_method_instance(mi; optimize=false, debuginfo=:source)
return mi, ci, rt
end
function code_typed1_by_method_instance(mi::Core.MethodInstance;
optimize::Bool=true,
debuginfo::Symbol=:default,
world::UInt=Base.get_world_counter(),
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world))
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection should not be used from generated functions")
debuginfo = Base.IRShow.debuginfo(debuginfo)
code = Core.Compiler.typeinf_code(interp, mi.def::Method, mi.specTypes, mi.sparam_vals, optimize)
rt = code.rettype
code isa Core.CodeInfo || error("no code is available for ", mi)
debuginfo === :none && Base.remove_linenums!(code)
return Pair{Core.CodeInfo,Any}(code, rt)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment