|
#!/usr/bin/env python3 |
|
|
|
# A python port of the SPHINCS+ team's outdated parameter exploration script available here: |
|
# https://sphincs.org/data/spx_parameter_exploration.sage |
|
# |
|
# This script works out of the box on any machine with a python3 runtime, uses no dependencies, |
|
# runs faster than the default script, and provides additional CLI-driven filtering options. |
|
|
|
from decimal import Decimal, getcontext |
|
from math import log2, log10, ceil, comb |
|
from collections.abc import Hashable |
|
from argparse import ArgumentParser |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--secbits", type=int, default=128, |
|
help="target bits of FORS forgery security after --max-sigs") |
|
|
|
parser.add_argument("--max-sigs", type=str, default="2**64", |
|
help="Aim for --secbits after this many signatures") |
|
|
|
parser.add_argument("--max-sig-size", type=int, default=8000, |
|
help="Print only parameter sets with signatures up to this size") |
|
|
|
parser.add_argument("--max-kilohashes", type=int, default=2000, |
|
help="Print only parameter sets which require at most this " |
|
"many thousands of hash invocations") |
|
|
|
parser.add_argument("--max-layers", type=int, default=9999999, |
|
help="Print only parameter sets with fewer than this many hypertree layers") |
|
|
|
parser.add_argument("--no-cache", action="store_true", |
|
help="Print only parameter sets which satisfy --max-kilohashes " |
|
"without any caching") |
|
|
|
parser.add_argument("--csv", action="store_true", |
|
help="Print output parameter sets in CSV format") |
|
|
|
args = parser.parse_args() |
|
|
|
# ensure we don't accidentally evaluate evil code. |
|
for c in args.max_sigs: |
|
assert c in '0123456789*+ ', "invalid --max-sigs: '%s'" % args.max_sigs |
|
|
|
tsec = args.secbits # Pr[one attacker hash call works] <= 1/2^tsec |
|
maxsigs = eval(args.max_sigs) # at most 2^72 |
|
maxhashes = args.max_kilohashes * 1000 |
|
|
|
|
|
#### Generic caching layer to save time |
|
class memoized(object): |
|
def __init__(self,func): |
|
self.func = func |
|
self.cache = {} |
|
self.__name__ = 'memoized:' + func.__name__ |
|
def __call__(self,*args): |
|
if not isinstance(args, Hashable): |
|
return self.func(*args) |
|
if not args in self.cache: |
|
self.cache[args] = self.func(*args) |
|
return self.cache[args] |
|
|
|
|
|
#### SPHINCS+ analysis |
|
|
|
getcontext().prec = int(log10(2**(tsec+100))) + 1 |
|
|
|
sigmalimit = 1 / Decimal(2**tsec) |
|
donelimit = 1 - sigmalimit / (2**20) |
|
hashbytes = tsec // 8 # length of hashes in bytes |
|
|
|
# Pr[exactly r sigs hit the leaf targeted by this forgery attempt] |
|
@memoized |
|
def qhitprob(leaves, qs, r): |
|
p = 1 / Decimal(leaves) |
|
return comb(qs, r) * p**r * (1 - p)**(qs - r) |
|
|
|
# Pr[FORS forgery given that exactly r sigs hit the leaf] = (1-(1-1/(2^a))^r)^k |
|
@memoized |
|
def forgeryprob(a, r, k): |
|
if k == 1: return 1 - (1 - 1 / Decimal(2**a))**r |
|
return forgeryprob(a, r, 1) * forgeryprob(a, r, k-1) |
|
|
|
# Number of WOTS chains |
|
@memoized |
|
def wotschains(m, w): |
|
la = ceil(m / log2(w)) |
|
return la + int(log2(la * (w - 1)) // log2(w)) + 1 |
|
|
|
def numhashops(h, d, k, a, w, wots_len, cache=False): |
|
# - k trees: |
|
# - 2**a leaves |
|
# - 2**a internal fors nodes |
|
# - 1 hash to compress the fors tree roots |
|
fors_hashes = k * 2**(a+1) + 1 |
|
|
|
# - 2**(h/d) leaves |
|
# - wots_len chains |
|
# - w hashes per chain |
|
xmss_hashes = (2**(h // d) * (wots_len * w + 1)) |
|
|
|
# d layers of XMSS signatures |
|
return fors_hashes + xmss_hashes * (d - 1 if cache else d) |
|
|
|
def compute_sigma(h, a, k, sigs): |
|
leaves = 2**h |
|
sigma = 0 |
|
r = 1 |
|
done = qhitprob(leaves, sigs, 0) |
|
while done < donelimit: |
|
t = qhitprob(leaves, sigs, r) |
|
sigma += t * forgeryprob(a, r, k) |
|
done += t |
|
r += 1 |
|
sigma += min(0, 1 - done) |
|
return sigma |
|
|
|
if args.csv: |
|
columns = ['h', 'd', 'a', 'k', 'w', 'secbits', 'sigsize', 'hashes', 'hashes_cached', 'cache_size'] |
|
print(','.join(columns)) |
|
|
|
found = False |
|
s = int(log2(maxsigs)) |
|
for h in range(s-8, s+20): # Iterate over total tree height |
|
leaves = 2**h |
|
for a in range(3, 24): # Iterate over height of FORS trees |
|
for k in range(1, 64): # Iterate over number of FORS trees |
|
sigma = compute_sigma(h, a, k, maxsigs) |
|
if sigma > sigmalimit: continue |
|
sec = log2(1 / sigma) |
|
for d in range(3, min(h, args.max_layers+1)): # Iterate over number of hypertree layers |
|
if h % d == 0 and h <= 64+(h/d): # Only valid hypertrees |
|
for w in [16, 256]: # Try different Winternitz parameters |
|
wots_len = wotschains(8*hashbytes, w) |
|
sigsize = ((a + 1)*k + h + wots_len*d + 1) * hashbytes |
|
|
|
# Rough performance estimates based on #hashes |
|
speed = numhashops(h, d, k, a, w, wots_len) |
|
cached_speed = numhashops(h, d, k, a, w, wots_len, cache=True) # If the first XMSS tree is cached |
|
|
|
if sigsize <= args.max_sig_size and (speed if args.no_cache else cached_speed) <= maxhashes: |
|
# the size of a single precomputed XMSS tree (leaves only) in bytes. |
|
cache_size = hashbytes * (2**(h//d)) |
|
|
|
if args.csv: |
|
columns = [h, d, a, k, w, round(sec, 2), sigsize, speed, cached_speed, cache_size] |
|
print(','.join((str(c) for c in columns))) |
|
continue |
|
|
|
|
|
if found: print() |
|
print("h=%d d=%d a=%d k=%d w=%d" % (h,d,a,k,w)) # SPHINCS+ parameters |
|
print("1 in 2^%.2f forgery probability" % sec) # FORS forgery probability |
|
print("sigsize=%d" % sigsize) |
|
print("%dk hashes" % (speed // 1000)) |
|
print("%dk hashes (with %.1fmb cache)" % (cached_speed // 1000, cache_size / 1024 / 1024)) |
|
|
|
found = True |