Skip to content

Instantly share code, notes, and snippets.

@finsberg
Last active March 10, 2025 17:28
Show Gist options
  • Select an option

  • Save finsberg/b7db6d8df23d6b9ee8f43481f767b303 to your computer and use it in GitHub Desktop.

Select an option

Save finsberg/b7db6d8df23d6b9ee8f43481f767b303 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
import shutil
from mpi4py import MPI
from petsc4py import PETSc
import dolfinx.fem.petsc
from dolfinx_external_operator import (
FEMExternalOperator,
evaluate_external_operators,
evaluate_operands,
replace_external_operators,
)
import numpy.typing as npt
import numpy as np
import basix.ufl
import ufl
def v_exact_func(x, t):
return ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.sin(t)
def s_exact_func(x, t):
return -ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.cos(t)
def ode_exact_func(x, t):
return ufl.as_vector((v_exact_func(x, t), s_exact_func(x, t)))
def ac_func(x, t):
return 8 * ufl.pi**2 * ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.sin(t)
@dataclass
class ODE:
time: dolfinx.fem.Constant
dt: float
parameters: npt.NDArray
def __call__(self, v: npt.NDArray, ode_states: npt.NDArray):
s_ode = ode_states[:, 1::2] # Extract s
states = np.vstack([v.flatten(), s_ode.flatten()])
new_values = simple_ode_forward_euler(
states, self.time.value, self.dt, parameters=self.parameters
)
return new_values.T.flatten()
def simple_ode_forward_euler(states, t, dt, parameters):
v, s = states
values = np.zeros_like(states)
values[0] = v - s * dt
values[1] = s + v * dt
return values
def VTXWriter(mesh, path, functions, engine="BP5", write: bool = False):
if write:
return dolfinx.io.VTXWriter(mesh.comm, path, functions, engine=engine)
else:
class DummyVTXWriter:
def write(self, *args, **kwargs):
pass
def close(self):
pass
return DummyVTXWriter()
def splitting_scheme(N=50, M=1.0, dt=0.01, T=1.0, quad_degree=4, save_vtx=False):
comm = MPI.COMM_WORLD
mesh = dolfinx.mesh.create_unit_square(comm, N, N, dolfinx.cpp.mesh.CellType.triangle)
time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0))
x = ufl.SpatialCoordinate(mesh)
I_s = ac_func(x, time)
el = basix.ufl.quadrature_element(
scheme="default", degree=quad_degree, cell=mesh.ufl_cell().cellname()
)
V_ode = dolfinx.fem.functionspace(mesh, el)
v_ode = dolfinx.fem.Function(V_ode)
s = dolfinx.fem.Function(V_ode)
s.interpolate(
dolfinx.fem.Expression(s_exact_func(x, time), V_ode.element.interpolation_points())
)
# This is just zero
# v_init = ufl.replace(v_exact_func(x, t_var), {t_var: 0.0})
# v_ode.interpolate(dolfinx.fem.Expression(v_init, V_ode.element.interpolation_points()))
states = np.zeros((2, s.x.array.size))
states[1, :] = s.x.array
states[0, :] = v_ode.x.array
V_pde = dolfinx.fem.functionspace(mesh, ("P", 1))
v_pde = dolfinx.fem.Function(V_pde, name="v_pde")
C_m = 1.0
dx = ufl.Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_degree})
# Define variational formulation
v = ufl.TrialFunction(V_pde)
w = ufl.TestFunction(V_pde)
# # Set-up variational problem
# Dt_v_dt = v - v_ode
# G = (C_m * Dt_v_dt * w + dt * (ufl.inner(M * ufl.grad(v), ufl.grad(w)) - I_s * w)) * dx
# a, L = ufl.system(G)
# I = I_s
a = C_m * v * w * dx + dt * ufl.inner(M * ufl.grad(v), ufl.grad(w)) * dx
L = C_m * v_ode * w * dx + dt * I_s * w * dx
solver = dolfinx.fem.petsc.LinearProblem(a, L, u=v_pde)
dolfinx.fem.petsc.assemble_matrix(solver.A, solver.a)
solver.A.assemble()
v_expr = dolfinx.fem.Expression(v_pde, V_ode.element.interpolation_points())
v_exact_expr = dolfinx.fem.Expression(
v_exact_func(x, time), V_pde.element.interpolation_points()
)
v_exact = dolfinx.fem.Function(V_pde, name="v_exact")
path = "splitting_scheme.bp"
shutil.rmtree(path, ignore_errors=True)
vtx = VTXWriter(mesh.comm, path, [v_exact, v_pde], engine="BP5", write=save_vtx)
while time.value < T:
states[:] = simple_ode_forward_euler(states, time.value, dt, parameters=None)
v_ode.x.array[:] = states[0, :]
with solver.b.localForm() as b_loc:
b_loc.set(0)
dolfinx.fem.petsc.assemble_vector(solver.b, solver.L)
solver.b.ghostUpdate(
addv=PETSc.InsertMode.ADD,
mode=PETSc.ScatterMode.REVERSE,
)
solver.solver.solve(solver.b, v_pde.x.petsc_vec)
v_pde.x.scatter_forward()
# Make sure to update previous states
v_ode.interpolate(v_expr)
states[0, :] = v_ode.x.array
time.value += dt
v_exact.interpolate(v_exact_expr)
vtx.write(time.value)
error = dolfinx.fem.form((v_pde - v_exact) ** 2 * dx)
E = np.sqrt(comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM))
vtx.close()
return E
def external_operator(N=50, M=1.0, dt=0.01, T=1.0, quad_degree=4, save_vtx=False):
comm = MPI.COMM_WORLD
mesh = dolfinx.mesh.create_unit_square(comm, N, N, dolfinx.cpp.mesh.CellType.triangle)
time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0))
x = ufl.SpatialCoordinate(mesh)
I_s = ac_func(x, time)
# We create a function space for all of the ODE states (including V)
el = basix.ufl.quadrature_element(
scheme="default",
degree=quad_degree,
cell=mesh.ufl_cell().cellname(),
value_shape=(2,),
)
V_ode = dolfinx.fem.functionspace(mesh, el)
# We keep track of the states from a previous time step
ode_states_old = dolfinx.fem.Function(V_ode)
# And interpolate the initial condition
ode_states_old.interpolate(
dolfinx.fem.Expression(ode_exact_func(x, time), V_ode.element.interpolation_points())
)
# Here we return a function that will be called by the external operator
def f_external(derivatives: tuple[int, ...]):
if derivatives == (0, 0):
# Here er should return a callable, but we would also like to be able to
# save state. Therefore we create an instance of a class that has a __call__
# method. For now let this object hold information about the time, dt and parameters
return ODE(time, dt, parameters=None)
elif derivatives == (1, 0):
return NotImplementedError
elif derivatives == (0, 1):
return NotImplementedError
else:
return NotImplementedError
# Function space for the PDE
V_pde = dolfinx.fem.functionspace(mesh, ("P", 1))
v_pde = dolfinx.fem.Function(V_pde, name="v_pde")
# We create an external operator. This will be a function in the fuctions space
# defined by the `function_space` argument. In our case this will be the ode states
# of the current time step which depends on the PDE solution and the previous ODE states
ode_states = FEMExternalOperator(
v_pde,
ode_states_old,
function_space=V_ode,
external_function=f_external,
)
# Here we extract the first subfunction from the ode_states which will be the membrane potential
v_ode = ufl.split(ode_states)[0]
C_m = 1.0
dx = ufl.Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_degree})
# Define variational formulation
v = ufl.TrialFunction(V_pde)
w = ufl.TestFunction(V_pde)
# # Set-up variational problem
a = C_m * v * w * dx + dt * ufl.inner(M * ufl.grad(v), ufl.grad(w)) * dx
L = C_m * v_ode * w * dx + dt * I_s * w * dx
# Update right hand side with external operators
L_updated, operators = replace_external_operators(L)
L_compiled = dolfinx.fem.form(L_updated)
# Create solver and assemble matrix
solver = dolfinx.fem.petsc.LinearProblem(a, L_compiled, u=v_pde)
dolfinx.fem.petsc.assemble_matrix(solver.A, solver.a)
solver.A.assemble()
# Create expression for the exact solution to so that we can compare the solutions
v_exact_expr = dolfinx.fem.Expression(
v_exact_func(x, time), V_pde.element.interpolation_points()
)
v_exact = dolfinx.fem.Function(V_pde, name="v_exact")
path = "external_operator.bp"
shutil.rmtree(path, ignore_errors=True)
# Let us save both the exact solution and the computed solution
vtx = VTXWriter(mesh.comm, path, [v_exact, v_pde], engine="BP5", write=save_vtx)
while time.value < T:
# Evaluate external operators
coefficients = evaluate_operands(operators)
# Associate coefficients with external operators
evaluate_external_operators(operators, coefficients)
# Hold solution of ODE post evaluation
with solver.b.localForm() as b_loc:
b_loc.set(0)
dolfinx.fem.petsc.assemble_vector(solver.b, solver.L)
solver.b.ghostUpdate(
addv=PETSc.InsertMode.ADD,
mode=PETSc.ScatterMode.REVERSE,
)
solver.solver.solve(solver.b, v_pde.x.petsc_vec)
v_pde.x.scatter_forward()
# Make sure to update the old states
ode_states_old.interpolate(ode_states.ref_coefficient)
time.value += dt
vtx.write(time.value)
v_exact.interpolate(v_exact_expr)
error = dolfinx.fem.form((v_pde - v_exact) ** 2 * dx)
E = np.sqrt(comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM))
vtx.close()
return E
def main():
print("\nSplitting scheme (spatial)")
err_splitting = []
for N in [4, 8, 16, 32, 64]:
err = splitting_scheme(N=N, T=1.0, dt=0.0005)
print(f"N={N}, error={err}")
err_splitting.append(err)
if len(err_splitting) > 1:
oder = np.log2(err_splitting[-2] / err_splitting[-1])
print(f"Order of convergence: {oder}")
print("\nExternal operator (spatial)")
err_external = []
for N in [4, 8, 16, 32, 64]:
err = external_operator(N=N, T=1.0, dt=0.0005)
print(f"N={N}, error={err}")
err_external.append(err)
if len(err_external) > 1:
oder = np.log2(err_external[-2] / err_external[-1])
print(f"Order of convergence: {oder}")
print("\nSplitting scheme (temporal)")
err_splitting = []
for dt in [0.1 / 2**i for i in range(4)]:
err = splitting_scheme(dt=dt)
print(f"dt={N}, error={err}")
err_splitting.append(err)
if len(err_splitting) > 1:
oder = np.log2(err_splitting[-2] / err_splitting[-1])
print(f"Order of convergence: {oder}")
print("\nExternal operator (temporal)")
err_external = []
for dt in [0.1 / 2**i for i in range(4)]:
err = external_operator(dt=dt)
print(f"dt={N}, error={err}")
err_external.append(err)
if len(err_external) > 1:
oder = np.log2(err_external[-2] / err_external[-1])
print(f"Order of convergence: {oder}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment