const std = @import("std");
const crypto = std.crypto;
const debug = std.debug;
const mem = std.mem;
const meta = std.meta;
const NonCanonicalError = crypto.errors.NonCanonicalError;
const NotSquareError = crypto.errors.NotSquareError;
pub const FieldParams = struct {
fiat: type,
field_order: comptime_int,
field_bits: comptime_int,
saturated_bits: comptime_int,
encoded_length: comptime_int,
};
pub fn Field(comptime params: FieldParams) type {
const fiat = params.fiat;
const MontgomeryDomainFieldElement = fiat.MontgomeryDomainFieldElement;
const NonMontgomeryDomainFieldElement = fiat.NonMontgomeryDomainFieldElement;
return struct {
const Fe = @This();
limbs: MontgomeryDomainFieldElement,
pub const field_order = params.field_order;
pub const field_bits = params.field_bits;
pub const saturated_bits = params.saturated_bits;
pub const encoded_length = params.encoded_length;
pub const zero: Fe = Fe{ .limbs = mem.zeroes(MontgomeryDomainFieldElement) };
pub const one = one: {
var fe: Fe = undefined;
fiat.setOne(&fe.limbs);
break :one fe;
};
pub fn rejectNonCanonical(s_: [encoded_length]u8, endian: std.builtin.Endian) NonCanonicalError!void {
var s = if (endian == .Little) s_ else orderSwap(s_);
const field_order_s = comptime fos: {
var fos: [encoded_length]u8 = undefined;
mem.writeIntLittle(std.meta.Int(.unsigned, encoded_length * 8), &fos, field_order);
break :fos fos;
};
if (crypto.utils.timingSafeCompare(u8, &s, &field_order_s, .Little) != .lt) {
return error.NonCanonical;
}
}
pub fn orderSwap(s: [encoded_length]u8) [encoded_length]u8 {
var t = s;
for (s) |x, i| t[t.len - 1 - i] = x;
return t;
}
pub fn fromBytes(s_: [encoded_length]u8, endian: std.builtin.Endian) NonCanonicalError!Fe {
var s = if (endian == .Little) s_ else orderSwap(s_);
try rejectNonCanonical(s, .Little);
var limbs_z: NonMontgomeryDomainFieldElement = undefined;
fiat.fromBytes(&limbs_z, s);
var limbs: MontgomeryDomainFieldElement = undefined;
fiat.toMontgomery(&limbs, limbs_z);
return Fe{ .limbs = limbs };
}
pub fn toBytes(fe: Fe, endian: std.builtin.Endian) [encoded_length]u8 {
var limbs_z: NonMontgomeryDomainFieldElement = undefined;
fiat.fromMontgomery(&limbs_z, fe.limbs);
var s: [encoded_length]u8 = undefined;
fiat.toBytes(&s, limbs_z);
return if (endian == .Little) s else orderSwap(s);
}
pub const IntRepr = meta.Int(.unsigned, params.field_bits);
pub fn fromInt(comptime x: IntRepr) NonCanonicalError!Fe {
var s: [encoded_length]u8 = undefined;
mem.writeIntLittle(IntRepr, &s, x);
return fromBytes(s, .Little);
}
pub fn toInt(fe: Fe) IntRepr {
const s = fe.toBytes(.Little);
return mem.readIntLittle(IntRepr, &s);
}
pub fn isZero(fe: Fe) bool {
var z: @TypeOf(fe.limbs[0]) = undefined;
fiat.nonzero(&z, fe.limbs);
return z == 0;
}
pub fn equivalent(a: Fe, b: Fe) bool {
return a.sub(b).isZero();
}
pub fn isOdd(fe: Fe) bool {
const s = fe.toBytes(.Little);
return @truncate(u1, s[0]) != 0;
}
pub fn cMov(fe: *Fe, a: Fe, c: u1) void {
fiat.selectznz(&fe.limbs, c, fe.limbs, a.limbs);
}
pub fn add(a: Fe, b: Fe) Fe {
var fe: Fe = undefined;
fiat.add(&fe.limbs, a.limbs, b.limbs);
return fe;
}
pub fn sub(a: Fe, b: Fe) Fe {
var fe: Fe = undefined;
fiat.sub(&fe.limbs, a.limbs, b.limbs);
return fe;
}
pub fn dbl(a: Fe) Fe {
var fe: Fe = undefined;
fiat.add(&fe.limbs, a.limbs, a.limbs);
return fe;
}
pub fn mul(a: Fe, b: Fe) Fe {
var fe: Fe = undefined;
fiat.mul(&fe.limbs, a.limbs, b.limbs);
return fe;
}
pub fn sq(a: Fe) Fe {
var fe: Fe = undefined;
fiat.square(&fe.limbs, a.limbs);
return fe;
}
fn sqn(a: Fe, comptime n: comptime_int) Fe {
var i: usize = 0;
var fe = a;
while (i < n) : (i += 1) {
fe = fe.sq();
}
return fe;
}
pub fn pow(a: Fe, comptime T: type, comptime n: T) Fe {
var fe = one;
var x: T = n;
var t = a;
while (true) {
if (@truncate(u1, x) != 0) fe = fe.mul(t);
x >>= 1;
if (x == 0) break;
t = t.sq();
}
return fe;
}
pub fn neg(a: Fe) Fe {
var fe: Fe = undefined;
fiat.opp(&fe.limbs, a.limbs);
return fe;
}
pub fn invert(a: Fe) Fe {
const iterations = (49 * field_bits + 57) / 17;
const Limbs = @TypeOf(a.limbs);
const Word = @TypeOf(a.limbs[0]);
const XLimbs = [a.limbs.len + 1]Word;
var d: Word = 1;
var f = comptime blk: {
var f: XLimbs = undefined;
fiat.msat(&f);
break :blk f;
};
var g: XLimbs = undefined;
fiat.fromMontgomery(g[0..a.limbs.len], a.limbs);
g[g.len - 1] = 0;
var r = Fe.one.limbs;
var v = Fe.zero.limbs;
var out1: Word = undefined;
var out2: XLimbs = undefined;
var out3: XLimbs = undefined;
var out4: Limbs = undefined;
var out5: Limbs = undefined;
var i: usize = 0;
while (i < iterations - iterations % 2) : (i += 2) {
fiat.divstep(&out1, &out2, &out3, &out4, &out5, d, f, g, v, r);
fiat.divstep(&d, &f, &g, &v, &r, out1, out2, out3, out4, out5);
}
if (iterations % 2 != 0) {
fiat.divstep(&out1, &out2, &out3, &out4, &out5, d, f, g, v, r);
mem.copy(Word, &v, &out4);
mem.copy(Word, &f, &out2);
}
var v_opp: Limbs = undefined;
fiat.opp(&v_opp, v);
fiat.selectznz(&v, @truncate(u1, f[f.len - 1] >> (@bitSizeOf(Word) - 1)), v, v_opp);
const precomp = blk: {
var precomp: Limbs = undefined;
fiat.divstepPrecomp(&precomp);
break :blk precomp;
};
var fe: Fe = undefined;
fiat.mul(&fe.limbs, v, precomp);
return fe;
}
pub fn isSquare(x2: Fe) bool {
if (field_order == 115792089210356248762697446949407573530086143415290314195533631308867097853951) {
const t110 = x2.mul(x2.sq()).sq();
const t111 = x2.mul(t110);
const t111111 = t111.mul(x2.mul(t110).sqn(3));
const x15 = t111111.sqn(6).mul(t111111).sqn(3).mul(t111);
const x16 = x15.sq().mul(x2);
const x53 = x16.sqn(16).mul(x16).sqn(15);
const x47 = x15.mul(x53);
const ls = x47.mul(((x53.sqn(17).mul(x2)).sqn(143).mul(x47)).sqn(47)).sq().mul(x2);
return ls.equivalent(Fe.one);
} else if (field_order == 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319) {
const t111 = x2.mul(x2.mul(x2.sq()).sq());
const t111111 = t111.mul(t111.sqn(3));
const t1111110 = t111111.sq();
const t1111111 = x2.mul(t1111110);
const x12 = t1111110.sqn(5).mul(t111111);
const x31 = x12.sqn(12).mul(x12).sqn(7).mul(t1111111);
const x32 = x31.sq().mul(x2);
const x63 = x32.sqn(31).mul(x31);
const x126 = x63.sqn(63).mul(x63);
const ls = x126.sqn(126).mul(x126).sqn(3).mul(t111).sqn(33).mul(x32).sqn(95).mul(x31);
return ls.equivalent(Fe.one);
} else {
const ls = x2.pow(std.meta.Int(.unsigned, field_bits), (field_order - 1) / 2);
return ls.equivalent(Fe.one);
}
}
fn uncheckedSqrt(x2: Fe) Fe {
comptime debug.assert(field_order % 4 == 3);
if (field_order == 115792089210356248762697446949407573530086143415290314195533631308867097853951) {
const t11 = x2.mul(x2.sq());
const t1111 = t11.mul(t11.sqn(2));
const t11111111 = t1111.mul(t1111.sqn(4));
const x16 = t11111111.sqn(8).mul(t11111111);
return x16.sqn(16).mul(x16).sqn(32).mul(x2).sqn(96).mul(x2).sqn(94);
} else if (field_order == 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319) {
const t111 = x2.mul(x2.mul(x2.sq()).sq());
const t111111 = t111.mul(t111.sqn(3));
const t1111110 = t111111.sq();
const t1111111 = x2.mul(t1111110);
const x12 = t1111110.sqn(5).mul(t111111);
const x31 = x12.sqn(12).mul(x12).sqn(7).mul(t1111111);
const x32 = x31.sq().mul(x2);
const x63 = x32.sqn(31).mul(x31);
const x126 = x63.sqn(63).mul(x63);
return x126.sqn(126).mul(x126).sqn(3).mul(t111).sqn(33).mul(x32).sqn(64).mul(x2).sqn(30);
} else if (field_order == 115792089237316195423570985008687907853269984665640564039457584007908834671663) {
const t11 = x2.mul(x2.sq());
const t1111 = t11.mul(t11.sqn(2));
const t11111 = x2.mul(t1111.sq());
const t1111111 = t11.mul(t11111.sqn(2));
const x11 = t1111111.sqn(4).mul(t1111);
const x22 = x11.sqn(11).mul(x11);
const x27 = x22.sqn(5).mul(t11111);
const x54 = x27.sqn(27).mul(x27);
const x108 = x54.sqn(54).mul(x54);
return x108.sqn(108).mul(x108).sqn(7).mul(t1111111).sqn(23).mul(x22).sqn(6).mul(t11).sqn(2);
} else {
return x2.pow(std.meta.Int(.unsigned, field_bits), (field_order + 1) / 4);
}
}
pub fn sqrt(x2: Fe) NotSquareError!Fe {
const x = x2.uncheckedSqrt();
if (x.sq().equivalent(x2)) {
return x;
}
return error.NotSquare;
}
};
}