Created
August 17, 2025 14:19
-
-
Save cohnt/8d8d927a3c0077498a1bc06136ff6a17 to your computer and use it in GitHub Desktop.
Calculation of scalar curvature of an embedded submanifold with JAX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Vibecoded with ChatGPT | |
| import jax | |
| import jax.numpy as jnp | |
| def scalar_curvature_projector(f, x0, n, k, rtol=1e-8): | |
| """ | |
| Compute the scalar curvature of M = {x : f(x) = 0} at x0, | |
| using the projector formula (basis-free Gauss equation). | |
| Parameters | |
| ---------- | |
| f : callable | |
| Smooth function f: R^n -> R^k, written in JAX. Must return a length-k array. | |
| x0 : array_like, shape (n,) | |
| Point near or on the manifold. | |
| n : int | |
| Input dimension of f. | |
| k : int | |
| Output dimension of f (codimension of M). | |
| rtol : float | |
| Relative tolerance for rank test of J. | |
| Returns | |
| ------- | |
| S : float | |
| Scalar curvature at x0. | |
| """ | |
| x0 = jnp.asarray(x0, dtype=jnp.float64) | |
| if x0.shape != (n,): | |
| raise ValueError(f"x0 must have shape ({n},), got {x0.shape}") | |
| # Jacobian J (k x n) | |
| J = jax.jacfwd(f)(x0) | |
| if J.shape != (k, n): | |
| raise ValueError(f"Jacobian has shape {J.shape}, expected ({k},{n})") | |
| # Rank check via SVD | |
| U, Svals, Vt = jnp.linalg.svd(J, full_matrices=False) | |
| rank = jnp.sum(Svals > rtol * Svals[0]) | |
| if rank != k: | |
| raise ValueError(f"Expected rank {k}, got {int(rank)} at x0.") | |
| # Normal metric C = JJ^T, invert to raise normal indices | |
| C = J @ J.T # (k,k) | |
| Cinv = jnp.linalg.inv(C) | |
| # Tangent projector: PT = I - J^T C^{-1} J | |
| PT = jnp.eye(n, dtype=x0.dtype) - J.T @ Cinv @ J # (n,n) | |
| # Hessians H_i (k, n, n) | |
| H = jax.jacfwd(jax.jacrev(f))(x0) | |
| if H.shape != (k, n, n): | |
| raise ValueError(f"Hessians have shape {H.shape}, expected ({k},{n},{n})") | |
| # tr_i = tr(PT H_i PT) | |
| tr_i = jax.vmap(lambda Hi: jnp.trace(PT @ Hi @ PT))(H) # (k,) | |
| # Bij = tr(PT H_i PT H_j PT) | |
| def bij(Hi, Hj): | |
| return jnp.trace(PT @ Hi @ PT @ Hj @ PT) | |
| B = jax.vmap(lambda Hi: jax.vmap(lambda Hj: bij(Hi, Hj))(H))(H) # (k,k) | |
| # Scalar curvature formula | |
| S = jnp.sum(Cinv * (tr_i[:, None]*tr_i[None, :] - B)) | |
| return S | |
| # Example: sphere in R^3, f(x) = ||x||^2 - R^2, k=1 | |
| R = 2.0 | |
| def f(x): | |
| return jnp.array([jnp.dot(x, x) - R**2]) | |
| x0 = jnp.array([R, 0.0, 0.0]) # point on sphere | |
| S = scalar_curvature_projector(f, x0, n=3, k=1) | |
| print("Scalar curvature:", S) # should be d(d-1)/R^2 with d=2 => 2/R^2 | |
| print("Expected:", 2 / R**2) | |
| print() | |
| def f_affine(x): | |
| return jnp.array([x[0] - 1.0, x[1] - 2.0]) | |
| x0 = jnp.array([1.0, 2.0, 3.0, -1.0]) | |
| S = scalar_curvature_projector(f_affine, x0, n=4, k=2) | |
| print("Affine subspace curvature:", S) # should be 0 | |
| print("Expected:", 0) | |
| print() | |
| def f_hyperboloid(x): | |
| return jnp.array([-x[0]**2 + x[1]**2 + x[2]**2 - 1.0]) | |
| def f_saddle(x): | |
| # x = (x, y, z) | |
| return jnp.array([x[2] - (x[0]**2 - x[1]**2)]) | |
| x0 = jnp.array([0.0, 0.0, 0.0]) # origin | |
| S = scalar_curvature_projector(f_saddle, x0, n=3, k=1) | |
| print("Saddle surface curvature at origin:", S) | |
| print("Expected:", -8) | |
| print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment