Skip to content

Instantly share code, notes, and snippets.

@daviddavo
Created January 24, 2025 09:22
Show Gist options
  • Select an option

  • Save daviddavo/21943f4998b54b8681760175de578239 to your computer and use it in GitHub Desktop.

Select an option

Save daviddavo/21943f4998b54b8681760175de578239 to your computer and use it in GitHub Desktop.
Implementation of alternative basis example for StackOverflow question https://math.stackexchange.com/posts/5026786
# See https://math.stackexchange.com/posts/5026786
def submatrices(A):
d = len(A)
d2 = d//2
return [ [ [ [ A[d2*i + k][ d2*j+l] for l in range(d2) ] for k in range(d2) ] for j in range(2) ] for i in range(2) ]
def combine(A):
d = len(A[0][0])
d2 = d*2
return [ [ A[i//d][j//d][i%d][j%d] for j in range(d2) ] for i in range(d2) ]
def madd(A, B):
# if len(A) == 1:
# return [[A[0][0] + B[0][0]]]
return [ [ A[i][j]+B[i][j] for j in range(len(A))] for i in range(len(A)) ]
def msub(A, B):
# if len(A) == 1:
# return [[ A[0][0] - B[0][0] ]]
return [ [ A[i][j]-B[i][j] for j in range(len(A))] for i in range(len(A)) ]
def phi(As):
return [
[As[0][0], As[0][1]],
[As[1][0], madd(As[1][1], msub(As[0][1], As[1][0]))],
]
def vopt(Cs):
return [
[
Cs[0][0],
msub(Cs[0][1],Cs[1][1]),
],
[
msub(Cs[1][1], Cs[1][0]),
Cs[1][1],
]
]
def naive_mmul(A, B):
return [ [ sum(A[i][k] * B[k][j] for k in range(len(A))) for j in range(len(A)) ] for i in range(len(A)) ]
def schwartz(A, B, transform=True):
if len(A) == 1:
return [[ A[0][0] * B[0][0] ]]
As = submatrices(A)
Bs = submatrices(B)
if transform:
As = phi(As)
Bs = phi(Bs)
t1 = madd(As[1][0], As[1][1])
t2 = msub(As[1][1], As[0][1])
t3 = msub(As[1][1], As[0][0])
t4 = msub(Bs[1][1], Bs[0][0])
t5 = madd(Bs[1][0], Bs[1][1])
t6 = msub(Bs[1][1], Bs[0][1])
M1 = schwartz(As[0][0], Bs[0][0], False)
M2 = schwartz(As[0][1], Bs[1][0], False)
M3 = schwartz(As[1][0], t4, False)
M4 = schwartz(As[1][1], Bs[1][1], False)
M5 = schwartz(t1, t5, False)
M6 = schwartz(t2, t6, False)
M7 = schwartz(t3, Bs[0][1], False)
Cs = [
[
madd(M1, M2),
msub(M5, M7),
],
[
madd(M3, M6),
msub(msub(madd(M5, M6), M2), M4),
]
]
if transform:
Cs = vopt(Cs)
return combine(Cs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment