Mix.install [
{:exla, github: "elixir-nx/nx", sparse: "exla"},
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:pythonx, "~> 0.4.7"}
]
Pythonx.uv_init("""
[project]
name = "project"
version = "0.0.0"
requires-python = "==3.13.*"
dependencies = [
"numpy==2.2.2"
]
""")
defmodule NxComputation do
import Nx.Defn
def runtime_callback(tensor, pid) do
send(pid, {:runtime_callback, tensor, self()})
receive do
{:result, infeed} ->
Nx.add(tensor, infeed)
end
end
defn foo(x, opts) do
y = x * 10
result = Nx.runtime_call(y, y, opts[:pid], &runtime_callback/2)
result / 10
end
end
pid = self()
result =
Task.async(fn ->
EXLA.jit_apply(&NxComputation.foo/2, [Nx.iota({4}), [pid: pid]])
end)
target_pid =
receive do
{:runtime_callback, tensor, pid} ->
IO.inspect(tensor, label: "received from callback")
pid
end
send(target_pid, {:result, 100})
Task.await(result)
defmodule NxPythonBridgeComputation do
import Nx.Defn
def python_callback(tensor) do
# This gets us the binary data for the tensor.
# in the case of Nx.BinaryBackend this is just a reference
# to the data.
data = Nx.to_binary(tensor)
{result, _env} =
Pythonx.eval(
"""
import numpy as np
def to_nx_type(dtype):
print(dtype)
if dtype == np.int32:
return ("s", 32)
if dtype == np.float64:
return ("f", 64)
def main(data, type, shape):
# Example type conversion, this is just so we
# can experiment with different input types
if type == ("s", 32):
dtype = np.int32
elif type == ("f", 32):
dtype = np.float32
else:
return "error", "invalid type"
array = np.frombuffer(data, dtype=dtype).reshape(shape)
result = array * 20.0
return "ok", (result.data.tobytes(), result.shape, to_nx_type(result.dtype))
main(data, type, shape)
""",
# data is a binary and is passed by reference, so the numpy array effectively points
# to BEAM-owned memory
%{"data" => data, "type" => Nx.type(tensor), "shape" => Nx.shape(tensor)}
)
{kind, value} = Pythonx.decode(result)
if kind == "ok" do
# data is returned as a reference, so we're not copying the data from python
{data, shape, {type, size}} = value
data
|> Nx.from_binary({String.to_existing_atom(type), size})
|> Nx.reshape(shape)
else
raise value
end
end
defn run(x) do
out = %{Nx.to_template(x) | type: {:f, 64}}
result = Nx.runtime_call(out, x, &python_callback/1)
{x, result}
end
end
result =
EXLA.jit_apply(
&NxPythonBridgeComputation.run/1,
[Nx.iota({4}, type: :s32)],
cache: false
)