Created
September 5, 2025 09:25
-
-
Save RomeoV/88e81df4855a17f26efa721ed8e94d81 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #= | |
| The goal is to write a C-callable function matrixsum_cc that we can use to | |
| compute the sum of a square matrix. The matrix is passed via a raw pointer and a | |
| size parameter, but may have one of various shapes. In our example, it can | |
| be a dense matrix or a diagonal matrix, which we want to specify by passing a | |
| C-style enum value, stored as an integer. | |
| In the end, we want to write a function | |
| ```julia | |
| @ccallable function matrixsum_cc(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble | |
| m = build_matrix(sz, ptr, mattype) | |
| return sum(m) | |
| end | |
| ``` | |
| that is trimmable, i.e., we can pass to =juliac --trim= to get a static C library or executable. | |
| The problem here is that this function is by default not "type grounded", i.e., | |
| not all variable types within the function are determinable from the input types. | |
| In particular, m will be either of type Matrix or Diagonal, depending on the *value* of mattype. | |
| We therefore need a way to go from the s-expression style formulation (storing the type info as a variable) | |
| to a formulation that is fully "type grounded" and compatible with multiple dispatch. | |
| We can do this by leveraging "closed" sum types, as shown below. | |
| We first need to build a datatype to return for our build_matrix function. | |
| Since that can either be a dense or diagonal matrix, we define a "closed" sum type, | |
| i.e. a sum type that defines all possibilities. | |
| =# | |
| import Pkg | |
| Pkg.activate(; temp=true) | |
| Pkg.add(["Moshi", "CEnum", "JET"]) | |
| import LinearAlgebra: Diagonal | |
| import Moshi.Data: @data | |
| @data MyMatrix{T<:Number} begin | |
| DenseMat(Matrix{T}) | |
| DiagMat(Diagonal{T, Vector{T}}) # <- Make sure this is a concrete type. `Diagonal{T}` wouldn't be enough. | |
| end | |
| #= | |
| For typing convenience, we also define a CEnum that allows us for passing in the type of the matrix from C. | |
| We wrap it into a module just to keep it in a namespace. | |
| =# | |
| module MatType | |
| using CEnum | |
| @cenum Enum::Cint begin | |
| dense = 1 | |
| diag = 2 | |
| end | |
| end | |
| #= | |
| Now we are ready to write our build_matrix function. The function takes in the MatType enum (an integer), | |
| build the julia matrix type, and returns it as a sum type such that it is type stable. | |
| =# | |
| function build_matrix(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::MyMatrix.Type{Cdouble} | |
| if mattype == MatType.dense | |
| # Unsafe pointer logic to create the Matrix | |
| m_dense = unsafe_wrap(Matrix{Cdouble}, ptr, (sz, sz)) | |
| return MyMatrix.DenseMat(m_dense) # Wrap it in the sum type variant | |
| elseif mattype == MatType.diag | |
| m_diag = Diagonal(unsafe_wrap(Vector{Cdouble}, ptr, sz)) | |
| return MyMatrix.DiagMat(m_diag) # Wrap it | |
| else | |
| error("Unmatched MatType") | |
| end | |
| end | |
| #= | |
| Now comes the magic. For each variant of our sum type, we are able to call our kernel (here the =sum= function) | |
| and call it with a concrete type! | |
| Crucially, to maintain type stability, our kernel needs to return the same type for each variant (here =T=, e.g. =Float64=). | |
| Instead of =sum=, this could also be something like =linsolve= or anything else. | |
| =# | |
| import Base: sum | |
| import Moshi.Match: @match | |
| function sum(m::MyMatrix.Type{T})::T where {T} | |
| @match m begin | |
| MyMatrix.DenseMat(mat) => sum(mat) | |
| MyMatrix.DiagMat(mat) => sum(mat) | |
| end | |
| end | |
| #= | |
| Finally we're able to write our C interface where we can pass in a raw pointer and the MatType enum value (an integer) | |
| and get our result in a type-stable way. | |
| =# | |
| import Base: @ccallable | |
| @ccallable function matrixsum_cc(sz::Cint, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble | |
| m = build_matrix(sz, ptr, mattype) | |
| return sum(m) | |
| end | |
| #= | |
| To validate the type stability, we can check with @code_warntype, or directly test with JET. | |
| =# | |
| using JET | |
| densemat = rand(3,3) | |
| @test_opt matrixsum_cc(Cint(3), pointer(densemat), MatType.dense) | |
| diagmat = Diagonal(rand(3)) | |
| @test_opt matrixsum_cc(Cint(3), pointer(diagmat.diag), MatType.diag) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment