Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions orchard_pallas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
#!/usr/bin/env python3
# -*- coding: utf8 -*-
import sys; assert sys.version_info[0] >= 3, "Python 3 required."

from sapling_jubjub import FieldElement
from sapling_utils import leos2ip

p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001

pm1d2 = 0x2000000000000000000000000000000011234c7e04a67c8dcc96987680000000
assert (p - 1) // 2 == pm1d2

S = 32
T = 0x40000000000000000000000000000000224698fc094cf91b992d30ed
assert (p - 1) == (1 << S) * T

tm1d2 = 0x2000000000000000000000000000000011234c7e04a67c8dcc969876
assert (T - 1) // 2 == tm1d2

// 5^T (mod p)
ROOT_OF_UNITY = 0x2bce74deac30ebda362120830561f81aea322bf2b7bb7584bdad6fabd87ea32f


#
# Field arithmetic
#

class Fp(FieldElement):
@staticmethod
def from_bytes(buf):
return Fp(leos2ip(buf), strict=True)

def __init__(self, s, strict=False):
FieldElement.__init__(self, Fp, s, p, strict=strict)

def __str__(self):
return 'Fp(%s)' % self.s

def sqrt(self):
# Tonelli-Shank's algorithm for p mod 16 = 1
# https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
a = self.exp(pm1d2)
if a == self.ONE:
# z <- c^t
c = Fp(ROOT_OF_UNITY)
# x <- a \omega
x = self.exp(tm1d2 + 1)
# b <- x \omega = a \omega^2
b = self.exp(T)
y = S

# 7: while b != 1 do
while b != self.ONE:
# 8: Find least integer k >= 0 such that b^(2^k) == 1
k = 1
b2k = b * b
while b2k != self.ONE:
b2k = b2k * b2k
k += 1
assert k < y

# 9:
# w <- z^(2^(y-k-1))
for _ in range(0, y - k - 1):
c = c * c
# x <- xw
x = x * c
# z <- w^2
c = c * c
# b <- bz
b = b * c
# y <- k
y = k
assert x * x == self
return x
elif a == self.MINUS_ONE:
return None
return self.ZERO


class Scalar(FieldElement):
def __init__(self, s, strict=False):
FieldElement.__init__(self, Scalar, s, q, strict=strict)

def __str__(self):
return 'Scalar(%s)' % self.s

Fp.ZERO = Fp(0)
Fp.ONE = Fp(1)
Fp.MINUS_ONE = Fp(-1)

assert Fp.ZERO + Fp.ZERO == Fp.ZERO
assert Fp.ZERO + Fp.ONE == Fp.ONE
assert Fp.ONE + Fp.ZERO == Fp.ONE
assert Fp.ZERO - Fp.ONE == Fp.MINUS_ONE
assert Fp.ZERO * Fp.ONE == Fp.ZERO
assert Fp.ONE * Fp.ZERO == Fp.ZERO


#
# Point arithmetic
#

PALLAS_B = Fp(5)

class Point(object):
@staticmethod
def rand(rand):
while True:
data = rand.b(32)
p = Point.from_bytes(data)
if p is not None:
return p

@staticmethod
def from_bytes(buf):
assert len(buf) == 32
if buf == bytes([0]*32):
return Point.identity()

y_sign = buf[31] >> 7
buf = buf[:31] + bytes([buf[31] & 0b01111111])
try:
x = Fp.from_bytes(buf)
except ValueError:
return None

x3 = x * x * x
y2 = x3 + PALLAS_B

y = y2.sqrt()
if y is None:
return None

if y.s % 2 != y_sign:
y = Fp.ZERO - y

return Point(x, y)

def __init__(self, x, y):
self.x = x
self.y = y
self.is_identity = False
Comment on lines +141 to +144
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cool to have an is_on_curve assertion here; y^2 == x^3 + b.


def identity():
p = Point(Fp.ZERO, Fp.ZERO)
p.is_identity = True
return p

def __neg__(self):
if self.is_identity:
return self
else:
return Point(Fp(self.x.s), -Fp(self.y.s))

def __add__(self, a):
if self.is_identity:
return a
elif a.is_identity:
return self
else:
(x1, y1) = (self.x, self.y)
(x2, y2) = (a.x, a.y)

if x1 != x2:
# <https://core.ac.uk/download/pdf/10898289.pdf> section 4.1
λ = (y1 - y2) / (x1 - x2)
x3 = λ*λ - x1 - x2
y3 = λ*(x1 - x3) - y1
return Point(x3, y3)
elif y1 == -y2:
return Point.identity()
else:
return self.double()

def __sub__(self, a):
return (-a) + self

def double(self):
if self.is_identity:
return self

# <https://core.ac.uk/download/pdf/10898289.pdf> section 4.1
λ = (Fp(3) * self.x * self.x) / (self.y + self.y)
x = λ*λ - self.x - self.x
y = λ*(self.x - x) - self.y
return Point(x, y)

def __mul__(self, s):
s = format(s.s, '0256b')
ret = self.ZERO
for c in s:
ret = ret.double()
if int(c):
ret = ret + self
return ret

def __bytes__(self):
if self.is_identity:
return bytes([0] * 32)

buf = bytes(self.x)
if self.y.s % 2 == 1:
buf = buf[:31] + bytes([buf[31] | (1 << 7)])
return buf

def __eq__(self, a):
if a is None:
return False
if not (self.is_identity or a.is_identity):
return self.x == a.x and self.y == a.y
else:
return self.is_identity == a.is_identity

def __str__(self):
if self.is_identity:
return 'Point(identity)'
else:
return 'Point(%s, %s)' % (self.x, self.y)


Point.ZERO = Point.identity()
Point.GENERATOR = Point(Fp.MINUS_ONE, Fp(2))

assert Point.ZERO + Point.ZERO == Point.ZERO
assert Point.GENERATOR - Point.GENERATOR == Point.ZERO
assert Point.GENERATOR + Point.GENERATOR + Point.GENERATOR == Point.GENERATOR * Scalar(3)
assert Point.GENERATOR + Point.GENERATOR - Point.GENERATOR == Point.GENERATOR
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good test is assert (-Point.GENERATOR) == Point.GENERATOR * -Scalar(-1).


assert Point.from_bytes(bytes([0]*32)) == Point.ZERO
assert Point.from_bytes(bytes(Point.GENERATOR)) == Point.GENERATOR
3 changes: 3 additions & 0 deletions sapling_jubjub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __init__(self, t, s, modulus, strict=False):
self.s = s % modulus
self.m = modulus

def __neg__(self):
return self.t(-self.s)

def __add__(self, a):
return self.t(self.s + a.s)

Expand Down