Skip to content

Instantly share code, notes, and snippets.

@RomeoV
Created September 5, 2025 09:25
Show Gist options
  • Select an option

  • Save RomeoV/88e81df4855a17f26efa721ed8e94d81 to your computer and use it in GitHub Desktop.

Select an option

Save RomeoV/88e81df4855a17f26efa721ed8e94d81 to your computer and use it in GitHub Desktop.
#=
The goal is to write a C-callable function matrixsum_cc that we can use to
compute the sum of a square matrix. The matrix is passed via a raw pointer and a
size parameter, but may have one of various shapes. In our example, it can
be a dense matrix or a diagonal matrix, which we want to specify by passing a
C-style enum value, stored as an integer.
In the end, we want to write a function
```julia
@ccallable function matrixsum_cc(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
m = build_matrix(sz, ptr, mattype)
return sum(m)
end
```
that is trimmable, i.e., we can pass to =juliac --trim= to get a static C library or executable.
The problem here is that this function is by default not "type grounded", i.e.,
not all variable types within the function are determinable from the input types.
In particular, m will be either of type Matrix or Diagonal, depending on the *value* of mattype.
We therefore need a way to go from the s-expression style formulation (storing the type info as a variable)
to a formulation that is fully "type grounded" and compatible with multiple dispatch.
We can do this by leveraging "closed" sum types, as shown below.
We first need to build a datatype to return for our build_matrix function.
Since that can either be a dense or diagonal matrix, we define a "closed" sum type,
i.e. a sum type that defines all possibilities.
=#
import Pkg
Pkg.activate(; temp=true)
Pkg.add(["Moshi", "CEnum", "JET"])
import LinearAlgebra: Diagonal
import Moshi.Data: @data
@data MyMatrix{T<:Number} begin
DenseMat(Matrix{T})
DiagMat(Diagonal{T, Vector{T}}) # <- Make sure this is a concrete type. `Diagonal{T}` wouldn't be enough.
end
#=
For typing convenience, we also define a CEnum that allows us for passing in the type of the matrix from C.
We wrap it into a module just to keep it in a namespace.
=#
module MatType
using CEnum
@cenum Enum::Cint begin
dense = 1
diag = 2
end
end
#=
Now we are ready to write our build_matrix function. The function takes in the MatType enum (an integer),
build the julia matrix type, and returns it as a sum type such that it is type stable.
=#
function build_matrix(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::MyMatrix.Type{Cdouble}
if mattype == MatType.dense
# Unsafe pointer logic to create the Matrix
m_dense = unsafe_wrap(Matrix{Cdouble}, ptr, (sz, sz))
return MyMatrix.DenseMat(m_dense) # Wrap it in the sum type variant
elseif mattype == MatType.diag
m_diag = Diagonal(unsafe_wrap(Vector{Cdouble}, ptr, sz))
return MyMatrix.DiagMat(m_diag) # Wrap it
else
error("Unmatched MatType")
end
end
#=
Now comes the magic. For each variant of our sum type, we are able to call our kernel (here the =sum= function)
and call it with a concrete type!
Crucially, to maintain type stability, our kernel needs to return the same type for each variant (here =T=, e.g. =Float64=).
Instead of =sum=, this could also be something like =linsolve= or anything else.
=#
import Base: sum
import Moshi.Match: @match
function sum(m::MyMatrix.Type{T})::T where {T}
@match m begin
MyMatrix.DenseMat(mat) => sum(mat)
MyMatrix.DiagMat(mat) => sum(mat)
end
end
#=
Finally we're able to write our C interface where we can pass in a raw pointer and the MatType enum value (an integer)
and get our result in a type-stable way.
=#
import Base: @ccallable
@ccallable function matrixsum_cc(sz::Cint, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
m = build_matrix(sz, ptr, mattype)
return sum(m)
end
#=
To validate the type stability, we can check with @code_warntype, or directly test with JET.
=#
using JET
densemat = rand(3,3)
@test_opt matrixsum_cc(Cint(3), pointer(densemat), MatType.dense)
diagmat = Diagonal(rand(3))
@test_opt matrixsum_cc(Cint(3), pointer(diagmat.diag), MatType.diag)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment