const std = @import("std");
const fmt = std.fmt;
const EncodingError = std.crypto.errors.EncodingError;
const IdentityElementError = std.crypto.errors.IdentityElementError;
const NonCanonicalError = std.crypto.errors.NonCanonicalError;
const WeakPublicKeyError = std.crypto.errors.WeakPublicKeyError;
pub const Ristretto255 = struct {
pub const Curve = @import("edwards25519.zig").Edwards25519;
pub const Fe = Curve.Fe;
pub const scalar = Curve.scalar;
pub const encoded_length: usize = 32;
p: Curve,
fn sqrtRatioM1(u: Fe, v: Fe) struct { ratio_is_square: u32, root: Fe } {
const v3 = v.sq().mul(v);
var x = v3.sq().mul(u).mul(v).pow2523().mul(v3).mul(u);
const vxx = x.sq().mul(v);
const m_root_check = vxx.sub(u);
const p_root_check = vxx.add(u);
const f_root_check = u.mul(Fe.sqrtm1).add(vxx);
const has_m_root = m_root_check.isZero();
const has_p_root = p_root_check.isZero();
const has_f_root = f_root_check.isZero();
const x_sqrtm1 = x.mul(Fe.sqrtm1);
x.cMov(x_sqrtm1, @boolToInt(has_p_root) | @boolToInt(has_f_root));
return .{ .ratio_is_square = @boolToInt(has_m_root) | @boolToInt(has_p_root), .root = x.abs() };
}
fn rejectNonCanonical(s: [encoded_length]u8) NonCanonicalError!void {
if ((s[0] & 1) != 0) {
return error.NonCanonical;
}
try Fe.rejectNonCanonical(s, false);
}
pub inline fn rejectIdentity(p: Ristretto255) IdentityElementError!void {
return p.p.rejectIdentity();
}
pub const basePoint = Ristretto255{ .p = Curve.basePoint };
pub fn fromBytes(s: [encoded_length]u8) (NonCanonicalError || EncodingError)!Ristretto255 {
try rejectNonCanonical(s);
const s_ = Fe.fromBytes(s);
const ss = s_.sq();
const u1_ = Fe.one.sub(ss);
const u1u1 = u1_.sq();
const u2_ = Fe.one.add(ss);
const u2u2 = u2_.sq();
const v = Fe.edwards25519d.mul(u1u1).neg().sub(u2u2);
const v_u2u2 = v.mul(u2u2);
const inv_sqrt = sqrtRatioM1(Fe.one, v_u2u2);
var x = inv_sqrt.root.mul(u2_);
const y = inv_sqrt.root.mul(x).mul(v).mul(u1_);
x = x.mul(s_);
x = x.add(x).abs();
const t = x.mul(y);
if ((1 - inv_sqrt.ratio_is_square) | @boolToInt(t.isNegative()) | @boolToInt(y.isZero()) != 0) {
return error.InvalidEncoding;
}
const p: Curve = .{
.x = x,
.y = y,
.z = Fe.one,
.t = t,
};
return Ristretto255{ .p = p };
}
pub fn toBytes(e: Ristretto255) [encoded_length]u8 {
const p = &e.p;
var u1_ = p.z.add(p.y);
const zmy = p.z.sub(p.y);
u1_ = u1_.mul(zmy);
const u2_ = p.x.mul(p.y);
const u1_u2u2 = u2_.sq().mul(u1_);
const inv_sqrt = sqrtRatioM1(Fe.one, u1_u2u2);
const den1 = inv_sqrt.root.mul(u1_);
const den2 = inv_sqrt.root.mul(u2_);
const z_inv = den1.mul(den2).mul(p.t);
const ix = p.x.mul(Fe.sqrtm1);
const iy = p.y.mul(Fe.sqrtm1);
const eden = den1.mul(Fe.edwards25519sqrtamd);
const t_z_inv = p.t.mul(z_inv);
const rotate = @boolToInt(t_z_inv.isNegative());
var x = p.x;
var y = p.y;
var den_inv = den2;
x.cMov(iy, rotate);
y.cMov(ix, rotate);
den_inv.cMov(eden, rotate);
const x_z_inv = x.mul(z_inv);
const yneg = y.neg();
y.cMov(yneg, @boolToInt(x_z_inv.isNegative()));
return p.z.sub(y).mul(den_inv).abs().toBytes();
}
fn elligator(t: Fe) Curve {
const r = t.sq().mul(Fe.sqrtm1);
const u = r.add(Fe.one).mul(Fe.edwards25519eonemsqd);
var c = comptime Fe.one.neg();
const v = c.sub(r.mul(Fe.edwards25519d)).mul(r.add(Fe.edwards25519d));
const ratio_sqrt = sqrtRatioM1(u, v);
const wasnt_square = 1 - ratio_sqrt.ratio_is_square;
var s = ratio_sqrt.root;
const s_prime = s.mul(t).abs().neg();
s.cMov(s_prime, wasnt_square);
c.cMov(r, wasnt_square);
const n = r.sub(Fe.one).mul(c).mul(Fe.edwards25519sqdmone).sub(v);
const w0 = s.add(s).mul(v);
const w1 = n.mul(Fe.edwards25519sqrtadm1);
const ss = s.sq();
const w2 = Fe.one.sub(ss);
const w3 = Fe.one.add(ss);
return .{ .x = w0.mul(w3), .y = w2.mul(w1), .z = w1.mul(w3), .t = w0.mul(w2) };
}
pub fn fromUniform(h: [64]u8) Ristretto255 {
const p0 = elligator(Fe.fromBytes(h[0..32].*));
const p1 = elligator(Fe.fromBytes(h[32..64].*));
return Ristretto255{ .p = p0.add(p1) };
}
pub inline fn dbl(p: Ristretto255) Ristretto255 {
return .{ .p = p.p.dbl() };
}
pub inline fn add(p: Ristretto255, q: Ristretto255) Ristretto255 {
return .{ .p = p.p.add(q.p) };
}
pub inline fn mul(p: Ristretto255, s: [encoded_length]u8) (IdentityElementError || WeakPublicKeyError)!Ristretto255 {
return Ristretto255{ .p = try p.p.mul(s) };
}
pub fn equivalent(p: Ristretto255, q: Ristretto255) bool {
const p_ = &p.p;
const q_ = &q.p;
const a = p_.x.mul(q_.y).equivalent(p_.y.mul(q_.x));
const b = p_.y.mul(q_.y).equivalent(p_.x.mul(q_.x));
return (@boolToInt(a) | @boolToInt(b)) != 0;
}
};
test "ristretto255" {
const p = Ristretto255.basePoint;
var buf: [256]u8 = undefined;
try std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&p.toBytes())}), "E2F2AE0A6ABC4E71A884A961C500515F58E30B6AA582DD8DB6A65945E08D2D76");
var r: [Ristretto255.encoded_length]u8 = undefined;
_ = try fmt.hexToBytes(r[0..], "6a493210f7499cd17fecb510ae0cea23a110e8d5b901f8acadd3095c73a3b919");
var q = try Ristretto255.fromBytes(r);
q = q.dbl().add(p);
try std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&q.toBytes())}), "E882B131016B52C1D3337080187CF768423EFCCBB517BB495AB812C4160FF44E");
const s = [_]u8{15} ++ [_]u8{0} ** 31;
const w = try p.mul(s);
try std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&w.toBytes())}), "E0C418F7C8D9C4CDD7395B93EA124F3AD99021BB681DFC3302A9D99A2E53E64E");
try std.testing.expect(p.dbl().dbl().dbl().dbl().equivalent(w.add(p)));
const h = [_]u8{69} ** 32 ++ [_]u8{42} ** 32;
const ph = Ristretto255.fromUniform(h);
try std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&ph.toBytes())}), "DCCA54E037A4311EFBEEF413ACD21D35276518970B7A61DC88F8587B493D5E19");
}