Skip to content

Instantly share code, notes, and snippets.

@n-WN
Created March 14, 2026 08:24
Show Gist options
  • Select an option

  • Save n-WN/51471d40eca6470823ec071c2910ce0b to your computer and use it in GitHub Desktop.

Select an option

Save n-WN/51471d40eca6470823ec071c2910ce0b to your computer and use it in GitHub Desktop.
AliCTF 2026 Final - Kitten Sign exploit (RSA + ECDSA + SM2 triple signature forgery)
#!/usr/bin/env python3
"""
Kitten Sign - Full Exploit
AliyunCTF 2026 Final
Attack chain:
1. Collect encrypt samples, recover RSA modulus & SM2 public key
2. Factor RSA via genus-2 Jacobian (call test.py)
3. Precompute BKZ-25 lattice basis (call sage subprocess)
4. Invalid-curve oracle: forge decrypt commands to recover SM2 private key via CRT
5. Sign "cat /flag.txt" with all 3 schemes, get flag
"""
from __future__ import annotations
import hashlib
import json
import math
import os
import random
import socket
import subprocess
import sys
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
from gmssl import sm2, sm3
# ── Constants ──────────────────────────────────────────────────────────────
E = 0x10001
SM2_N = int("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16)
SM2_P = int("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16)
SM2_A = int("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC", 16)
SM2_B = int("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 16)
SM2_GX = int("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 16)
SM2_GY = int("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 16)
PROMPT = "📤 ".encode()
RESPONSE = "📥 ".encode()
# Invalid curves y^2 = x^3 + a*x + b' with smooth-order subgroups
# For each b', the "usable" part of the curve order is a product of small primes.
# We query one point per prime factor to recover d mod that prime.
# The full curve order = usable_part * cofactor.
# To get a point of order exactly `l` (a prime factor), we:
# 1. Pick random point P on the curve
# 2. Compute Q = [curve_order / l] * P
# 3. If Q != O and [l]*Q == O, then Q has order l.
# Invalid curve subgroups for oracle queries.
# Crash prob ≈ 1-(1-1/l)^256. For l=449: ~43%, l=503: ~40%.
# Small primes: high crash risk but needed for CRT bits.
# Large primes: negligible crash but slow table build (must be precomputed).
# All primes LCM = 259 bits (including 53, 71).
# Without 53, 71: 247 bits → need ~512 brute-force candidates.
INVALID_CURVES = {
38: {
"primes": [449, 755071, 1799533, 8713961, 14345957],
},
15: {
"primes": [503, 3530591, 27352219],
},
19: {
"primes": [3709033, 5629919],
},
14: {
"primes": [3691, 421349, 1307507],
},
}
# ── SM2 point arithmetic (via gmssl internals) ────────────────────────────
def sm2_engine(public_key: str = "", private_key: str = "") -> sm2.CryptSM2:
return sm2.CryptSM2(public_key=public_key, private_key=private_key)
def point_to_hex(point: tuple[int, int]) -> str:
return f"{point[0]:064x}{point[1]:064x}"
def hex_to_point(blob: str) -> tuple[int, int]:
return int(blob[:64], 16), int(blob[64:128], 16)
def point_add(p1, p2):
if p1 is None: return p2
if p2 is None: return p1
if p1 == p2:
# Doubling: _add_point doesn't handle P+P, use _kg(2, P) instead
return point_mul(2, p1)
if p1[0] == p2[0]:
# Same x, different y => P + (-P) = O
return None
inst = sm2_engine()
jac = inst._add_point(point_to_hex(p1) + "1", point_to_hex(p2))
if jac is None: return None
nor = inst._convert_jacb_to_nor(jac)
if nor is None: return None
return hex_to_point(nor)
def point_neg(point):
return point[0], (-point[1]) % SM2_P
def point_mul(k, point):
inst = sm2_engine()
out = inst._kg(k, point_to_hex(point))
if out is None: return None
return hex_to_point(out)
def point_mul_custom_curve(k, point_hex, a_hex=None, b_hex=None):
"""Multiply on a potentially different curve (same p, a but different b).
gmssl._kg only uses a and p, not b, so this works directly."""
inst = sm2_engine()
out = inst._kg(k, point_hex)
if out is None: return None
return hex_to_point(out)
def tonelli_shanks(n, p):
n %= p
if n == 0: return 0
if pow(n, (p - 1) // 2, p) != 1: return None
if p % 4 == 3: return pow(n, (p + 1) // 4, p)
q = p - 1
s = 0
while q % 2 == 0: s += 1; q //= 2
z = 2
while pow(z, (p - 1) // 2, p) != p - 1: z += 1
m, c, t, r = s, pow(z, q, p), pow(n, q, p), pow(n, (q + 1) // 2, p)
while t != 1:
i, t2 = 1, pow(t, 2, p)
while t2 != 1: t2 = pow(t2, 2, p); i += 1
b = pow(c, 1 << (m - i - 1), p)
m, c, t, r = i, pow(b, 2, p), (t * pow(b, 2, p)) % p, (r * b) % p
return r
def lift_x(x, b=SM2_B):
"""Lift x to a point on y^2 = x^3 + a*x + b (mod p)."""
rhs = (pow(x, 3, SM2_P) + SM2_A * x + b) % SM2_P
y = tonelli_shanks(rhs, SM2_P)
if y is None: return []
if y == 0: return [(x, 0)]
return [(x, y), (x, (-y) % SM2_P)]
# ── Network ────────────────────────────────────────────────────────────────
@dataclass
class EncryptReply:
message_hex: str
rsa_sig: int
ecdsa_sig: list[int]
sm2_sig: str
@property
def message(self) -> bytes:
return bytes.fromhex(self.message_hex)
@property
def sm2_rs(self) -> tuple[int, int]:
return int(self.sm2_sig[:64], 16), int(self.sm2_sig[64:], 16)
class KittenClient:
def __init__(self, host, port, timeout=10.0):
self.host = host
self.port = port
self.timeout = timeout
self.sock = None
self.buf = bytearray()
def __enter__(self):
self.buf.clear()
self.sock = socket.create_connection((self.host, self.port), timeout=self.timeout)
self._recv_until(PROMPT)
return self
def __exit__(self, *a):
if self.sock:
self.sock.close()
self.sock = None
def _recv_until(self, marker):
while marker not in self.buf:
chunk = self.sock.recv(4096)
if not chunk: raise ConnectionError("closed")
self.buf.extend(chunk)
idx = self.buf.index(marker) + len(marker)
out = bytes(self.buf[:idx])
del self.buf[:idx]
return out
def _request(self, payload):
self.sock.sendall(json.dumps(payload, separators=(",", ":")).encode() + b"\n")
blob = self._recv_until(PROMPT)
for line in blob.splitlines():
if line.startswith(RESPONSE):
return json.loads(line[len(RESPONSE):].decode())
raise ValueError(f"no response in {blob!r}")
def encrypt(self, plaintext_hex="00"):
r = self._request({"option": "encrypt", "plaintext": plaintext_hex})
return EncryptReply(r["message"], int(r["rsa_sig"]), list(r["ecdsa_sig"]), str(r["sm2_sig"]))
def decrypt(self, message_hex, rsa_sig, sm2_sig, ecdsa_sig=None):
if ecdsa_sig is None:
ecdsa_sig = [1.5, 1.5]
return self._request({
"option": "decrypt",
"message": message_hex,
"rsa_sig": rsa_sig,
"ecdsa_sig": ecdsa_sig,
"sm2_sig": sm2_sig,
})
# ── RSA ────────────────────────────────────────────────────────────────────
def rsa_message_int(message: bytes) -> int:
return int.from_bytes(message, "big")
def recover_rsa_modulus(samples: list[EncryptReply]) -> int:
"""Recover RSA modulus from encrypt samples.
Since sig^e ≡ m (mod n), we have n | (sig^e - m).
Must compute full sig^e (~334M bits) using gmpy2 then GCD.
Modular reduction approaches don't work (destroy divisibility).
Takes ~4 minutes due to GCD of 334M-bit numbers.
"""
import gmpy2
print("[*] Computing full sig^e for RSA modulus recovery (~4 min)...")
# Compute sig^e - m for first two samples (need only 2 for GCD)
m0 = rsa_message_int(samples[0].message)
t0 = time.time()
a = gmpy2.mpz(samples[0].rsa_sig) ** E - m0
print(f" sample 1: {time.time()-t0:.1f}s ({a.bit_length()} bits)")
m1 = rsa_message_int(samples[1].message)
t1 = time.time()
b = gmpy2.mpz(samples[1].rsa_sig) ** E - m1
print(f" sample 2: {time.time()-t1:.1f}s ({b.bit_length()} bits)")
t2 = time.time()
g = int(gmpy2.gcd(abs(a), abs(b)))
print(f" gcd: {time.time()-t2:.1f}s")
# Remove small spurious factors
for small_prime in [2, 3, 5, 7, 11, 13]:
while g % small_prime == 0:
g //= small_prime
# Verify with remaining samples
for i, s in enumerate(samples[2:], 3):
m = rsa_message_int(s.message)
if pow(s.rsa_sig, E, g) != m % g:
print(f" [!] Sample {i} failed verification, retrying with full GCD")
c = gmpy2.mpz(s.rsa_sig) ** E - m
g = int(gmpy2.gcd(g, abs(c)))
for sp in [2, 3, 5, 7, 11, 13]:
while g % sp == 0:
g //= sp
print(f"[+] RSA n: {g.bit_length()} bits")
return g
def factor_rsa(n: int) -> tuple[int, int, int]:
"""Factor n = p*q*r using genus-2 Jacobian attack. Calls test.py."""
test_py = Path(__file__).resolve().parent.parent.parent.parent / "test.py"
if not test_py.exists():
test_py = Path("/Users/lov3/Downloads/test.py")
print(f"[*] Factoring RSA modulus ({n.bit_length()} bits) via {test_py}")
result = subprocess.run(
[sys.executable, str(test_py), "--n", hex(n), "--progress"],
capture_output=True, text=True, timeout=300,
)
if result.returncode != 0:
raise RuntimeError(f"RSA factoring failed: {result.stderr}")
vals = {}
for line in result.stdout.splitlines():
for key in ("p", "q", "r"):
if line.startswith(f"{key} = "):
vals[key] = int(line.split("=")[1].strip())
return vals["p"], vals["q"], vals["r"]
def rsa_sign(message: bytes, d: int, n: int) -> int:
return pow(rsa_message_int(message), d, n)
# ── SM2 public key recovery ───────────────────────────────────────────────
def recover_sm2_pubkey(samples: list[EncryptReply]) -> tuple[int, int]:
"""Recover SM2 public key from encrypt samples."""
def candidates_from(sample):
e = int(sample.message.hex(), 16) % SM2_N
r, s = sample.sm2_rs
t = (r + s) % SM2_N
if t == 0: return []
sg = point_mul(s, (SM2_GX, SM2_GY))
if sg is None: return []
out = []
for x in {(r - e) % SM2_N, (r - e + SM2_N) % SM2_P}:
if x >= SM2_P: continue
for R in lift_x(x):
rhs = point_add(R, point_neg(sg))
if rhs is None: continue
q = point_mul(pow(t, -1, SM2_N), rhs)
if q is not None:
out.append(q)
return out
cands = candidates_from(samples[0])
verifier = sm2_engine()
for sample in samples[1:]:
keep = []
for q in cands:
verifier.public_key = point_to_hex(q)
if verifier.verify(sample.sm2_sig, sample.message):
keep.append(q)
cands = keep
if len(cands) != 1:
raise RuntimeError(f"Expected 1 SM2 pubkey candidate, got {len(cands)}")
return cands[0]
# ── SM2 universal forgery ─────────────────────────────────────────────────
def forge_sm2(pubkey, s=None, t=None):
"""SM2 universal forgery: returns (signature_hex, required_message_residue)."""
rng = random.Random()
if s is None: s = rng.randrange(1, SM2_N)
if t is None: t = rng.randrange(1, SM2_N)
sg = point_mul(s, (SM2_GX, SM2_GY))
tq = point_mul(t, pubkey)
R = point_add(sg, tq)
if R is None: raise ValueError("R is infinity")
x = R[0] % SM2_N
r = (t - s) % SM2_N
if r == 0: raise ValueError("r=0")
e = (r - x) % SM2_N
sig = f"{r:064x}{s:064x}"
return sig, e
# ── BKZ lattice for residue construction ───────────────────────────────────
def build_bkz_basis_sage():
"""Run Sage to build and BKZ-reduce the 130-dim kernel lattice.
Returns the reduced basis as a list of lists."""
sage_script = r'''
import json, sys, time
n_sm2 = 0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123
L = 130
a = [int(pow(256, L - 1 - i, n_sm2)) for i in range(L)]
inv_a0 = int(pow(a[0], -1, n_sm2))
B = matrix(ZZ, L, L)
B[0, 0] = ZZ(n_sm2)
for i in range(1, L):
val = ZZ((-a[i] * inv_a0) % n_sm2)
if val > n_sm2 // 2:
val -= n_sm2
B[i, 0] = val
B[i, i] = 1
t0 = time.time()
B = B.LLL()
for bs in [10, 15, 20, 25]:
B = B.BKZ(block_size=bs)
print(json.dumps({
"basis": [[int(B[i][j]) for j in range(L)] for i in range(L)],
"time": time.time() - t0
}))
'''
print("[*] Building BKZ-25 reduced basis via Sage...")
t0 = time.time()
result = subprocess.run(
["sage", "-python", "-c", sage_script],
capture_output=True, text=True, timeout=120,
)
if result.returncode != 0:
raise RuntimeError(f"Sage BKZ failed: {result.stderr[:500]}")
data = json.loads(result.stdout.strip())
print(f"[+] BKZ done in {time.time() - t0:.1f}s (sage reported {data['time']:.1f}s)")
return data["basis"]
class ResidueConstructor:
"""Given a BKZ-reduced basis, find 130 decimal-digit bytes with a target residue."""
def __init__(self, basis):
import mpmath
self.L = 130
self.n = SM2_N
self.CENTER = 52 # byte '4'
self.BOUND = 4
self.a = [pow(256, self.L - 1 - i, self.n) for i in range(self.L)]
self.inv_a0 = pow(self.a[0], -1, self.n)
# Store basis rows as plain Python lists of int
self.int_rows = [list(row) for row in basis]
# Compute Gram-Schmidt with high precision (mpmath)
self._compute_gs_mpmath()
def _compute_gs_mpmath(self):
"""Compute GS orthogonalization with exact rational arithmetic (Fraction),
yielding a fast integer-only Babai."""
from fractions import Fraction
L = self.L
print("[*] Computing exact rational GS decomposition...")
# Exact rational GS
gs = [[Fraction(x) for x in row] for row in self.int_rows]
gs_nsq = [Fraction(0)] * L
for i in range(L):
for j in range(i):
if gs_nsq[j] == 0: continue
dot_ij = sum(gs[i][k] * gs[j][k] for k in range(L))
mu = dot_ij / gs_nsq[j]
for k in range(L):
gs[i][k] -= mu * gs[j][k]
gs_nsq[i] = sum(gs[i][k] * gs[i][k] for k in range(L))
# Store as (numerator_row, denominator) pairs for the GS vectors
# Each gs[i] = gs_num[i] / gs_den[i] where gs_den[i] = lcm of denominators
self._gs_num = []
self._gs_den = []
self._gs_nsq_num = []
self._gs_nsq_den = []
for i in range(L):
# Find common denominator for gs[i]
den = 1
for k in range(L):
den = den * gs[i][k].denominator // math.gcd(den, gs[i][k].denominator)
num_row = [int(gs[i][k] * den) for k in range(L)]
self._gs_num.append(num_row)
self._gs_den.append(den)
# gs_nsq[i] = sum(gs[i][k]^2) = sum(num[k]^2) / den^2
nsq_num = gs_nsq[i].numerator
nsq_den = gs_nsq[i].denominator
self._gs_nsq_num.append(nsq_num)
self._gs_nsq_den.append(nsq_den)
print("[+] GS decomposition complete")
def _babai_cvp(self, T_val):
"""Exact integer Babai nearest plane."""
L = self.L
c0 = (T_val * self.inv_a0) % self.n
if c0 > self.n // 2:
c0 -= self.n
# rem = target vector, scaled by gs_den[i] for each level
# We keep rem as exact integers in the original lattice coordinate space
# Actually, we use the standard Babai with exact dot products.
#
# rem starts as (c0, 0, 0, ..., 0)
# For each i from L-1 to 0:
# ci = round(dot(rem, gs[i]) / dot(gs[i], gs[i]))
# rem -= ci * int_rows[i]
#
# dot(rem, gs[i]) where rem is integer and gs[i] = gs_num[i] / gs_den[i]
# = sum(rem[k] * gs_num[i][k]) / gs_den[i]
# ci = round(sum(rem[k] * gs_num[i][k]) / gs_den[i] / (gs_nsq_num[i] / gs_nsq_den[i]))
# = round(sum(rem[k] * gs_num[i][k]) * gs_nsq_den[i] / (gs_den[i] * gs_nsq_num[i]))
rem = [0] * L
rem[0] = c0
coeffs = [0] * L
for i in range(L - 1, -1, -1):
nsq_num = self._gs_nsq_num[i]
if nsq_num == 0: continue
nsq_den = self._gs_nsq_den[i]
gs_num = self._gs_num[i]
gs_den = self._gs_den[i]
# numerator of dot(rem, gs[i]) / gs_nsq[i]
# = (sum(rem[k]*gs_num[k]) / gs_den) / (nsq_num / nsq_den)
# = sum(rem[k]*gs_num[k]) * nsq_den / (gs_den * nsq_num)
dot_num = sum(rem[k] * gs_num[k] for k in range(L))
# ci = round(dot_num * nsq_den / (gs_den * nsq_num))
numer = dot_num * nsq_den
denom = gs_den * nsq_num
# round = (2*numer + denom) // (2*denom) for positive, etc.
ci = _round_div(numer, denom)
coeffs[i] = ci
for k in range(L):
rem[k] -= ci * self.int_rows[i][k]
return list(rem)
def _round_div(a, b):
"""Compute round(a/b) for integers a, b with b > 0."""
if b < 0:
a, b = -a, -b
q, r = divmod(a, b)
if 2 * r >= b:
q += 1
return q
def find_digits(self, prefix: bytes, c3_hex: str, max_trials=50000):
"""Find 130 decimal-digit characters for C2 such that the full message
b"decrypt " + C1_hex + digit_C2 + C3_hex has SM2 residue matching some
freely-chosen forgery target.
Returns (digit_string, target_residue_T) or raises RuntimeError.
"""
# We need: int((prefix + digits + c3_suffix).hex(), 16) ≡ T (mod n)
# prefix contribution: int(prefix.hex(), 16) * 256^(L + len(c3_bytes))
# c3 contribution: int(c3_hex_bytes.hex(), 16) (the last bytes)
# digit contribution: sum(digit_byte[i] * 256^(L + len(c3_bytes) - 1 - i - len(prefix)))
# ... Actually it's simpler: the full message bytes are
# msg_bytes = prefix_bytes + digit_bytes + c3_suffix_bytes
# and e = int(msg_bytes.hex(), 16) % n
# We want to find digit_bytes such that e ≡ T for some T we get to choose.
# The digit bytes contribute:
# sum(digit_byte[i] * 256^(total_len - prefix_len - i - 1)) for i in 0..L-1
# where total_len = len(prefix) + L + len(c3_suffix)
c3_suffix = c3_hex.encode() # ASCII bytes of the hex C3
total_len = len(prefix) + self.L + len(c3_suffix)
# prefix_val = int(prefix.hex(), 16) * 256^(L + len(c3_suffix))
prefix_val = int(prefix.hex(), 16) * pow(256, self.L + len(c3_suffix), self.n) % self.n
# c3_val = int(c3_suffix.hex(), 16)
c3_val = int(c3_suffix.hex(), 16) % self.n
# For digit position i (0-indexed within the digit block):
# coefficient = 256^(len(c3_suffix) + L - 1 - i)
c3_len = len(c3_suffix)
digit_coeffs = [pow(256, c3_len + self.L - 1 - i, self.n) for i in range(self.L)]
# We need: prefix_val + sum(digit_byte[i] * digit_coeffs[i]) + c3_val ≡ T (mod n)
# Let target_digit_sum = T - prefix_val - c3_val (mod n)
# We want: sum(digit_byte[i] * digit_coeffs[i]) ≡ target_digit_sum (mod n)
# With digit_byte[i] = CENTER + v[i], v[i] in [-BOUND, BOUND]:
# sum((CENTER + v[i]) * digit_coeffs[i]) ≡ target_digit_sum (mod n)
# sum(v[i] * digit_coeffs[i]) ≡ target_digit_sum - CENTER * sum(digit_coeffs[i]) (mod n)
center_sum = (self.CENTER * sum(digit_coeffs)) % self.n
# The lattice basis we built uses coefficients a[i] = 256^(L-1-i) mod n
# But our actual coefficients are digit_coeffs[i] = 256^(c3_len + L - 1 - i) mod n
# = a[i] * 256^c3_len mod n
# So the lattice solution v satisfies sum(v[i] * a[i]) ≡ T_lattice (mod n)
# And we need sum(v[i] * digit_coeffs[i]) = sum(v[i] * a[i] * 256^c3_len) ≡ target
# So T_lattice = target * 256^(-c3_len) mod n
inv_256_c3 = pow(pow(256, c3_len, self.n), -1, self.n)
best_max = 999
for trial in range(max_trials):
T_rand = random.randrange(0, self.n)
v = self._babai_cvp(T_rand)
mx = max(abs(x) for x in v)
if mx < best_max:
best_max = mx
if mx <= self.BOUND:
# Verify lattice residue
res = sum(v[i] * self.a[i] for i in range(self.L)) % self.n
if res != T_rand:
continue # float precision issue
# Convert T_rand back to the actual message residue T
# T_lattice = (T - prefix_val - c3_val - center_sum) * inv_256_c3 mod n
# So T = T_lattice * 256^c3_len + prefix_val + c3_val + center_sum mod n
actual_T = (T_rand * pow(256, c3_len, self.n) + prefix_val + c3_val + center_sum) % self.n
digit_bytes = bytes([self.CENTER + v[i] for i in range(self.L)])
digit_str = digit_bytes.decode('ascii')
# Verify: all valid digits
if not all(48 <= b <= 57 for b in digit_bytes):
continue
# Full verification
full_msg = prefix + digit_bytes + c3_suffix
actual_residue = int(full_msg.hex(), 16) % self.n
if actual_residue != actual_T:
# Precision issue, skip
continue
print(f"[+] Residue found at trial {trial}, max|v|={mx}")
return digit_str, actual_T
raise RuntimeError(f"No residue solution found in {max_trials} trials (best max|v|={best_max})")
# ── SM3 KDF ────────────────────────────────────────────────────────────────
def sm3_kdf(x, y, out_len):
"""SM3-based KDF as used by gmssl: input is ASCII hex of coordinates."""
xy = f"{x:064x}{y:064x}".encode("utf8")
return bytes.fromhex(sm3.sm3_kdf(xy, out_len))
def sm3_kdf_fast(x, y, out_len=8):
"""Fast SM3 KDF using hashlib. Only works for out_len <= 32."""
coord_bytes = bytes.fromhex(f"{x:064x}{y:064x}")
return hashlib.new('sm3', coord_bytes + b'\x00\x00\x00\x01').digest()[:out_len]
def ec_add_affine(P, Q):
"""Direct affine EC point addition on curve with a=SM2_A, any b."""
if P is None: return Q
if Q is None: return P
x1, y1 = P
x2, y2 = Q
p = SM2_P
if x1 == x2:
if y1 == y2:
if y1 == 0: return None
lam = (3 * x1 * x1 + SM2_A) * pow(2 * y1, -1, p) % p
else:
return None
else:
lam = (y2 - y1) * pow(x2 - x1, -1, p) % p
x3 = (lam * lam - x1 - x2) % p
y3 = (lam * (x1 - x3) - y1) % p
return (x3, y3)
# ── Invalid curve oracle ──────────────────────────────────────────────────
def get_curve_order_sage(b_prime):
"""Use Sage to compute the order of y^2 = x^3 + a*x + b' over F_p."""
cache_file = Path(f"/tmp/kitten_curve_order_{b_prime}.txt")
if cache_file.exists():
return int(cache_file.read_text().strip())
script = f"""
from sage.all import GF, EllipticCurve
p = {SM2_P}
a = {SM2_A}
b = {b_prime}
F = GF(p)
E = EllipticCurve(F, [a, b])
print(E.order())
"""
result = subprocess.run(["sage", "-python", "-c", script],
capture_output=True, text=True, timeout=60)
if result.returncode != 0:
raise RuntimeError(f"Sage curve order failed: {result.stderr[:300]}")
order = int(result.stdout.strip())
cache_file.write_text(str(order))
return order
def find_point_of_order(b_prime, target_prime, curve_order=None):
"""Find a point of exact order target_prime on y^2 = x^3 + a*x + b'."""
if curve_order is None:
curve_order = get_curve_order_sage(b_prime)
if curve_order % target_prime != 0:
raise ValueError(f"target_prime {target_prime} does not divide curve order")
cofactor = curve_order // target_prime
for attempt in range(200):
x = random.randrange(1, SM2_P)
pts = lift_x(x, b=b_prime)
if not pts: continue
P = pts[0]
# Q = [cofactor] * P should have order dividing target_prime
Q = point_mul(cofactor, P)
if Q is None: continue # Q = O, try another point
# Verify Q has exact order target_prime (not a proper divisor)
check = point_mul(target_prime, Q)
if check is None: # [target_prime]*Q = O => Q has order target_prime
return Q
return None
def _prime_factors(n):
"""Return set of prime factors of n."""
factors = set()
d = 2
while d * d <= n:
while n % d == 0:
factors.add(d)
n //= d
d += 1
if n > 1:
factors.add(n)
return factors
def build_oracle_point_table(point, order, b_prime):
"""Build table: residue -> (x, y) for [r]*point, r=1..order-1."""
table = {}
Q = point
for r in range(1, order):
if Q is None:
break
kdf_out = sm3_kdf(Q[0], Q[1], 8) # just 8 bytes to identify
table[kdf_out] = r
Q = point_add(Q, point)
return table
def oracle_query(client, c1_hex, c2_hex, c3_hex, rsa_d, rsa_n, sm2_pubkey, residue_constructor):
"""Send a forged decrypt command with chosen C1 and get back KDF([d]C1).
Returns the plaintext hex from the server, or None if server crashed (d*C1 = O).
"""
# Build the message: b"decrypt " + c1_hex + c2_digits + c3_hex
prefix = b"decrypt " + c1_hex.encode()
# Get SM2 forgery + matching digit payload
digit_str, target_residue = residue_constructor.find_digits(prefix, c3_hex)
# Now forge SM2 signature for this exact residue
# We need sig such that verify passes for message with residue = target_residue
# Use the universal forgery: try (s, t) pairs until e matches target_residue
sm2_sig = forge_sm2_for_residue(sm2_pubkey, target_residue)
# Build full message
full_message = prefix + digit_str.encode() + c3_hex.encode()
message_hex = full_message.hex()
# RSA sign
rsa_sig = rsa_sign(full_message, rsa_d, rsa_n)
# ECDSA bypass
ecdsa_sig = [1.5, 1.5]
# Send
try:
reply = client.decrypt(message_hex, rsa_sig, sm2_sig, ecdsa_sig)
except ConnectionError:
return None # Server crashed = d*C1 = O
if "error" in reply:
print(f"[!] Oracle error: {reply['error']}")
return None
return reply.get("plaintext", "")
def forge_sm2_for_residue(pubkey, target_e):
"""Find SM2 forgery (r, s) such that the required message residue equals target_e."""
# Universal forgery: choose s, t arbitrarily, get:
# R = sG + tQ
# r = (t - s) mod n
# e = (r - x(R)) mod n
# We need e = target_e, so:
# target_e = (t - s - x(sG + tQ)) mod n
# This is nonlinear in s, t. Just try random pairs until we hit target_e.
# Actually, we can fix s and solve for t:
# target_e = t - s - x(sG + tQ) mod n
# Still nonlinear. Just try random (s, t) and check.
# With ~2^256 possible pairs and a 256-bit target, each trial has ~1/n chance.
# That's infeasible.
#
# Better approach: we have freedom in the residue constructor!
# The residue constructor returns (digit_str, actual_T) where actual_T is
# whatever residue the lattice found. Then we choose (s, t) to make
# e = actual_T. But that's circular...
#
# The RIGHT approach: first pick (s, t), compute e from forgery, then
# find digits matching that e. Let me restructure.
raise NotImplementedError("Use forge_and_construct instead")
def forge_and_construct(sm2_pubkey, prefix, c3_hex, residue_constructor):
"""Combined: pick random (s,t), get forgery residue e, then find digits matching e.
Returns (sm2_sig_hex, digit_string, full_message_bytes).
"""
c3_suffix = c3_hex.encode()
c3_len = len(c3_suffix)
L = residue_constructor.L
n = SM2_N
# Precompute constants
prefix_val = int(prefix.hex(), 16) * pow(256, L + c3_len, n) % n
c3_val = int(c3_suffix.hex(), 16) % n
CENTER = residue_constructor.CENTER
a = residue_constructor.a
digit_coeffs_scale = pow(256, c3_len, n)
center_sum = 0
for i in range(L):
center_sum = (center_sum + CENTER * ((a[i] * digit_coeffs_scale) % n)) % n
inv_scale = pow(int(digit_coeffs_scale), -1, n)
# Phase 1: Find a valid lattice solution (Babai is fast, ~0.04s per trial)
# About 10% of random targets give max|v| <= 4.
valid_T = None
valid_v = None
for trial in range(50000):
T_lattice = random.randrange(0, n)
v = residue_constructor._babai_cvp(T_lattice)
mx = max(abs(x) for x in v)
if mx > residue_constructor.BOUND:
continue
# Verify
res = sum(v[i] * a[i] for i in range(L)) % n
if res != T_lattice:
continue
digit_bytes = bytes([CENTER + v[i] for i in range(L)])
if not all(48 <= b <= 57 for b in digit_bytes):
continue
valid_T = T_lattice
valid_v = v
break
if valid_T is None:
raise RuntimeError("Could not find valid lattice solution")
# Phase 2: Compute the required e_forgery from the lattice solution
# e_forgery = T_lattice * digit_coeffs_scale + prefix_val + c3_val + center_sum (mod n)
e_forgery = (valid_T * digit_coeffs_scale + prefix_val + c3_val + center_sum) % n
# Build the digit payload
digit_bytes = bytes([CENTER + valid_v[i] for i in range(L)])
full_msg = prefix + digit_bytes + c3_suffix
# Sanity check
actual_e = int(full_msg.hex(), 16) % n
assert actual_e == e_forgery, f"Residue mismatch: {actual_e} != {e_forgery}"
# Phase 3: Find SM2 forgery (r, s) with this exact e_forgery
# Universal forgery: for any (s, t), e = (t - s - x(sG + tQ)) mod n
# We need this to equal e_forgery. Since x(R) is pseudorandom,
# each (s, t) gives a uniformly random e, so we need ~n trials = infeasible.
#
# BUT: we have freedom to adjust the digit payload!
# The key insight: we can shift e_forgery by a multiple of any digit coefficient.
# Specifically, if we change digit[j] by delta_j (keeping it in 0-9),
# e_forgery changes by delta_j * a[j] * digit_coeffs_scale.
#
# So: pick random (s, t), get e_actual, compute e_diff = e_actual - e_forgery,
# check if e_diff is achievable by small digit adjustments.
#
# With 130 digits each having headroom of ±4 from center, we have
# 130 × 9 = 1170 possible single-digit adjustments, giving 1170 achievable deltas.
# Probability per trial ≈ 1170/n ≈ 0 — still too sparse for single changes.
#
# BETTER: just redo Phase 1 with the constraint that e must match a specific forgery.
# Or: try many (s, t), do Babai for each. The point_mul is the bottleneck.
#
# FASTEST: precompute sG for many s values using batch doubling, then combine.
# Actually, let's just accept the ~3s per oracle query cost.
# Precompute batches of sG and tQ, then try all pairwise combinations.
# Each point_mul costs ~0.03s. With batch_size=30, precomputation = 60 × 0.03s = 1.8s.
# Then 30×30=900 combinations to try, each needing only point_add + Babai (~0.04s).
# Total ≈ 1.8s + 900×0.04s ≈ 38s worst case, but ~10% hit rate means ~3.8s expected.
batch_size = 40
for batch_round in range(20):
s_vals = [random.randrange(1, n) for _ in range(batch_size)]
t_vals = [random.randrange(1, n) for _ in range(batch_size)]
sg_points = [point_mul(s, (SM2_GX, SM2_GY)) for s in s_vals]
tq_points = [point_mul(t, sm2_pubkey) for t in t_vals]
for i, (s_val, sg) in enumerate(zip(s_vals, sg_points)):
if sg is None: continue
for j, (t_val, tq) in enumerate(zip(t_vals, tq_points)):
if tq is None: continue
R = point_add(sg, tq)
if R is None: continue
x_R = R[0] % n
r = (t_val - s_val) % n
if r == 0: continue
e_actual = (r - x_R) % n
target_v = (e_actual - prefix_val - c3_val - center_sum) % n
T_lat = (target_v * inv_scale) % n
v = residue_constructor._babai_cvp(T_lat)
mx = max(abs(x) for x in v)
if mx > residue_constructor.BOUND:
continue
res = sum(v[i2] * a[i2] for i2 in range(L)) % n
if res != T_lat:
continue
digit_bytes = bytes([CENTER + v[k] for k in range(L)])
if not all(48 <= b <= 57 for b in digit_bytes):
continue
full_msg = prefix + digit_bytes + c3_suffix
actual_e = int(full_msg.hex(), 16) % n
if actual_e != e_actual:
continue
sig = f"{r:064x}{s_val:064x}"
verifier = sm2_engine(public_key=point_to_hex(sm2_pubkey))
if not verifier.verify(sig, full_msg):
continue
attempt = batch_round * batch_size * batch_size + i * batch_size + j
print(f"[+] Forgery+residue found (batch {batch_round}, pair {i},{j}), max|v|={mx}")
return sig, digit_bytes.decode('ascii'), full_msg
raise RuntimeError("Failed to find forgery+residue")
def build_kdf_table(pt, order, cache_dir=Path("/tmp/kitten_kdf_tables")):
"""Build KDF lookup table for [r]*pt, r=1..order-1.
Caches to disk for reuse. Uses fast ec_add_affine + hashlib SM3."""
cache_dir.mkdir(exist_ok=True)
# Cache key: point coordinates + order
cache_key = f"{pt[0]:064x}_{pt[1]:064x}_{order}"
cache_file = cache_dir / f"kdf_{order}_{hash(cache_key) & 0xFFFFFFFF:08x}.json"
if cache_file.exists():
print(f"[*] Loading cached KDF table for order {order}...")
data = json.loads(cache_file.read_text())
# Convert hex keys back to bytes
return {bytes.fromhex(k): v for k, v in data.items()}
print(f"[*] Building KDF table for order {order}...")
t0 = time.time()
table = {}
Q = pt
for r_val in range(1, order):
kdf_r = sm3_kdf_fast(Q[0], Q[1], 8)
table[kdf_r] = r_val
Q = ec_add_affine(Q, pt)
if r_val % 1000000 == 0:
elapsed = time.time() - t0
rate = r_val / elapsed
eta = (order - r_val) / rate
print(f" {r_val}/{order} ({r_val*100//order}%), {elapsed:.0f}s, ETA {eta:.0f}s")
elapsed = time.time() - t0
print(f"[+] KDF table: {len(table)} entries in {elapsed:.1f}s")
# Cache to disk
if order <= 5000000: # Only cache small-ish tables
data = {k.hex(): v for k, v in table.items()}
cache_file.write_text(json.dumps(data))
print(f"[+] Cached to {cache_file}")
return table
def recover_d_mod_l(client, b_prime, subgroup_order, curve_order,
rsa_d, rsa_n, sm2_pubkey, residue_constructor, max_retries=3):
"""Send oracle queries to recover d mod subgroup_order."""
for retry in range(max_retries):
pt = find_point_of_order(b_prime, subgroup_order, curve_order)
if pt is None:
print(f"[!] Could not find point of order {subgroup_order}")
return None
# Build KDF table
table = build_kdf_table(pt, subgroup_order)
print(f"[+] KDF table: {len(table)} entries")
c1_hex = point_to_hex(pt)
c3_hex = "0" * 64
prefix = b"decrypt " + c1_hex.encode()
# Forge the decrypt command
sm2_sig, digit_str, full_msg = forge_and_construct(
sm2_pubkey, prefix, c3_hex, residue_constructor
)
rsa_sig_val = rsa_sign(full_msg, rsa_d, rsa_n)
try:
reply = client.decrypt(full_msg.hex(), rsa_sig_val, sm2_sig, [1.5, 1.5])
except ConnectionError:
print(f"[!] Server crashed (retry {retry+1}/{max_retries})")
continue
if "error" in reply:
print(f"[!] Decrypt error: {reply['error']}")
return None
plaintext_hex = reply.get("plaintext", "")
if not plaintext_hex:
print("[!] Empty plaintext")
return None
# Extract KDF output
ciphertext_hex = c1_hex + digit_str + c3_hex
ct_bytes = bytes.fromhex(ciphertext_hex)
c2_bytes = ct_bytes[64:-32]
pt_bytes = bytes.fromhex(plaintext_hex)
kdf_output = bytes(a ^ b for a, b in zip(pt_bytes[:len(c2_bytes)], c2_bytes))
kdf_prefix = kdf_output[:8]
if kdf_prefix in table:
d_mod_l = table[kdf_prefix]
print(f"[+] d ≡ {d_mod_l} (mod {subgroup_order})")
return d_mod_l
print(f"[!] KDF not in table (retry {retry+1}/{max_retries})")
print(f"[!] Failed to recover d mod {subgroup_order}")
return None
def baby_giant_dlog(point, base_point, order, b_prime):
"""Baby-step giant-step to find k such that point = [k]*base_point, order is smooth."""
# For small prime-power subgroups, direct enumeration is fine
if order <= 100000:
Q = base_point
for k in range(1, order):
if Q == point:
return k
Q = point_add(Q, base_point)
if point_mul(order, base_point) is None:
# base_point has the right order
return 0 if point is None else None
return None
# Baby-step giant-step
m = int(order ** 0.5) + 1
baby = {}
Q = None # identity
for j in range(m):
if Q == point:
return j
key = Q if Q is not None else "inf"
baby[key] = j
Q = point_add(Q, base_point)
# giant = -m * base_point
neg_m_base = point_mul(order - m, base_point)
gamma = point
for i in range(m):
key = gamma if gamma is not None else "inf"
if key in baby:
return (i * m + baby[key]) % order
gamma = point_add(gamma, neg_m_base)
return None
def crt(residues, moduli):
"""Chinese Remainder Theorem. Returns (x, M) where x ≡ r_i (mod m_i) and M = lcm(m_i)."""
x, M = 0, 1
for r, m in zip(residues, moduli):
g = math.gcd(M, m)
if (r - x) % g != 0:
raise ValueError(f"Incompatible: {x} mod {M} vs {r} mod {m}")
lcm = M * m // g
# Extended Euclidean
_, u, _ = _extended_gcd(M // g, m // g)
x = (x + M * ((r - x) // g) * u) % lcm
M = lcm
return x % M, M
def _extended_gcd(a, b):
if b == 0: return a, 1, 0
g, u, v = _extended_gcd(b, a % b)
return g, v, u - (a // b) * v
# ── Main exploit ──────────────────────────────────────────────────────────
def cmd_precompute():
"""Precompute BKZ basis and KDF tables offline."""
LATTICE_CACHE = Path("/tmp/kitten_bkz_basis.json")
# BKZ basis
if not LATTICE_CACHE.exists():
basis = build_bkz_basis_sage()
LATTICE_CACHE.write_text(json.dumps(basis))
print(f"[+] BKZ basis cached to {LATTICE_CACHE}")
else:
print("[+] BKZ basis already cached")
# KDF tables for each invalid curve prime
for b_prime, info in INVALID_CURVES.items():
curve_order = get_curve_order_sage(b_prime)
print(f"\n[*] Curve b'={b_prime}, order={curve_order}")
for prime in info["primes"]:
pt = find_point_of_order(b_prime, prime, curve_order)
if pt is None:
print(f" [!] Could not find point of order {prime}")
continue
t0 = time.time()
table = build_kdf_table(pt, prime)
print(f" [{prime}]: {len(table)} entries, {time.time()-t0:.1f}s")
def main():
import argparse
parser = argparse.ArgumentParser(description="Kitten Sign exploit")
parser.add_argument("--host", default="223.6.249.127")
parser.add_argument("--port", type=int, default=26467)
parser.add_argument("--timeout", type=float, default=10.0)
parser.add_argument("--skip-lattice-cache", action="store_true")
parser.add_argument("--precompute", action="store_true", help="Precompute tables offline")
args = parser.parse_args()
if args.precompute:
cmd_precompute()
return
LATTICE_CACHE = Path("/tmp/kitten_bkz_basis.json")
GS_CACHE = Path("/tmp/kitten_gs_cache.json")
# ── Step 0: Precompute BKZ basis (can be done offline) ─────────────
if LATTICE_CACHE.exists() and not args.skip_lattice_cache:
print("[*] Loading cached BKZ basis...")
basis = json.loads(LATTICE_CACHE.read_text())
else:
basis = build_bkz_basis_sage()
LATTICE_CACHE.write_text(json.dumps(basis))
print(f"[+] Cached basis to {LATTICE_CACHE}")
residue_ctor = ResidueConstructor(basis)
print("[+] Residue constructor ready")
# ── Step 1: Connect and collect samples ────────────────────────────
with KittenClient(args.host, args.port, timeout=args.timeout) as client:
t_start = time.time()
print("[*] Collecting encrypt samples...")
samples = []
for i in range(4):
samples.append(client.encrypt("00"))
print(f" sample {i+1}/4 collected")
# ── Step 2: Recover RSA modulus ────────────────────────────────
print("[*] Recovering RSA modulus...")
n = recover_rsa_modulus(samples)
print(f"[+] RSA n = {n.bit_length()} bits")
# ── Step 3: Factor RSA ─────────────────────────────────────────
p, q, r = factor_rsa(n)
d_rsa = pow(E, -1, (p - 1) * (q - 1) * (r - 1))
print(f"[+] RSA factored: p={p.bit_length()}b, q={q.bit_length()}b, r={r.bit_length()}b")
# Verify RSA
test_msg = samples[0].message
test_sig = rsa_sign(test_msg, d_rsa, n)
assert pow(test_sig, E, n) == rsa_message_int(test_msg), "RSA verify failed"
print("[+] RSA signing verified")
# ── Step 4: Recover SM2 public key ─────────────────────────────
print("[*] Recovering SM2 public key...")
sm2_pub = recover_sm2_pubkey(samples)
print(f"[+] SM2 pubkey: {point_to_hex(sm2_pub)[:32]}...")
# ── Step 5: Test forgery + residue construction ────────────────
print("[*] Testing forgery + residue construction...")
c1_test = "0" * 128
c3_test = "0" * 64
prefix_test = b"decrypt " + c1_test.encode()
sig_test, digits_test, msg_test = forge_and_construct(
sm2_pub, prefix_test, c3_test, residue_ctor
)
print(f"[+] Test forgery OK, message length={len(msg_test)}")
elapsed = time.time() - t_start
print(f"[*] Setup complete in {elapsed:.1f}s, starting oracle phase...")
# ── Step 6: Invalid-curve oracle to recover SM2 private key ────
residues = []
moduli = []
# Collect all (b_prime, prime_factor) pairs, sorted by crash risk (safest first)
oracle_targets = []
for b_prime, info in INVALID_CURVES.items():
for prime_factor in info["primes"]:
crash_prob = 1 - (1 - 1/prime_factor)**256
oracle_targets.append((crash_prob, b_prime, prime_factor))
oracle_targets.sort() # Safest first
curve_orders = {}
for crash_prob, b_prime, factor in oracle_targets:
print(f"\n[*] Curve b'={b_prime}, order={factor} (crash≈{crash_prob*100:.1f}%)")
# Compute curve order (cached)
if b_prime not in curve_orders:
print(f"[*] Computing curve order for b'={b_prime}...")
curve_orders[b_prime] = get_curve_order_sage(b_prime)
# Send oracle query
d_mod = recover_d_mod_l(
client, b_prime, factor, curve_orders[b_prime],
d_rsa, n, sm2_pub, residue_ctor
)
if d_mod is not None:
residues.append(d_mod)
moduli.append(factor)
_, total_mod = crt(residues, moduli)
bits = total_mod.bit_length()
print(f"[+] CRT: {bits} bits accumulated ({len(residues)} residues)")
if bits >= 256:
break
elapsed = time.time() - t_start
print(f"[*] Elapsed: {elapsed:.0f}s / 600s")
if elapsed > 500:
print("[!] Running low on time, stopping oracle phase")
break
# ── Step 7: Recover SM2 private key via CRT ───────────────────
d_sm2, total_mod = crt(residues, moduli)
print(f"\n[*] CRT result: d_sm2 = {d_sm2}")
print(f"[*] Total modulus bits: {total_mod.bit_length()}")
if total_mod.bit_length() < 256:
# Try a few candidates: d_sm2, d_sm2 + total_mod, etc.
print("[!] Not enough CRT bits, trying candidates...")
found = False
for k in range(max(1, SM2_N // total_mod + 1)):
cand = d_sm2 + k * total_mod
if cand >= SM2_N:
break
# Verify: [cand]*G == sm2_pub?
check = point_mul(cand, (SM2_GX, SM2_GY))
if check == sm2_pub:
d_sm2 = cand
found = True
print(f"[+] Found SM2 private key at k={k}")
break
if not found:
print("[-] Could not determine SM2 private key")
return
else:
# Verify
check = point_mul(d_sm2, (SM2_GX, SM2_GY))
if check != sm2_pub:
print("[!] CRT result doesn't match pubkey, trying d_sm2 candidates...")
found = False
for k in range(100):
for cand in [d_sm2 + k * total_mod, d_sm2 - k * total_mod]:
cand %= SM2_N
check = point_mul(cand, (SM2_GX, SM2_GY))
if check == sm2_pub:
d_sm2 = cand
found = True
break
if found: break
if not found:
print("[-] SM2 private key verification failed")
return
print(f"[+] SM2 private key recovered: {d_sm2:064x}")
# ── Step 8: Sign "cat /flag.txt" and get flag ──────────────────
flag_msg = b"cat /flag.txt"
flag_msg_hex = flag_msg.hex()
# RSA signature
flag_rsa_sig = rsa_sign(flag_msg, d_rsa, n)
# ECDSA bypass
flag_ecdsa_sig = [1.5, 1.5]
# SM2 signature (using recovered private key)
sm2_inst = sm2_engine(
public_key=point_to_hex(sm2_pub),
private_key=f"{d_sm2:064x}",
)
# Generate a random k for SM2 signing
import secrets
k_hex = secrets.token_hex(32)
flag_sm2_sig = sm2_inst.sign(flag_msg, k_hex)
# Verify locally
assert sm2_inst.verify(flag_sm2_sig, flag_msg), "SM2 local verify failed"
print("[+] All signatures ready")
# Send
reply = client.decrypt(flag_msg_hex, flag_rsa_sig, flag_sm2_sig, flag_ecdsa_sig)
print(f"\n[*] Server reply: {json.dumps(reply, indent=2)}")
if "flag" in reply:
print(f"\n{'='*60}")
print(f"FLAG: {reply['flag']}")
print(f"{'='*60}")
else:
print(f"[-] No flag in reply: {reply}")
elapsed = time.time() - t_start
print(f"\n[*] Total time: {elapsed:.1f}s")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment