Skip to content

Instantly share code, notes, and snippets.

@polvalente
Last active December 5, 2025 09:25
Show Gist options
  • Select an option

  • Save polvalente/bd0eb1aa8b2a1fa01734019a60b27ca5 to your computer and use it in GitHub Desktop.

Select an option

Save polvalente/bd0eb1aa8b2a1fa01734019a60b27ca5 to your computer and use it in GitHub Desktop.
Example Python-Nx runtime bridge

Untitled notebook

Mix.install [
  {:exla, github: "elixir-nx/nx", sparse: "exla"},
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:pythonx, "~> 0.4.7"}
]

Section

Pythonx.uv_init("""
[project]
name = "project"
version = "0.0.0"
requires-python = "==3.13.*"
dependencies = [
  "numpy==2.2.2"
]
""")
defmodule NxComputation do
  import Nx.Defn
  
  def runtime_callback(tensor, pid) do
    send(pid, {:runtime_callback, tensor, self()})
    
    receive do
      {:result, infeed} -> 
        Nx.add(tensor, infeed)
    end
  end

  defn foo(x, opts) do
    y = x * 10
    result = Nx.runtime_call(y, y, opts[:pid], &runtime_callback/2)
    result / 10
  end
end

pid = self()

result = 
  Task.async(fn -> 
    EXLA.jit_apply(&NxComputation.foo/2, [Nx.iota({4}), [pid: pid]])
   end)

target_pid = 
  receive do
    {:runtime_callback, tensor, pid} -> 
      IO.inspect(tensor, label: "received from callback")
      pid
  end

send(target_pid, {:result, 100})

Task.await(result)
defmodule NxPythonBridgeComputation do
  import Nx.Defn

  def python_callback(tensor) do
    # This gets us the binary data for the tensor.
    # in the case of Nx.BinaryBackend this is just a reference
    # to the data.

    data = Nx.to_binary(tensor)

    {result, _env} =
      Pythonx.eval(
        """
        import numpy as np

        def to_nx_type(dtype):
            print(dtype)
            if dtype == np.int32:
                return ("s", 32)
            if dtype == np.float64:
                return ("f", 64)

        def main(data, type, shape):    
            # Example type conversion, this is just so we 
            # can experiment with different input types
            if type == ("s", 32):
                dtype = np.int32
            elif type == ("f", 32):
                dtype = np.float32
            else:
              return "error", "invalid type"
    
            array = np.frombuffer(data, dtype=dtype).reshape(shape)

            result = array * 20.0

            return "ok", (result.data.tobytes(), result.shape, to_nx_type(result.dtype))

        main(data, type, shape)
        """,
        # data is a binary and is passed by reference, so the numpy array effectively points
        # to BEAM-owned memory
        %{"data" => data, "type" => Nx.type(tensor), "shape" => Nx.shape(tensor)}
      )

    {kind, value} = Pythonx.decode(result)

    if kind == "ok" do
      # data is returned as a reference, so we're not copying the data from python
      {data, shape, {type, size}} = value

      data
      |> Nx.from_binary({String.to_existing_atom(type), size})
      |> Nx.reshape(shape)
    else
      raise value
    end
  end

  defn run(x) do
    out = %{Nx.to_template(x) | type: {:f, 64}}
    result = Nx.runtime_call(out, x, &python_callback/1)
    {x, result}
  end
end
result =
  EXLA.jit_apply(
    &NxPythonBridgeComputation.run/1,
    [Nx.iota({4}, type: :s32)],
    cache: false
  )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment