Last active
March 10, 2025 17:28
-
-
Save finsberg/b7db6d8df23d6b9ee8f43481f767b303 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| import shutil | |
| from mpi4py import MPI | |
| from petsc4py import PETSc | |
| import dolfinx.fem.petsc | |
| from dolfinx_external_operator import ( | |
| FEMExternalOperator, | |
| evaluate_external_operators, | |
| evaluate_operands, | |
| replace_external_operators, | |
| ) | |
| import numpy.typing as npt | |
| import numpy as np | |
| import basix.ufl | |
| import ufl | |
| def v_exact_func(x, t): | |
| return ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.sin(t) | |
| def s_exact_func(x, t): | |
| return -ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.cos(t) | |
| def ode_exact_func(x, t): | |
| return ufl.as_vector((v_exact_func(x, t), s_exact_func(x, t))) | |
| def ac_func(x, t): | |
| return 8 * ufl.pi**2 * ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.sin(t) | |
| @dataclass | |
| class ODE: | |
| time: dolfinx.fem.Constant | |
| dt: float | |
| parameters: npt.NDArray | |
| def __call__(self, v: npt.NDArray, ode_states: npt.NDArray): | |
| s_ode = ode_states[:, 1::2] # Extract s | |
| states = np.vstack([v.flatten(), s_ode.flatten()]) | |
| new_values = simple_ode_forward_euler( | |
| states, self.time.value, self.dt, parameters=self.parameters | |
| ) | |
| return new_values.T.flatten() | |
| def simple_ode_forward_euler(states, t, dt, parameters): | |
| v, s = states | |
| values = np.zeros_like(states) | |
| values[0] = v - s * dt | |
| values[1] = s + v * dt | |
| return values | |
| def VTXWriter(mesh, path, functions, engine="BP5", write: bool = False): | |
| if write: | |
| return dolfinx.io.VTXWriter(mesh.comm, path, functions, engine=engine) | |
| else: | |
| class DummyVTXWriter: | |
| def write(self, *args, **kwargs): | |
| pass | |
| def close(self): | |
| pass | |
| return DummyVTXWriter() | |
| def splitting_scheme(N=50, M=1.0, dt=0.01, T=1.0, quad_degree=4, save_vtx=False): | |
| comm = MPI.COMM_WORLD | |
| mesh = dolfinx.mesh.create_unit_square(comm, N, N, dolfinx.cpp.mesh.CellType.triangle) | |
| time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0)) | |
| x = ufl.SpatialCoordinate(mesh) | |
| I_s = ac_func(x, time) | |
| el = basix.ufl.quadrature_element( | |
| scheme="default", degree=quad_degree, cell=mesh.ufl_cell().cellname() | |
| ) | |
| V_ode = dolfinx.fem.functionspace(mesh, el) | |
| v_ode = dolfinx.fem.Function(V_ode) | |
| s = dolfinx.fem.Function(V_ode) | |
| s.interpolate( | |
| dolfinx.fem.Expression(s_exact_func(x, time), V_ode.element.interpolation_points()) | |
| ) | |
| # This is just zero | |
| # v_init = ufl.replace(v_exact_func(x, t_var), {t_var: 0.0}) | |
| # v_ode.interpolate(dolfinx.fem.Expression(v_init, V_ode.element.interpolation_points())) | |
| states = np.zeros((2, s.x.array.size)) | |
| states[1, :] = s.x.array | |
| states[0, :] = v_ode.x.array | |
| V_pde = dolfinx.fem.functionspace(mesh, ("P", 1)) | |
| v_pde = dolfinx.fem.Function(V_pde, name="v_pde") | |
| C_m = 1.0 | |
| dx = ufl.Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_degree}) | |
| # Define variational formulation | |
| v = ufl.TrialFunction(V_pde) | |
| w = ufl.TestFunction(V_pde) | |
| # # Set-up variational problem | |
| # Dt_v_dt = v - v_ode | |
| # G = (C_m * Dt_v_dt * w + dt * (ufl.inner(M * ufl.grad(v), ufl.grad(w)) - I_s * w)) * dx | |
| # a, L = ufl.system(G) | |
| # I = I_s | |
| a = C_m * v * w * dx + dt * ufl.inner(M * ufl.grad(v), ufl.grad(w)) * dx | |
| L = C_m * v_ode * w * dx + dt * I_s * w * dx | |
| solver = dolfinx.fem.petsc.LinearProblem(a, L, u=v_pde) | |
| dolfinx.fem.petsc.assemble_matrix(solver.A, solver.a) | |
| solver.A.assemble() | |
| v_expr = dolfinx.fem.Expression(v_pde, V_ode.element.interpolation_points()) | |
| v_exact_expr = dolfinx.fem.Expression( | |
| v_exact_func(x, time), V_pde.element.interpolation_points() | |
| ) | |
| v_exact = dolfinx.fem.Function(V_pde, name="v_exact") | |
| path = "splitting_scheme.bp" | |
| shutil.rmtree(path, ignore_errors=True) | |
| vtx = VTXWriter(mesh.comm, path, [v_exact, v_pde], engine="BP5", write=save_vtx) | |
| while time.value < T: | |
| states[:] = simple_ode_forward_euler(states, time.value, dt, parameters=None) | |
| v_ode.x.array[:] = states[0, :] | |
| with solver.b.localForm() as b_loc: | |
| b_loc.set(0) | |
| dolfinx.fem.petsc.assemble_vector(solver.b, solver.L) | |
| solver.b.ghostUpdate( | |
| addv=PETSc.InsertMode.ADD, | |
| mode=PETSc.ScatterMode.REVERSE, | |
| ) | |
| solver.solver.solve(solver.b, v_pde.x.petsc_vec) | |
| v_pde.x.scatter_forward() | |
| # Make sure to update previous states | |
| v_ode.interpolate(v_expr) | |
| states[0, :] = v_ode.x.array | |
| time.value += dt | |
| v_exact.interpolate(v_exact_expr) | |
| vtx.write(time.value) | |
| error = dolfinx.fem.form((v_pde - v_exact) ** 2 * dx) | |
| E = np.sqrt(comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM)) | |
| vtx.close() | |
| return E | |
| def external_operator(N=50, M=1.0, dt=0.01, T=1.0, quad_degree=4, save_vtx=False): | |
| comm = MPI.COMM_WORLD | |
| mesh = dolfinx.mesh.create_unit_square(comm, N, N, dolfinx.cpp.mesh.CellType.triangle) | |
| time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0)) | |
| x = ufl.SpatialCoordinate(mesh) | |
| I_s = ac_func(x, time) | |
| # We create a function space for all of the ODE states (including V) | |
| el = basix.ufl.quadrature_element( | |
| scheme="default", | |
| degree=quad_degree, | |
| cell=mesh.ufl_cell().cellname(), | |
| value_shape=(2,), | |
| ) | |
| V_ode = dolfinx.fem.functionspace(mesh, el) | |
| # We keep track of the states from a previous time step | |
| ode_states_old = dolfinx.fem.Function(V_ode) | |
| # And interpolate the initial condition | |
| ode_states_old.interpolate( | |
| dolfinx.fem.Expression(ode_exact_func(x, time), V_ode.element.interpolation_points()) | |
| ) | |
| # Here we return a function that will be called by the external operator | |
| def f_external(derivatives: tuple[int, ...]): | |
| if derivatives == (0, 0): | |
| # Here er should return a callable, but we would also like to be able to | |
| # save state. Therefore we create an instance of a class that has a __call__ | |
| # method. For now let this object hold information about the time, dt and parameters | |
| return ODE(time, dt, parameters=None) | |
| elif derivatives == (1, 0): | |
| return NotImplementedError | |
| elif derivatives == (0, 1): | |
| return NotImplementedError | |
| else: | |
| return NotImplementedError | |
| # Function space for the PDE | |
| V_pde = dolfinx.fem.functionspace(mesh, ("P", 1)) | |
| v_pde = dolfinx.fem.Function(V_pde, name="v_pde") | |
| # We create an external operator. This will be a function in the fuctions space | |
| # defined by the `function_space` argument. In our case this will be the ode states | |
| # of the current time step which depends on the PDE solution and the previous ODE states | |
| ode_states = FEMExternalOperator( | |
| v_pde, | |
| ode_states_old, | |
| function_space=V_ode, | |
| external_function=f_external, | |
| ) | |
| # Here we extract the first subfunction from the ode_states which will be the membrane potential | |
| v_ode = ufl.split(ode_states)[0] | |
| C_m = 1.0 | |
| dx = ufl.Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_degree}) | |
| # Define variational formulation | |
| v = ufl.TrialFunction(V_pde) | |
| w = ufl.TestFunction(V_pde) | |
| # # Set-up variational problem | |
| a = C_m * v * w * dx + dt * ufl.inner(M * ufl.grad(v), ufl.grad(w)) * dx | |
| L = C_m * v_ode * w * dx + dt * I_s * w * dx | |
| # Update right hand side with external operators | |
| L_updated, operators = replace_external_operators(L) | |
| L_compiled = dolfinx.fem.form(L_updated) | |
| # Create solver and assemble matrix | |
| solver = dolfinx.fem.petsc.LinearProblem(a, L_compiled, u=v_pde) | |
| dolfinx.fem.petsc.assemble_matrix(solver.A, solver.a) | |
| solver.A.assemble() | |
| # Create expression for the exact solution to so that we can compare the solutions | |
| v_exact_expr = dolfinx.fem.Expression( | |
| v_exact_func(x, time), V_pde.element.interpolation_points() | |
| ) | |
| v_exact = dolfinx.fem.Function(V_pde, name="v_exact") | |
| path = "external_operator.bp" | |
| shutil.rmtree(path, ignore_errors=True) | |
| # Let us save both the exact solution and the computed solution | |
| vtx = VTXWriter(mesh.comm, path, [v_exact, v_pde], engine="BP5", write=save_vtx) | |
| while time.value < T: | |
| # Evaluate external operators | |
| coefficients = evaluate_operands(operators) | |
| # Associate coefficients with external operators | |
| evaluate_external_operators(operators, coefficients) | |
| # Hold solution of ODE post evaluation | |
| with solver.b.localForm() as b_loc: | |
| b_loc.set(0) | |
| dolfinx.fem.petsc.assemble_vector(solver.b, solver.L) | |
| solver.b.ghostUpdate( | |
| addv=PETSc.InsertMode.ADD, | |
| mode=PETSc.ScatterMode.REVERSE, | |
| ) | |
| solver.solver.solve(solver.b, v_pde.x.petsc_vec) | |
| v_pde.x.scatter_forward() | |
| # Make sure to update the old states | |
| ode_states_old.interpolate(ode_states.ref_coefficient) | |
| time.value += dt | |
| vtx.write(time.value) | |
| v_exact.interpolate(v_exact_expr) | |
| error = dolfinx.fem.form((v_pde - v_exact) ** 2 * dx) | |
| E = np.sqrt(comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM)) | |
| vtx.close() | |
| return E | |
| def main(): | |
| print("\nSplitting scheme (spatial)") | |
| err_splitting = [] | |
| for N in [4, 8, 16, 32, 64]: | |
| err = splitting_scheme(N=N, T=1.0, dt=0.0005) | |
| print(f"N={N}, error={err}") | |
| err_splitting.append(err) | |
| if len(err_splitting) > 1: | |
| oder = np.log2(err_splitting[-2] / err_splitting[-1]) | |
| print(f"Order of convergence: {oder}") | |
| print("\nExternal operator (spatial)") | |
| err_external = [] | |
| for N in [4, 8, 16, 32, 64]: | |
| err = external_operator(N=N, T=1.0, dt=0.0005) | |
| print(f"N={N}, error={err}") | |
| err_external.append(err) | |
| if len(err_external) > 1: | |
| oder = np.log2(err_external[-2] / err_external[-1]) | |
| print(f"Order of convergence: {oder}") | |
| print("\nSplitting scheme (temporal)") | |
| err_splitting = [] | |
| for dt in [0.1 / 2**i for i in range(4)]: | |
| err = splitting_scheme(dt=dt) | |
| print(f"dt={N}, error={err}") | |
| err_splitting.append(err) | |
| if len(err_splitting) > 1: | |
| oder = np.log2(err_splitting[-2] / err_splitting[-1]) | |
| print(f"Order of convergence: {oder}") | |
| print("\nExternal operator (temporal)") | |
| err_external = [] | |
| for dt in [0.1 / 2**i for i in range(4)]: | |
| err = external_operator(dt=dt) | |
| print(f"dt={N}, error={err}") | |
| err_external.append(err) | |
| if len(err_external) > 1: | |
| oder = np.log2(err_external[-2] / err_external[-1]) | |
| print(f"Order of convergence: {oder}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment