Created
April 6, 2025 20:29
-
-
Save futurisold/a98e43f189e83c724b4a3ab9b02c23c9 to your computer and use it in GitHub Desktop.
(e.g.) projecting tokens onto sphere
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import colorsys | |
| import logging | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from rich.console import Console | |
| from rich.text import Text | |
| from scipy.interpolate import griddata | |
| from scipy.ndimage import gaussian_filter | |
| from symai import Symbol | |
| logging.getLogger("jax").setLevel(logging.WARNING) | |
| jnp.set_printoptions(precision=4, suppress=True) | |
| def get_color(prob): | |
| """Generate a color between red (0.0) and green (1.0).""" | |
| # Use HSV color space where hue 1/3 (120°) is green and 0.0 (0°) is red | |
| h = prob / 3 # Maps 0-1 to 0-0.33 (red to green in HSV) | |
| s = 0.9 | |
| v = 0.9 | |
| r, g, b = colorsys.hsv_to_rgb(h, s, v) | |
| return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}" | |
| def viz_token_entropy(logprobs, flowing_text=False): | |
| """ | |
| Visualize entropy of token distributions using top-k alternatives. | |
| Color scale: red (high entropy/uncertainty) -> green (low entropy/certainty) | |
| """ | |
| console = Console() | |
| console.print("\n[bold]Token Entropy Visualization:[/bold]") | |
| # First pass: compute all entropies to find range for normalization | |
| entropies = [compute_entropy(token.top_logprobs) for token in logprobs] | |
| max_entropy = max(entropies) if entropies else 1.0 | |
| min_entropy = min(entropies) if entropies else 0.0 | |
| entropy_range = max_entropy - min_entropy | |
| if flowing_text: | |
| text = Text() | |
| for token_logprob, entropy in zip(logprobs, entropies): | |
| token_text = token_logprob.token | |
| # Normalize and invert (so low entropy = green, high entropy = red) | |
| norm_entropy = 1.0 - ((entropy - min_entropy) / entropy_range) if entropy_range > 0 else 1.0 | |
| color = get_color(norm_entropy) | |
| text.append(token_text, style=str(color)) | |
| console.print(text) | |
| return | |
| # Non-flowing text: show each token with entropy value | |
| for token_logprob, entropy in zip(logprobs, entropies): | |
| token_text = token_logprob.token | |
| norm_entropy = 1.0 - ((entropy - min_entropy) / entropy_range) if entropy_range > 0 else 1.0 | |
| color = get_color(norm_entropy) | |
| text = Text() | |
| text.append(f"{token_text}", style=f"black on {color}") | |
| text.append(f" (entropy: {entropy:.3f})", style="dim") | |
| console.print(text) | |
| def compute_probs(logprobs): | |
| return jnp.exp( | |
| jnp.array( | |
| [t.logprob for t in logprobs] | |
| ) | |
| ) | |
| def compute_entropy(top_logprobs): | |
| probs = jnp.exp( | |
| jnp.array( | |
| [lp.logprob for lp in top_logprobs] | |
| ) | |
| ) | |
| return -jnp.sum(probs * jnp.log(probs)) | |
| @jax.jit | |
| def compute_curvature( | |
| R_GRID, | |
| dX_t, dY_t, dZ_t, | |
| dX_p, dY_p, dZ_p, | |
| dX_tt, dY_tt, dZ_tt, | |
| dX_tp, dY_tp, dZ_tp, | |
| dX_pp, dY_pp, dZ_pp | |
| ): | |
| I, J = R_GRID.shape | |
| K = jnp.zeros((I, J)) | |
| def body_fun(idx, K): | |
| # Convert a flat index to (i,j). | |
| i = idx // J | |
| j = idx % J | |
| # Construct derivative vectors at (i,j). | |
| rt = jnp.array([dX_t[i, j], dY_t[i, j], dZ_t[i, j]]) | |
| rp = jnp.array([dX_p[i, j], dY_p[i, j], dZ_p[i, j]]) | |
| rtt = jnp.array([dX_tt[i, j], dY_tt[i, j], dZ_tt[i, j]]) | |
| rtp = jnp.array([dX_tp[i, j], dY_tp[i, j], dZ_tp[i, j]]) | |
| rpp = jnp.array([dX_pp[i, j], dY_pp[i, j], dZ_pp[i, j]]) | |
| # First fundamental form coefficients. | |
| E = jnp.dot(rt, rt) | |
| F = jnp.dot(rt, rp) | |
| G = jnp.dot(rp, rp) | |
| det_g = E * G - F ** 2 | |
| # Compute unnormalized normal and its norm. | |
| n_unnorm = jnp.cross(rt, rp) | |
| n_len = jnp.linalg.norm(n_unnorm, ord=2) | |
| # Non-degenerate branch. | |
| def non_degenerate(): | |
| n = n_unnorm / n_len | |
| L = jnp.dot(rtt, n) | |
| M = jnp.dot(rtp, n) | |
| N_val = jnp.dot(rpp, n) | |
| det_L = L * N_val - M ** 2 | |
| return det_L / det_g | |
| # Use lax.cond to decide if the point is degenerate (conditionals can't be jit) | |
| curvature = jax.lax.cond( | |
| jnp.logical_or(n_len < 1e-14, jnp.abs(det_g) < 1e-14), | |
| lambda: 0.0, | |
| non_degenerate | |
| ) | |
| return K.at[i, j].set(curvature) | |
| # Use a single loop over all indices. | |
| K = jax.lax.fori_loop(0, I*J, body_fun, K) | |
| return K | |
| @jax.jit | |
| def find_nearest_token(grid_theta, grid_phi, data_theta, data_phi): | |
| # Compute periodic difference in theta. | |
| theta_diff = jnp.abs(data_theta - grid_theta) | |
| theta_diff = jnp.minimum(theta_diff, 2 * jnp.pi - theta_diff) | |
| phi_diff = jnp.abs(data_phi - grid_phi) | |
| dist = jnp.sqrt(theta_diff**2 + phi_diff**2) | |
| return jnp.argmin(dist) | |
| def viz_token_projection( | |
| logprobs, | |
| title="Curvature with Outlier Clipping & Pole Offsets", | |
| cmap="blackbody", | |
| grid_size=150, | |
| smooth_radius=None, # smooth radius field | |
| smooth_curvature=None, # smooth final curvature | |
| clip_curvature=False, # clip curvature outliers | |
| phi_epsilon=0.05, # avoid exact 0 or π at the poles | |
| mod_factor=0.125, | |
| interp_method="nearest", | |
| fname="output.html" | |
| ): | |
| # 1) Unpacking | |
| toks = [t.token for t in logprobs] | |
| probs = compute_probs(logprobs) | |
| ents = jnp.array([compute_entropy(lp.top_logprobs) for lp in logprobs]) | |
| # 2) Map each token to (theta, phi) | |
| # Spherical coordinates transformation | |
| # Convention I'm following: | |
| # * radial distance: r ≥ 0, | |
| # * polar angle: 0 rad ≤ φ < π rad, | |
| # * azimuth: 0 rad ≤ θ < 2π rad. | |
| # | |
| # Then I'm constructing a parametric surface as follow: \ | |
| # r(θ,φ) = R(θ,φ) • [sinφ • cosθ, sinφ • sinθ, cosφ], with θ and φ in the above ranges. | |
| toks_pos = jnp.arange(len(toks)) | |
| # Map position -> [0,1], then to [0,2π] | |
| toks_pos_norm = (toks_pos - toks_pos.min()) / jnp.ptp(toks_pos) | |
| data_theta = toks_pos_norm * 2*jnp.pi | |
| # Probability -> [phi_epsilon, π - phi_epsilon] | |
| # (So we never exactly hit 0 or π.) | |
| data_phi = phi_epsilon + (probs / probs.max())*(jnp.pi - 2*phi_epsilon) | |
| # Radius function: base + alpha * entropy | |
| base_radius = 1.0 | |
| mod_factor = mod_factor | |
| data_R = base_radius + mod_factor * ents | |
| # 3) Interpolate (theta, phi) -> R over a 2D grid | |
| theta_lin = jnp.linspace(0, 2*jnp.pi, grid_size) | |
| phi_lin = jnp.linspace(phi_epsilon, jnp.pi - phi_epsilon, grid_size) | |
| THETA, PHI = jnp.meshgrid(theta_lin, phi_lin) | |
| pts = jnp.column_stack([data_theta, data_phi]) # shape=(N,2) | |
| R_GRID = jnp.array(griddata(pts, data_R, (THETA, PHI), method=interp_method, fill_value=base_radius)) | |
| if smooth_radius is not None: | |
| R_GRID = jnp.array(gaussian_filter(R_GRID, sigma=smooth_radius)) | |
| # 4) Gaussian Curvature | |
| # I'm trying to compute now the Gaussian curvature following the first approx from | |
| # here (https://en.wikipedia.org/wiki/Gaussian_curvature#Alternative_formulas), | |
| # namely as the ratio of determinants between first and second fundamental components. | |
| # I need: | |
| # - First derivatives with respect to θ and φ | |
| # - Second derivatives as follow: θθ, θφ, φφ | |
| def fX(r, phi, theta): return r * jnp.sin(phi) * jnp.cos(theta) # X | |
| def fY(r, phi, theta): return r * jnp.sin(phi) * jnp.sin(theta) # Y | |
| def fZ(r, phi, theta): return r * jnp.cos(phi) # Z | |
| first_ord_deriv = lambda z, arg: jax.vmap(jax.grad(z, argnums=arg)) | |
| second_ord_deriv = lambda z, args: jax.vmap(jax.grad(jax.grad(z, argnums=args[0]), argnums=args[1])) | |
| I, J = R_GRID.shape | |
| args = (R_GRID.ravel(), PHI.ravel(), THETA.ravel()) | |
| X = fX(*args).reshape(I, J) | |
| Y = fY(*args).reshape(I, J) | |
| Z = fZ(*args).reshape(I, J) | |
| # t = theta; p = phi | |
| #------------------- | |
| # 1st order | |
| # phi | |
| dX_p = first_ord_deriv(fX, arg=1)(*args).reshape(I, J) | |
| dY_p = first_ord_deriv(fY, arg=1)(*args).reshape(I, J) | |
| dZ_p = first_ord_deriv(fZ, arg=1)(*args).reshape(I, J) | |
| # theta | |
| dX_t = first_ord_deriv(fX, arg=2)(*args).reshape(I, J) | |
| dY_t = first_ord_deriv(fY, arg=2)(*args).reshape(I, J) | |
| dZ_t = first_ord_deriv(fZ, arg=2)(*args).reshape(I, J) | |
| #------------------- | |
| # 2st order | |
| # phi | |
| dX_pp = second_ord_deriv(fX, args=(1, 1))(*args).reshape(I, J) | |
| dY_pp = second_ord_deriv(fY, args=(1, 1))(*args).reshape(I, J) | |
| dZ_pp = second_ord_deriv(fZ, args=(1, 1))(*args).reshape(I, J) | |
| # theta | |
| dX_tt = second_ord_deriv(fX, args=(2, 2))(*args).reshape(I, J) | |
| dY_tt = second_ord_deriv(fY, args=(2, 2))(*args).reshape(I, J) | |
| dZ_tt = second_ord_deriv(fZ, args=(2, 2))(*args).reshape(I, J) | |
| # theta, phi | |
| dX_tp = second_ord_deriv(fX, args=(2, 1))(*args).reshape(I, J) | |
| dY_tp = second_ord_deriv(fY, args=(2, 1))(*args).reshape(I, J) | |
| dZ_tp = second_ord_deriv(fZ, args=(2, 1))(*args).reshape(I, J) | |
| K = compute_curvature( | |
| R_GRID, | |
| dX_t, dY_t, dZ_t, | |
| dX_p, dY_p, dZ_p, | |
| dX_tt, dY_tt, dZ_tt, | |
| dX_tp, dY_tp, dZ_tp, | |
| dX_pp, dY_pp, dZ_pp | |
| ) | |
| if smooth_curvature is not None: | |
| K = jnp.array(gaussian_filter(K, sigma=smooth_curvature)) | |
| # Clip out extreme outliers so color scale is more informative. | |
| if clip_curvature: | |
| lower_clip = jnp.percentile(K, 5) | |
| upper_clip = jnp.percentile(K, 95) | |
| K = jnp.clip(K, lower_clip, upper_clip) | |
| # 5) Prepare figure | |
| hover_text = np.empty_like(THETA, dtype=object) | |
| for i in range(I): | |
| for j in range(J): | |
| idx = (i, j) | |
| t_idx = find_nearest_token(THETA[idx], PHI[idx], data_theta, data_phi) | |
| token_str = toks[t_idx] | |
| token_prob = probs[t_idx] | |
| token_ent = ents[t_idx] | |
| curvature_val = K[idx] | |
| hover_text[j, i] = ( | |
| f"Token: {token_str}<br>" | |
| f"Entropy: {token_ent:.3f}<br>" | |
| f"Prob: {token_prob:.3g}<br>" | |
| f"GaussCurv: {curvature_val:.3f}" | |
| ) | |
| fig = go.Figure(data=[go.Surface( | |
| x=X, y=Y, z=Z, | |
| surfacecolor=K, | |
| colorscale=cmap, | |
| hoverinfo="text", | |
| text=hover_text, | |
| hovertemplate="%{text}<extra></extra>", | |
| colorbar=dict(title="Curvature (K)") | |
| )]) | |
| fig.update_layout( | |
| title=title, | |
| paper_bgcolor="#121212", # Dark background for the entire plot area | |
| plot_bgcolor="#121212", # Dark background for the plot itself | |
| font=dict(color="#f8f8f2"), # Light text for better contrast | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False), | |
| aspectmode='cube', | |
| camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)), | |
| bgcolor="#121212", # Ensure the 3D scene also has dark background | |
| ) | |
| ) | |
| fig.show() | |
| if fname is not None: | |
| fig.write_html(fname) | |
| if __name__ == "__main__": | |
| # Here just prepare some toy data | |
| k = 20 # max (as far as I know) | |
| use_logprobs = True | |
| res = Symbol('Topic: Scientific Method').query( | |
| 'Write a philosophical essay about the topic.', | |
| logprobs=use_logprobs, | |
| top_logprobs=k, | |
| raw_output=True, | |
| seed=42, | |
| temperature=1.0 | |
| ) | |
| message = res.choices[0].message.content | |
| logprobs = res.choices[0].logprobs.content | |
| # Visualize token entropy in the console | |
| viz_token_entropy(logprobs, flowing_text=True) | |
| # Geometric curvature visualization | |
| kwargs = { | |
| "cmap": "blackbody", | |
| "grid_size": 200, | |
| "smooth_radius": None, # smooth radius field | |
| "smooth_curvature": None, # smooth final curvature | |
| "clip_curvature": True, # clip curvature outliers | |
| "phi_epsilon": 0.05, # avoid exact 0 or π at the poles | |
| "mod_factor": 0.15, | |
| "interp_method": "nearest" | |
| } | |
| viz_token_projection( | |
| logprobs, | |
| title="Curvature visualization (method='nearest', temperature=1.0)", | |
| **kwargs | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment