Skip to content

Instantly share code, notes, and snippets.

@cohnt
Created August 17, 2025 14:19
Show Gist options
  • Select an option

  • Save cohnt/8d8d927a3c0077498a1bc06136ff6a17 to your computer and use it in GitHub Desktop.

Select an option

Save cohnt/8d8d927a3c0077498a1bc06136ff6a17 to your computer and use it in GitHub Desktop.
Calculation of scalar curvature of an embedded submanifold with JAX
# Vibecoded with ChatGPT
import jax
import jax.numpy as jnp
def scalar_curvature_projector(f, x0, n, k, rtol=1e-8):
"""
Compute the scalar curvature of M = {x : f(x) = 0} at x0,
using the projector formula (basis-free Gauss equation).
Parameters
----------
f : callable
Smooth function f: R^n -> R^k, written in JAX. Must return a length-k array.
x0 : array_like, shape (n,)
Point near or on the manifold.
n : int
Input dimension of f.
k : int
Output dimension of f (codimension of M).
rtol : float
Relative tolerance for rank test of J.
Returns
-------
S : float
Scalar curvature at x0.
"""
x0 = jnp.asarray(x0, dtype=jnp.float64)
if x0.shape != (n,):
raise ValueError(f"x0 must have shape ({n},), got {x0.shape}")
# Jacobian J (k x n)
J = jax.jacfwd(f)(x0)
if J.shape != (k, n):
raise ValueError(f"Jacobian has shape {J.shape}, expected ({k},{n})")
# Rank check via SVD
U, Svals, Vt = jnp.linalg.svd(J, full_matrices=False)
rank = jnp.sum(Svals > rtol * Svals[0])
if rank != k:
raise ValueError(f"Expected rank {k}, got {int(rank)} at x0.")
# Normal metric C = JJ^T, invert to raise normal indices
C = J @ J.T # (k,k)
Cinv = jnp.linalg.inv(C)
# Tangent projector: PT = I - J^T C^{-1} J
PT = jnp.eye(n, dtype=x0.dtype) - J.T @ Cinv @ J # (n,n)
# Hessians H_i (k, n, n)
H = jax.jacfwd(jax.jacrev(f))(x0)
if H.shape != (k, n, n):
raise ValueError(f"Hessians have shape {H.shape}, expected ({k},{n},{n})")
# tr_i = tr(PT H_i PT)
tr_i = jax.vmap(lambda Hi: jnp.trace(PT @ Hi @ PT))(H) # (k,)
# Bij = tr(PT H_i PT H_j PT)
def bij(Hi, Hj):
return jnp.trace(PT @ Hi @ PT @ Hj @ PT)
B = jax.vmap(lambda Hi: jax.vmap(lambda Hj: bij(Hi, Hj))(H))(H) # (k,k)
# Scalar curvature formula
S = jnp.sum(Cinv * (tr_i[:, None]*tr_i[None, :] - B))
return S
# Example: sphere in R^3, f(x) = ||x||^2 - R^2, k=1
R = 2.0
def f(x):
return jnp.array([jnp.dot(x, x) - R**2])
x0 = jnp.array([R, 0.0, 0.0]) # point on sphere
S = scalar_curvature_projector(f, x0, n=3, k=1)
print("Scalar curvature:", S) # should be d(d-1)/R^2 with d=2 => 2/R^2
print("Expected:", 2 / R**2)
print()
def f_affine(x):
return jnp.array([x[0] - 1.0, x[1] - 2.0])
x0 = jnp.array([1.0, 2.0, 3.0, -1.0])
S = scalar_curvature_projector(f_affine, x0, n=4, k=2)
print("Affine subspace curvature:", S) # should be 0
print("Expected:", 0)
print()
def f_hyperboloid(x):
return jnp.array([-x[0]**2 + x[1]**2 + x[2]**2 - 1.0])
def f_saddle(x):
# x = (x, y, z)
return jnp.array([x[2] - (x[0]**2 - x[1]**2)])
x0 = jnp.array([0.0, 0.0, 0.0]) # origin
S = scalar_curvature_projector(f_saddle, x0, n=3, k=1)
print("Saddle surface curvature at origin:", S)
print("Expected:", -8)
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment