Skip to content

Instantly share code, notes, and snippets.

@Rexicon226
Created January 19, 2026 00:48
Show Gist options
  • Select an option

  • Save Rexicon226/c4b82731ee96923668e0063103cb8b40 to your computer and use it in GitHub Desktop.

Select an option

Save Rexicon226/c4b82731ee96923668e0063103cb8b40 to your computer and use it in GitHub Desktop.
6 x 43-bit curve25519 implementation
const std = @import("std");
const V = @Vector(8, u64);
const S = @Vector(8, i64);
const u64x4 = @Vector(4, u64);
const avx512 = @import("src/curves/ed25519/avx512.zig");
extern fn @"llvm.x86.avx512.vpmadd52l.uq.512"(V, V, V) V;
extern fn @"llvm.x86.avx512.vpmadd52h.uq.512"(V, V, V) V;
inline fn madd52lo(x: V, y: V, z: V) V {
return @"llvm.x86.avx512.vpmadd52l.uq.512"(x, y, z);
}
inline fn madd52hi(x: V, y: V, z: V) V {
return @"llvm.x86.avx512.vpmadd52h.uq.512"(x, y, z);
}
const Fe = struct {
/// Limbs 0-4 are in [0, 2^43), limb 5 is in [0, 2^40). Can represent [0, p).
const Reduced = struct {
limbs: V,
const zero: Reduced = .{ .limbs = @splat(0) };
const one: Reduced = .{ 1, 0, 0, 0, 0, 0, 0, 0 };
fn pack(r: Reduced) @Vector(4, u64) {
const x = r.limbs;
const t0 = @shuffle(u64, x, undefined, @Vector(8, i32){ 0, 1, 2, 3, 0, 0, 0, 0 }) >> V{ 0, 21, 42, 20, 0, 0, 0, 0 };
const t1 = @shuffle(u64, x, undefined, @Vector(8, i32){ 1, 2, 3, 5, 0, 0, 0, 0 }) << V{ 43, 22, 1, 23, 0, 0, 0, 0 };
const t2 = @shuffle(u64, x, undefined, @Vector(8, i32){ 1, 2, 4, 5, 0, 0, 0, 0 }) << V{ 43, 22, 44, 23, 0, 0, 0, 0 };
return @shuffle(u64, t0 | t1 | t2, undefined, @Vector(4, i32){ 0, 1, 2, 3 });
}
fn toBytes(r: Reduced, out: *[32]u8) void {
const p: [4]u64 = r.pack();
@memcpy(out, std.mem.asBytes(&p));
}
};
/// Limbs 0-4 are in [0, 2^43), limb 5 is in [0, 2^41). Can represent [0, 2*p).
const NearlyReduced = struct {
limbs: V,
fn fromBytes(bytes: *const [32]u8) NearlyReduced {
// This produces an optimal vmovups + vandps, LLVM emits some
// nonsense codegen when doing this per-element.
const v: u64x4 = @bitCast(bytes.*);
const max = std.math.maxInt(u64);
// Truncate the representation to 254-bits, thus making it "nearly reduced".
const masked = v & u64x4{ max, max, max, 0x7fffffffffffffff };
return unpack(masked);
}
fn unpack(v: u64x4) NearlyReduced {
const z: V = @splat(0);
const zz = 63; // Index used in the perm mask to indicate a zero element.
// zig fmt: off
const perm: @Vector(64, i32) = .{
0, 1, 2, 3, 4, 5, zz, zz, // r0 = bits 0: 42 (43 bits, zero extend to 64 bits)
5, 6, 7, 8, 9, 10, zz, zz, // r1 = bits 43: 85 (43 bits, zero extend to 64 bits)
10, 11, 12, 13, 14, 15, 16, zz, // r2 = bits 86:128 (43 bits, zero extend to 64 bits)
16, 17, 18, 19, 20, 21, zz, zz, // r3 = bits 129:171 (43 bits, zero extend to 64 bits)
21, 22, 23, 24, 25, 26, zz, zz, // r4 = bits 172:214 (43 bits, zero extend to 64 bits)
26, 27, 28, 29, 30, 31, zz, zz, // r5 = bits 215:255 (41 bits, zero extend to 64 bits)
zz, zz, zz, zz, zz, zz, zz, zz, // r6 = zero
zz, zz, zz, zz, zz, zz, zz, zz, // r7 = zero
};
// zig fmt: on
// Expand our vector from u64x4, to u64x8, filling the top 4 elements with zeroes.
const extend = @shuffle(u64, v, z, @Vector(8, i32){ 0, 1, 2, 3, ~@as(i32, 0), ~@as(i32, 0), ~@as(i32, 0), ~@as(i32, 0) });
const permute: V = @bitCast(@shuffle(
u8,
@as(@Vector(64, u8), @bitCast(extend)),
@as(@Vector(64, u8), @splat(0)),
perm,
));
const rshift: V = .{ 0, 3, 6, 1, 4, 7, 0, 0 };
return .{ .limbs = (permute >> rshift) & @as(V, @splat((1 << 43) - 1)) };
}
fn reduce(nr: NearlyReduced) Reduced {
const m43 = (1 << 43) - 1;
const m40 = (1 << 40) - 1;
var y0, var y1, var y2, var y3, var y4, var y5, _, _ = nr.limbs;
var c: u64 = undefined;
// y = x + 19, q = y >> 255, y -= q << 255
// zig fmt: off
y0 +%= 19;
c = y0 >> 43; y0 &= m43; y1 +%= c;
c = y1 >> 43; y1 &= m43; y2 +%= c;
c = y2 >> 43; y2 &= m43; y3 +%= c;
c = y3 >> 43; y3 &= m43; y4 +%= c;
c = y4 >> 43; y4 &= m43; y5 +%= c;
c = y5 >> 40; y5 &= m40;
y0 -%= if (c == 0) 19 else 0;
c = y0 >> 43; y0 &= m43; y1 +%= c;
c = y1 >> 43; y1 &= m43; y2 +%= c;
c = y2 >> 43; y2 &= m43; y3 +%= c;
c = y3 >> 43; y3 &= m43; y4 +%= c;
c = y4 >> 43; y4 &= m43; y5 +%= c;
// zig fmt: on
return .{ .limbs = .{ y0, y1, y2, y3, y4, y5, 0, 0 } };
}
const add = Unreduced.add;
const sq = Unreduced.sq;
const mul = Unreduced.mul;
const toBytes = generic.toBytes;
const cast = generic.cast;
};
/// All limbs are in [0, 2^47).
const Unreduced = struct {
limbs: V,
// 0x0100000000000000000000000000000000000000000000000000000000000000
const one: Unreduced = .{ .limbs = .{ 1, 0, 0, 0, 0, 0, 0, 0 } };
// 0xa3785913ca4deb75abd841414d0a700098e879777940c78c73fe6f2bee6c0352
const d: Unreduced = .{ .limbs = .{ 0x000005ca135978a3, 0x0000003b156ebd69, 0x00000001c0293505, 0x0000003cbbbcf44c, 0x000006ffe738cc74, 0x000000a406d9dc56, 0x0, 0x0 } };
const p: Unreduced = .{ .limbs = .{ 8796093022189, 8796093022207, 8796093022207, 8796093022207, 8796093022207, 1099511627775, 0, 0 } };
// 0xb0a00e4a271beec478e42fad0618432fa7d7fb3d99004d2b0bdfc14f8024832b
const sqrtm1: Unreduced = .{ .limbs = .{ 0x000003274a0ea0b0, 0x000005fc8f189dc3, 0x000004bd0c601ab4, 0x0000004c9efdebd3, 0x0000041df0b2b4d0, 0x000000570649009f, 0x0, 0x0 } };
const d2: Unreduced = .{ .limbs = .{ 3934839304537, 507525298899, 15037786634, 521695520920, 6596238350568, 309467527341, 0, 0 } };
fn add(x: anytype, y: Elem(@TypeOf(x))) Unreduced {
return addFast(x, y).fold();
}
/// Returns `x + y`. [0, 2^47) + [0, 2^47) -> [0, 2^62).
fn addFast(x: anytype, y: Elem(@TypeOf(x))) Unsigned {
return .{ .limbs = x.limbs +% y.limbs };
}
fn dbl(x: anytype) Unreduced {
return add(x, x.*);
}
fn sub(x: anytype, y: Elem(@TypeOf(x))) Unreduced {
return subFast(x, y).fold();
}
fn subFast(x: anytype, y: Elem(@TypeOf(x))) Signed {
return .{ .limbs = x.limbs +% y.negFast().limbs };
}
fn mul(x: anytype, y: Elem(@TypeOf(x))) Unreduced {
return mulFast(x, y).fold();
}
fn mulFast(a: anytype, b: Elem(@TypeOf(a))) Unsigned {
const z: V = @splat(0);
const x = a.limbs;
const y = b.limbs;
const x0: V = @splat(x[0]);
const x1: V = @splat(x[1]);
const x2: V = @splat(x[2]);
const x3: V = @splat(x[3]);
const x4: V = @splat(x[4]);
const x5: V = @splat(x[5]);
// zig fmt: off
const t0 = madd52lo( z, x0, y);
const t1 = madd52lo(madd52hi(z, x0, y) << @splat(9), x1, y);
const t2 = madd52lo(madd52hi(z, x1, y) << @splat(9), x2, y);
const t3 = madd52lo(madd52hi(z, x2, y) << @splat(9), x3, y);
const t4 = madd52lo(madd52hi(z, x3, y) << @splat(9), x4, y);
const t5 = madd52lo(madd52hi(z, x4, y) << @splat(9), x5, y);
const t6 = madd52hi(z, x5, y) << @splat(9);
const p0j = t0;
const p1j = alignr(z, t1, 7);
const p2j = alignr(z, t2, 6);
const p3j = alignr(z, t3, 5); const q3j = alignr(t3, z, 5);
const p4j = alignr(z, t4, 4); const q4j = alignr(t4, z, 4);
const p5j = alignr(z, t5, 3); const q5j = alignr(t5, z, 3);
const p6j = alignr(z, t6, 2); const q6j = alignr(t6, z, 2);
// zig fmt: on
const zl = p0j +% p1j +% p2j +% p3j +% p4j +% p5j +% p6j;
const zh = q3j +% q4j +% q5j +% q6j;
const max = std.math.maxInt(u64);
const za = zl & V{ max, max, max, max, max, max, 0, 0 };
const zb = alignr(zl, zh, 6);
return .{ .limbs = za +% (zb << @splat(7)) +% (zb << @splat(4)) +% (zb << @splat(3)) };
}
fn neg(x: anytype) Unreduced {
return negFast(x).fold();
}
fn negFast(x: anytype) Signed {
return .{ .limbs = p.limbs -% x.limbs };
}
fn sq(x: anytype) Unreduced {
return sqFast(x).fold();
}
fn sqFast(q: anytype) Unsigned {
const z: V = @splat(0);
const x = q.limbs;
var x0 = @shuffle(u64, x, undefined, @Vector(8, i32){ 0, 0, 0, 0, 0, 0, 3, 3 });
const x1 = @shuffle(u64, x, undefined, @Vector(8, i32){ 0, 1, 2, 3, 4, 5, 3, 4 });
var x2 = @shuffle(u64, x, undefined, @Vector(8, i32){ 4, 4, 1, 1, 1, 1, 1, 7 });
const x3 = @shuffle(u64, x, undefined, @Vector(8, i32){ 4, 5, 1, 2, 3, 4, 5, 7 });
var x4 = @shuffle(u64, x, undefined, @Vector(8, i32){ 3, 7, 5, 7, 2, 2, 2, 2 });
const x5 = @shuffle(u64, x, undefined, @Vector(8, i32){ 5, 7, 5, 7, 2, 3, 4, 5 });
x0 <<= V{ 0, 1, 1, 1, 1, 1, 0, 1 };
x2 <<= V{ 0, 1, 0, 1, 1, 1, 1, 1 };
x4 <<= V{ 1, 1, 0, 1, 0, 1, 1, 1 };
const p0l = madd52lo(z, x0, x1);
const p1l = madd52lo(z, x2, x3);
const p2l = madd52lo(z, x4, x5);
const p0h = madd52hi(z, x0, x1) << @splat(9);
const p1h = madd52hi(z, x2, x3) << @splat(9);
const p2h = madd52hi(z, x4, x5) << @splat(9);
const max = std.math.maxInt(u64);
const mask1: V = .{ max, max, 0, 0, 0, 0, 0, 0 };
const mask2: V = .{ max, 0, max, 0, 0, 0, 0, 0 };
// zig fmt: off
const zll = p0l +% (~mask1 & p1l) +% (~mask2 & p2l);
const zlh = p0h +% (~mask1 & p1h) +% (~mask2 & p2h);
const zhl = ( mask1 & p1l) +% ( mask2 & p2l);
const zhh = ( mask1 & p1h) +% ( mask2 & p2h);
// zig fmt: on
const zl = zll +% alignr(z, zlh, 7);
const zh = zhl +% alignr(z, zhh, 7) +% alignr(zlh, z, 7);
const za = zl & V{ max, max, max, max, max, max, 0, 0 };
const zb = alignr(zl, zh, 6);
return .{ .limbs = za +% (zb << @splat(7)) +% (zb << @splat(4)) +% (zb << @splat(3)) };
}
fn abs(x: anytype) Elem(@TypeOf(x)) {
return if (x.isNegative()) x.neg() else x.*;
}
fn isNegative(x: anytype) bool {
var bytes: [32]u8 = undefined;
x.toBytes(&bytes);
return bytes[0] & 1 != 0;
}
fn eql(x: anytype, y: Elem(@TypeOf(x))) bool {
const r = x.sub(y);
return r.isZero();
}
fn isZero(x: anytype) bool {
const y = x.limbs;
const mask = (y == @as(V, @splat(0))) | (y == p.limbs);
return @as(u8, @bitCast(mask)) == 0xFF;
}
// a^((p-5)/8) = a^(252^3)
fn pow2523(a: Unreduced) Unreduced {
var t0 = a.mul(a.sq());
var t1 = t0.mul(t0.sqn(2)).sq().mul(a);
t0 = t1.sqn(5).mul(t1);
var t2 = t0.sqn(5).mul(t1);
t1 = t2.sqn(15).mul(t2);
t2 = t1.sqn(30).mul(t1);
t1 = t2.sqn(60).mul(t2);
return t1.sqn(120).mul(t1).sqn(10).mul(t0).sqn(2).mul(a);
}
fn invert(a: Unreduced) Unreduced {
var t0 = a.sq();
var t1 = t0.sqn(2).mul(a);
t0 = t0.mul(t1);
t1 = t1.mul(t0.sq());
t1 = t1.mul(t1.sqn(5));
var t2 = t1.sqn(10).mul(t1);
t2 = t2.mul(t2.sqn(20)).sqn(10);
t1 = t1.mul(t2);
t2 = t1.sqn(50).mul(t1);
return t1.mul(t2.mul(t2.sqn(100)).sqn(50)).sqn(5).mul(t0);
}
fn approx(x: Unreduced) NearlyReduced {
const m43 = (1 << 43) - 1;
const m40 = (1 << 40) - 1;
var y0, var y1, var y2, var y3, var y4, var y5, _, _ = x.limbs;
const b = 0;
var c: u64 = undefined;
// zig fmt: off
y5 -%= b;
c = (y5 >> 40); y5 &= m40; y0 +%= 19 * c;
c = (y0 >> 43); y0 &= m43; y1 +%= c;
c = (y1 >> 43); y1 &= m43; y2 +%= c;
c = (y2 >> 43); y2 &= m43; y3 +%= c;
c = (y3 >> 43); y3 &= m43; y4 +%= c;
c = (y4 >> 43); y4 &= m43; y5 +%= c;
y5 +%= b;
// zig fmt: on
return .{ .limbs = .{ y0, y1, y2, y3, y4, y5, 0, 0 } };
}
const sqn = generic.sqn;
const cast = generic.cast;
const toBytes = generic.toBytes;
};
/// All limbs are in [0, 2^62).
const Unsigned = struct {
limbs: V,
fn fold(u: Unsigned) Unreduced {
const x = u.limbs;
const m43 = (1 << 43) - 1;
const m40 = (1 << 40) - 1;
return .{ .limbs = madd52lo(
x & V{ m43, m43, m43, m43, m43, m40, 0, 0 },
.{ 19, 1, 1, 1, 1, 1, 0, 0 },
@shuffle(
u64,
x >> .{ 43, 43, 43, 43, 43, 40, 0, 0 },
undefined,
@Vector(8, i32){ 5, 0, 1, 2, 3, 4, 6, 7 },
),
) };
}
};
/// All limbs are in [-2^62, 2^62).
const Signed = struct {
limbs: V,
fn carryPropagate(x: V) V {
const z: V = @splat(0);
const m43 = (1 << 43) - 1;
const m40 = (1 << 40) - 1;
const xl = x & V{ m43, m43, m43, m43, m43, m40, 0, 0 };
const xh: V = @bitCast(@as(S, @bitCast(x)) >> .{ 43, 43, 43, 43, 43, 40, 0, 0 });
const c = @shuffle(u64, xh, z, @Vector(8, i32){ 5, 0, 1, 2, 3, 4, ~@as(i32, 0), ~@as(i32, 0) });
// TODO: simplify into c *% .{ 19, 0, 0, ..., 0 }
const d = ((c << @splat(1)) +% (c << @splat(4))) & V{ std.math.maxInt(u64), 0, 0, 0, 0, 0, 0, 0 };
return xl +% c +% d;
}
fn fold(a: Signed) Unreduced {
const x = a.limbs;
const b: V = .{ 19 << 23, 1 << 20, 1 << 20, 1 << 20, 1 << 20, 1 << 20, 0, 0 };
return .{ .limbs = carryPropagate(x -% b) +% b };
}
};
/// All limbs are in [-2^63, 2^63).
const Arbitrary = struct {
limbs: V,
};
const generic = struct {
fn toBytes(a: anytype, out: *[32]u8) void {
const reduced = switch (Elem(@TypeOf(a))) {
Reduced => a,
NearlyReduced => a.reduce(),
else => a.approx().reduce(),
};
reduced.toBytes(out);
}
const Order = enum(u8) {
reduced = 0,
nearly_reduced = 1,
unreduced = 2,
unsigned = 3,
signed = 4,
arbitrary = 5,
fn Type(o: Order) type {
return switch (o) {
.reduced => Reduced,
.nearly_reduced => NearlyReduced,
.unreduced => Unreduced,
.unsigned => Unsigned,
.signed => Signed,
.arbitrary => Arbitrary,
};
}
};
fn order(T: type) Order {
return switch (Elem(T)) {
Reduced => .reduced,
NearlyReduced => .nearly_reduced,
Unreduced => .unreduced,
Unsigned => .unsigned,
Signed => .signed,
Arbitrary => .arbtirary,
else => unreachable,
};
}
fn checkSafe(F: type, T: type) void {
if (@intFromEnum(order(F)) > @intFromEnum(order(T))) {
@compileError("cannot cast from " ++ @typeName(F) ++ " to " ++ @typeName(T));
}
}
inline fn cast(x: anytype, comptime to: Order) to.Type() {
comptime checkSafe(@TypeOf(x), to.Type());
return .{ .limbs = x.limbs };
}
fn sqn(a: anytype, n: comptime_int) Unreduced {
var r = a.cast(.unreduced);
inline for (0..n) |_| {
r = r.sq();
}
return r;
}
};
inline fn alignr(x: V, y: V, comptime imm: u32) V {
const front: [8 - imm]i32 = ~(std.simd.iota(i32, 8 - imm) + (@as(@Vector(8 - imm, i32), @splat(imm))));
const back: [imm]i32 = std.simd.iota(i32, imm);
const mask: @Vector(8, i32) = front ++ back;
return @shuffle(u64, y, x, mask);
}
};
const Point = struct {
/// [X0, Y0, Z0, T0 | X3, Y3, Z3, T3]
/// [X1, Y1, Z1, T1 | X4, Y4, Z4, T4]
/// [X2, Y2, Z2, T2 | X5, Y5, Z5, T5]
limbs: [3]V,
/// The point at infinity.
const zero: Point = .{ .limbs = .{
.{ 1, 1, 1, 0, 0, 0, 0, 0 },
.{ 0, 0, 0, 0, 0, 0, 0, 0 },
.{ 0, 0, 0, 0, 0, 0, 0, 0 },
} };
/// (1, 1, 1, 2*d)
const d1: Point = .{ .limbs = .{
.{ 1, 1, 1, 3934839304537, 0, 0, 0, 521695520920 },
.{ 0, 0, 0, 507525298899, 0, 0, 0, 6596238350568 },
.{ 0, 0, 0, 15037786634, 0, 0, 0, 309467527341 },
} };
/// (1, 1, 2, 2*d)
const d2: Point = .{ .limbs = .{
.{ 1, 1, 2, 3934839304537, 0, 0, 0, 521695520920 },
.{ 0, 0, 0, 507525298899, 0, 0, 0, 6596238350568 },
.{ 0, 0, 0, 15037786634, 0, 0, 0, 309467527341 },
} };
pub fn fromBytes(bytes: *const [32]u8) error{InvalidEncoding}!Point {
const z: Fe.Unreduced = .one;
const y = Fe.NearlyReduced.fromBytes(bytes).cast(.unreduced);
const y2 = y.sq();
const u = y2.sub(z);
const v = y2.mul(.d).add(z);
var x = u.mul(v).pow2523().mul(u);
const vxx = x.sq().mul(v);
const has_m_root = vxx.sub(u).isZero();
const has_p_root = vxx.add(u).isZero();
if ((@intFromBool(has_m_root) | @intFromBool(has_p_root)) == 0) { // best-effort to avoid two conditional branches
return error.InvalidEncoding;
}
// TODO: make these cmovs
if (1 - @intFromBool(has_m_root) != 0) x = x.mul(.sqrtm1);
if (@intFromBool(x.isNegative()) ^ (bytes[31] >> 7) != 0) x = x.neg();
const t = x.mul(y);
return .init(x, y, z, t);
}
fn toBytes(p: Point, bytes: *[32]u8) void {
var x, var y, const z, _ = p.split();
const inv = z.invert();
x = x.mul(inv);
y = y.mul(inv);
y.toBytes(bytes);
bytes[31] ^= @as(u8, @intFromBool(x.isNegative())) << 7;
}
fn init(r0: Fe.Unreduced, r1: Fe.Unreduced, r2: Fe.Unreduced, r3: Fe.Unreduced) Point {
// At the start:
// x = [x0, x1, x2, x3, x4, x5, __, __]
// y = [y0, y1, y2, y3, y4, y5, __, __]
// z = [z0, z1, z2, z3, z4, z5, __, __]
// t = [t0, t1, t2, t3, t4, t5, __, __]
const x = r0.limbs;
const y = r1.limbs;
const z = r2.limbs;
const t = r3.limbs;
// t0 = [x0, x1, z0, z1, x4, x5, z4, z5]
// t1 = [y0, y1, t0, t1, y4, y5, t4, t5]
// t2 = [x2, x3, z2, z3, __, __, __, __]; we don't care about the last 4 elements
// t3 = [y2, y3, t2, t3, __, __, __, __]
const t0 = @shuffle(u64, x, z, @Vector(8, i32){ 0, 1, ~@as(i32, 0), ~@as(i32, 1), 4, 5, ~@as(i32, 4), ~@as(i32, 5) });
const t1 = @shuffle(u64, y, t, @Vector(8, i32){ 0, 1, ~@as(i32, 0), ~@as(i32, 1), 4, 5, ~@as(i32, 4), ~@as(i32, 5) });
const t2 = @shuffle(u64, x, z, @Vector(8, i32){ 2, 3, ~@as(i32, 2), ~@as(i32, 3), 6, 7, ~@as(i32, 6), ~@as(i32, 7) });
const t3 = @shuffle(u64, y, t, @Vector(8, i32){ 2, 3, ~@as(i32, 2), ~@as(i32, 3), 6, 7, ~@as(i32, 6), ~@as(i32, 7) });
// c04 = [x0, y0, z0, t0, x4, y4, z4, t4]
// c15 = [x1, y1, z1, t1, x5, y5, z5, t5]
// c26 = [x2, y2, z2, t2, __, __, __, __]
// c37 = [x3, y3, z3, t3, __, __, __, __]
const c04 = @shuffle(u64, t0, t1, @Vector(8, i32){ 0, ~@as(i32, 0), 2, ~@as(i32, 2), 4, ~@as(i32, 4), 6, ~@as(i32, 6) });
const c15 = @shuffle(u64, t0, t1, @Vector(8, i32){ 1, ~@as(i32, 1), 3, ~@as(i32, 3), 5, ~@as(i32, 5), 7, ~@as(i32, 7) });
const c26 = @shuffle(u64, t2, t3, @Vector(8, i32){ 0, ~@as(i32, 0), 2, ~@as(i32, 2), 4, ~@as(i32, 4), 6, ~@as(i32, 6) });
const c37 = @shuffle(u64, t2, t3, @Vector(8, i32){ 1, ~@as(i32, 1), 3, ~@as(i32, 3), 5, ~@as(i32, 5), 7, ~@as(i32, 7) });
// zig fmt: off
return .{ .limbs = .{
// [x0, y0, z0, t0 | x3, y3, z3, t3]
pack(c04, false, c37, false),
// [x1, y1, z1, t1 | x4, y4, z4, t4]
pack(c15, false, c04, true),
// [x2, y2, z2, t2 | x5, y5, z5, t5]
pack(c26, false, c15, true),
} };
// zig fmt: on
}
fn split(p: Point) [4]Fe.Unreduced {
// r0 = [x0, y0, z0, t0, x3, y3, z3, t3]
// r1 = [x1, y1, z1, t1, x4, y4, z4, t4]
// r2 = [x2, y2, z2, t2, x5, y5, z5, t5]
// r3 = [ 0, 0, 0, 0, 0, 0, 0, 0]
const r0, const r1, const r2 = p.limbs;
const r3: V = @splat(0);
// c0 = [x0, x1, z0, z1, x3, x4, z3, z4]
// c1 = [y0, y1, t0, t1, y3, t3, y4, t4]
// c2 = [x2, 0, z2, 0, x5, 0, z5, 0]
// c3 = [y2, 0, t2, 0, y5, 0, t5, 0]
const c0 = @shuffle(u64, r0, r1, @Vector(8, i32){ 0, ~@as(i32, 0), 2, ~@as(i32, 2), 4, ~@as(i32, 4), 6, ~@as(i32, 6) });
const c1 = @shuffle(u64, r0, r1, @Vector(8, i32){ 1, ~@as(i32, 1), 3, ~@as(i32, 3), 5, ~@as(i32, 5), 7, ~@as(i32, 7) });
const c2 = @shuffle(u64, r2, r3, @Vector(8, i32){ 0, ~@as(i32, 0), 2, ~@as(i32, 2), 4, ~@as(i32, 4), 6, ~@as(i32, 6) });
const c3 = @shuffle(u64, r2, r3, @Vector(8, i32){ 1, ~@as(i32, 1), 3, ~@as(i32, 3), 5, ~@as(i32, 5), 7, ~@as(i32, 7) });
return .{
// x = [x0, x1, x2, x3, x4, 0, 0]
.{ .limbs = @shuffle(u64, c0, c2, @Vector(8, i32){ 0, 1, ~@as(i32, 0), 4, 5, ~@as(i32, 4), ~@as(i32, 1), ~@as(i32, 1) }) },
// y = [y0, y1, y2, y3, y4, 0, 0]
.{ .limbs = @shuffle(u64, c1, c3, @Vector(8, i32){ 0, 1, ~@as(i32, 0), 4, 5, ~@as(i32, 4), ~@as(i32, 1), ~@as(i32, 1) }) },
// z = [z0, z1, z2, z3, z4, 0, 0]
.{ .limbs = @shuffle(u64, c0, c2, @Vector(8, i32){ 2, 3, ~@as(i32, 2), 6, 7, ~@as(i32, 6), ~@as(i32, 1), ~@as(i32, 1) }) },
// t = [t0, t1, t2, t3, t4, 0, 0]
.{ .limbs = @shuffle(u64, c1, c3, @Vector(8, i32){ 2, 3, ~@as(i32, 2), 6, 7, ~@as(i32, 6), ~@as(i32, 1), ~@as(i32, 1) }) },
};
}
const Shuffle = struct {
first: u3,
second: u3,
third: u3,
fourth: u3,
const Lane = enum {
X,
Y,
Z,
T,
};
const TYTY: Shuffle = .parse("TYTY");
const XYYX: Shuffle = .parse("XYYX");
const XZZX: Shuffle = .parse("XZZX");
const YXTZ: Shuffle = .parse("YXTZ");
const YXZT: Shuffle = .parse("YXZT");
const YYYY: Shuffle = .parse("YYYY");
const YYZX: Shuffle = .parse("YYZX");
const ZTZT: Shuffle = .parse("ZTZT");
const TTTT: Shuffle = .parse("TTTT");
const XZXZ: Shuffle = .parse("XZXZ");
const YTYT: Shuffle = .parse("YTYT");
const YZTZ: Shuffle = .parse("YZTZ");
fn parse(bytes: []const u8) Shuffle {
std.debug.assert(bytes.len == 4);
var out: Shuffle = undefined;
for (bytes, @typeInfo(Shuffle).@"struct".fields) |b, field| {
@field(out, field.name) = @intFromEnum(@field(Lane, &.{b}));
}
return out;
}
};
const Lanes = struct {
x: bool = false,
y: bool = false,
z: bool = false,
t: bool = false,
const X: Lanes = .{ .x = true };
const Y: Lanes = .{ .y = true };
const Z: Lanes = .{ .z = true };
const YZ: Lanes = .{ .y = true, .z = true };
const XY: Lanes = .{ .x = true, .y = true };
const ZT: Lanes = .{ .z = true, .t = true };
const XT: Lanes = .{ .x = true, .t = true };
/// Returns a select mask which is "true" for the active lanes, and "false" for the inactive ones.
///
/// E.g. `x = true, z = true`, returns a mask `{ true, false, true, false, true, false, true, false}`
fn mask(l: Lanes) @Vector(8, bool) {
// We notice that the lanes are split down the middle,
// so we can just generate the mask for one side and then duplicate it across.
const half: [4]bool = .{ l.x, l.y, l.z, l.t };
return half ** 2;
}
};
fn shuffle(p: Point, comptime c: Shuffle) Point {
const t0, const t1, const t2 = p.limbs;
const mask: @Vector(8, i32) = .{
c.first,
c.second,
c.third,
c.fourth,
c.first + 4,
c.second + 4,
c.third + 4,
c.fourth + 4,
};
return .{ .limbs = .{
@shuffle(u64, t0, undefined, mask),
@shuffle(u64, t1, undefined, mask),
@shuffle(u64, t2, undefined, mask),
} };
}
fn blend(
x: Point,
y: Point,
comptime l: enum { X, Y },
) Point {
const x03, const x14, const x25 = x.limbs;
const y03, const y14, const y25 = y.limbs;
const start = switch (l) {
.X => 0,
.Y => 1,
};
const end = start + 4;
const mask: @Vector(8, i32) = .{
start, ~@as(i32, start), start, ~@as(i32, start),
end, ~@as(i32, end), end, ~@as(i32, end),
};
return .{ .limbs = .{
@shuffle(u64, x03, y03, mask),
@shuffle(u64, x14, y14, mask),
@shuffle(u64, x25, y25, mask),
} };
}
/// For each lane, if the lane is active it is set to `x[i] + y[i]`, otherwise to `z[i]`.
fn addMask(x: Point, y: Point, comptime l: Lanes, z: Point) Point {
const s03, const s14, const s25 = z.limbs;
const p03, const p14, const p25 = x.limbs;
const q03, const q14, const q25 = y.limbs;
const mask = comptime l.mask();
return .{ .limbs = .{
@select(u64, mask, p03 +% q03, s03),
@select(u64, mask, p14 +% q14, s14),
@select(u64, mask, p25 +% q25, s25),
} };
}
/// For each lane, if the lane is active it is set to `x[i] - y[i]`, otherwise to `z[i]`.
fn subMask(x: Point, y: Point, comptime l: Lanes, z: Point) Point {
const s03, const s14, const s25 = z.limbs;
const p03, const p14, const p25 = x.limbs;
const q03, const q14, const q25 = y.limbs;
var m03: V = .{ 8796093022189, 8796093022189, 8796093022189, 8796093022189, 8796093022207, 8796093022207, 8796093022207, 8796093022207 };
var m14: V = .{ 8796093022207, 8796093022207, 8796093022207, 8796093022207, 8796093022207, 8796093022207, 8796093022207, 8796093022207 };
var m25: V = .{ 8796093022207, 8796093022207, 8796093022207, 8796093022207, 1099511627775, 1099511627775, 1099511627775, 1099511627775 };
m03 -%= q03;
m14 -%= q14;
m25 -%= q25;
const mask = comptime l.mask();
return .{ .limbs = .{
@select(u64, mask, p03 +% m03, s03),
@select(u64, mask, p14 +% m14, s14),
@select(u64, mask, p25 +% m25, s25),
} };
}
fn foldUnsigned(x: Point) Point {
const m43 = (1 << 43) - 1;
const m40 = (1 << 40) - 1;
const m43_m43: V = @splat(m43);
const m43_m40: V = .{ m43, m43, m43, m43, m40, m40, m40, m40 };
const s43_s40: V = .{ 43, 43, 43, 43, 40, 40, 40, 40 };
const p03, const p14, const p25 = x.limbs;
const ph03: V = @bitCast(@as(S, @bitCast(p03)) >> @splat(43));
const ph14: V = @bitCast(@as(S, @bitCast(p14)) >> @splat(43));
const ph25: V = @bitCast(@as(S, @bitCast(p25)) >> s43_s40);
const za = ph25 *% @as(V, @splat(19));
return .{ .limbs = .{
(p03 & m43_m43) +% pack(za, true, ph25, false),
(p14 & m43_m43) +% ph03,
(p25 & m43_m40) +% ph14,
} };
}
fn foldSigned(x: Point) Point {
const b0 = 19 << 23;
const bb = 1 << 20;
const m43 = (1 << 43) - 1;
const m40 = (1 << 40) - 1;
const bias03: V = .{ b0, b0, b0, b0, bb, bb, bb, bb };
const bias: V = @splat(bb);
const m43_m43: V = @splat(m43);
const m43_m40: V = .{ m43, m43, m43, m43, m40, m40, m40, m40 };
const s43_s40: V = .{ 43, 43, 43, 43, 40, 40, 40, 40 };
const p03 = x.limbs[0] -% bias03;
const p14 = x.limbs[1] -% bias;
const p25 = x.limbs[2] -% bias;
const ph03: V = @bitCast(@as(S, @bitCast(p03)) >> @splat(43));
const ph14: V = @bitCast(@as(S, @bitCast(p14)) >> @splat(43));
const ph25: V = @bitCast(@as(S, @bitCast(p25)) >> s43_s40);
const ph25_19 = ph25 *% @as(V, @splat(19));
return .{ .limbs = .{
(p03 & m43_m43) +% pack(ph25_19, true, ph25, false) +% bias03,
(p14 & m43_m43) +% ph03 +% bias,
(p25 & m43_m40) +% ph14 +% bias,
} };
}
const Wide = [6]V;
/// https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.4
/// https://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html#addition-add-2008-hwcd-3
fn add(x: Point, y: Point) Point {
// // Per-element version:
{
const x1, const y1, const z1, const t1 = x.split();
const x2, const y2, const z2, const t2 = y.split();
const a = (y1.sub(x1)).mul(y2.sub(x2));
const b = (y1.add(x1)).mul(y2.add(x2));
const c = t1.mul(.d2).mul(t2);
const d = z1.add(z1).mul(z2);
const e = b.sub(a);
const f = d.sub(c);
const g = d.add(c);
const h = b.add(a);
if (false) return .init(
e.mul(f),
g.mul(h),
f.mul(g),
e.mul(h),
);
}
// // Parallel version:
{
// zig fmt: off
var a = x.shuffle(.YXZT); // a = [Y1, X1, Z1, T1]
var b = y.shuffle(.YXZT); // b = [Y2, X2, Z2, T2]
a = subMask(a, x, .X, a); // a = [Y1-X1, X1, Z1, T1]
b = subMask(b, y, .X, b); // b = [Y2-X2, X2, Z2, T2]
a = addMask(a, x, .YZ, a); // a = [Y1-X1, Y1+X1, Z1*2, T1]
b = addMask(b, y, .Y, b); // b = [Y2-X2, X2+Y2, Z2, T2]
a = a.foldSigned();
b = b.foldSigned();
a = quad.mul(a, .d1); // a = [Y1-X1, Y1+X1, Z1*2, T1*2d]
a = a.foldUnsigned();
// A = (Y1-X1)*(Y2-X2)
// B = (Y1+X1)*(Y2+X2)
// C = T1*2d*T2
// D = Z1*2*Z2
a = quad.mul(a, b); // a = [A, B, D, C]
b = a.shuffle(.YXTZ); // a = [B, A, C, D]
// E = B-A
// F = D-C
b = subMask(b, a, .XT, b); // b = [E, A, C, F]
// G = D+C
// H = B+A
b = addMask(b, a, .YZ, b); // b = [E, H, G, F]
b = b.foldSigned();
a = b.shuffle(.XZZX); // a = [E, G, G, E]
b = b.shuffle(.TYTY); // b = [F, H, F, H]
// X3 = E*F
// Y3 = G*H
// T3 = E*H
// Z3 = F*G
a = quad.mul(a, b); // a = [X3, Y3, Z3, T3]
// zig fmt: on
if (false) return a.foldUnsigned();
}
// Better parallel version:
{
const y4 = blend(x, y, .Y); // (Y1, Y2, Y1, Y2) u44|u44|u44|u44
const x4 = blend(x, y, .X); // (X1, X2, X1, X2) u44|u44|u44|u44
// (-X1, -X2, X1, X2)
// const nx: Point = nx: {
// // (M, M, 0, 0)
// var m03: V = .{ 0x7FFFFFFFFED, 0x7FFFFFFFFED, 0, 0, 0x7FFFFFFFFFF, 0x7FFFFFFFFFF, 0, 0 };
// var m14: V = .{ 0x7FFFFFFFFFF, 0x7FFFFFFFFFF, 0, 0, 0x7FFFFFFFFFF, 0x7FFFFFFFFFF, 0, 0 };
// var m25: V = .{ 0x7FFFFFFFFFF, 0x7FFFFFFFFFF, 0, 0, 0x0FFFFFFFFFF, 0x0FFFFFFFFFF, 0, 0 };
// const q03, const q14, const q25 = x4.limbs;
// m03 -%= q03;
// m14 -%= q14;
// m25 -%= q25;
// break :nx .{ .limbs = .{ m03, m14, m25 } };
// };
const nx = addMask(y4, x4, .ZT, y4); // (Y1, Y2, Y1+X1, Y2+X2)
const ysx = subMask(y4, x4, .XY, nx); // (Y1-X1, Y2-X2, Y1+X1, Y2+X2)
// wide-(Z1, T1, Z2, T2)
const t_wide: Wide = t: {
// To start:
// x03 = [__, __, Z10, T10 | __, __, Z13, T13 ]
// x14 = [__, __, Z11, T11 | __, __, Z14, T14 ]
// x25 = [__, __, Z12, T12 | __, __, Z15, T15 ]
// ---
// y03 = [__, __, Z20, T20 | __, __, Z23, T23 ]
// y14 = [__, __, Z21, T21 | __, __, Z24, T24 ]
// y25 = [__, __, Z22, T22 | __, __, Z25, T25 ]
//
// We combine the following transformations:
//
// [X0, Y0, Z0, T0 | X3, Y3, Z3, T3]
// ->
// [X0, Y0, Z0, T0 | X0, Y0, Z0, T0]
//
// as well as
//
// [__, __, Z10, T10 | __, __, Z13, T13]
// [__, __, Z20, T20 | __, __, Z23, T23]
// ->
// [Z10, T10, Z20, T20 | Z13, T13, Z23, T23]
//
// We want to get to:
// t00 = [Z10, T10, Z20, T20 | Z10, T10, Z20, T20]
// t11 = [Z11, T11, Z21, T21 | Z11, T11, Z21, T21]
// t22 = [Z12, T12, Z22, T22 | Z12, T12, Z22, T22]
// t33 = [Z13, T13, Z23, T23 | Z13, T13, Z23, T23]
// t44 = [Z14, T14, Z24, T24 | Z14, T14, Z24, T24]
// t55 = [Z15, T15, Z25, T25 | Z15, T15, Z25, T25]
const x03, const x14, const x25 = x.limbs;
const y03, const y14, const y25 = y.limbs;
const first: @Vector(8, i32) = .{
2, 3, ~@as(i32, 2), ~@as(i32, 3),
2, 3, ~@as(i32, 2), ~@as(i32, 3),
};
const second: @Vector(8, i32) = .{
6, 7, ~@as(i32, 6), ~@as(i32, 7),
6, 7, ~@as(i32, 6), ~@as(i32, 7),
};
break :t .{
@shuffle(u64, x03, y03, first),
@shuffle(u64, x14, y14, first),
@shuffle(u64, x25, y25, first),
@shuffle(u64, x03, y03, second),
@shuffle(u64, x14, y14, second),
@shuffle(u64, x25, y25, second),
};
};
const t = quad.mulWide(t_wide, .d2).foldUnsigned(); // (Z1, T1, Z2*2, T2*d*2)
// (Y1-X1, Y1+X1, T1, Z1)
const a: Wide = a: {
const x03, const x14, const x25 = ysx.limbs;
const y03, const y14, const y25 = t.limbs;
const first: @Vector(8, i32) = .{
// Y1-X1, Y1+X1, T1, Z1
0, 2, ~@as(i32, 1), ~@as(i32, 0),
0, 2, ~@as(i32, 1), ~@as(i32, 0),
};
const second: @Vector(8, i32) = .{
4, 6, ~@as(i32, 5), ~@as(i32, 4),
4, 6, ~@as(i32, 5), ~@as(i32, 4),
};
break :a .{
@shuffle(u64, x03, y03, first),
@shuffle(u64, x14, y14, first),
@shuffle(u64, x25, y25, first),
@shuffle(u64, x03, y03, second),
@shuffle(u64, x14, y14, second),
@shuffle(u64, x25, y25, second),
};
};
// ysx = (Y1-X1, Y2-X2, Y1+X1, Y2+X2)
// t = ( Z1, T1, Z2*2, T2*d*2)
// b = (Y2-X2, Y2+X2, T2*d*2, Z2*2)
const b: Point = b: {
const x03, const x14, const x25 = ysx.limbs;
const y03, const y14, const y25 = t.limbs;
const mask: @Vector(8, i32) = .{
1, 3, ~@as(i32, 3), ~@as(i32, 2), // first half
5, 7, ~@as(i32, 7), ~@as(i32, 6),
};
break :b .{ .limbs = .{
@shuffle(u64, x03, y03, mask),
@shuffle(u64, x14, y14, mask),
@shuffle(u64, x25, y25, mask),
} };
};
const abcd = quad.mulWide(a, b).foldUnsigned(); // (A, B, C, D)
const ac = abcd.shuffle(.XZXZ); // (A, C, A, C)
const bd = abcd.shuffle(.YTYT); // (B, D, B, D)
const nac = subMask(bd, ac, .XY, bd); // (B-A, D-C, B, D)
// (B-A, D-C, B+A, D+C)
// (E, F, H, G)
const efhg = addMask(bd, ac, .ZT, nac).foldUnsigned();
// (E, G, F, E)
const egfe: Wide = egef: {
const x03, const x14, const x25 = efhg.limbs;
const first: @Vector(8, i32) = .{
0, 3, 1, 0,
0, 3, 1, 0,
};
const second: @Vector(8, i32) = .{
4, 7, 5, 4,
4, 7, 5, 4,
};
break :egef .{
@shuffle(u64, x03, undefined, first),
@shuffle(u64, x14, undefined, first),
@shuffle(u64, x25, undefined, first),
@shuffle(u64, x03, undefined, second),
@shuffle(u64, x14, undefined, second),
@shuffle(u64, x25, undefined, second),
};
};
// X3 = E*F
// Y3 = G*H
// Z3 = F*G
// T3 = E*H
// (F, H, G, H)
const fhgh = efhg.shuffle(.YZTZ);
return quad.mulWide(egfe, fhgh).foldUnsigned();
}
}
/// https://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html#doubling-dbl-2008-hwcd
fn dbl(x: Point) Point {
// Per-element version:
// const x1, const y1, const z1, _ = x.split();
// const a = x1.sq();
// const b = y1.sq();
// const c = z1.sq().add(z1.sq());
// const h = a.add(b);
// const e = h.sub(x1.add(y1).sq());
// const g = a.sub(b);
// const f = c.add(g);
// return .init(
// e.mul(f),
// g.mul(h),
// f.mul(g),
// e.mul(h),
// );
// Parallel version:
// zig fmt: off
var a = x.shuffle(.YYZX); // a = [ Y1, Y1, T2*d*2,Z1, X1]
a = addMask(a, x, .X, a); // a = [ Y1+X1, Y1, Z1, X1]
// A = X1^2
// B = Y1^2
a = quad.sq(a); // a = [ (Y1+X1)^2, B, Z1^2, A]
// C = 2*Z1^2
a = addMask(a, a, .Z, a); // a = [ (Y1+X1)^2, B, C, A]
var b = a.shuffle(.TTTT); // b = [ A A, A, A]
const BB = a.shuffle(.YYYY); //BB = [ B B, B, B]
// H = A+B
// G = A-B
b = addMask(b, BB, .XT, b); // b = [ H, A, A, H]
b = subMask(b, BB, .YZ, b); // b = [ H, G, G, H]
// F = C+G
b = addMask(b, a, .Z, b); // b = [ H, G, F, H]
b = subMask(b, a, .X, b); // b = [ E G, F, H]
b = b.foldSigned();
a = b.shuffle(.XYYX); // a = [ E G, G, E]
b = b.shuffle(.ZTZT); // b = [ F, H, F, H]
a = quad.mul(a, b); // a = [X3, Y3, T3, Z3]
// zig fmt: on
return a.foldUnsigned();
}
pub fn format(p: Point, writer: *std.Io.Writer) !void {
const x, const y, const z, const t = p.split();
try writer.print(
\\x: {}
\\y: {}
\\z: {}
\\t: {}
, .{ x.limbs, y.limbs, z.limbs, t.limbs });
}
const quad = struct {
pub fn mulWide(p: Wide, q: Point) Point {
const zz: V = @splat(0);
const x00, const x11, const x22, const x33, const x44, const x55 = p;
const y03, const y14, const y25 = q.limbs;
// zig fmt: off
const p0_q3 = madd52lo( zz, x00, y03);
var p1_q4 = madd52lo(madd52lo( zz, x11, y03), x00, y14);
var p2_q5 = madd52lo(madd52lo(madd52lo(zz, x22, y03), x11, y14), x00, y25);
var p3_q6 = madd52lo(madd52lo(madd52lo(zz, x33, y03), x22, y14), x11, y25);
var p4_q7 = madd52lo(madd52lo(madd52lo(zz, x44, y03), x33, y14), x22, y25);
var p5_q8 = madd52lo(madd52lo(madd52lo(zz, x55, y03), x44, y14), x33, y25);
var p6_q9 = madd52lo(madd52lo(zz, x55, y14), x44, y25);
var p7_qa = madd52lo(zz, x55, y25);
p1_q4 = p1_q4 +% (madd52hi( zz, x00, y03) << @splat(9));
p2_q5 = p2_q5 +% (madd52hi(madd52hi( zz, x11, y03), x00, y14) << @splat(9));
p3_q6 = p3_q6 +% (madd52hi(madd52hi(madd52hi(zz, x22, y03), x11, y14), x00, y25) << @splat(9));
p4_q7 = p4_q7 +% (madd52hi(madd52hi(madd52hi(zz, x33, y03), x22, y14), x11, y25) << @splat(9));
p5_q8 = p5_q8 +% (madd52hi(madd52hi(madd52hi(zz, x44, y03), x33, y14), x22, y25) << @splat(9));
p6_q9 = p6_q9 +% (madd52hi(madd52hi(madd52hi(zz, x55, y03), x44, y14), x33, y25) << @splat(9));
p7_qa = p7_qa +% ( madd52hi(madd52hi(zz, x55, y14), x44, y25) << @splat(9));
const p8_qb = ( madd52hi(zz, x55, y25) << @splat(9));
// zig fmt: on
const q6_p3 = pack(p3_q6, true, p3_q6, false);
const q7_p4 = pack(p4_q7, true, p4_q7, false);
const q8_p5 = pack(p5_q8, true, p5_q8, false);
const bottom: @Vector(8, bool) = .{ false, false, false, false, true, true, true, true };
const top: @Vector(8, bool) = .{ true, true, true, true, false, false, false, false };
const za03 = @select(u64, bottom, q6_p3 +% p0_q3, p0_q3);
const za14 = @select(u64, bottom, q7_p4 +% p1_q4, p1_q4);
const za25 = @select(u64, bottom, q8_p5 +% p2_q5, p2_q5);
const zb03 = @select(u64, top, q6_p3 +% p6_q9, p6_q9);
const zb14 = @select(u64, top, q7_p4 +% p7_qa, p7_qa);
const zb25 = @select(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment