Skip to content

Instantly share code, notes, and snippets.

@jorgensd
Last active June 9, 2024 14:51
Show Gist options
  • Select an option

  • Save jorgensd/addf4aab06c6820bb2d8138dbb26c7d1 to your computer and use it in GitHub Desktop.

Select an option

Save jorgensd/addf4aab06c6820bb2d8138dbb26c7d1 to your computer and use it in GitHub Desktop.
Local solver for DOLFINx
# Local solver using Scipy
# To be used with DOLFINx and a local-assembler
#
# Author: Jørgen S. Dokken
#
# License: MIT
import scipy
import scipy.linalg
import numpy
import numpy.typing
from typing import List
import numba
def plu_solve(P: numpy.typing.NDArray[numpy.float64],
L: numpy.typing.NDArray[numpy.float64],
U: numpy.typing.NDArray[numpy.float64],
b_local: numpy.typing.NDArray[numpy.float64],
x: numpy.typing.NDArray[numpy.float64]):
"""Solve a partial LU pivoted sustem using a forward and backward solve"""
numpy.dot(P.T, b_local, out=x)
scipy.linalg.solve_triangular(L, x, lower=True, overwrite_b=True)
scipy.linalg.solve_triangular(U, x, lower=False, overwrite_b=True)
class LocalSolver():
# Lookup for LU factorization
lookup_index = numpy.typing.NDArray[numpy.int32]
# Caches for matrix factorization
Ps: List[numpy.typing.NDArray[numpy.float64]]
Ls: List[numpy.typing.NDArray[numpy.float64]]
Us: List[numpy.typing.NDArray[numpy.float64]]
# Arrays for store solutions while solving
sol: numpy.typing.NDArray[numpy.float64]
def __init__(self, num_cells_local: int):
self.lookup_index = numpy.full(num_cells_local, -1, dtype=numpy.int32)
self.Ps = []
self.Ls = []
self.Us = []
self.sol = None
def factorize(self, A_local:numpy.typing.NDArray, cell: int):
"""
Factorize Local matrix A and store lookup for cell index `i`
"""
assert cell < len(self.lookup_index), f"Illegal {cell=} > {len(self.lookup_index)}"
# Insert new matrices if LU-decomposition has not been computed
P, L, U = scipy.linalg.lu(A_local)
if (input_index:=self.lookup_index[cell]) >= 0:
print(f"Cache exists, overwriting {input_index} for {cell}")
self.Ps[input_index] = P
self.Ls[input_index] = L
self.Us[input_index] = U
else:
print(f"New {cell} inserting at {len(self.Ps)}")
self.lookup_index[cell] = len(self.Ps)
self.Ps.append(P)
self.Ls.append(L)
self.Us.append(U)
def solve(self, b_local:numpy.typing.NDArray[numpy.float64], cell:int):
index = self.lookup_index[cell]
assert index >= 0, f"Missing factorization for {cell=}. Have you called `factorize(A_local, {cell})`"
if self.sol is None:
self.sol = numpy.empty_like(b_local)
else:
assert len(self.sol) == len(b_local), f"Inconsistent size of solution vector and input vector"
plu_solve(self.Ps[index], self.Ls[index], self.Us[index], b_local, self.sol)
return self.sol
h = 0.2
A = numpy.array([[1, 0, 0, 0, 0, 0],
[-1/h, 2/h, -1/h, 0, 0, 0],
[0, -1/h, 2/h, -1/h, 0, 0],
[0, 0, -1/h, 2/h, -1/h, 0],
[0, 0, 0, -1/h, 2/h, -1/h],
[0, 0, 0, 0, 0, 1]], dtype=numpy.float64)
P, L, U = scipy.linalg.lu(A)
b = numpy.array([-3, 0, 0, 0, 0, 2], dtype=numpy.float64)
x = numpy.empty_like(b)
plu_solve(P, L, U, b, x)
print("Hardcoded sol", x)
solver = LocalSolver(10)
solver.factorize(A, 0)
solver.factorize(A, 0)
solver.factorize(A, 1)
solver.factorize(A, 1)
print(solver.solve(b, 0))
print(solver.solve(b, 1))
#solver.solve(b, 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment