added dependencies

This commit is contained in:
Ayush Saini 2023-06-24 18:08:00 +05:30
parent a070124d1d
commit 48f1f4be1a
32 changed files with 16613 additions and 0 deletions

View file

@ -0,0 +1,90 @@
# while we don't use six in this file, we did bundle it for a long time, so
# keep as part of module in a virtual way (through __all__)
import six
from .keys import (
SigningKey,
VerifyingKey,
BadSignatureError,
BadDigestError,
MalformedPointError,
)
from .curves import (
NIST192p,
NIST224p,
NIST256p,
NIST384p,
NIST521p,
SECP256k1,
BRAINPOOLP160r1,
BRAINPOOLP192r1,
BRAINPOOLP224r1,
BRAINPOOLP256r1,
BRAINPOOLP320r1,
BRAINPOOLP384r1,
BRAINPOOLP512r1,
SECP112r1,
SECP112r2,
SECP128r1,
SECP160r1,
Ed25519,
Ed448,
)
from .ecdh import (
ECDH,
NoKeyError,
NoCurveError,
InvalidCurveError,
InvalidSharedSecretError,
)
from .der import UnexpectedDER
from . import _version
# This code comes from http://github.com/tlsfuzzer/python-ecdsa
__all__ = [
"curves",
"der",
"ecdsa",
"ellipticcurve",
"keys",
"numbertheory",
"test_pyecdsa",
"util",
"six",
]
_hush_pyflakes = [
SigningKey,
VerifyingKey,
BadSignatureError,
BadDigestError,
MalformedPointError,
UnexpectedDER,
InvalidCurveError,
NoKeyError,
InvalidSharedSecretError,
ECDH,
NoCurveError,
NIST192p,
NIST224p,
NIST256p,
NIST384p,
NIST521p,
SECP256k1,
BRAINPOOLP160r1,
BRAINPOOLP192r1,
BRAINPOOLP224r1,
BRAINPOOLP256r1,
BRAINPOOLP320r1,
BRAINPOOLP384r1,
BRAINPOOLP512r1,
SECP112r1,
SECP112r2,
SECP128r1,
SECP160r1,
Ed25519,
Ed448,
six.b(""),
]
del _hush_pyflakes
__version__ = _version.get_versions()["version"]

View file

@ -0,0 +1,153 @@
"""
Common functions for providing cross-python version compatibility.
"""
import sys
import re
import binascii
from six import integer_types
def str_idx_as_int(string, index):
"""Take index'th byte from string, return as integer"""
val = string[index]
if isinstance(val, integer_types):
return val
return ord(val)
if sys.version_info < (3, 0): # pragma: no branch
import platform
def normalise_bytes(buffer_object):
"""Cast the input into array of bytes."""
# flake8 runs on py3 where `buffer` indeed doesn't exist...
return buffer(buffer_object) # noqa: F821
def hmac_compat(ret):
return ret
if (
sys.version_info < (2, 7)
or sys.version_info < (2, 7, 4)
or platform.system() == "Java"
): # pragma: no branch
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text)
def compat26_str(val):
return str(val)
def bit_length(val):
if val == 0:
return 0
return len(bin(val)) - 2
else:
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
def compat26_str(val):
return val
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
return val.bit_length()
def b2a_hex(val):
return binascii.b2a_hex(compat26_str(val))
def a2b_hex(val):
try:
return bytearray(binascii.a2b_hex(val))
except Exception as e:
raise ValueError("base16 error: %s" % e)
def bytes_to_int(val, byteorder):
"""Convert bytes to an int."""
if not val:
return 0
if byteorder == "big":
return int(b2a_hex(val), 16)
if byteorder == "little":
return int(b2a_hex(val[::-1]), 16)
raise ValueError("Only 'big' and 'little' endian supported")
def int_to_bytes(val, length=None, byteorder="big"):
"""Return number converted to bytes"""
if length is None:
length = byte_length(val)
if byteorder == "big":
return bytearray(
(val >> i) & 0xFF for i in reversed(range(0, length * 8, 8))
)
if byteorder == "little":
return bytearray(
(val >> i) & 0xFF for i in range(0, length * 8, 8)
)
raise ValueError("Only 'big' or 'little' endian supported")
else:
if sys.version_info < (3, 4): # pragma: no branch
# on python 3.3 hmac.hmac.update() accepts only bytes, on newer
# versions it does accept memoryview() also
def hmac_compat(data):
if not isinstance(data, bytes): # pragma: no branch
return bytes(data)
return data
def normalise_bytes(buffer_object):
"""Cast the input into array of bytes."""
if not buffer_object:
return b""
return memoryview(buffer_object).cast("B")
else:
def hmac_compat(data):
return data
def normalise_bytes(buffer_object):
"""Cast the input into array of bytes."""
return memoryview(buffer_object).cast("B")
def compat26_str(val):
return val
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
def a2b_hex(val):
try:
return bytearray(binascii.a2b_hex(bytearray(val, "ascii")))
except Exception as e:
raise ValueError("base16 error: %s" % e)
# pylint: disable=invalid-name
# pylint is stupid here and doesn't notice it's a function, not
# constant
bytes_to_int = int.from_bytes
# pylint: enable=invalid-name
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
return val.bit_length()
def int_to_bytes(val, length=None, byteorder="big"):
"""Convert integer to bytes."""
if length is None:
length = byte_length(val)
# for gmpy we need to convert back to native int
if type(val) != int:
val = int(val)
return bytearray(val.to_bytes(length=length, byteorder=byteorder))
def byte_length(val):
"""Return number of bytes necessary to represent an integer."""
length = bit_length(val)
return (length + 7) // 8

View file

@ -0,0 +1,86 @@
# Copyright Mateusz Kobos, (c) 2011
# https://code.activestate.com/recipes/577803-reader-writer-lock-with-priority-for-writers/
# released under the MIT licence
import threading
__author__ = "Mateusz Kobos"
class RWLock:
"""
Read-Write locking primitive
Synchronization object used in a solution of so-called second
readers-writers problem. In this problem, many readers can simultaneously
access a share, and a writer has an exclusive access to this share.
Additionally, the following constraints should be met:
1) no reader should be kept waiting if the share is currently opened for
reading unless a writer is also waiting for the share,
2) no writer should be kept waiting for the share longer than absolutely
necessary.
The implementation is based on [1, secs. 4.2.2, 4.2.6, 4.2.7]
with a modification -- adding an additional lock (C{self.__readers_queue})
-- in accordance with [2].
Sources:
[1] A.B. Downey: "The little book of semaphores", Version 2.1.5, 2008
[2] P.J. Courtois, F. Heymans, D.L. Parnas:
"Concurrent Control with 'Readers' and 'Writers'",
Communications of the ACM, 1971 (via [3])
[3] http://en.wikipedia.org/wiki/Readers-writers_problem
"""
def __init__(self):
"""
A lock giving an even higher priority to the writer in certain
cases (see [2] for a discussion).
"""
self.__read_switch = _LightSwitch()
self.__write_switch = _LightSwitch()
self.__no_readers = threading.Lock()
self.__no_writers = threading.Lock()
self.__readers_queue = threading.Lock()
def reader_acquire(self):
self.__readers_queue.acquire()
self.__no_readers.acquire()
self.__read_switch.acquire(self.__no_writers)
self.__no_readers.release()
self.__readers_queue.release()
def reader_release(self):
self.__read_switch.release(self.__no_writers)
def writer_acquire(self):
self.__write_switch.acquire(self.__no_readers)
self.__no_writers.acquire()
def writer_release(self):
self.__no_writers.release()
self.__write_switch.release(self.__no_readers)
class _LightSwitch:
"""An auxiliary "light switch"-like object. The first thread turns on the
"switch", the last one turns it off (see [1, sec. 4.2.2] for details)."""
def __init__(self):
self.__counter = 0
self.__mutex = threading.Lock()
def acquire(self, lock):
self.__mutex.acquire()
self.__counter += 1
if self.__counter == 1:
lock.acquire()
self.__mutex.release()
def release(self, lock):
self.__mutex.acquire()
self.__counter -= 1
if self.__counter == 0:
lock.release()
self.__mutex.release()

View file

@ -0,0 +1,181 @@
"""
Implementation of the SHAKE-256 algorithm for Ed448
"""
try:
import hashlib
hashlib.new("shake256").digest(64)
def shake_256(msg, outlen):
return hashlib.new("shake256", msg).digest(outlen)
except (TypeError, ValueError):
from ._compat import bytes_to_int, int_to_bytes
# From little endian.
def _from_le(s):
return bytes_to_int(s, byteorder="little")
# Rotate a word x by b places to the left.
def _rol(x, b):
return ((x << b) | (x >> (64 - b))) & (2**64 - 1)
# Do the SHA-3 state transform on state s.
def _sha3_transform(s):
ROTATIONS = [
0,
1,
62,
28,
27,
36,
44,
6,
55,
20,
3,
10,
43,
25,
39,
41,
45,
15,
21,
8,
18,
2,
61,
56,
14,
]
PERMUTATION = [
1,
6,
9,
22,
14,
20,
2,
12,
13,
19,
23,
15,
4,
24,
21,
8,
16,
5,
3,
18,
17,
11,
7,
10,
]
RC = [
0x0000000000000001,
0x0000000000008082,
0x800000000000808A,
0x8000000080008000,
0x000000000000808B,
0x0000000080000001,
0x8000000080008081,
0x8000000000008009,
0x000000000000008A,
0x0000000000000088,
0x0000000080008009,
0x000000008000000A,
0x000000008000808B,
0x800000000000008B,
0x8000000000008089,
0x8000000000008003,
0x8000000000008002,
0x8000000000000080,
0x000000000000800A,
0x800000008000000A,
0x8000000080008081,
0x8000000000008080,
0x0000000080000001,
0x8000000080008008,
]
for rnd in range(0, 24):
# AddColumnParity (Theta)
c = [0] * 5
d = [0] * 5
for i in range(0, 25):
c[i % 5] ^= s[i]
for i in range(0, 5):
d[i] = c[(i + 4) % 5] ^ _rol(c[(i + 1) % 5], 1)
for i in range(0, 25):
s[i] ^= d[i % 5]
# RotateWords (Rho)
for i in range(0, 25):
s[i] = _rol(s[i], ROTATIONS[i])
# PermuteWords (Pi)
t = s[PERMUTATION[0]]
for i in range(0, len(PERMUTATION) - 1):
s[PERMUTATION[i]] = s[PERMUTATION[i + 1]]
s[PERMUTATION[-1]] = t
# NonlinearMixRows (Chi)
for i in range(0, 25, 5):
t = [
s[i],
s[i + 1],
s[i + 2],
s[i + 3],
s[i + 4],
s[i],
s[i + 1],
]
for j in range(0, 5):
s[i + j] = t[j] ^ ((~t[j + 1]) & (t[j + 2]))
# AddRoundConstant (Iota)
s[0] ^= RC[rnd]
# Reinterpret octet array b to word array and XOR it to state s.
def _reinterpret_to_words_and_xor(s, b):
for j in range(0, len(b) // 8):
s[j] ^= _from_le(b[8 * j : 8 * j + 8])
# Reinterpret word array w to octet array and return it.
def _reinterpret_to_octets(w):
mp = bytearray()
for j in range(0, len(w)):
mp += int_to_bytes(w[j], 8, byteorder="little")
return mp
def _sha3_raw(msg, r_w, o_p, e_b):
"""Semi-generic SHA-3 implementation"""
r_b = 8 * r_w
s = [0] * 25
# Handle whole blocks.
idx = 0
blocks = len(msg) // r_b
for i in range(0, blocks):
_reinterpret_to_words_and_xor(s, msg[idx : idx + r_b])
idx += r_b
_sha3_transform(s)
# Handle last block padding.
m = bytearray(msg[idx:])
m.append(o_p)
while len(m) < r_b:
m.append(0)
m[len(m) - 1] |= 128
# Handle padded last block.
_reinterpret_to_words_and_xor(s, m)
_sha3_transform(s)
# Output.
out = bytearray()
while len(out) < e_b:
out += _reinterpret_to_octets(s[:r_w])
_sha3_transform(s)
return out[:e_b]
def shake_256(msg, outlen):
return _sha3_raw(msg, 17, 31, outlen)

View file

@ -0,0 +1,21 @@
# This file was generated by 'versioneer.py' (0.21) from
# revision-control system data, or from the parent directory name of an
# unpacked source archive. Distribution tarballs contain a pre-generated copy
# of this file.
import json
version_json = '''
{
"date": "2022-07-09T14:49:17+0200",
"dirty": false,
"error": null,
"full-revisionid": "341e0d8be9fedf66fbc9a95630b4ed2138343380",
"version": "0.18.0"
}
''' # END VERSION_JSON
def get_versions():
return json.loads(version_json)

View file

@ -0,0 +1,513 @@
from __future__ import division
from six import PY2
from . import der, ecdsa, ellipticcurve, eddsa
from .util import orderlen, number_to_string, string_to_number
from ._compat import normalise_bytes, bit_length
# orderlen was defined in this module previously, so keep it in __all__,
# will need to mark it as deprecated later
__all__ = [
"UnknownCurveError",
"orderlen",
"Curve",
"SECP112r1",
"SECP112r2",
"SECP128r1",
"SECP160r1",
"NIST192p",
"NIST224p",
"NIST256p",
"NIST384p",
"NIST521p",
"curves",
"find_curve",
"curve_by_name",
"SECP256k1",
"BRAINPOOLP160r1",
"BRAINPOOLP192r1",
"BRAINPOOLP224r1",
"BRAINPOOLP256r1",
"BRAINPOOLP320r1",
"BRAINPOOLP384r1",
"BRAINPOOLP512r1",
"PRIME_FIELD_OID",
"CHARACTERISTIC_TWO_FIELD_OID",
"Ed25519",
"Ed448",
]
PRIME_FIELD_OID = (1, 2, 840, 10045, 1, 1)
CHARACTERISTIC_TWO_FIELD_OID = (1, 2, 840, 10045, 1, 2)
class UnknownCurveError(Exception):
pass
class Curve:
def __init__(self, name, curve, generator, oid, openssl_name=None):
self.name = name
self.openssl_name = openssl_name # maybe None
self.curve = curve
self.generator = generator
self.order = generator.order()
if isinstance(curve, ellipticcurve.CurveEdTw):
# EdDSA keys are special in that both private and public
# are the same size (as it's defined only with compressed points)
# +1 for the sign bit and then round up
self.baselen = (bit_length(curve.p()) + 1 + 7) // 8
self.verifying_key_length = self.baselen
else:
self.baselen = orderlen(self.order)
self.verifying_key_length = 2 * orderlen(curve.p())
self.signature_length = 2 * self.baselen
self.oid = oid
if oid:
self.encoded_oid = der.encode_oid(*oid)
def __eq__(self, other):
if isinstance(other, Curve):
return (
self.curve == other.curve and self.generator == other.generator
)
return NotImplemented
def __ne__(self, other):
return not self == other
def __repr__(self):
return self.name
def to_der(self, encoding=None, point_encoding="uncompressed"):
"""Serialise the curve parameters to binary string.
:param str encoding: the format to save the curve parameters in.
Default is ``named_curve``, with fallback being the ``explicit``
if the OID is not set for the curve.
:param str point_encoding: the point encoding of the generator when
explicit curve encoding is used. Ignored for ``named_curve``
format.
:return: DER encoded ECParameters structure
:rtype: bytes
"""
if encoding is None:
if self.oid:
encoding = "named_curve"
else:
encoding = "explicit"
if encoding not in ("named_curve", "explicit"):
raise ValueError(
"Only 'named_curve' and 'explicit' encodings supported"
)
if encoding == "named_curve":
if not self.oid:
raise UnknownCurveError(
"Can't encode curve using named_curve encoding without "
"associated curve OID"
)
return der.encode_oid(*self.oid)
elif isinstance(self.curve, ellipticcurve.CurveEdTw):
assert encoding == "explicit"
raise UnknownCurveError(
"Twisted Edwards curves don't support explicit encoding"
)
# encode the ECParameters sequence
curve_p = self.curve.p()
version = der.encode_integer(1)
field_id = der.encode_sequence(
der.encode_oid(*PRIME_FIELD_OID), der.encode_integer(curve_p)
)
curve = der.encode_sequence(
der.encode_octet_string(
number_to_string(self.curve.a() % curve_p, curve_p)
),
der.encode_octet_string(
number_to_string(self.curve.b() % curve_p, curve_p)
),
)
base = der.encode_octet_string(self.generator.to_bytes(point_encoding))
order = der.encode_integer(self.generator.order())
seq_elements = [version, field_id, curve, base, order]
if self.curve.cofactor():
cofactor = der.encode_integer(self.curve.cofactor())
seq_elements.append(cofactor)
return der.encode_sequence(*seq_elements)
def to_pem(self, encoding=None, point_encoding="uncompressed"):
"""
Serialise the curve parameters to the :term:`PEM` format.
:param str encoding: the format to save the curve parameters in.
Default is ``named_curve``, with fallback being the ``explicit``
if the OID is not set for the curve.
:param str point_encoding: the point encoding of the generator when
explicit curve encoding is used. Ignored for ``named_curve``
format.
:return: PEM encoded ECParameters structure
:rtype: str
"""
return der.topem(
self.to_der(encoding, point_encoding), "EC PARAMETERS"
)
@staticmethod
def from_der(data, valid_encodings=None):
"""Decode the curve parameters from DER file.
:param data: the binary string to decode the parameters from
:type data: :term:`bytes-like object`
:param valid_encodings: set of names of allowed encodings, by default
all (set by passing ``None``), supported ones are ``named_curve``
and ``explicit``
:type valid_encodings: :term:`set-like object`
"""
if not valid_encodings:
valid_encodings = set(("named_curve", "explicit"))
if not all(i in ["named_curve", "explicit"] for i in valid_encodings):
raise ValueError(
"Only named_curve and explicit encodings supported"
)
data = normalise_bytes(data)
if not der.is_sequence(data):
if "named_curve" not in valid_encodings:
raise der.UnexpectedDER(
"named_curve curve parameters not allowed"
)
oid, empty = der.remove_object(data)
if empty:
raise der.UnexpectedDER("Unexpected data after OID")
return find_curve(oid)
if "explicit" not in valid_encodings:
raise der.UnexpectedDER("explicit curve parameters not allowed")
seq, empty = der.remove_sequence(data)
if empty:
raise der.UnexpectedDER(
"Unexpected data after ECParameters structure"
)
# decode the ECParameters sequence
version, rest = der.remove_integer(seq)
if version != 1:
raise der.UnexpectedDER("Unknown parameter encoding format")
field_id, rest = der.remove_sequence(rest)
curve, rest = der.remove_sequence(rest)
base_bytes, rest = der.remove_octet_string(rest)
order, rest = der.remove_integer(rest)
cofactor = None
if rest:
# the ASN.1 specification of ECParameters allows for future
# extensions of the sequence, so ignore the remaining bytes
cofactor, _ = der.remove_integer(rest)
# decode the ECParameters.fieldID sequence
field_type, rest = der.remove_object(field_id)
if field_type == CHARACTERISTIC_TWO_FIELD_OID:
raise UnknownCurveError("Characteristic 2 curves unsupported")
if field_type != PRIME_FIELD_OID:
raise UnknownCurveError(
"Unknown field type: {0}".format(field_type)
)
prime, empty = der.remove_integer(rest)
if empty:
raise der.UnexpectedDER(
"Unexpected data after ECParameters.fieldID.Prime-p element"
)
# decode the ECParameters.curve sequence
curve_a_bytes, rest = der.remove_octet_string(curve)
curve_b_bytes, rest = der.remove_octet_string(rest)
# seed can be defined here, but we don't parse it, so ignore `rest`
curve_a = string_to_number(curve_a_bytes)
curve_b = string_to_number(curve_b_bytes)
curve_fp = ellipticcurve.CurveFp(prime, curve_a, curve_b, cofactor)
# decode the ECParameters.base point
base = ellipticcurve.PointJacobi.from_bytes(
curve_fp,
base_bytes,
valid_encodings=("uncompressed", "compressed", "hybrid"),
order=order,
generator=True,
)
tmp_curve = Curve("unknown", curve_fp, base, None)
# if the curve matches one of the well-known ones, use the well-known
# one in preference, as it will have the OID and name associated
for i in curves:
if tmp_curve == i:
return i
return tmp_curve
@classmethod
def from_pem(cls, string, valid_encodings=None):
"""Decode the curve parameters from PEM file.
:param str string: the text string to decode the parameters from
:param valid_encodings: set of names of allowed encodings, by default
all (set by passing ``None``), supported ones are ``named_curve``
and ``explicit``
:type valid_encodings: :term:`set-like object`
"""
if not PY2 and isinstance(string, str): # pragma: no branch
string = string.encode()
ec_param_index = string.find(b"-----BEGIN EC PARAMETERS-----")
if ec_param_index == -1:
raise der.UnexpectedDER("EC PARAMETERS PEM header not found")
return cls.from_der(
der.unpem(string[ec_param_index:]), valid_encodings
)
# the SEC curves
SECP112r1 = Curve(
"SECP112r1",
ecdsa.curve_112r1,
ecdsa.generator_112r1,
(1, 3, 132, 0, 6),
"secp112r1",
)
SECP112r2 = Curve(
"SECP112r2",
ecdsa.curve_112r2,
ecdsa.generator_112r2,
(1, 3, 132, 0, 7),
"secp112r2",
)
SECP128r1 = Curve(
"SECP128r1",
ecdsa.curve_128r1,
ecdsa.generator_128r1,
(1, 3, 132, 0, 28),
"secp128r1",
)
SECP160r1 = Curve(
"SECP160r1",
ecdsa.curve_160r1,
ecdsa.generator_160r1,
(1, 3, 132, 0, 8),
"secp160r1",
)
# the NIST curves
NIST192p = Curve(
"NIST192p",
ecdsa.curve_192,
ecdsa.generator_192,
(1, 2, 840, 10045, 3, 1, 1),
"prime192v1",
)
NIST224p = Curve(
"NIST224p",
ecdsa.curve_224,
ecdsa.generator_224,
(1, 3, 132, 0, 33),
"secp224r1",
)
NIST256p = Curve(
"NIST256p",
ecdsa.curve_256,
ecdsa.generator_256,
(1, 2, 840, 10045, 3, 1, 7),
"prime256v1",
)
NIST384p = Curve(
"NIST384p",
ecdsa.curve_384,
ecdsa.generator_384,
(1, 3, 132, 0, 34),
"secp384r1",
)
NIST521p = Curve(
"NIST521p",
ecdsa.curve_521,
ecdsa.generator_521,
(1, 3, 132, 0, 35),
"secp521r1",
)
SECP256k1 = Curve(
"SECP256k1",
ecdsa.curve_secp256k1,
ecdsa.generator_secp256k1,
(1, 3, 132, 0, 10),
"secp256k1",
)
BRAINPOOLP160r1 = Curve(
"BRAINPOOLP160r1",
ecdsa.curve_brainpoolp160r1,
ecdsa.generator_brainpoolp160r1,
(1, 3, 36, 3, 3, 2, 8, 1, 1, 1),
"brainpoolP160r1",
)
BRAINPOOLP192r1 = Curve(
"BRAINPOOLP192r1",
ecdsa.curve_brainpoolp192r1,
ecdsa.generator_brainpoolp192r1,
(1, 3, 36, 3, 3, 2, 8, 1, 1, 3),
"brainpoolP192r1",
)
BRAINPOOLP224r1 = Curve(
"BRAINPOOLP224r1",
ecdsa.curve_brainpoolp224r1,
ecdsa.generator_brainpoolp224r1,
(1, 3, 36, 3, 3, 2, 8, 1, 1, 5),
"brainpoolP224r1",
)
BRAINPOOLP256r1 = Curve(
"BRAINPOOLP256r1",
ecdsa.curve_brainpoolp256r1,
ecdsa.generator_brainpoolp256r1,
(1, 3, 36, 3, 3, 2, 8, 1, 1, 7),
"brainpoolP256r1",
)
BRAINPOOLP320r1 = Curve(
"BRAINPOOLP320r1",
ecdsa.curve_brainpoolp320r1,
ecdsa.generator_brainpoolp320r1,
(1, 3, 36, 3, 3, 2, 8, 1, 1, 9),
"brainpoolP320r1",
)
BRAINPOOLP384r1 = Curve(
"BRAINPOOLP384r1",
ecdsa.curve_brainpoolp384r1,
ecdsa.generator_brainpoolp384r1,
(1, 3, 36, 3, 3, 2, 8, 1, 1, 11),
"brainpoolP384r1",
)
BRAINPOOLP512r1 = Curve(
"BRAINPOOLP512r1",
ecdsa.curve_brainpoolp512r1,
ecdsa.generator_brainpoolp512r1,
(1, 3, 36, 3, 3, 2, 8, 1, 1, 13),
"brainpoolP512r1",
)
Ed25519 = Curve(
"Ed25519",
eddsa.curve_ed25519,
eddsa.generator_ed25519,
(1, 3, 101, 112),
)
Ed448 = Curve(
"Ed448",
eddsa.curve_ed448,
eddsa.generator_ed448,
(1, 3, 101, 113),
)
# no order in particular, but keep previously added curves first
curves = [
NIST192p,
NIST224p,
NIST256p,
NIST384p,
NIST521p,
SECP256k1,
BRAINPOOLP160r1,
BRAINPOOLP192r1,
BRAINPOOLP224r1,
BRAINPOOLP256r1,
BRAINPOOLP320r1,
BRAINPOOLP384r1,
BRAINPOOLP512r1,
SECP112r1,
SECP112r2,
SECP128r1,
SECP160r1,
Ed25519,
Ed448,
]
def find_curve(oid_curve):
"""Select a curve based on its OID
:param tuple[int,...] oid_curve: ASN.1 Object Identifier of the
curve to return, like ``(1, 2, 840, 10045, 3, 1, 7)`` for ``NIST256p``.
:raises UnknownCurveError: When the oid doesn't match any of the supported
curves
:rtype: ~ecdsa.curves.Curve
"""
for c in curves:
if c.oid == oid_curve:
return c
raise UnknownCurveError(
"I don't know about the curve with oid %s."
"I only know about these: %s" % (oid_curve, [c.name for c in curves])
)
def curve_by_name(name):
"""Select a curve based on its name.
Returns a :py:class:`~ecdsa.curves.Curve` object with a ``name`` name.
Note that ``name`` is case-sensitve.
:param str name: Name of the curve to return, like ``NIST256p`` or
``prime256v1``
:raises UnknownCurveError: When the name doesn't match any of the supported
curves
:rtype: ~ecdsa.curves.Curve
"""
for c in curves:
if name == c.name or (c.openssl_name and name == c.openssl_name):
return c
raise UnknownCurveError(
"Curve with name {0!r} unknown, only curves supported: {1}".format(
name, [c.name for c in curves]
)
)

View file

@ -0,0 +1,409 @@
from __future__ import division
import binascii
import base64
import warnings
from itertools import chain
from six import int2byte, b, text_type
from ._compat import str_idx_as_int
class UnexpectedDER(Exception):
pass
def encode_constructed(tag, value):
return int2byte(0xA0 + tag) + encode_length(len(value)) + value
def encode_integer(r):
assert r >= 0 # can't support negative numbers yet
h = ("%x" % r).encode()
if len(h) % 2:
h = b("0") + h
s = binascii.unhexlify(h)
num = str_idx_as_int(s, 0)
if num <= 0x7F:
return b("\x02") + encode_length(len(s)) + s
else:
# DER integers are two's complement, so if the first byte is
# 0x80-0xff then we need an extra 0x00 byte to prevent it from
# looking negative.
return b("\x02") + encode_length(len(s) + 1) + b("\x00") + s
# sentry object to check if an argument was specified (used to detect
# deprecated calling convention)
_sentry = object()
def encode_bitstring(s, unused=_sentry):
"""
Encode a binary string as a BIT STRING using :term:`DER` encoding.
Note, because there is no native Python object that can encode an actual
bit string, this function only accepts byte strings as the `s` argument.
The byte string is the actual bit string that will be encoded, padded
on the right (least significant bits, looking from big endian perspective)
to the first full byte. If the bit string has a bit length that is multiple
of 8, then the padding should not be included. For correct DER encoding
the padding bits MUST be set to 0.
Number of bits of padding need to be provided as the `unused` parameter.
In case they are specified as None, it means the number of unused bits
is already encoded in the string as the first byte.
The deprecated call convention specifies just the `s` parameters and
encodes the number of unused bits as first parameter (same convention
as with None).
Empty string must be encoded with `unused` specified as 0.
Future version of python-ecdsa will make specifying the `unused` argument
mandatory.
:param s: bytes to encode
:type s: bytes like object
:param unused: number of bits at the end of `s` that are unused, must be
between 0 and 7 (inclusive)
:type unused: int or None
:raises ValueError: when `unused` is too large or too small
:return: `s` encoded using DER
:rtype: bytes
"""
encoded_unused = b""
len_extra = 0
if unused is _sentry:
warnings.warn(
"Legacy call convention used, unused= needs to be specified",
DeprecationWarning,
)
elif unused is not None:
if not 0 <= unused <= 7:
raise ValueError("unused must be integer between 0 and 7")
if unused:
if not s:
raise ValueError("unused is non-zero but s is empty")
last = str_idx_as_int(s, -1)
if last & (2**unused - 1):
raise ValueError("unused bits must be zeros in DER")
encoded_unused = int2byte(unused)
len_extra = 1
return b("\x03") + encode_length(len(s) + len_extra) + encoded_unused + s
def encode_octet_string(s):
return b("\x04") + encode_length(len(s)) + s
def encode_oid(first, second, *pieces):
assert 0 <= first < 2 and 0 <= second <= 39 or first == 2 and 0 <= second
body = b"".join(
chain(
[encode_number(40 * first + second)],
(encode_number(p) for p in pieces),
)
)
return b"\x06" + encode_length(len(body)) + body
def encode_sequence(*encoded_pieces):
total_len = sum([len(p) for p in encoded_pieces])
return b("\x30") + encode_length(total_len) + b("").join(encoded_pieces)
def encode_number(n):
b128_digits = []
while n:
b128_digits.insert(0, (n & 0x7F) | 0x80)
n = n >> 7
if not b128_digits:
b128_digits.append(0)
b128_digits[-1] &= 0x7F
return b("").join([int2byte(d) for d in b128_digits])
def is_sequence(string):
return string and string[:1] == b"\x30"
def remove_constructed(string):
s0 = str_idx_as_int(string, 0)
if (s0 & 0xE0) != 0xA0:
raise UnexpectedDER(
"wanted type 'constructed tag' (0xa0-0xbf), got 0x%02x" % s0
)
tag = s0 & 0x1F
length, llen = read_length(string[1:])
body = string[1 + llen : 1 + llen + length]
rest = string[1 + llen + length :]
return tag, body, rest
def remove_sequence(string):
if not string:
raise UnexpectedDER("Empty string does not encode a sequence")
if string[:1] != b"\x30":
n = str_idx_as_int(string, 0)
raise UnexpectedDER("wanted type 'sequence' (0x30), got 0x%02x" % n)
length, lengthlength = read_length(string[1:])
if length > len(string) - 1 - lengthlength:
raise UnexpectedDER("Length longer than the provided buffer")
endseq = 1 + lengthlength + length
return string[1 + lengthlength : endseq], string[endseq:]
def remove_octet_string(string):
if string[:1] != b"\x04":
n = str_idx_as_int(string, 0)
raise UnexpectedDER("wanted type 'octetstring' (0x04), got 0x%02x" % n)
length, llen = read_length(string[1:])
body = string[1 + llen : 1 + llen + length]
rest = string[1 + llen + length :]
return body, rest
def remove_object(string):
if not string:
raise UnexpectedDER(
"Empty string does not encode an object identifier"
)
if string[:1] != b"\x06":
n = str_idx_as_int(string, 0)
raise UnexpectedDER("wanted type 'object' (0x06), got 0x%02x" % n)
length, lengthlength = read_length(string[1:])
body = string[1 + lengthlength : 1 + lengthlength + length]
rest = string[1 + lengthlength + length :]
if not body:
raise UnexpectedDER("Empty object identifier")
if len(body) != length:
raise UnexpectedDER(
"Length of object identifier longer than the provided buffer"
)
numbers = []
while body:
n, ll = read_number(body)
numbers.append(n)
body = body[ll:]
n0 = numbers.pop(0)
if n0 < 80:
first = n0 // 40
else:
first = 2
second = n0 - (40 * first)
numbers.insert(0, first)
numbers.insert(1, second)
return tuple(numbers), rest
def remove_integer(string):
if not string:
raise UnexpectedDER(
"Empty string is an invalid encoding of an integer"
)
if string[:1] != b"\x02":
n = str_idx_as_int(string, 0)
raise UnexpectedDER("wanted type 'integer' (0x02), got 0x%02x" % n)
length, llen = read_length(string[1:])
if length > len(string) - 1 - llen:
raise UnexpectedDER("Length longer than provided buffer")
if length == 0:
raise UnexpectedDER("0-byte long encoding of integer")
numberbytes = string[1 + llen : 1 + llen + length]
rest = string[1 + llen + length :]
msb = str_idx_as_int(numberbytes, 0)
if not msb < 0x80:
raise UnexpectedDER("Negative integers are not supported")
# check if the encoding is the minimal one (DER requirement)
if length > 1 and not msb:
# leading zero byte is allowed if the integer would have been
# considered a negative number otherwise
smsb = str_idx_as_int(numberbytes, 1)
if smsb < 0x80:
raise UnexpectedDER(
"Invalid encoding of integer, unnecessary "
"zero padding bytes"
)
return int(binascii.hexlify(numberbytes), 16), rest
def read_number(string):
number = 0
llen = 0
if str_idx_as_int(string, 0) == 0x80:
raise UnexpectedDER("Non minimal encoding of OID subidentifier")
# base-128 big endian, with most significant bit set in all but the last
# byte
while True:
if llen >= len(string):
raise UnexpectedDER("ran out of length bytes")
number = number << 7
d = str_idx_as_int(string, llen)
number += d & 0x7F
llen += 1
if not d & 0x80:
break
return number, llen
def encode_length(l):
assert l >= 0
if l < 0x80:
return int2byte(l)
s = ("%x" % l).encode()
if len(s) % 2:
s = b("0") + s
s = binascii.unhexlify(s)
llen = len(s)
return int2byte(0x80 | llen) + s
def read_length(string):
if not string:
raise UnexpectedDER("Empty string can't encode valid length value")
num = str_idx_as_int(string, 0)
if not (num & 0x80):
# short form
return (num & 0x7F), 1
# else long-form: b0&0x7f is number of additional base256 length bytes,
# big-endian
llen = num & 0x7F
if not llen:
raise UnexpectedDER("Invalid length encoding, length of length is 0")
if llen > len(string) - 1:
raise UnexpectedDER("Length of length longer than provided buffer")
# verify that the encoding is minimal possible (DER requirement)
msb = str_idx_as_int(string, 1)
if not msb or llen == 1 and msb < 0x80:
raise UnexpectedDER("Not minimal encoding of length")
return int(binascii.hexlify(string[1 : 1 + llen]), 16), 1 + llen
def remove_bitstring(string, expect_unused=_sentry):
"""
Remove a BIT STRING object from `string` following :term:`DER`.
The `expect_unused` can be used to specify if the bit string should
have the amount of unused bits decoded or not. If it's an integer, any
read BIT STRING that has number of unused bits different from specified
value will cause UnexpectedDER exception to be raised (this is especially
useful when decoding BIT STRINGS that have DER encoded object in them;
DER encoding is byte oriented, so the unused bits will always equal 0).
If the `expect_unused` is specified as None, the first element returned
will be a tuple, with the first value being the extracted bit string
while the second value will be the decoded number of unused bits.
If the `expect_unused` is unspecified, the decoding of byte with
number of unused bits will not be attempted and the bit string will be
returned as-is, the callee will be required to decode it and verify its
correctness.
Future version of python will require the `expected_unused` parameter
to be specified.
:param string: string of bytes to extract the BIT STRING from
:type string: bytes like object
:param expect_unused: number of bits that should be unused in the BIT
STRING, or None, to return it to caller
:type expect_unused: int or None
:raises UnexpectedDER: when the encoding does not follow DER.
:return: a tuple with first element being the extracted bit string and
the second being the remaining bytes in the string (if any); if the
`expect_unused` is specified as None, the first element of the returned
tuple will be a tuple itself, with first element being the bit string
as bytes and the second element being the number of unused bits at the
end of the byte array as an integer
:rtype: tuple
"""
if not string:
raise UnexpectedDER("Empty string does not encode a bitstring")
if expect_unused is _sentry:
warnings.warn(
"Legacy call convention used, expect_unused= needs to be"
" specified",
DeprecationWarning,
)
num = str_idx_as_int(string, 0)
if string[:1] != b"\x03":
raise UnexpectedDER("wanted bitstring (0x03), got 0x%02x" % num)
length, llen = read_length(string[1:])
if not length:
raise UnexpectedDER("Invalid length of bit string, can't be 0")
body = string[1 + llen : 1 + llen + length]
rest = string[1 + llen + length :]
if expect_unused is not _sentry:
unused = str_idx_as_int(body, 0)
if not 0 <= unused <= 7:
raise UnexpectedDER("Invalid encoding of unused bits")
if expect_unused is not None and expect_unused != unused:
raise UnexpectedDER("Unexpected number of unused bits")
body = body[1:]
if unused:
if not body:
raise UnexpectedDER("Invalid encoding of empty bit string")
last = str_idx_as_int(body, -1)
# verify that all the unused bits are set to zero (DER requirement)
if last & (2**unused - 1):
raise UnexpectedDER("Non zero padding bits in bit string")
if expect_unused is None:
body = (body, unused)
return body, rest
# SEQUENCE([1, STRING(secexp), cont[0], OBJECT(curvename), cont[1], BINTSTRING)
# signatures: (from RFC3279)
# ansi-X9-62 OBJECT IDENTIFIER ::= {
# iso(1) member-body(2) us(840) 10045 }
#
# id-ecSigType OBJECT IDENTIFIER ::= {
# ansi-X9-62 signatures(4) }
# ecdsa-with-SHA1 OBJECT IDENTIFIER ::= {
# id-ecSigType 1 }
# so 1,2,840,10045,4,1
# so 0x42, .. ..
# Ecdsa-Sig-Value ::= SEQUENCE {
# r INTEGER,
# s INTEGER }
# id-public-key-type OBJECT IDENTIFIER ::= { ansi-X9.62 2 }
#
# id-ecPublicKey OBJECT IDENTIFIER ::= { id-publicKeyType 1 }
# I think the secp224r1 identifier is (t=06,l=05,v=2b81040021)
# secp224r1 OBJECT IDENTIFIER ::= {
# iso(1) identified-organization(3) certicom(132) curve(0) 33 }
# and the secp384r1 is (t=06,l=05,v=2b81040022)
# secp384r1 OBJECT IDENTIFIER ::= {
# iso(1) identified-organization(3) certicom(132) curve(0) 34 }
def unpem(pem):
if isinstance(pem, text_type): # pragma: no branch
pem = pem.encode()
d = b("").join(
[
l.strip()
for l in pem.split(b("\n"))
if l and not l.startswith(b("-----"))
]
)
return base64.b64decode(d)
def topem(der, name):
b64 = base64.b64encode(der)
lines = [("-----BEGIN %s-----\n" % name).encode()]
lines.extend(
[b64[start : start + 64] + b("\n") for start in range(0, len(b64), 64)]
)
lines.append(("-----END %s-----\n" % name).encode())
return b("").join(lines)

View file

@ -0,0 +1,336 @@
"""
Class for performing Elliptic-curve Diffie-Hellman (ECDH) operations.
"""
from .util import number_to_string
from .ellipticcurve import INFINITY
from .keys import SigningKey, VerifyingKey
__all__ = [
"ECDH",
"NoKeyError",
"NoCurveError",
"InvalidCurveError",
"InvalidSharedSecretError",
]
class NoKeyError(Exception):
"""ECDH. Key not found but it is needed for operation."""
pass
class NoCurveError(Exception):
"""ECDH. Curve not set but it is needed for operation."""
pass
class InvalidCurveError(Exception):
"""
ECDH. Raised in case the public and private keys use different curves.
"""
pass
class InvalidSharedSecretError(Exception):
"""ECDH. Raised in case the shared secret we obtained is an INFINITY."""
pass
class ECDH(object):
"""
Elliptic-curve Diffie-Hellman (ECDH). A key agreement protocol.
Allows two parties, each having an elliptic-curve public-private key
pair, to establish a shared secret over an insecure channel
"""
def __init__(self, curve=None, private_key=None, public_key=None):
"""
ECDH init.
Call can be initialised without parameters, then the first operation
(loading either key) will set the used curve.
All parameters must be ultimately set before shared secret
calculation will be allowed.
:param curve: curve for operations
:type curve: Curve
:param private_key: `my` private key for ECDH
:type private_key: SigningKey
:param public_key: `their` public key for ECDH
:type public_key: VerifyingKey
"""
self.curve = curve
self.private_key = None
self.public_key = None
if private_key:
self.load_private_key(private_key)
if public_key:
self.load_received_public_key(public_key)
def _get_shared_secret(self, remote_public_key):
if not self.private_key:
raise NoKeyError(
"Private key needs to be set to create shared secret"
)
if not self.public_key:
raise NoKeyError(
"Public key needs to be set to create shared secret"
)
if not (
self.private_key.curve == self.curve == remote_public_key.curve
):
raise InvalidCurveError(
"Curves for public key and private key is not equal."
)
# shared secret = PUBKEYtheirs * PRIVATEKEYours
result = (
remote_public_key.pubkey.point
* self.private_key.privkey.secret_multiplier
)
if result == INFINITY:
raise InvalidSharedSecretError("Invalid shared secret (INFINITY).")
return result.x()
def set_curve(self, key_curve):
"""
Set the working curve for ecdh operations.
:param key_curve: curve from `curves` module
:type key_curve: Curve
"""
self.curve = key_curve
def generate_private_key(self):
"""
Generate local private key for ecdh operation with curve that was set.
:raises NoCurveError: Curve must be set before key generation.
:return: public (verifying) key from this private key.
:rtype: VerifyingKey
"""
if not self.curve:
raise NoCurveError("Curve must be set prior to key generation.")
return self.load_private_key(SigningKey.generate(curve=self.curve))
def load_private_key(self, private_key):
"""
Load private key from SigningKey (keys.py) object.
Needs to have the same curve as was set with set_curve method.
If curve is not set - it sets from this SigningKey
:param private_key: Initialised SigningKey class
:type private_key: SigningKey
:raises InvalidCurveError: private_key curve not the same as self.curve
:return: public (verifying) key from this private key.
:rtype: VerifyingKey
"""
if not self.curve:
self.curve = private_key.curve
if self.curve != private_key.curve:
raise InvalidCurveError("Curve mismatch.")
self.private_key = private_key
return self.private_key.get_verifying_key()
def load_private_key_bytes(self, private_key):
"""
Load private key from byte string.
Uses current curve and checks if the provided key matches
the curve of ECDH key agreement.
Key loads via from_string method of SigningKey class
:param private_key: private key in bytes string format
:type private_key: :term:`bytes-like object`
:raises NoCurveError: Curve must be set before loading.
:return: public (verifying) key from this private key.
:rtype: VerifyingKey
"""
if not self.curve:
raise NoCurveError("Curve must be set prior to key load.")
return self.load_private_key(
SigningKey.from_string(private_key, curve=self.curve)
)
def load_private_key_der(self, private_key_der):
"""
Load private key from DER byte string.
Compares the curve of the DER-encoded key with the ECDH set curve,
uses the former if unset.
Note, the only DER format supported is the RFC5915
Look at keys.py:SigningKey.from_der()
:param private_key_der: string with the DER encoding of private ECDSA
key
:type private_key_der: string
:raises InvalidCurveError: private_key curve not the same as self.curve
:return: public (verifying) key from this private key.
:rtype: VerifyingKey
"""
return self.load_private_key(SigningKey.from_der(private_key_der))
def load_private_key_pem(self, private_key_pem):
"""
Load private key from PEM string.
Compares the curve of the DER-encoded key with the ECDH set curve,
uses the former if unset.
Note, the only PEM format supported is the RFC5915
Look at keys.py:SigningKey.from_pem()
it needs to have `EC PRIVATE KEY` section
:param private_key_pem: string with PEM-encoded private ECDSA key
:type private_key_pem: string
:raises InvalidCurveError: private_key curve not the same as self.curve
:return: public (verifying) key from this private key.
:rtype: VerifyingKey
"""
return self.load_private_key(SigningKey.from_pem(private_key_pem))
def get_public_key(self):
"""
Provides a public key that matches the local private key.
Needs to be sent to the remote party.
:return: public (verifying) key from local private key.
:rtype: VerifyingKey
"""
return self.private_key.get_verifying_key()
def load_received_public_key(self, public_key):
"""
Load public key from VerifyingKey (keys.py) object.
Needs to have the same curve as set as current for ecdh operation.
If curve is not set - it sets it from VerifyingKey.
:param public_key: Initialised VerifyingKey class
:type public_key: VerifyingKey
:raises InvalidCurveError: public_key curve not the same as self.curve
"""
if not self.curve:
self.curve = public_key.curve
if self.curve != public_key.curve:
raise InvalidCurveError("Curve mismatch.")
self.public_key = public_key
def load_received_public_key_bytes(
self, public_key_str, valid_encodings=None
):
"""
Load public key from byte string.
Uses current curve and checks if key length corresponds to
the current curve.
Key loads via from_string method of VerifyingKey class
:param public_key_str: public key in bytes string format
:type public_key_str: :term:`bytes-like object`
:param valid_encodings: list of acceptable point encoding formats,
supported ones are: :term:`uncompressed`, :term:`compressed`,
:term:`hybrid`, and :term:`raw encoding` (specified with ``raw``
name). All formats by default (specified with ``None``).
:type valid_encodings: :term:`set-like object`
"""
return self.load_received_public_key(
VerifyingKey.from_string(
public_key_str, self.curve, valid_encodings
)
)
def load_received_public_key_der(self, public_key_der):
"""
Load public key from DER byte string.
Compares the curve of the DER-encoded key with the ECDH set curve,
uses the former if unset.
Note, the only DER format supported is the RFC5912
Look at keys.py:VerifyingKey.from_der()
:param public_key_der: string with the DER encoding of public ECDSA key
:type public_key_der: string
:raises InvalidCurveError: public_key curve not the same as self.curve
"""
return self.load_received_public_key(
VerifyingKey.from_der(public_key_der)
)
def load_received_public_key_pem(self, public_key_pem):
"""
Load public key from PEM string.
Compares the curve of the PEM-encoded key with the ECDH set curve,
uses the former if unset.
Note, the only PEM format supported is the RFC5912
Look at keys.py:VerifyingKey.from_pem()
:param public_key_pem: string with PEM-encoded public ECDSA key
:type public_key_pem: string
:raises InvalidCurveError: public_key curve not the same as self.curve
"""
return self.load_received_public_key(
VerifyingKey.from_pem(public_key_pem)
)
def generate_sharedsecret_bytes(self):
"""
Generate shared secret from local private key and remote public key.
The objects needs to have both private key and received public key
before generation is allowed.
:raises InvalidCurveError: public_key curve not the same as self.curve
:raises NoKeyError: public_key or private_key is not set
:return: shared secret
:rtype: bytes
"""
return number_to_string(
self.generate_sharedsecret(), self.private_key.curve.curve.p()
)
def generate_sharedsecret(self):
"""
Generate shared secret from local private key and remote public key.
The objects needs to have both private key and received public key
before generation is allowed.
It's the same for local and remote party,
shared secret(local private key, remote public key) ==
shared secret(local public key, remote private key)
:raises InvalidCurveError: public_key curve not the same as self.curve
:raises NoKeyError: public_key or private_key is not set
:return: shared secret
:rtype: int
"""
return self._get_shared_secret(self.public_key)

View file

@ -0,0 +1,859 @@
#! /usr/bin/env python
"""
Low level implementation of Elliptic-Curve Digital Signatures.
.. note ::
You're most likely looking for the :py:class:`~ecdsa.keys` module.
This is a low-level implementation of the ECDSA that operates on
integers, not byte strings.
NOTE: This a low level implementation of ECDSA, for normal applications
you should be looking at the keys.py module.
Classes and methods for elliptic-curve signatures:
private keys, public keys, signatures,
and definitions of prime-modulus curves.
Example:
.. code-block:: python
# (In real-life applications, you would probably want to
# protect against defects in SystemRandom.)
from random import SystemRandom
randrange = SystemRandom().randrange
# Generate a public/private key pair using the NIST Curve P-192:
g = generator_192
n = g.order()
secret = randrange( 1, n )
pubkey = Public_key( g, g * secret )
privkey = Private_key( pubkey, secret )
# Signing a hash value:
hash = randrange( 1, n )
signature = privkey.sign( hash, randrange( 1, n ) )
# Verifying a signature for a hash value:
if pubkey.verifies( hash, signature ):
print_("Demo verification succeeded.")
else:
print_("*** Demo verification failed.")
# Verification fails if the hash value is modified:
if pubkey.verifies( hash-1, signature ):
print_("**** Demo verification failed to reject tampered hash.")
else:
print_("Demo verification correctly rejected tampered hash.")
Revision history:
2005.12.31 - Initial version.
2008.11.25 - Substantial revisions introducing new classes.
2009.05.16 - Warn against using random.randrange in real applications.
2009.05.17 - Use random.SystemRandom by default.
Originally written in 2005 by Peter Pearson and placed in the public domain,
modified as part of the python-ecdsa package.
"""
from six import int2byte, b
from . import ellipticcurve
from . import numbertheory
from .util import bit_length
from ._compat import remove_whitespace
class RSZeroError(RuntimeError):
pass
class InvalidPointError(RuntimeError):
pass
class Signature(object):
"""
ECDSA signature.
:ivar int r: the ``r`` element of the ECDSA signature
:ivar int s: the ``s`` element of the ECDSA signature
"""
def __init__(self, r, s):
self.r = r
self.s = s
def recover_public_keys(self, hash, generator):
"""
Returns two public keys for which the signature is valid
:param int hash: signed hash
:param AbstractPoint generator: is the generator used in creation
of the signature
:rtype: tuple(Public_key, Public_key)
:return: a pair of public keys that can validate the signature
"""
curve = generator.curve()
n = generator.order()
r = self.r
s = self.s
e = hash
x = r
# Compute the curve point with x as x-coordinate
alpha = (
pow(x, 3, curve.p()) + (curve.a() * x) + curve.b()
) % curve.p()
beta = numbertheory.square_root_mod_prime(alpha, curve.p())
y = beta if beta % 2 == 0 else curve.p() - beta
# Compute the public key
R1 = ellipticcurve.PointJacobi(curve, x, y, 1, n)
Q1 = numbertheory.inverse_mod(r, n) * (s * R1 + (-e % n) * generator)
Pk1 = Public_key(generator, Q1)
# And the second solution
R2 = ellipticcurve.PointJacobi(curve, x, -y, 1, n)
Q2 = numbertheory.inverse_mod(r, n) * (s * R2 + (-e % n) * generator)
Pk2 = Public_key(generator, Q2)
return [Pk1, Pk2]
class Public_key(object):
"""Public key for ECDSA."""
def __init__(self, generator, point, verify=True):
"""Low level ECDSA public key object.
:param generator: the Point that generates the group (the base point)
:param point: the Point that defines the public key
:param bool verify: if True check if point is valid point on curve
:raises InvalidPointError: if the point parameters are invalid or
point does not lay on the curve
"""
self.curve = generator.curve()
self.generator = generator
self.point = point
n = generator.order()
p = self.curve.p()
if not (0 <= point.x() < p) or not (0 <= point.y() < p):
raise InvalidPointError(
"The public point has x or y out of range."
)
if verify and not self.curve.contains_point(point.x(), point.y()):
raise InvalidPointError("Point does not lay on the curve")
if not n:
raise InvalidPointError("Generator point must have order.")
# for curve parameters with base point with cofactor 1, all points
# that are on the curve are scalar multiples of the base point, so
# verifying that is not necessary. See Section 3.2.2.1 of SEC 1 v2
if (
verify
and self.curve.cofactor() != 1
and not n * point == ellipticcurve.INFINITY
):
raise InvalidPointError("Generator point order is bad.")
def __eq__(self, other):
"""Return True if the keys are identical, False otherwise.
Note: for comparison, only placement on the same curve and point
equality is considered, use of the same generator point is not
considered.
"""
if isinstance(other, Public_key):
return self.curve == other.curve and self.point == other.point
return NotImplemented
def __ne__(self, other):
"""Return False if the keys are identical, True otherwise."""
return not self == other
def verifies(self, hash, signature):
"""Verify that signature is a valid signature of hash.
Return True if the signature is valid.
"""
# From X9.62 J.3.1.
G = self.generator
n = G.order()
r = signature.r
s = signature.s
if r < 1 or r > n - 1:
return False
if s < 1 or s > n - 1:
return False
c = numbertheory.inverse_mod(s, n)
u1 = (hash * c) % n
u2 = (r * c) % n
if hasattr(G, "mul_add"):
xy = G.mul_add(u1, self.point, u2)
else:
xy = u1 * G + u2 * self.point
v = xy.x() % n
return v == r
class Private_key(object):
"""Private key for ECDSA."""
def __init__(self, public_key, secret_multiplier):
"""public_key is of class Public_key;
secret_multiplier is a large integer.
"""
self.public_key = public_key
self.secret_multiplier = secret_multiplier
def __eq__(self, other):
"""Return True if the points are identical, False otherwise."""
if isinstance(other, Private_key):
return (
self.public_key == other.public_key
and self.secret_multiplier == other.secret_multiplier
)
return NotImplemented
def __ne__(self, other):
"""Return False if the points are identical, True otherwise."""
return not self == other
def sign(self, hash, random_k):
"""Return a signature for the provided hash, using the provided
random nonce. It is absolutely vital that random_k be an unpredictable
number in the range [1, self.public_key.point.order()-1]. If
an attacker can guess random_k, he can compute our private key from a
single signature. Also, if an attacker knows a few high-order
bits (or a few low-order bits) of random_k, he can compute our private
key from many signatures. The generation of nonces with adequate
cryptographic strength is very difficult and far beyond the scope
of this comment.
May raise RuntimeError, in which case retrying with a new
random value k is in order.
"""
G = self.public_key.generator
n = G.order()
k = random_k % n
# Fix the bit-length of the random nonce,
# so that it doesn't leak via timing.
# This does not change that ks = k mod n
ks = k + n
kt = ks + n
if bit_length(ks) == bit_length(n):
p1 = kt * G
else:
p1 = ks * G
r = p1.x() % n
if r == 0:
raise RSZeroError("amazingly unlucky random number r")
s = (
numbertheory.inverse_mod(k, n)
* (hash + (self.secret_multiplier * r) % n)
) % n
if s == 0:
raise RSZeroError("amazingly unlucky random number s")
return Signature(r, s)
def int_to_string(x):
"""Convert integer x into a string of bytes, as per X9.62."""
assert x >= 0
if x == 0:
return b("\0")
result = []
while x:
ordinal = x & 0xFF
result.append(int2byte(ordinal))
x >>= 8
result.reverse()
return b("").join(result)
def string_to_int(s):
"""Convert a string of bytes into an integer, as per X9.62."""
result = 0
for c in s:
if not isinstance(c, int):
c = ord(c)
result = 256 * result + c
return result
def digest_integer(m):
"""Convert an integer into a string of bytes, compute
its SHA-1 hash, and convert the result to an integer."""
#
# I don't expect this function to be used much. I wrote
# it in order to be able to duplicate the examples
# in ECDSAVS.
#
from hashlib import sha1
return string_to_int(sha1(int_to_string(m)).digest())
def point_is_valid(generator, x, y):
"""Is (x,y) a valid public key based on the specified generator?"""
# These are the tests specified in X9.62.
n = generator.order()
curve = generator.curve()
p = curve.p()
if not (0 <= x < p) or not (0 <= y < p):
return False
if not curve.contains_point(x, y):
return False
if (
curve.cofactor() != 1
and not n * ellipticcurve.PointJacobi(curve, x, y, 1)
== ellipticcurve.INFINITY
):
return False
return True
# secp112r1 curve
_p = int(remove_whitespace("DB7C 2ABF62E3 5E668076 BEAD208B"), 16)
# s = 00F50B02 8E4D696E 67687561 51752904 72783FB1
_a = int(remove_whitespace("DB7C 2ABF62E3 5E668076 BEAD2088"), 16)
_b = int(remove_whitespace("659E F8BA0439 16EEDE89 11702B22"), 16)
_Gx = int(remove_whitespace("09487239 995A5EE7 6B55F9C2 F098"), 16)
_Gy = int(remove_whitespace("A89C E5AF8724 C0A23E0E 0FF77500"), 16)
_r = int(remove_whitespace("DB7C 2ABF62E3 5E7628DF AC6561C5"), 16)
_h = 1
curve_112r1 = ellipticcurve.CurveFp(_p, _a, _b, _h)
generator_112r1 = ellipticcurve.PointJacobi(
curve_112r1, _Gx, _Gy, 1, _r, generator=True
)
# secp112r2 curve
_p = int(remove_whitespace("DB7C 2ABF62E3 5E668076 BEAD208B"), 16)
# s = 022757A1 114D69E 67687561 51755316 C05E0BD4
_a = int(remove_whitespace("6127 C24C05F3 8A0AAAF6 5C0EF02C"), 16)
_b = int(remove_whitespace("51DE F1815DB5 ED74FCC3 4C85D709"), 16)
_Gx = int(remove_whitespace("4BA30AB5 E892B4E1 649DD092 8643"), 16)
_Gy = int(remove_whitespace("ADCD 46F5882E 3747DEF3 6E956E97"), 16)
_r = int(remove_whitespace("36DF 0AAFD8B8 D7597CA1 0520D04B"), 16)
_h = 4
curve_112r2 = ellipticcurve.CurveFp(_p, _a, _b, _h)
generator_112r2 = ellipticcurve.PointJacobi(
curve_112r2, _Gx, _Gy, 1, _r, generator=True
)
# secp128r1 curve
_p = int(remove_whitespace("FFFFFFFD FFFFFFFF FFFFFFFF FFFFFFFF"), 16)
# S = 000E0D4D 69E6768 75615175 0CC03A44 73D03679
# a and b are mod p, so a is equal to p-3, or simply -3
# _a = -3
_b = int(remove_whitespace("E87579C1 1079F43D D824993C 2CEE5ED3"), 16)
_Gx = int(remove_whitespace("161FF752 8B899B2D 0C28607C A52C5B86"), 16)
_Gy = int(remove_whitespace("CF5AC839 5BAFEB13 C02DA292 DDED7A83"), 16)
_r = int(remove_whitespace("FFFFFFFE 00000000 75A30D1B 9038A115"), 16)
_h = 1
curve_128r1 = ellipticcurve.CurveFp(_p, -3, _b, _h)
generator_128r1 = ellipticcurve.PointJacobi(
curve_128r1, _Gx, _Gy, 1, _r, generator=True
)
# secp160r1
_p = int(remove_whitespace("FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF 7FFFFFFF"), 16)
# S = 1053CDE4 2C14D696 E6768756 1517533B F3F83345
# a and b are mod p, so a is equal to p-3, or simply -3
# _a = -3
_b = int(remove_whitespace("1C97BEFC 54BD7A8B 65ACF89F 81D4D4AD C565FA45"), 16)
_Gx = int(
remove_whitespace("4A96B568 8EF57328 46646989 68C38BB9 13CBFC82"),
16,
)
_Gy = int(
remove_whitespace("23A62855 3168947D 59DCC912 04235137 7AC5FB32"),
16,
)
_r = int(
remove_whitespace("01 00000000 00000000 0001F4C8 F927AED3 CA752257"),
16,
)
_h = 1
curve_160r1 = ellipticcurve.CurveFp(_p, -3, _b, _h)
generator_160r1 = ellipticcurve.PointJacobi(
curve_160r1, _Gx, _Gy, 1, _r, generator=True
)
# NIST Curve P-192:
_p = 6277101735386680763835789423207666416083908700390324961279
_r = 6277101735386680763835789423176059013767194773182842284081
# s = 0x3045ae6fc8422f64ed579528d38120eae12196d5L
# c = 0x3099d2bbbfcb2538542dcd5fb078b6ef5f3d6fe2c745de65L
_b = int(
remove_whitespace(
"""
64210519 E59C80E7 0FA7E9AB 72243049 FEB8DEEC C146B9B1"""
),
16,
)
_Gx = int(
remove_whitespace(
"""
188DA80E B03090F6 7CBF20EB 43A18800 F4FF0AFD 82FF1012"""
),
16,
)
_Gy = int(
remove_whitespace(
"""
07192B95 FFC8DA78 631011ED 6B24CDD5 73F977A1 1E794811"""
),
16,
)
curve_192 = ellipticcurve.CurveFp(_p, -3, _b, 1)
generator_192 = ellipticcurve.PointJacobi(
curve_192, _Gx, _Gy, 1, _r, generator=True
)
# NIST Curve P-224:
_p = int(
remove_whitespace(
"""
2695994666715063979466701508701963067355791626002630814351
0066298881"""
)
)
_r = int(
remove_whitespace(
"""
2695994666715063979466701508701962594045780771442439172168
2722368061"""
)
)
# s = 0xbd71344799d5c7fcdc45b59fa3b9ab8f6a948bc5L
# c = 0x5b056c7e11dd68f40469ee7f3c7a7d74f7d121116506d031218291fbL
_b = int(
remove_whitespace(
"""
B4050A85 0C04B3AB F5413256 5044B0B7 D7BFD8BA 270B3943
2355FFB4"""
),
16,
)
_Gx = int(
remove_whitespace(
"""
B70E0CBD 6BB4BF7F 321390B9 4A03C1D3 56C21122 343280D6
115C1D21"""
),
16,
)
_Gy = int(
remove_whitespace(
"""
BD376388 B5F723FB 4C22DFE6 CD4375A0 5A074764 44D58199
85007E34"""
),
16,
)
curve_224 = ellipticcurve.CurveFp(_p, -3, _b, 1)
generator_224 = ellipticcurve.PointJacobi(
curve_224, _Gx, _Gy, 1, _r, generator=True
)
# NIST Curve P-256:
_p = int(
remove_whitespace(
"""
1157920892103562487626974469494075735300861434152903141955
33631308867097853951"""
)
)
_r = int(
remove_whitespace(
"""
115792089210356248762697446949407573529996955224135760342
422259061068512044369"""
)
)
# s = 0xc49d360886e704936a6678e1139d26b7819f7e90L
# c = 0x7efba1662985be9403cb055c75d4f7e0ce8d84a9c5114abcaf3177680104fa0dL
_b = int(
remove_whitespace(
"""
5AC635D8 AA3A93E7 B3EBBD55 769886BC 651D06B0 CC53B0F6
3BCE3C3E 27D2604B"""
),
16,
)
_Gx = int(
remove_whitespace(
"""
6B17D1F2 E12C4247 F8BCE6E5 63A440F2 77037D81 2DEB33A0
F4A13945 D898C296"""
),
16,
)
_Gy = int(
remove_whitespace(
"""
4FE342E2 FE1A7F9B 8EE7EB4A 7C0F9E16 2BCE3357 6B315ECE
CBB64068 37BF51F5"""
),
16,
)
curve_256 = ellipticcurve.CurveFp(_p, -3, _b, 1)
generator_256 = ellipticcurve.PointJacobi(
curve_256, _Gx, _Gy, 1, _r, generator=True
)
# NIST Curve P-384:
_p = int(
remove_whitespace(
"""
3940200619639447921227904010014361380507973927046544666794
8293404245721771496870329047266088258938001861606973112319"""
)
)
_r = int(
remove_whitespace(
"""
3940200619639447921227904010014361380507973927046544666794
6905279627659399113263569398956308152294913554433653942643"""
)
)
# s = 0xa335926aa319a27a1d00896a6773a4827acdac73L
# c = int(remove_whitespace(
# """
# 79d1e655 f868f02f ff48dcde e14151dd b80643c1 406d0ca1
# 0dfe6fc5 2009540a 495e8042 ea5f744f 6e184667 cc722483"""
# ), 16)
_b = int(
remove_whitespace(
"""
B3312FA7 E23EE7E4 988E056B E3F82D19 181D9C6E FE814112
0314088F 5013875A C656398D 8A2ED19D 2A85C8ED D3EC2AEF"""
),
16,
)
_Gx = int(
remove_whitespace(
"""
AA87CA22 BE8B0537 8EB1C71E F320AD74 6E1D3B62 8BA79B98
59F741E0 82542A38 5502F25D BF55296C 3A545E38 72760AB7"""
),
16,
)
_Gy = int(
remove_whitespace(
"""
3617DE4A 96262C6F 5D9E98BF 9292DC29 F8F41DBD 289A147C
E9DA3113 B5F0B8C0 0A60B1CE 1D7E819D 7A431D7C 90EA0E5F"""
),
16,
)
curve_384 = ellipticcurve.CurveFp(_p, -3, _b, 1)
generator_384 = ellipticcurve.PointJacobi(
curve_384, _Gx, _Gy, 1, _r, generator=True
)
# NIST Curve P-521:
_p = int(
"686479766013060971498190079908139321726943530014330540939"
"446345918554318339765605212255964066145455497729631139148"
"0858037121987999716643812574028291115057151"
)
_r = int(
"686479766013060971498190079908139321726943530014330540939"
"446345918554318339765539424505774633321719753296399637136"
"3321113864768612440380340372808892707005449"
)
# s = 0xd09e8800291cb85396cc6717393284aaa0da64baL
# c = int(remove_whitespace(
# """
# 0b4 8bfa5f42 0a349495 39d2bdfc 264eeeeb 077688e4
# 4fbf0ad8 f6d0edb3 7bd6b533 28100051 8e19f1b9 ffbe0fe9
# ed8a3c22 00b8f875 e523868c 70c1e5bf 55bad637"""
# ), 16)
_b = int(
remove_whitespace(
"""
051 953EB961 8E1C9A1F 929A21A0 B68540EE A2DA725B
99B315F3 B8B48991 8EF109E1 56193951 EC7E937B 1652C0BD
3BB1BF07 3573DF88 3D2C34F1 EF451FD4 6B503F00"""
),
16,
)
_Gx = int(
remove_whitespace(
"""
C6 858E06B7 0404E9CD 9E3ECB66 2395B442 9C648139
053FB521 F828AF60 6B4D3DBA A14B5E77 EFE75928 FE1DC127
A2FFA8DE 3348B3C1 856A429B F97E7E31 C2E5BD66"""
),
16,
)
_Gy = int(
remove_whitespace(
"""
118 39296A78 9A3BC004 5C8A5FB4 2C7D1BD9 98F54449
579B4468 17AFBD17 273E662C 97EE7299 5EF42640 C550B901
3FAD0761 353C7086 A272C240 88BE9476 9FD16650"""
),
16,
)
curve_521 = ellipticcurve.CurveFp(_p, -3, _b, 1)
generator_521 = ellipticcurve.PointJacobi(
curve_521, _Gx, _Gy, 1, _r, generator=True
)
# Certicom secp256-k1
_a = 0x0000000000000000000000000000000000000000000000000000000000000000
_b = 0x0000000000000000000000000000000000000000000000000000000000000007
_p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
_Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
_Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8
_r = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
curve_secp256k1 = ellipticcurve.CurveFp(_p, _a, _b, 1)
generator_secp256k1 = ellipticcurve.PointJacobi(
curve_secp256k1, _Gx, _Gy, 1, _r, generator=True
)
# Brainpool P-160-r1
_a = 0x340E7BE2A280EB74E2BE61BADA745D97E8F7C300
_b = 0x1E589A8595423412134FAA2DBDEC95C8D8675E58
_p = 0xE95E4A5F737059DC60DFC7AD95B3D8139515620F
_Gx = 0xBED5AF16EA3F6A4F62938C4631EB5AF7BDBCDBC3
_Gy = 0x1667CB477A1A8EC338F94741669C976316DA6321
_q = 0xE95E4A5F737059DC60DF5991D45029409E60FC09
curve_brainpoolp160r1 = ellipticcurve.CurveFp(_p, _a, _b, 1)
generator_brainpoolp160r1 = ellipticcurve.PointJacobi(
curve_brainpoolp160r1, _Gx, _Gy, 1, _q, generator=True
)
# Brainpool P-192-r1
_a = 0x6A91174076B1E0E19C39C031FE8685C1CAE040E5C69A28EF
_b = 0x469A28EF7C28CCA3DC721D044F4496BCCA7EF4146FBF25C9
_p = 0xC302F41D932A36CDA7A3463093D18DB78FCE476DE1A86297
_Gx = 0xC0A0647EAAB6A48753B033C56CB0F0900A2F5C4853375FD6
_Gy = 0x14B690866ABD5BB88B5F4828C1490002E6773FA2FA299B8F
_q = 0xC302F41D932A36CDA7A3462F9E9E916B5BE8F1029AC4ACC1
curve_brainpoolp192r1 = ellipticcurve.CurveFp(_p, _a, _b, 1)
generator_brainpoolp192r1 = ellipticcurve.PointJacobi(
curve_brainpoolp192r1, _Gx, _Gy, 1, _q, generator=True
)
# Brainpool P-224-r1
_a = 0x68A5E62CA9CE6C1C299803A6C1530B514E182AD8B0042A59CAD29F43
_b = 0x2580F63CCFE44138870713B1A92369E33E2135D266DBB372386C400B
_p = 0xD7C134AA264366862A18302575D1D787B09F075797DA89F57EC8C0FF
_Gx = 0x0D9029AD2C7E5CF4340823B2A87DC68C9E4CE3174C1E6EFDEE12C07D
_Gy = 0x58AA56F772C0726F24C6B89E4ECDAC24354B9E99CAA3F6D3761402CD
_q = 0xD7C134AA264366862A18302575D0FB98D116BC4B6DDEBCA3A5A7939F
curve_brainpoolp224r1 = ellipticcurve.CurveFp(_p, _a, _b, 1)
generator_brainpoolp224r1 = ellipticcurve.PointJacobi(
curve_brainpoolp224r1, _Gx, _Gy, 1, _q, generator=True
)
# Brainpool P-256-r1
_a = 0x7D5A0975FC2C3057EEF67530417AFFE7FB8055C126DC5C6CE94A4B44F330B5D9
_b = 0x26DC5C6CE94A4B44F330B5D9BBD77CBF958416295CF7E1CE6BCCDC18FF8C07B6
_p = 0xA9FB57DBA1EEA9BC3E660A909D838D726E3BF623D52620282013481D1F6E5377
_Gx = 0x8BD2AEB9CB7E57CB2C4B482FFC81B7AFB9DE27E1E3BD23C23A4453BD9ACE3262
_Gy = 0x547EF835C3DAC4FD97F8461A14611DC9C27745132DED8E545C1D54C72F046997
_q = 0xA9FB57DBA1EEA9BC3E660A909D838D718C397AA3B561A6F7901E0E82974856A7
curve_brainpoolp256r1 = ellipticcurve.CurveFp(_p, _a, _b, 1)
generator_brainpoolp256r1 = ellipticcurve.PointJacobi(
curve_brainpoolp256r1, _Gx, _Gy, 1, _q, generator=True
)
# Brainpool P-320-r1
_a = int(
remove_whitespace(
"""
3EE30B568FBAB0F883CCEBD46D3F3BB8A2A73513F5EB79DA66190EB085FFA9
F492F375A97D860EB4"""
),
16,
)
_b = int(
remove_whitespace(
"""
520883949DFDBC42D3AD198640688A6FE13F41349554B49ACC31DCCD884539
816F5EB4AC8FB1F1A6"""
),
16,
)
_p = int(
remove_whitespace(
"""
D35E472036BC4FB7E13C785ED201E065F98FCFA6F6F40DEF4F92B9EC7893EC
28FCD412B1F1B32E27"""
),
16,
)
_Gx = int(
remove_whitespace(
"""
43BD7E9AFB53D8B85289BCC48EE5BFE6F20137D10A087EB6E7871E2A10A599
C710AF8D0D39E20611"""
),
16,
)
_Gy = int(
remove_whitespace(
"""
14FDD05545EC1CC8AB4093247F77275E0743FFED117182EAA9C77877AAAC6A
C7D35245D1692E8EE1"""
),
16,
)
_q = int(
remove_whitespace(
"""
D35E472036BC4FB7E13C785ED201E065F98FCFA5B68F12A32D482EC7EE8658
E98691555B44C59311"""
),
16,
)
curve_brainpoolp320r1 = ellipticcurve.CurveFp(_p, _a, _b, 1)
generator_brainpoolp320r1 = ellipticcurve.PointJacobi(
curve_brainpoolp320r1, _Gx, _Gy, 1, _q, generator=True
)
# Brainpool P-384-r1
_a = int(
remove_whitespace(
"""
7BC382C63D8C150C3C72080ACE05AFA0C2BEA28E4FB22787139165EFBA91F9
0F8AA5814A503AD4EB04A8C7DD22CE2826"""
),
16,
)
_b = int(
remove_whitespace(
"""
04A8C7DD22CE28268B39B55416F0447C2FB77DE107DCD2A62E880EA53EEB62
D57CB4390295DBC9943AB78696FA504C11"""
),
16,
)
_p = int(
remove_whitespace(
"""
8CB91E82A3386D280F5D6F7E50E641DF152F7109ED5456B412B1DA197FB711
23ACD3A729901D1A71874700133107EC53"""
),
16,
)
_Gx = int(
remove_whitespace(
"""
1D1C64F068CF45FFA2A63A81B7C13F6B8847A3E77EF14FE3DB7FCAFE0CBD10
E8E826E03436D646AAEF87B2E247D4AF1E"""
),
16,
)
_Gy = int(
remove_whitespace(
"""
8ABE1D7520F9C2A45CB1EB8E95CFD55262B70B29FEEC5864E19C054FF991292
80E4646217791811142820341263C5315"""
),
16,
)
_q = int(
remove_whitespace(
"""
8CB91E82A3386D280F5D6F7E50E641DF152F7109ED5456B31F166E6CAC0425
A7CF3AB6AF6B7FC3103B883202E9046565"""
),
16,
)
curve_brainpoolp384r1 = ellipticcurve.CurveFp(_p, _a, _b, 1)
generator_brainpoolp384r1 = ellipticcurve.PointJacobi(
curve_brainpoolp384r1, _Gx, _Gy, 1, _q, generator=True
)
# Brainpool P-512-r1
_a = int(
remove_whitespace(
"""
7830A3318B603B89E2327145AC234CC594CBDD8D3DF91610A83441CAEA9863
BC2DED5D5AA8253AA10A2EF1C98B9AC8B57F1117A72BF2C7B9E7C1AC4D77FC94CA"""
),
16,
)
_b = int(
remove_whitespace(
"""
3DF91610A83441CAEA9863BC2DED5D5AA8253AA10A2EF1C98B9AC8B57F1117
A72BF2C7B9E7C1AC4D77FC94CADC083E67984050B75EBAE5DD2809BD638016F723"""
),
16,
)
_p = int(
remove_whitespace(
"""
AADD9DB8DBE9C48B3FD4E6AE33C9FC07CB308DB3B3C9D20ED6639CCA703308
717D4D9B009BC66842AECDA12AE6A380E62881FF2F2D82C68528AA6056583A48F3"""
),
16,
)
_Gx = int(
remove_whitespace(
"""
81AEE4BDD82ED9645A21322E9C4C6A9385ED9F70B5D916C1B43B62EEF4D009
8EFF3B1F78E2D0D48D50D1687B93B97D5F7C6D5047406A5E688B352209BCB9F822"""
),
16,
)
_Gy = int(
remove_whitespace(
"""
7DDE385D566332ECC0EABFA9CF7822FDF209F70024A57B1AA000C55B881F81
11B2DCDE494A5F485E5BCA4BD88A2763AED1CA2B2FA8F0540678CD1E0F3AD80892"""
),
16,
)
_q = int(
remove_whitespace(
"""
AADD9DB8DBE9C48B3FD4E6AE33C9FC07CB308DB3B3C9D20ED6639CCA703308
70553E5C414CA92619418661197FAC10471DB1D381085DDADDB58796829CA90069"""
),
16,
)
curve_brainpoolp512r1 = ellipticcurve.CurveFp(_p, _a, _b, 1)
generator_brainpoolp512r1 = ellipticcurve.PointJacobi(
curve_brainpoolp512r1, _Gx, _Gy, 1, _q, generator=True
)

View file

@ -0,0 +1,252 @@
"""Implementation of Edwards Digital Signature Algorithm."""
import hashlib
from ._sha3 import shake_256
from . import ellipticcurve
from ._compat import (
remove_whitespace,
bit_length,
bytes_to_int,
int_to_bytes,
compat26_str,
)
# edwards25519, defined in RFC7748
_p = 2**255 - 19
_a = -1
_d = int(
remove_whitespace(
"370957059346694393431380835087545651895421138798432190163887855330"
"85940283555"
)
)
_h = 8
_Gx = int(
remove_whitespace(
"151122213495354007725011514095885315114540126930418572060461132"
"83949847762202"
)
)
_Gy = int(
remove_whitespace(
"463168356949264781694283940034751631413079938662562256157830336"
"03165251855960"
)
)
_r = 2**252 + 0x14DEF9DEA2F79CD65812631A5CF5D3ED
def _sha512(data):
return hashlib.new("sha512", compat26_str(data)).digest()
curve_ed25519 = ellipticcurve.CurveEdTw(_p, _a, _d, _h, _sha512)
generator_ed25519 = ellipticcurve.PointEdwards(
curve_ed25519, _Gx, _Gy, 1, _Gx * _Gy % _p, _r, generator=True
)
# edwards448, defined in RFC7748
_p = 2**448 - 2**224 - 1
_a = 1
_d = -39081 % _p
_h = 4
_Gx = int(
remove_whitespace(
"224580040295924300187604334099896036246789641632564134246125461"
"686950415467406032909029192869357953282578032075146446173674602635"
"247710"
)
)
_Gy = int(
remove_whitespace(
"298819210078481492676017930443930673437544040154080242095928241"
"372331506189835876003536878655418784733982303233503462500531545062"
"832660"
)
)
_r = 2**446 - 0x8335DC163BB124B65129C96FDE933D8D723A70AADC873D6D54A7BB0D
def _shake256(data):
return shake_256(data, 114)
curve_ed448 = ellipticcurve.CurveEdTw(_p, _a, _d, _h, _shake256)
generator_ed448 = ellipticcurve.PointEdwards(
curve_ed448, _Gx, _Gy, 1, _Gx * _Gy % _p, _r, generator=True
)
class PublicKey(object):
"""Public key for the Edwards Digital Signature Algorithm."""
def __init__(self, generator, public_key, public_point=None):
self.generator = generator
self.curve = generator.curve()
self.__encoded = public_key
# plus one for the sign bit and round up
self.baselen = (bit_length(self.curve.p()) + 1 + 7) // 8
if len(public_key) != self.baselen:
raise ValueError(
"Incorrect size of the public key, expected: {0} bytes".format(
self.baselen
)
)
if public_point:
self.__point = public_point
else:
self.__point = ellipticcurve.PointEdwards.from_bytes(
self.curve, public_key
)
def __eq__(self, other):
if isinstance(other, PublicKey):
return (
self.curve == other.curve and self.__encoded == other.__encoded
)
return NotImplemented
def __ne__(self, other):
return not self == other
@property
def point(self):
return self.__point
@point.setter
def point(self, other):
if self.__point != other:
raise ValueError("Can't change the coordinates of the point")
self.__point = other
def public_point(self):
return self.__point
def public_key(self):
return self.__encoded
def verify(self, data, signature):
"""Verify a Pure EdDSA signature over data."""
data = compat26_str(data)
if len(signature) != 2 * self.baselen:
raise ValueError(
"Invalid signature length, expected: {0} bytes".format(
2 * self.baselen
)
)
R = ellipticcurve.PointEdwards.from_bytes(
self.curve, signature[: self.baselen]
)
S = bytes_to_int(signature[self.baselen :], "little")
if S >= self.generator.order():
raise ValueError("Invalid signature")
dom = bytearray()
if self.curve == curve_ed448:
dom = bytearray(b"SigEd448" + b"\x00\x00")
k = bytes_to_int(
self.curve.hash_func(dom + R.to_bytes() + self.__encoded + data),
"little",
)
if self.generator * S != self.__point * k + R:
raise ValueError("Invalid signature")
return True
class PrivateKey(object):
"""Private key for the Edwards Digital Signature Algorithm."""
def __init__(self, generator, private_key):
self.generator = generator
self.curve = generator.curve()
# plus one for the sign bit and round up
self.baselen = (bit_length(self.curve.p()) + 1 + 7) // 8
if len(private_key) != self.baselen:
raise ValueError(
"Incorrect size of private key, expected: {0} bytes".format(
self.baselen
)
)
self.__private_key = bytes(private_key)
self.__h = bytearray(self.curve.hash_func(private_key))
self.__public_key = None
a = self.__h[: self.baselen]
a = self._key_prune(a)
scalar = bytes_to_int(a, "little")
self.__s = scalar
@property
def private_key(self):
return self.__private_key
def __eq__(self, other):
if isinstance(other, PrivateKey):
return (
self.curve == other.curve
and self.__private_key == other.__private_key
)
return NotImplemented
def __ne__(self, other):
return not self == other
def _key_prune(self, key):
# make sure the key is not in a small subgroup
h = self.curve.cofactor()
if h == 4:
h_log = 2
elif h == 8:
h_log = 3
else:
raise ValueError("Only cofactor 4 and 8 curves supported")
key[0] &= ~((1 << h_log) - 1)
# ensure the highest bit is set but no higher
l = bit_length(self.curve.p())
if l % 8 == 0:
key[-1] = 0
key[-2] |= 0x80
else:
key[-1] = key[-1] & (1 << (l % 8)) - 1 | 1 << (l % 8) - 1
return key
def public_key(self):
"""Generate the public key based on the included private key"""
if self.__public_key:
return self.__public_key
public_point = self.generator * self.__s
self.__public_key = PublicKey(
self.generator, public_point.to_bytes(), public_point
)
return self.__public_key
def sign(self, data):
"""Perform a Pure EdDSA signature over data."""
data = compat26_str(data)
A = self.public_key().public_key()
prefix = self.__h[self.baselen :]
dom = bytearray()
if self.curve == curve_ed448:
dom = bytearray(b"SigEd448" + b"\x00\x00")
r = bytes_to_int(self.curve.hash_func(dom + prefix + data), "little")
R = (self.generator * r).to_bytes()
k = bytes_to_int(self.curve.hash_func(dom + R + A + data), "little")
k %= self.generator.order()
S = (r + k * self.__s) % self.generator.order()
return R + int_to_bytes(S, self.baselen, "little")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,4 @@
class MalformedPointError(AssertionError):
"""Raised in case the encoding of private or public key is malformed."""
pass

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,825 @@
#! /usr/bin/env python
#
# Provide some simple capabilities from number theory.
#
# Version of 2008.11.14.
#
# Written in 2005 and 2006 by Peter Pearson and placed in the public domain.
# Revision history:
# 2008.11.14: Use pow(base, exponent, modulus) for modular_exp.
# Make gcd and lcm accept arbitrarily many arguments.
from __future__ import division
import sys
from six import integer_types, PY2
from six.moves import reduce
try:
xrange
except NameError:
xrange = range
try:
from gmpy2 import powmod
GMPY2 = True
GMPY = False
except ImportError:
GMPY2 = False
try:
from gmpy import mpz
GMPY = True
except ImportError:
GMPY = False
import math
import warnings
class Error(Exception):
"""Base class for exceptions in this module."""
pass
class JacobiError(Error):
pass
class SquareRootError(Error):
pass
class NegativeExponentError(Error):
pass
def modular_exp(base, exponent, modulus): # pragma: no cover
"""Raise base to exponent, reducing by modulus"""
# deprecated in 0.14
warnings.warn(
"Function is unused in library code. If you use this code, "
"change to pow() builtin.",
DeprecationWarning,
)
if exponent < 0:
raise NegativeExponentError(
"Negative exponents (%d) not allowed" % exponent
)
return pow(base, exponent, modulus)
def polynomial_reduce_mod(poly, polymod, p):
"""Reduce poly by polymod, integer arithmetic modulo p.
Polynomials are represented as lists of coefficients
of increasing powers of x."""
# This module has been tested only by extensive use
# in calculating modular square roots.
# Just to make this easy, require a monic polynomial:
assert polymod[-1] == 1
assert len(polymod) > 1
while len(poly) >= len(polymod):
if poly[-1] != 0:
for i in xrange(2, len(polymod) + 1):
poly[-i] = (poly[-i] - poly[-1] * polymod[-i]) % p
poly = poly[0:-1]
return poly
def polynomial_multiply_mod(m1, m2, polymod, p):
"""Polynomial multiplication modulo a polynomial over ints mod p.
Polynomials are represented as lists of coefficients
of increasing powers of x."""
# This is just a seat-of-the-pants implementation.
# This module has been tested only by extensive use
# in calculating modular square roots.
# Initialize the product to zero:
prod = (len(m1) + len(m2) - 1) * [0]
# Add together all the cross-terms:
for i in xrange(len(m1)):
for j in xrange(len(m2)):
prod[i + j] = (prod[i + j] + m1[i] * m2[j]) % p
return polynomial_reduce_mod(prod, polymod, p)
def polynomial_exp_mod(base, exponent, polymod, p):
"""Polynomial exponentiation modulo a polynomial over ints mod p.
Polynomials are represented as lists of coefficients
of increasing powers of x."""
# Based on the Handbook of Applied Cryptography, algorithm 2.227.
# This module has been tested only by extensive use
# in calculating modular square roots.
assert exponent < p
if exponent == 0:
return [1]
G = base
k = exponent
if k % 2 == 1:
s = G
else:
s = [1]
while k > 1:
k = k // 2
G = polynomial_multiply_mod(G, G, polymod, p)
if k % 2 == 1:
s = polynomial_multiply_mod(G, s, polymod, p)
return s
def jacobi(a, n):
"""Jacobi symbol"""
# Based on the Handbook of Applied Cryptography (HAC), algorithm 2.149.
# This function has been tested by comparison with a small
# table printed in HAC, and by extensive use in calculating
# modular square roots.
if not n >= 3:
raise JacobiError("n must be larger than 2")
if not n % 2 == 1:
raise JacobiError("n must be odd")
a = a % n
if a == 0:
return 0
if a == 1:
return 1
a1, e = a, 0
while a1 % 2 == 0:
a1, e = a1 // 2, e + 1
if e % 2 == 0 or n % 8 == 1 or n % 8 == 7:
s = 1
else:
s = -1
if a1 == 1:
return s
if n % 4 == 3 and a1 % 4 == 3:
s = -s
return s * jacobi(n % a1, a1)
def square_root_mod_prime(a, p):
"""Modular square root of a, mod p, p prime."""
# Based on the Handbook of Applied Cryptography, algorithms 3.34 to 3.39.
# This module has been tested for all values in [0,p-1] for
# every prime p from 3 to 1229.
assert 0 <= a < p
assert 1 < p
if a == 0:
return 0
if p == 2:
return a
jac = jacobi(a, p)
if jac == -1:
raise SquareRootError("%d has no square root modulo %d" % (a, p))
if p % 4 == 3:
return pow(a, (p + 1) // 4, p)
if p % 8 == 5:
d = pow(a, (p - 1) // 4, p)
if d == 1:
return pow(a, (p + 3) // 8, p)
assert d == p - 1
return (2 * a * pow(4 * a, (p - 5) // 8, p)) % p
if PY2:
# xrange on python2 can take integers representable as C long only
range_top = min(0x7FFFFFFF, p)
else:
range_top = p
for b in xrange(2, range_top):
if jacobi(b * b - 4 * a, p) == -1:
f = (a, -b, 1)
ff = polynomial_exp_mod((0, 1), (p + 1) // 2, f, p)
if ff[1]:
raise SquareRootError("p is not prime")
return ff[0]
raise RuntimeError("No b found.")
# because all the inverse_mod code is arch/environment specific, and coveralls
# expects it to execute equal number of times, we need to waive it by
# adding the "no branch" pragma to all branches
if GMPY2: # pragma: no branch
def inverse_mod(a, m):
"""Inverse of a mod m."""
if a == 0: # pragma: no branch
return 0
return powmod(a, -1, m)
elif GMPY: # pragma: no branch
def inverse_mod(a, m):
"""Inverse of a mod m."""
# while libgmp does support inverses modulo, it is accessible
# only using the native `pow()` function, and `pow()` in gmpy sanity
# checks the parameters before passing them on to underlying
# implementation
if a == 0: # pragma: no branch
return 0
a = mpz(a)
m = mpz(m)
lm, hm = mpz(1), mpz(0)
low, high = a % m, m
while low > 1: # pragma: no branch
r = high // low
lm, low, hm, high = hm - lm * r, high - low * r, lm, low
return lm % m
elif sys.version_info >= (3, 8): # pragma: no branch
def inverse_mod(a, m):
"""Inverse of a mod m."""
if a == 0: # pragma: no branch
return 0
return pow(a, -1, m)
else: # pragma: no branch
def inverse_mod(a, m):
"""Inverse of a mod m."""
if a == 0: # pragma: no branch
return 0
lm, hm = 1, 0
low, high = a % m, m
while low > 1: # pragma: no branch
r = high // low
lm, low, hm, high = hm - lm * r, high - low * r, lm, low
return lm % m
try:
gcd2 = math.gcd
except AttributeError:
def gcd2(a, b):
"""Greatest common divisor using Euclid's algorithm."""
while a:
a, b = b % a, a
return b
def gcd(*a):
"""Greatest common divisor.
Usage: gcd([ 2, 4, 6 ])
or: gcd(2, 4, 6)
"""
if len(a) > 1:
return reduce(gcd2, a)
if hasattr(a[0], "__iter__"):
return reduce(gcd2, a[0])
return a[0]
def lcm2(a, b):
"""Least common multiple of two integers."""
return (a * b) // gcd(a, b)
def lcm(*a):
"""Least common multiple.
Usage: lcm([ 3, 4, 5 ])
or: lcm(3, 4, 5)
"""
if len(a) > 1:
return reduce(lcm2, a)
if hasattr(a[0], "__iter__"):
return reduce(lcm2, a[0])
return a[0]
def factorization(n):
"""Decompose n into a list of (prime,exponent) pairs."""
assert isinstance(n, integer_types)
if n < 2:
return []
result = []
# Test the small primes:
for d in smallprimes:
if d > n:
break
q, r = divmod(n, d)
if r == 0:
count = 1
while d <= n:
n = q
q, r = divmod(n, d)
if r != 0:
break
count = count + 1
result.append((d, count))
# If n is still greater than the last of our small primes,
# it may require further work:
if n > smallprimes[-1]:
if is_prime(n): # If what's left is prime, it's easy:
result.append((n, 1))
else: # Ugh. Search stupidly for a divisor:
d = smallprimes[-1]
while 1:
d = d + 2 # Try the next divisor.
q, r = divmod(n, d)
if q < d: # n < d*d means we're done, n = 1 or prime.
break
if r == 0: # d divides n. How many times?
count = 1
n = q
while d <= n: # As long as d might still divide n,
q, r = divmod(n, d) # see if it does.
if r != 0:
break
n = q # It does. Reduce n, increase count.
count = count + 1
result.append((d, count))
if n > 1:
result.append((n, 1))
return result
def phi(n): # pragma: no cover
"""Return the Euler totient function of n."""
# deprecated in 0.14
warnings.warn(
"Function is unused by library code. If you use this code, "
"please open an issue in "
"https://github.com/tlsfuzzer/python-ecdsa",
DeprecationWarning,
)
assert isinstance(n, integer_types)
if n < 3:
return 1
result = 1
ff = factorization(n)
for f in ff:
e = f[1]
if e > 1:
result = result * f[0] ** (e - 1) * (f[0] - 1)
else:
result = result * (f[0] - 1)
return result
def carmichael(n): # pragma: no cover
"""Return Carmichael function of n.
Carmichael(n) is the smallest integer x such that
m**x = 1 mod n for all m relatively prime to n.
"""
# deprecated in 0.14
warnings.warn(
"Function is unused by library code. If you use this code, "
"please open an issue in "
"https://github.com/tlsfuzzer/python-ecdsa",
DeprecationWarning,
)
return carmichael_of_factorized(factorization(n))
def carmichael_of_factorized(f_list): # pragma: no cover
"""Return the Carmichael function of a number that is
represented as a list of (prime,exponent) pairs.
"""
# deprecated in 0.14
warnings.warn(
"Function is unused by library code. If you use this code, "
"please open an issue in "
"https://github.com/tlsfuzzer/python-ecdsa",
DeprecationWarning,
)
if len(f_list) < 1:
return 1
result = carmichael_of_ppower(f_list[0])
for i in xrange(1, len(f_list)):
result = lcm(result, carmichael_of_ppower(f_list[i]))
return result
def carmichael_of_ppower(pp): # pragma: no cover
"""Carmichael function of the given power of the given prime."""
# deprecated in 0.14
warnings.warn(
"Function is unused by library code. If you use this code, "
"please open an issue in "
"https://github.com/tlsfuzzer/python-ecdsa",
DeprecationWarning,
)
p, a = pp
if p == 2 and a > 2:
return 2 ** (a - 2)
else:
return (p - 1) * p ** (a - 1)
def order_mod(x, m): # pragma: no cover
"""Return the order of x in the multiplicative group mod m."""
# deprecated in 0.14
warnings.warn(
"Function is unused by library code. If you use this code, "
"please open an issue in "
"https://github.com/tlsfuzzer/python-ecdsa",
DeprecationWarning,
)
# Warning: this implementation is not very clever, and will
# take a long time if m is very large.
if m <= 1:
return 0
assert gcd(x, m) == 1
z = x
result = 1
while z != 1:
z = (z * x) % m
result = result + 1
return result
def largest_factor_relatively_prime(a, b): # pragma: no cover
"""Return the largest factor of a relatively prime to b."""
# deprecated in 0.14
warnings.warn(
"Function is unused by library code. If you use this code, "
"please open an issue in "
"https://github.com/tlsfuzzer/python-ecdsa",
DeprecationWarning,
)
while 1:
d = gcd(a, b)
if d <= 1:
break
b = d
while 1:
q, r = divmod(a, d)
if r > 0:
break
a = q
return a
def kinda_order_mod(x, m): # pragma: no cover
"""Return the order of x in the multiplicative group mod m',
where m' is the largest factor of m relatively prime to x.
"""
# deprecated in 0.14
warnings.warn(
"Function is unused by library code. If you use this code, "
"please open an issue in "
"https://github.com/tlsfuzzer/python-ecdsa",
DeprecationWarning,
)
return order_mod(x, largest_factor_relatively_prime(m, x))
def is_prime(n):
"""Return True if x is prime, False otherwise.
We use the Miller-Rabin test, as given in Menezes et al. p. 138.
This test is not exact: there are composite values n for which
it returns True.
In testing the odd numbers from 10000001 to 19999999,
about 66 composites got past the first test,
5 got past the second test, and none got past the third.
Since factors of 2, 3, 5, 7, and 11 were detected during
preliminary screening, the number of numbers tested by
Miller-Rabin was (19999999 - 10000001)*(2/3)*(4/5)*(6/7)
= 4.57 million.
"""
# (This is used to study the risk of false positives:)
global miller_rabin_test_count
miller_rabin_test_count = 0
if n <= smallprimes[-1]:
if n in smallprimes:
return True
else:
return False
if gcd(n, 2 * 3 * 5 * 7 * 11) != 1:
return False
# Choose a number of iterations sufficient to reduce the
# probability of accepting a composite below 2**-80
# (from Menezes et al. Table 4.4):
t = 40
n_bits = 1 + int(math.log(n, 2))
for k, tt in (
(100, 27),
(150, 18),
(200, 15),
(250, 12),
(300, 9),
(350, 8),
(400, 7),
(450, 6),
(550, 5),
(650, 4),
(850, 3),
(1300, 2),
):
if n_bits < k:
break
t = tt
# Run the test t times:
s = 0
r = n - 1
while (r % 2) == 0:
s = s + 1
r = r // 2
for i in xrange(t):
a = smallprimes[i]
y = pow(a, r, n)
if y != 1 and y != n - 1:
j = 1
while j <= s - 1 and y != n - 1:
y = pow(y, 2, n)
if y == 1:
miller_rabin_test_count = i + 1
return False
j = j + 1
if y != n - 1:
miller_rabin_test_count = i + 1
return False
return True
def next_prime(starting_value):
"""Return the smallest prime larger than the starting value."""
if starting_value < 2:
return 2
result = (starting_value + 1) | 1
while not is_prime(result):
result = result + 2
return result
smallprimes = [
2,
3,
5,
7,
11,
13,
17,
19,
23,
29,
31,
37,
41,
43,
47,
53,
59,
61,
67,
71,
73,
79,
83,
89,
97,
101,
103,
107,
109,
113,
127,
131,
137,
139,
149,
151,
157,
163,
167,
173,
179,
181,
191,
193,
197,
199,
211,
223,
227,
229,
233,
239,
241,
251,
257,
263,
269,
271,
277,
281,
283,
293,
307,
311,
313,
317,
331,
337,
347,
349,
353,
359,
367,
373,
379,
383,
389,
397,
401,
409,
419,
421,
431,
433,
439,
443,
449,
457,
461,
463,
467,
479,
487,
491,
499,
503,
509,
521,
523,
541,
547,
557,
563,
569,
571,
577,
587,
593,
599,
601,
607,
613,
617,
619,
631,
641,
643,
647,
653,
659,
661,
673,
677,
683,
691,
701,
709,
719,
727,
733,
739,
743,
751,
757,
761,
769,
773,
787,
797,
809,
811,
821,
823,
827,
829,
839,
853,
857,
859,
863,
877,
881,
883,
887,
907,
911,
919,
929,
937,
941,
947,
953,
967,
971,
977,
983,
991,
997,
1009,
1013,
1019,
1021,
1031,
1033,
1039,
1049,
1051,
1061,
1063,
1069,
1087,
1091,
1093,
1097,
1103,
1109,
1117,
1123,
1129,
1151,
1153,
1163,
1171,
1181,
1187,
1193,
1201,
1213,
1217,
1223,
1229,
]
miller_rabin_test_count = 0

View file

@ -0,0 +1,113 @@
"""
RFC 6979:
Deterministic Usage of the Digital Signature Algorithm (DSA) and
Elliptic Curve Digital Signature Algorithm (ECDSA)
http://tools.ietf.org/html/rfc6979
Many thanks to Coda Hale for his implementation in Go language:
https://github.com/codahale/rfc6979
"""
import hmac
from binascii import hexlify
from .util import number_to_string, number_to_string_crop, bit_length
from ._compat import hmac_compat
# bit_length was defined in this module previously so keep it for backwards
# compatibility, will need to deprecate and remove it later
__all__ = ["bit_length", "bits2int", "bits2octets", "generate_k"]
def bits2int(data, qlen):
x = int(hexlify(data), 16)
l = len(data) * 8
if l > qlen:
return x >> (l - qlen)
return x
def bits2octets(data, order):
z1 = bits2int(data, bit_length(order))
z2 = z1 - order
if z2 < 0:
z2 = z1
return number_to_string_crop(z2, order)
# https://tools.ietf.org/html/rfc6979#section-3.2
def generate_k(order, secexp, hash_func, data, retry_gen=0, extra_entropy=b""):
"""
Generate the ``k`` value - the nonce for DSA.
:param int order: order of the DSA generator used in the signature
:param int secexp: secure exponent (private key) in numeric form
:param hash_func: reference to the same hash function used for generating
hash, like :py:class:`hashlib.sha1`
:param bytes data: hash in binary form of the signing data
:param int retry_gen: how many good 'k' values to skip before returning
:param bytes extra_entropy: additional added data in binary form as per
section-3.6 of rfc6979
:rtype: int
"""
qlen = bit_length(order)
holen = hash_func().digest_size
rolen = (qlen + 7) // 8
bx = (
hmac_compat(number_to_string(secexp, order)),
hmac_compat(bits2octets(data, order)),
hmac_compat(extra_entropy),
)
# Step B
v = b"\x01" * holen
# Step C
k = b"\x00" * holen
# Step D
k = hmac.new(k, digestmod=hash_func)
k.update(v + b"\x00")
for i in bx:
k.update(i)
k = k.digest()
# Step E
v = hmac.new(k, v, hash_func).digest()
# Step F
k = hmac.new(k, digestmod=hash_func)
k.update(v + b"\x01")
for i in bx:
k.update(i)
k = k.digest()
# Step G
v = hmac.new(k, v, hash_func).digest()
# Step H
while True:
# Step H1
t = b""
# Step H2
while len(t) < rolen:
v = hmac.new(k, v, hash_func).digest()
t += v
# Step H3
secret = bits2int(t, qlen)
if 1 <= secret < order:
if retry_gen <= 0:
return secret
retry_gen -= 1
k = hmac.new(k, v + b"\x00", hash_func).digest()
v = hmac.new(k, v, hash_func).digest()

View file

@ -0,0 +1,361 @@
try:
import unittest2 as unittest
except ImportError:
import unittest
import base64
import pytest
from .curves import (
Curve,
NIST256p,
curves,
UnknownCurveError,
PRIME_FIELD_OID,
curve_by_name,
)
from .ellipticcurve import CurveFp, PointJacobi, CurveEdTw
from . import der
from .util import number_to_string
class TestParameterEncoding(unittest.TestCase):
@classmethod
def setUpClass(cls):
# minimal, but with cofactor (excludes seed when compared to
# OpenSSL output)
cls.base64_params = (
"MIHgAgEBMCwGByqGSM49AQECIQD/////AAAAAQAAAAAAAAAAAAAAAP/////////"
"//////zBEBCD/////AAAAAQAAAAAAAAAAAAAAAP///////////////AQgWsY12K"
"o6k+ez671VdpiGvGUdBrDMU7D2O848PifSYEsEQQRrF9Hy4SxCR/i85uVjpEDyd"
"wN9gS3rM6D0oTlF2JjClk/jQuL+Gn+bjufrSnwPnhYrzjNXazFezsu2QGg3v1H1"
"AiEA/////wAAAAD//////////7zm+q2nF56E87nKwvxjJVECAQE="
)
def test_from_pem(self):
pem_params = (
"-----BEGIN EC PARAMETERS-----\n"
"MIHgAgEBMCwGByqGSM49AQECIQD/////AAAAAQAAAAAAAAAAAAAAAP/////////\n"
"//////zBEBCD/////AAAAAQAAAAAAAAAAAAAAAP///////////////AQgWsY12K\n"
"o6k+ez671VdpiGvGUdBrDMU7D2O848PifSYEsEQQRrF9Hy4SxCR/i85uVjpEDyd\n"
"wN9gS3rM6D0oTlF2JjClk/jQuL+Gn+bjufrSnwPnhYrzjNXazFezsu2QGg3v1H1\n"
"AiEA/////wAAAAD//////////7zm+q2nF56E87nKwvxjJVECAQE=\n"
"-----END EC PARAMETERS-----\n"
)
curve = Curve.from_pem(pem_params)
self.assertIs(curve, NIST256p)
def test_from_pem_with_explicit_when_explicit_disabled(self):
pem_params = (
"-----BEGIN EC PARAMETERS-----\n"
"MIHgAgEBMCwGByqGSM49AQECIQD/////AAAAAQAAAAAAAAAAAAAAAP/////////\n"
"//////zBEBCD/////AAAAAQAAAAAAAAAAAAAAAP///////////////AQgWsY12K\n"
"o6k+ez671VdpiGvGUdBrDMU7D2O848PifSYEsEQQRrF9Hy4SxCR/i85uVjpEDyd\n"
"wN9gS3rM6D0oTlF2JjClk/jQuL+Gn+bjufrSnwPnhYrzjNXazFezsu2QGg3v1H1\n"
"AiEA/////wAAAAD//////////7zm+q2nF56E87nKwvxjJVECAQE=\n"
"-----END EC PARAMETERS-----\n"
)
with self.assertRaises(der.UnexpectedDER) as e:
Curve.from_pem(pem_params, ["named_curve"])
self.assertIn("explicit curve parameters not", str(e.exception))
def test_from_pem_with_named_curve_with_named_curve_disabled(self):
pem_params = (
"-----BEGIN EC PARAMETERS-----\n"
"BggqhkjOPQMBBw==\n"
"-----END EC PARAMETERS-----\n"
)
with self.assertRaises(der.UnexpectedDER) as e:
Curve.from_pem(pem_params, ["explicit"])
self.assertIn("named_curve curve parameters not", str(e.exception))
def test_from_pem_with_wrong_header(self):
pem_params = (
"-----BEGIN PARAMETERS-----\n"
"MIHgAgEBMCwGByqGSM49AQECIQD/////AAAAAQAAAAAAAAAAAAAAAP/////////\n"
"//////zBEBCD/////AAAAAQAAAAAAAAAAAAAAAP///////////////AQgWsY12K\n"
"o6k+ez671VdpiGvGUdBrDMU7D2O848PifSYEsEQQRrF9Hy4SxCR/i85uVjpEDyd\n"
"wN9gS3rM6D0oTlF2JjClk/jQuL+Gn+bjufrSnwPnhYrzjNXazFezsu2QGg3v1H1\n"
"AiEA/////wAAAAD//////////7zm+q2nF56E87nKwvxjJVECAQE=\n"
"-----END PARAMETERS-----\n"
)
with self.assertRaises(der.UnexpectedDER) as e:
Curve.from_pem(pem_params)
self.assertIn("PARAMETERS PEM header", str(e.exception))
def test_to_pem(self):
pem_params = (
b"-----BEGIN EC PARAMETERS-----\n"
b"BggqhkjOPQMBBw==\n"
b"-----END EC PARAMETERS-----\n"
)
encoding = NIST256p.to_pem()
self.assertEqual(pem_params, encoding)
def test_compare_with_different_object(self):
self.assertNotEqual(NIST256p, 256)
def test_named_curve_params_der(self):
encoded = NIST256p.to_der()
# just the encoding of the NIST256p OID (prime256v1)
self.assertEqual(b"\x06\x08\x2a\x86\x48\xce\x3d\x03\x01\x07", encoded)
def test_verify_that_default_is_named_curve_der(self):
encoded_default = NIST256p.to_der()
encoded_named = NIST256p.to_der("named_curve")
self.assertEqual(encoded_default, encoded_named)
def test_encoding_to_explicit_params(self):
encoded = NIST256p.to_der("explicit")
self.assertEqual(encoded, bytes(base64.b64decode(self.base64_params)))
def test_encoding_to_unsupported_type(self):
with self.assertRaises(ValueError) as e:
NIST256p.to_der("unsupported")
self.assertIn("Only 'named_curve'", str(e.exception))
def test_encoding_to_explicit_compressed_params(self):
encoded = NIST256p.to_der("explicit", "compressed")
compressed_base_point = (
"MIHAAgEBMCwGByqGSM49AQECIQD/////AAAAAQAAAAAAAAAAAAAAAP//////////"
"/////zBEBCD/////AAAAAQAAAAAAAAAAAAAAAP///////////////AQgWsY12Ko6"
"k+ez671VdpiGvGUdBrDMU7D2O848PifSYEsEIQNrF9Hy4SxCR/i85uVjpEDydwN9"
"gS3rM6D0oTlF2JjClgIhAP////8AAAAA//////////+85vqtpxeehPO5ysL8YyVR"
"AgEB"
)
self.assertEqual(
encoded, bytes(base64.b64decode(compressed_base_point))
)
def test_decoding_explicit_from_openssl(self):
# generated with openssl 1.1.1k using
# openssl ecparam -name P-256 -param_enc explicit -out /tmp/file.pem
p256_explicit = (
"MIH3AgEBMCwGByqGSM49AQECIQD/////AAAAAQAAAAAAAAAAAAAAAP//////////"
"/////zBbBCD/////AAAAAQAAAAAAAAAAAAAAAP///////////////AQgWsY12Ko6"
"k+ez671VdpiGvGUdBrDMU7D2O848PifSYEsDFQDEnTYIhucEk2pmeOETnSa3gZ9+"
"kARBBGsX0fLhLEJH+Lzm5WOkQPJ3A32BLeszoPShOUXYmMKWT+NC4v4af5uO5+tK"
"fA+eFivOM1drMV7Oy7ZAaDe/UfUCIQD/////AAAAAP//////////vOb6racXnoTz"
"ucrC/GMlUQIBAQ=="
)
decoded = Curve.from_der(bytes(base64.b64decode(p256_explicit)))
self.assertEqual(NIST256p, decoded)
def test_decoding_well_known_from_explicit_params(self):
curve = Curve.from_der(bytes(base64.b64decode(self.base64_params)))
self.assertIs(curve, NIST256p)
def test_decoding_with_incorrect_valid_encodings(self):
with self.assertRaises(ValueError) as e:
Curve.from_der(b"", ["explicitCA"])
self.assertIn("Only named_curve", str(e.exception))
def test_compare_curves_with_different_generators(self):
curve_fp = CurveFp(23, 1, 7)
base_a = PointJacobi(curve_fp, 13, 3, 1, 9, generator=True)
base_b = PointJacobi(curve_fp, 1, 20, 1, 9, generator=True)
curve_a = Curve("unknown", curve_fp, base_a, None)
curve_b = Curve("unknown", curve_fp, base_b, None)
self.assertNotEqual(curve_a, curve_b)
def test_default_encode_for_custom_curve(self):
curve_fp = CurveFp(23, 1, 7)
base_point = PointJacobi(curve_fp, 13, 3, 1, 9, generator=True)
curve = Curve("unknown", curve_fp, base_point, None)
encoded = curve.to_der()
decoded = Curve.from_der(encoded)
self.assertEqual(curve, decoded)
expected = "MCECAQEwDAYHKoZIzj0BAQIBFzAGBAEBBAEHBAMEDQMCAQk="
self.assertEqual(encoded, bytes(base64.b64decode(expected)))
def test_named_curve_encode_for_custom_curve(self):
curve_fp = CurveFp(23, 1, 7)
base_point = PointJacobi(curve_fp, 13, 3, 1, 9, generator=True)
curve = Curve("unknown", curve_fp, base_point, None)
with self.assertRaises(UnknownCurveError) as e:
curve.to_der("named_curve")
self.assertIn("Can't encode curve", str(e.exception))
def test_try_decoding_binary_explicit(self):
sect113r1_explicit = (
"MIGRAgEBMBwGByqGSM49AQIwEQIBcQYJKoZIzj0BAgMCAgEJMDkEDwAwiCUMpufH"
"/mSc6Fgg9wQPAOi+5NPiJgdEGIvg6ccjAxUAEOcjqxTWluZ2h1YVF1b+v4/LSakE"
"HwQAnXNhbzX0qxQH1zViwQ8ApSgwJ3lY7oTRMV7TGIYCDwEAAAAAAAAA2czsijnl"
"bwIBAg=="
)
with self.assertRaises(UnknownCurveError) as e:
Curve.from_der(base64.b64decode(sect113r1_explicit))
self.assertIn("Characteristic 2 curves unsupported", str(e.exception))
def test_decode_malformed_named_curve(self):
bad_der = der.encode_oid(*NIST256p.oid) + der.encode_integer(1)
with self.assertRaises(der.UnexpectedDER) as e:
Curve.from_der(bad_der)
self.assertIn("Unexpected data after OID", str(e.exception))
def test_decode_malformed_explicit_garbage_after_ECParam(self):
bad_der = bytes(
base64.b64decode(self.base64_params)
) + der.encode_integer(1)
with self.assertRaises(der.UnexpectedDER) as e:
Curve.from_der(bad_der)
self.assertIn("Unexpected data after ECParameters", str(e.exception))
def test_decode_malformed_unknown_version_number(self):
bad_der = der.encode_sequence(der.encode_integer(2))
with self.assertRaises(der.UnexpectedDER) as e:
Curve.from_der(bad_der)
self.assertIn("Unknown parameter encoding format", str(e.exception))
def test_decode_malformed_unknown_field_type(self):
curve_p = NIST256p.curve.p()
bad_der = der.encode_sequence(
der.encode_integer(1),
der.encode_sequence(
der.encode_oid(1, 2, 3), der.encode_integer(curve_p)
),
der.encode_sequence(
der.encode_octet_string(
number_to_string(NIST256p.curve.a() % curve_p, curve_p)
),
der.encode_octet_string(
number_to_string(NIST256p.curve.b(), curve_p)
),
),
der.encode_octet_string(
NIST256p.generator.to_bytes("uncompressed")
),
der.encode_integer(NIST256p.generator.order()),
)
with self.assertRaises(UnknownCurveError) as e:
Curve.from_der(bad_der)
self.assertIn("Unknown field type: (1, 2, 3)", str(e.exception))
def test_decode_malformed_garbage_after_prime(self):
curve_p = NIST256p.curve.p()
bad_der = der.encode_sequence(
der.encode_integer(1),
der.encode_sequence(
der.encode_oid(*PRIME_FIELD_OID),
der.encode_integer(curve_p),
der.encode_integer(1),
),
der.encode_sequence(
der.encode_octet_string(
number_to_string(NIST256p.curve.a() % curve_p, curve_p)
),
der.encode_octet_string(
number_to_string(NIST256p.curve.b(), curve_p)
),
),
der.encode_octet_string(
NIST256p.generator.to_bytes("uncompressed")
),
der.encode_integer(NIST256p.generator.order()),
)
with self.assertRaises(der.UnexpectedDER) as e:
Curve.from_der(bad_der)
self.assertIn("Prime-p element", str(e.exception))
class TestCurveSearching(unittest.TestCase):
def test_correct_name(self):
c = curve_by_name("NIST256p")
self.assertIs(c, NIST256p)
def test_openssl_name(self):
c = curve_by_name("prime256v1")
self.assertIs(c, NIST256p)
def test_unknown_curve(self):
with self.assertRaises(UnknownCurveError) as e:
curve_by_name("foo bar")
self.assertIn(
"name 'foo bar' unknown, only curves supported: "
"['NIST192p', 'NIST224p'",
str(e.exception),
)
def test_with_None_as_parameter(self):
with self.assertRaises(UnknownCurveError) as e:
curve_by_name(None)
self.assertIn(
"name None unknown, only curves supported: "
"['NIST192p', 'NIST224p'",
str(e.exception),
)
@pytest.mark.parametrize("curve", curves, ids=[i.name for i in curves])
def test_curve_params_encode_decode_named(curve):
ret = Curve.from_der(curve.to_der("named_curve"))
assert curve == ret
@pytest.mark.parametrize("curve", curves, ids=[i.name for i in curves])
def test_curve_params_encode_decode_explicit(curve):
if isinstance(curve.curve, CurveEdTw):
with pytest.raises(UnknownCurveError):
curve.to_der("explicit")
else:
ret = Curve.from_der(curve.to_der("explicit"))
assert curve == ret
@pytest.mark.parametrize("curve", curves, ids=[i.name for i in curves])
def test_curve_params_encode_decode_default(curve):
ret = Curve.from_der(curve.to_der())
assert curve == ret
@pytest.mark.parametrize("curve", curves, ids=[i.name for i in curves])
def test_curve_params_encode_decode_explicit_compressed(curve):
if isinstance(curve.curve, CurveEdTw):
with pytest.raises(UnknownCurveError):
curve.to_der("explicit", "compressed")
else:
ret = Curve.from_der(curve.to_der("explicit", "compressed"))
assert curve == ret

View file

@ -0,0 +1,476 @@
# compatibility with Python 2.6, for that we need unittest2 package,
# which is not available on 3.3 or 3.4
import warnings
from binascii import hexlify
try:
import unittest2 as unittest
except ImportError:
import unittest
from six import b
import hypothesis.strategies as st
from hypothesis import given
import pytest
from ._compat import str_idx_as_int
from .curves import NIST256p, NIST224p
from .der import (
remove_integer,
UnexpectedDER,
read_length,
encode_bitstring,
remove_bitstring,
remove_object,
encode_oid,
remove_constructed,
remove_octet_string,
remove_sequence,
)
class TestRemoveInteger(unittest.TestCase):
# DER requires the integers to be 0-padded only if they would be
# interpreted as negative, check if those errors are detected
def test_non_minimal_encoding(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b("\x02\x02\x00\x01"))
def test_negative_with_high_bit_set(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b("\x02\x01\x80"))
def test_minimal_with_high_bit_set(self):
val, rem = remove_integer(b("\x02\x02\x00\x80"))
self.assertEqual(val, 0x80)
self.assertEqual(rem, b"")
def test_two_zero_bytes_with_high_bit_set(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b("\x02\x03\x00\x00\xff"))
def test_zero_length_integer(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b("\x02\x00"))
def test_empty_string(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b(""))
def test_encoding_of_zero(self):
val, rem = remove_integer(b("\x02\x01\x00"))
self.assertEqual(val, 0)
self.assertEqual(rem, b"")
def test_encoding_of_127(self):
val, rem = remove_integer(b("\x02\x01\x7f"))
self.assertEqual(val, 127)
self.assertEqual(rem, b"")
def test_encoding_of_128(self):
val, rem = remove_integer(b("\x02\x02\x00\x80"))
self.assertEqual(val, 128)
self.assertEqual(rem, b"")
def test_wrong_tag(self):
with self.assertRaises(UnexpectedDER) as e:
remove_integer(b"\x01\x02\x00\x80")
self.assertIn("wanted type 'integer'", str(e.exception))
def test_wrong_length(self):
with self.assertRaises(UnexpectedDER) as e:
remove_integer(b"\x02\x03\x00\x80")
self.assertIn("Length longer", str(e.exception))
class TestReadLength(unittest.TestCase):
# DER requires the lengths between 0 and 127 to be encoded using the short
# form and lengths above that encoded with minimal number of bytes
# necessary
def test_zero_length(self):
self.assertEqual((0, 1), read_length(b("\x00")))
def test_two_byte_zero_length(self):
with self.assertRaises(UnexpectedDER):
read_length(b("\x81\x00"))
def test_two_byte_small_length(self):
with self.assertRaises(UnexpectedDER):
read_length(b("\x81\x7f"))
def test_long_form_with_zero_length(self):
with self.assertRaises(UnexpectedDER):
read_length(b("\x80"))
def test_smallest_two_byte_length(self):
self.assertEqual((128, 2), read_length(b("\x81\x80")))
def test_zero_padded_length(self):
with self.assertRaises(UnexpectedDER):
read_length(b("\x82\x00\x80"))
def test_two_three_byte_length(self):
self.assertEqual((256, 3), read_length(b"\x82\x01\x00"))
def test_empty_string(self):
with self.assertRaises(UnexpectedDER):
read_length(b(""))
def test_length_overflow(self):
with self.assertRaises(UnexpectedDER):
read_length(b("\x83\x01\x00"))
class TestEncodeBitstring(unittest.TestCase):
# DER requires BIT STRINGS to include a number of padding bits in the
# encoded byte string, that padding must be between 0 and 7
def test_old_call_convention(self):
"""This is the old way to use the function."""
warnings.simplefilter("always")
with pytest.warns(DeprecationWarning) as warns:
der = encode_bitstring(b"\x00\xff")
self.assertEqual(len(warns), 1)
self.assertIn(
"unused= needs to be specified", warns[0].message.args[0]
)
self.assertEqual(der, b"\x03\x02\x00\xff")
def test_new_call_convention(self):
"""This is how it should be called now."""
warnings.simplefilter("always")
with pytest.warns(None) as warns:
der = encode_bitstring(b"\xff", 0)
# verify that new call convention doesn't raise Warnings
self.assertEqual(len(warns), 0)
self.assertEqual(der, b"\x03\x02\x00\xff")
def test_implicit_unused_bits(self):
"""
Writing bit string with already included the number of unused bits.
"""
warnings.simplefilter("always")
with pytest.warns(None) as warns:
der = encode_bitstring(b"\x00\xff", None)
# verify that new call convention doesn't raise Warnings
self.assertEqual(len(warns), 0)
self.assertEqual(der, b"\x03\x02\x00\xff")
def test_explicit_unused_bits(self):
der = encode_bitstring(b"\xff\xf0", 4)
self.assertEqual(der, b"\x03\x03\x04\xff\xf0")
def test_empty_string(self):
self.assertEqual(encode_bitstring(b"", 0), b"\x03\x01\x00")
def test_invalid_unused_count(self):
with self.assertRaises(ValueError):
encode_bitstring(b"\xff\x00", 8)
def test_invalid_unused_with_empty_string(self):
with self.assertRaises(ValueError):
encode_bitstring(b"", 1)
def test_non_zero_padding_bits(self):
with self.assertRaises(ValueError):
encode_bitstring(b"\xff", 2)
class TestRemoveBitstring(unittest.TestCase):
def test_old_call_convention(self):
"""This is the old way to call the function."""
warnings.simplefilter("always")
with pytest.warns(DeprecationWarning) as warns:
bits, rest = remove_bitstring(b"\x03\x02\x00\xff")
self.assertEqual(len(warns), 1)
self.assertIn(
"expect_unused= needs to be specified", warns[0].message.args[0]
)
self.assertEqual(bits, b"\x00\xff")
self.assertEqual(rest, b"")
def test_new_call_convention(self):
warnings.simplefilter("always")
with pytest.warns(None) as warns:
bits, rest = remove_bitstring(b"\x03\x02\x00\xff", 0)
self.assertEqual(len(warns), 0)
self.assertEqual(bits, b"\xff")
self.assertEqual(rest, b"")
def test_implicit_unexpected_unused(self):
warnings.simplefilter("always")
with pytest.warns(None) as warns:
bits, rest = remove_bitstring(b"\x03\x02\x00\xff", None)
self.assertEqual(len(warns), 0)
self.assertEqual(bits, (b"\xff", 0))
self.assertEqual(rest, b"")
def test_with_padding(self):
ret, rest = remove_bitstring(b"\x03\x02\x04\xf0", None)
self.assertEqual(ret, (b"\xf0", 4))
self.assertEqual(rest, b"")
def test_not_a_bitstring(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b"\x02\x02\x00\xff", None)
def test_empty_encoding(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b"\x03\x00", None)
def test_empty_string(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b"", None)
def test_no_length(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b"\x03", None)
def test_unexpected_number_of_unused_bits(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b"\x03\x02\x00\xff", 1)
def test_invalid_encoding_of_unused_bits(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b"\x03\x03\x08\xff\x00", None)
def test_invalid_encoding_of_empty_string(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b"\x03\x01\x01", None)
def test_invalid_padding_bits(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b"\x03\x02\x01\xff", None)
class TestStrIdxAsInt(unittest.TestCase):
def test_str(self):
self.assertEqual(115, str_idx_as_int("str", 0))
def test_bytes(self):
self.assertEqual(115, str_idx_as_int(b"str", 0))
def test_bytearray(self):
self.assertEqual(115, str_idx_as_int(bytearray(b"str"), 0))
class TestEncodeOid(unittest.TestCase):
def test_pub_key_oid(self):
oid_ecPublicKey = encode_oid(1, 2, 840, 10045, 2, 1)
self.assertEqual(hexlify(oid_ecPublicKey), b("06072a8648ce3d0201"))
def test_nist224p_oid(self):
self.assertEqual(hexlify(NIST224p.encoded_oid), b("06052b81040021"))
def test_nist256p_oid(self):
self.assertEqual(
hexlify(NIST256p.encoded_oid), b"06082a8648ce3d030107"
)
def test_large_second_subid(self):
# from X.690, section 8.19.5
oid = encode_oid(2, 999, 3)
self.assertEqual(oid, b"\x06\x03\x88\x37\x03")
def test_with_two_subids(self):
oid = encode_oid(2, 999)
self.assertEqual(oid, b"\x06\x02\x88\x37")
def test_zero_zero(self):
oid = encode_oid(0, 0)
self.assertEqual(oid, b"\x06\x01\x00")
def test_with_wrong_types(self):
with self.assertRaises((TypeError, AssertionError)):
encode_oid(0, None)
def test_with_small_first_large_second(self):
with self.assertRaises(AssertionError):
encode_oid(1, 40)
def test_small_first_max_second(self):
oid = encode_oid(1, 39)
self.assertEqual(oid, b"\x06\x01\x4f")
def test_with_invalid_first(self):
with self.assertRaises(AssertionError):
encode_oid(3, 39)
class TestRemoveObject(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.oid_ecPublicKey = encode_oid(1, 2, 840, 10045, 2, 1)
def test_pub_key_oid(self):
oid, rest = remove_object(self.oid_ecPublicKey)
self.assertEqual(rest, b"")
self.assertEqual(oid, (1, 2, 840, 10045, 2, 1))
def test_with_extra_bytes(self):
oid, rest = remove_object(self.oid_ecPublicKey + b"more")
self.assertEqual(rest, b"more")
self.assertEqual(oid, (1, 2, 840, 10045, 2, 1))
def test_with_large_second_subid(self):
# from X.690, section 8.19.5
oid, rest = remove_object(b"\x06\x03\x88\x37\x03")
self.assertEqual(rest, b"")
self.assertEqual(oid, (2, 999, 3))
def test_with_padded_first_subid(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"\x06\x02\x80\x00")
def test_with_padded_second_subid(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"\x06\x04\x88\x37\x80\x01")
def test_with_missing_last_byte_of_multi_byte(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"\x06\x03\x88\x37\x83")
def test_with_two_subids(self):
oid, rest = remove_object(b"\x06\x02\x88\x37")
self.assertEqual(rest, b"")
self.assertEqual(oid, (2, 999))
def test_zero_zero(self):
oid, rest = remove_object(b"\x06\x01\x00")
self.assertEqual(rest, b"")
self.assertEqual(oid, (0, 0))
def test_empty_string(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"")
def test_missing_length(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"\x06")
def test_empty_oid(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"\x06\x00")
def test_empty_oid_overflow(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"\x06\x01")
def test_with_wrong_type(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"\x04\x02\x88\x37")
def test_with_too_long_length(self):
with self.assertRaises(UnexpectedDER):
remove_object(b"\x06\x03\x88\x37")
class TestRemoveConstructed(unittest.TestCase):
def test_simple(self):
data = b"\xa1\x02\xff\xaa"
tag, body, rest = remove_constructed(data)
self.assertEqual(tag, 0x01)
self.assertEqual(body, b"\xff\xaa")
self.assertEqual(rest, b"")
def test_with_malformed_tag(self):
data = b"\x01\x02\xff\xaa"
with self.assertRaises(UnexpectedDER) as e:
remove_constructed(data)
self.assertIn("constructed tag", str(e.exception))
class TestRemoveOctetString(unittest.TestCase):
def test_simple(self):
data = b"\x04\x03\xaa\xbb\xcc"
body, rest = remove_octet_string(data)
self.assertEqual(body, b"\xaa\xbb\xcc")
self.assertEqual(rest, b"")
def test_with_malformed_tag(self):
data = b"\x03\x03\xaa\xbb\xcc"
with self.assertRaises(UnexpectedDER) as e:
remove_octet_string(data)
self.assertIn("octetstring", str(e.exception))
class TestRemoveSequence(unittest.TestCase):
def test_simple(self):
data = b"\x30\x02\xff\xaa"
body, rest = remove_sequence(data)
self.assertEqual(body, b"\xff\xaa")
self.assertEqual(rest, b"")
def test_with_empty_string(self):
with self.assertRaises(UnexpectedDER) as e:
remove_sequence(b"")
self.assertIn("Empty string", str(e.exception))
def test_with_wrong_tag(self):
data = b"\x20\x02\xff\xaa"
with self.assertRaises(UnexpectedDER) as e:
remove_sequence(data)
self.assertIn("wanted type 'sequence'", str(e.exception))
def test_with_wrong_length(self):
data = b"\x30\x03\xff\xaa"
with self.assertRaises(UnexpectedDER) as e:
remove_sequence(data)
self.assertIn("Length longer", str(e.exception))
@st.composite
def st_oid(draw, max_value=2**512, max_size=50):
"""
Hypothesis strategy that returns valid OBJECT IDENTIFIERs as tuples
:param max_value: maximum value of any single sub-identifier
:param max_size: maximum length of the generated OID
"""
first = draw(st.integers(min_value=0, max_value=2))
if first < 2:
second = draw(st.integers(min_value=0, max_value=39))
else:
second = draw(st.integers(min_value=0, max_value=max_value))
rest = draw(
st.lists(
st.integers(min_value=0, max_value=max_value), max_size=max_size
)
)
return (first, second) + tuple(rest)
@given(st_oid())
def test_oids(ids):
encoded_oid = encode_oid(*ids)
decoded_oid, rest = remove_object(encoded_oid)
assert rest == b""
assert decoded_oid == ids

View file

@ -0,0 +1,441 @@
import os
import shutil
import subprocess
import pytest
from binascii import unhexlify
try:
import unittest2 as unittest
except ImportError:
import unittest
from .curves import (
NIST192p,
NIST224p,
NIST256p,
NIST384p,
NIST521p,
BRAINPOOLP160r1,
)
from .curves import curves
from .ecdh import (
ECDH,
InvalidCurveError,
InvalidSharedSecretError,
NoKeyError,
NoCurveError,
)
from .keys import SigningKey, VerifyingKey
from .ellipticcurve import CurveEdTw
@pytest.mark.parametrize(
"vcurve",
curves,
ids=[curve.name for curve in curves],
)
def test_ecdh_each(vcurve):
if isinstance(vcurve.curve, CurveEdTw):
pytest.skip("ECDH is not supported for Edwards curves")
ecdh1 = ECDH(curve=vcurve)
ecdh2 = ECDH(curve=vcurve)
ecdh2.generate_private_key()
ecdh1.load_received_public_key(ecdh2.get_public_key())
ecdh2.load_received_public_key(ecdh1.generate_private_key())
secret1 = ecdh1.generate_sharedsecret_bytes()
secret2 = ecdh2.generate_sharedsecret_bytes()
assert secret1 == secret2
def test_ecdh_both_keys_present():
key1 = SigningKey.generate(BRAINPOOLP160r1)
key2 = SigningKey.generate(BRAINPOOLP160r1)
ecdh1 = ECDH(BRAINPOOLP160r1, key1, key2.verifying_key)
ecdh2 = ECDH(private_key=key2, public_key=key1.verifying_key)
secret1 = ecdh1.generate_sharedsecret_bytes()
secret2 = ecdh2.generate_sharedsecret_bytes()
assert secret1 == secret2
def test_ecdh_no_public_key():
ecdh1 = ECDH(curve=NIST192p)
with pytest.raises(NoKeyError):
ecdh1.generate_sharedsecret_bytes()
ecdh1.generate_private_key()
with pytest.raises(NoKeyError):
ecdh1.generate_sharedsecret_bytes()
class TestECDH(unittest.TestCase):
def test_load_key_from_wrong_curve(self):
ecdh1 = ECDH()
ecdh1.set_curve(NIST192p)
key1 = SigningKey.generate(BRAINPOOLP160r1)
with self.assertRaises(InvalidCurveError) as e:
ecdh1.load_private_key(key1)
self.assertIn("Curve mismatch", str(e.exception))
def test_generate_without_curve(self):
ecdh1 = ECDH()
with self.assertRaises(NoCurveError) as e:
ecdh1.generate_private_key()
self.assertIn("Curve must be set", str(e.exception))
def test_load_bytes_without_curve_set(self):
ecdh1 = ECDH()
with self.assertRaises(NoCurveError) as e:
ecdh1.load_private_key_bytes(b"\x01" * 32)
self.assertIn("Curve must be set", str(e.exception))
def test_set_curve_from_received_public_key(self):
ecdh1 = ECDH()
key1 = SigningKey.generate(BRAINPOOLP160r1)
ecdh1.load_received_public_key(key1.verifying_key)
self.assertEqual(ecdh1.curve, BRAINPOOLP160r1)
def test_ecdh_wrong_public_key_curve():
ecdh1 = ECDH(curve=NIST192p)
ecdh1.generate_private_key()
ecdh2 = ECDH(curve=NIST256p)
ecdh2.generate_private_key()
with pytest.raises(InvalidCurveError):
ecdh1.load_received_public_key(ecdh2.get_public_key())
with pytest.raises(InvalidCurveError):
ecdh2.load_received_public_key(ecdh1.get_public_key())
ecdh1.public_key = ecdh2.get_public_key()
ecdh2.public_key = ecdh1.get_public_key()
with pytest.raises(InvalidCurveError):
ecdh1.generate_sharedsecret_bytes()
with pytest.raises(InvalidCurveError):
ecdh2.generate_sharedsecret_bytes()
def test_ecdh_invalid_shared_secret_curve():
ecdh1 = ECDH(curve=NIST256p)
ecdh1.generate_private_key()
ecdh1.load_received_public_key(
SigningKey.generate(NIST256p).get_verifying_key()
)
ecdh1.private_key.privkey.secret_multiplier = ecdh1.private_key.curve.order
with pytest.raises(InvalidSharedSecretError):
ecdh1.generate_sharedsecret_bytes()
# https://github.com/scogliani/ecc-test-vectors/blob/master/ecdh_kat/secp192r1.txt
# https://github.com/scogliani/ecc-test-vectors/blob/master/ecdh_kat/secp256r1.txt
# https://github.com/coruus/nist-testvectors/blob/master/csrc.nist.gov/groups/STM/cavp/documents/components/ecccdhtestvectors/KAS_ECC_CDH_PrimitiveTest.txt
@pytest.mark.parametrize(
"curve,privatekey,pubkey,secret",
[
pytest.param(
NIST192p,
"f17d3fea367b74d340851ca4270dcb24c271f445bed9d527",
"42ea6dd9969dd2a61fea1aac7f8e98edcc896c6e55857cc0"
"dfbe5d7c61fac88b11811bde328e8a0d12bf01a9d204b523",
"803d8ab2e5b6e6fca715737c3a82f7ce3c783124f6d51cd0",
id="NIST192p-1",
),
pytest.param(
NIST192p,
"56e853349d96fe4c442448dacb7cf92bb7a95dcf574a9bd5",
"deb5712fa027ac8d2f22c455ccb73a91e17b6512b5e030e7"
"7e2690a02cc9b28708431a29fb54b87b1f0c14e011ac2125",
"c208847568b98835d7312cef1f97f7aa298283152313c29d",
id="NIST192p-2",
),
pytest.param(
NIST192p,
"c6ef61fe12e80bf56f2d3f7d0bb757394519906d55500949",
"4edaa8efc5a0f40f843663ec5815e7762dddc008e663c20f"
"0a9f8dc67a3e60ef6d64b522185d03df1fc0adfd42478279",
"87229107047a3b611920d6e3b2c0c89bea4f49412260b8dd",
id="NIST192p-3",
),
pytest.param(
NIST192p,
"e6747b9c23ba7044f38ff7e62c35e4038920f5a0163d3cda",
"8887c276edeed3e9e866b46d58d895c73fbd80b63e382e88"
"04c5097ba6645e16206cfb70f7052655947dd44a17f1f9d5",
"eec0bed8fc55e1feddc82158fd6dc0d48a4d796aaf47d46c",
id="NIST192p-4",
),
pytest.param(
NIST192p,
"beabedd0154a1afcfc85d52181c10f5eb47adc51f655047d",
"0d045f30254adc1fcefa8a5b1f31bf4e739dd327cd18d594"
"542c314e41427c08278a08ce8d7305f3b5b849c72d8aff73",
"716e743b1b37a2cd8479f0a3d5a74c10ba2599be18d7e2f4",
id="NIST192p-5",
),
pytest.param(
NIST192p,
"cf70354226667321d6e2baf40999e2fd74c7a0f793fa8699",
"fb35ca20d2e96665c51b98e8f6eb3d79113508d8bccd4516"
"368eec0d5bfb847721df6aaff0e5d48c444f74bf9cd8a5a7",
"f67053b934459985a315cb017bf0302891798d45d0e19508",
id="NIST192p-6",
),
pytest.param(
NIST224p,
"8346a60fc6f293ca5a0d2af68ba71d1dd389e5e40837942df3e43cbd",
"af33cd0629bc7e996320a3f40368f74de8704fa37b8fab69abaae280"
"882092ccbba7930f419a8a4f9bb16978bbc3838729992559a6f2e2d7",
"7d96f9a3bd3c05cf5cc37feb8b9d5209d5c2597464dec3e9983743e8",
id="NIST224p",
),
pytest.param(
NIST256p,
"7d7dc5f71eb29ddaf80d6214632eeae03d9058af1fb6d22ed80badb62bc1a534",
"700c48f77f56584c5cc632ca65640db91b6bacce3a4df6b42ce7cc838833d287"
"db71e509e3fd9b060ddb20ba5c51dcc5948d46fbf640dfe0441782cab85fa4ac",
"46fc62106420ff012e54a434fbdd2d25ccc5852060561e68040dd7778997bd7b",
id="NIST256p-1",
),
pytest.param(
NIST256p,
"38f65d6dce47676044d58ce5139582d568f64bb16098d179dbab07741dd5caf5",
"809f04289c64348c01515eb03d5ce7ac1a8cb9498f5caa50197e58d43a86a7ae"
"b29d84e811197f25eba8f5194092cb6ff440e26d4421011372461f579271cda3",
"057d636096cb80b67a8c038c890e887d1adfa4195e9b3ce241c8a778c59cda67",
id="NIST256p-2",
),
pytest.param(
NIST256p,
"1accfaf1b97712b85a6f54b148985a1bdc4c9bec0bd258cad4b3d603f49f32c8",
"a2339c12d4a03c33546de533268b4ad667debf458b464d77443636440ee7fec3"
"ef48a3ab26e20220bcda2c1851076839dae88eae962869a497bf73cb66faf536",
"2d457b78b4614132477618a5b077965ec90730a8c81a1c75d6d4ec68005d67ec",
id="NIST256p-3",
),
pytest.param(
NIST256p,
"207c43a79bfee03db6f4b944f53d2fb76cc49ef1c9c4d34d51b6c65c4db6932d",
"df3989b9fa55495719b3cf46dccd28b5153f7808191dd518eff0c3cff2b705ed"
"422294ff46003429d739a33206c8752552c8ba54a270defc06e221e0feaf6ac4",
"96441259534b80f6aee3d287a6bb17b5094dd4277d9e294f8fe73e48bf2a0024",
id="NIST256p-4",
),
pytest.param(
NIST256p,
"59137e38152350b195c9718d39673d519838055ad908dd4757152fd8255c09bf",
"41192d2813e79561e6a1d6f53c8bc1a433a199c835e141b05a74a97b0faeb922"
"1af98cc45e98a7e041b01cf35f462b7562281351c8ebf3ffa02e33a0722a1328",
"19d44c8d63e8e8dd12c22a87b8cd4ece27acdde04dbf47f7f27537a6999a8e62",
id="NIST256p-5",
),
pytest.param(
NIST256p,
"f5f8e0174610a661277979b58ce5c90fee6c9b3bb346a90a7196255e40b132ef",
"33e82092a0f1fb38f5649d5867fba28b503172b7035574bf8e5b7100a3052792"
"f2cf6b601e0a05945e335550bf648d782f46186c772c0f20d3cd0d6b8ca14b2f",
"664e45d5bba4ac931cd65d52017e4be9b19a515f669bea4703542a2c525cd3d3",
id="NIST256p-6",
),
pytest.param(
NIST384p,
"3cc3122a68f0d95027ad38c067916ba0eb8c38894d22e1b1"
"5618b6818a661774ad463b205da88cf699ab4d43c9cf98a1",
"a7c76b970c3b5fe8b05d2838ae04ab47697b9eaf52e76459"
"2efda27fe7513272734466b400091adbf2d68c58e0c50066"
"ac68f19f2e1cb879aed43a9969b91a0839c4c38a49749b66"
"1efedf243451915ed0905a32b060992b468c64766fc8437a",
"5f9d29dc5e31a163060356213669c8ce132e22f57c9a04f4"
"0ba7fcead493b457e5621e766c40a2e3d4d6a04b25e533f1",
id="NIST384p",
),
pytest.param(
NIST521p,
"017eecc07ab4b329068fba65e56a1f8890aa935e57134ae0ffcce802735151f4ea"
"c6564f6ee9974c5e6887a1fefee5743ae2241bfeb95d5ce31ddcb6f9edb4d6fc47",
"00685a48e86c79f0f0875f7bc18d25eb5fc8c0b07e5da4f4370f3a949034085433"
"4b1e1b87fa395464c60626124a4e70d0f785601d37c09870ebf176666877a2046d"
"01ba52c56fc8776d9e8f5db4f0cc27636d0b741bbe05400697942e80b739884a83"
"bde99e0f6716939e632bc8986fa18dccd443a348b6c3e522497955a4f3c302f676",
"005fc70477c3e63bc3954bd0df3ea0d1f41ee21746ed95fc5e1fdf90930d5e1366"
"72d72cc770742d1711c3c3a4c334a0ad9759436a4d3c5bf6e74b9578fac148c831",
id="NIST521p",
),
],
)
def test_ecdh_NIST(curve, privatekey, pubkey, secret):
ecdh = ECDH(curve=curve)
ecdh.load_private_key_bytes(unhexlify(privatekey))
ecdh.load_received_public_key_bytes(unhexlify(pubkey))
sharedsecret = ecdh.generate_sharedsecret_bytes()
assert sharedsecret == unhexlify(secret)
pem_local_private_key = (
"-----BEGIN EC PRIVATE KEY-----\n"
"MF8CAQEEGF7IQgvW75JSqULpiQQ8op9WH6Uldw6xxaAKBggqhkjOPQMBAaE0AzIA\n"
"BLiBd9CE7xf15FY5QIAoNg+fWbSk1yZOYtoGUdzkejWkxbRc9RWTQjqLVXucIJnz\n"
"bA==\n"
"-----END EC PRIVATE KEY-----\n"
)
der_local_private_key = (
"305f02010104185ec8420bd6ef9252a942e989043ca29f561fa525770eb1c5a00a06082a864"
"8ce3d030101a13403320004b88177d084ef17f5e45639408028360f9f59b4a4d7264e62da06"
"51dce47a35a4c5b45cf51593423a8b557b9c2099f36c"
)
pem_remote_public_key = (
"-----BEGIN PUBLIC KEY-----\n"
"MEkwEwYHKoZIzj0CAQYIKoZIzj0DAQEDMgAEuIF30ITvF/XkVjlAgCg2D59ZtKTX\n"
"Jk5i2gZR3OR6NaTFtFz1FZNCOotVe5wgmfNs\n"
"-----END PUBLIC KEY-----\n"
)
der_remote_public_key = (
"3049301306072a8648ce3d020106082a8648ce3d03010103320004b88177d084ef17f5e4563"
"9408028360f9f59b4a4d7264e62da0651dce47a35a4c5b45cf51593423a8b557b9c2099f36c"
)
gshared_secret = "8f457e34982478d1c34b9cd2d0c15911b72dd60d869e2cea"
def test_ecdh_pem():
ecdh = ECDH()
ecdh.load_private_key_pem(pem_local_private_key)
ecdh.load_received_public_key_pem(pem_remote_public_key)
sharedsecret = ecdh.generate_sharedsecret_bytes()
assert sharedsecret == unhexlify(gshared_secret)
def test_ecdh_der():
ecdh = ECDH()
ecdh.load_private_key_der(unhexlify(der_local_private_key))
ecdh.load_received_public_key_der(unhexlify(der_remote_public_key))
sharedsecret = ecdh.generate_sharedsecret_bytes()
assert sharedsecret == unhexlify(gshared_secret)
# Exception classes used by run_openssl.
class RunOpenSslError(Exception):
pass
def run_openssl(cmd):
OPENSSL = "openssl"
p = subprocess.Popen(
[OPENSSL] + cmd.split(),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
stdout, ignored = p.communicate()
if p.returncode != 0:
raise RunOpenSslError(
"cmd '%s %s' failed: rc=%s, stdout/err was %s"
% (OPENSSL, cmd, p.returncode, stdout)
)
return stdout.decode()
OPENSSL_SUPPORTED_CURVES = set(
c.split(":")[0].strip()
for c in run_openssl("ecparam -list_curves").split("\n")
)
@pytest.mark.parametrize(
"vcurve",
curves,
ids=[curve.name for curve in curves],
)
def test_ecdh_with_openssl(vcurve):
if isinstance(vcurve.curve, CurveEdTw):
pytest.skip("Edwards curves are not supported for ECDH")
assert vcurve.openssl_name
if vcurve.openssl_name not in OPENSSL_SUPPORTED_CURVES:
pytest.skip("system openssl does not support " + vcurve.openssl_name)
try:
hlp = run_openssl("pkeyutl -help")
if hlp.find("-derive") == 0: # pragma: no cover
pytest.skip("system openssl does not support `pkeyutl -derive`")
except RunOpenSslError: # pragma: no cover
pytest.skip("system openssl could not be executed")
if os.path.isdir("t"): # pragma: no branch
shutil.rmtree("t")
os.mkdir("t")
run_openssl(
"ecparam -name %s -genkey -out t/privkey1.pem" % vcurve.openssl_name
)
run_openssl(
"ecparam -name %s -genkey -out t/privkey2.pem" % vcurve.openssl_name
)
run_openssl("ec -in t/privkey1.pem -pubout -out t/pubkey1.pem")
ecdh1 = ECDH(curve=vcurve)
ecdh2 = ECDH(curve=vcurve)
with open("t/privkey1.pem") as e:
key = e.read()
ecdh1.load_private_key_pem(key)
with open("t/privkey2.pem") as e:
key = e.read()
ecdh2.load_private_key_pem(key)
with open("t/pubkey1.pem") as e:
key = e.read()
vk1 = VerifyingKey.from_pem(key)
assert vk1.to_string() == ecdh1.get_public_key().to_string()
vk2 = ecdh2.get_public_key()
with open("t/pubkey2.pem", "wb") as e:
e.write(vk2.to_pem())
ecdh1.load_received_public_key(vk2)
ecdh2.load_received_public_key(vk1)
secret1 = ecdh1.generate_sharedsecret_bytes()
secret2 = ecdh2.generate_sharedsecret_bytes()
assert secret1 == secret2
run_openssl(
"pkeyutl -derive -inkey t/privkey1.pem -peerkey t/pubkey2.pem -out t/secret1"
)
run_openssl(
"pkeyutl -derive -inkey t/privkey2.pem -peerkey t/pubkey1.pem -out t/secret2"
)
with open("t/secret1", "rb") as e:
ssl_secret1 = e.read()
with open("t/secret1", "rb") as e:
ssl_secret2 = e.read()
assert len(ssl_secret1) == vk1.curve.verifying_key_length // 2
assert len(secret1) == vk1.curve.verifying_key_length // 2
assert ssl_secret1 == ssl_secret2
assert secret1 == ssl_secret1

View file

@ -0,0 +1,661 @@
from __future__ import print_function
import sys
import hypothesis.strategies as st
from hypothesis import given, settings, note, example
try:
import unittest2 as unittest
except ImportError:
import unittest
import pytest
from .ecdsa import (
Private_key,
Public_key,
Signature,
generator_192,
digest_integer,
ellipticcurve,
point_is_valid,
generator_224,
generator_256,
generator_384,
generator_521,
generator_secp256k1,
curve_192,
InvalidPointError,
curve_112r2,
generator_112r2,
int_to_string,
)
HYP_SETTINGS = {}
# old hypothesis doesn't have the "deadline" setting
if sys.version_info > (2, 7): # pragma: no branch
# SEC521p is slow, allow long execution for it
HYP_SETTINGS["deadline"] = 5000
class TestP192FromX9_62(unittest.TestCase):
"""Check test vectors from X9.62"""
@classmethod
def setUpClass(cls):
cls.d = 651056770906015076056810763456358567190100156695615665659
cls.Q = cls.d * generator_192
cls.k = 6140507067065001063065065565667405560006161556565665656654
cls.R = cls.k * generator_192
cls.msg = 968236873715988614170569073515315707566766479517
cls.pubk = Public_key(generator_192, generator_192 * cls.d)
cls.privk = Private_key(cls.pubk, cls.d)
cls.sig = cls.privk.sign(cls.msg, cls.k)
def test_point_multiplication(self):
assert self.Q.x() == 0x62B12D60690CDCF330BABAB6E69763B471F994DD702D16A5
def test_point_multiplication_2(self):
assert self.R.x() == 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD
assert self.R.y() == 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835
def test_mult_and_addition(self):
u1 = 2563697409189434185194736134579731015366492496392189760599
u2 = 6266643813348617967186477710235785849136406323338782220568
temp = u1 * generator_192 + u2 * self.Q
assert temp.x() == 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD
assert temp.y() == 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835
def test_signature(self):
r, s = self.sig.r, self.sig.s
assert r == 3342403536405981729393488334694600415596881826869351677613
assert s == 5735822328888155254683894997897571951568553642892029982342
def test_verification(self):
assert self.pubk.verifies(self.msg, self.sig)
def test_rejection(self):
assert not self.pubk.verifies(self.msg - 1, self.sig)
class TestPublicKey(unittest.TestCase):
def test_equality_public_keys(self):
gen = generator_192
x = 0xC58D61F88D905293BCD4CD0080BCB1B7F811F2FFA41979F6
y = 0x8804DC7A7C4C7F8B5D437F5156F3312CA7D6DE8A0E11867F
point = ellipticcurve.Point(gen.curve(), x, y)
pub_key1 = Public_key(gen, point)
pub_key2 = Public_key(gen, point)
self.assertEqual(pub_key1, pub_key2)
def test_inequality_public_key(self):
gen = generator_192
x1 = 0xC58D61F88D905293BCD4CD0080BCB1B7F811F2FFA41979F6
y1 = 0x8804DC7A7C4C7F8B5D437F5156F3312CA7D6DE8A0E11867F
point1 = ellipticcurve.Point(gen.curve(), x1, y1)
x2 = 0x6A223D00BD22C52833409A163E057E5B5DA1DEF2A197DD15
y2 = 0x7B482604199367F1F303F9EF627F922F97023E90EAE08ABF
point2 = ellipticcurve.Point(gen.curve(), x2, y2)
pub_key1 = Public_key(gen, point1)
pub_key2 = Public_key(gen, point2)
self.assertNotEqual(pub_key1, pub_key2)
def test_inequality_different_curves(self):
gen = generator_192
x1 = 0xC58D61F88D905293BCD4CD0080BCB1B7F811F2FFA41979F6
y1 = 0x8804DC7A7C4C7F8B5D437F5156F3312CA7D6DE8A0E11867F
point1 = ellipticcurve.Point(gen.curve(), x1, y1)
x2 = 0x722BA0FB6B8FC8898A4C6AB49E66
y2 = 0x2B7344BB57A7ABC8CA0F1A398C7D
point2 = ellipticcurve.Point(generator_112r2.curve(), x2, y2)
pub_key1 = Public_key(gen, point1)
pub_key2 = Public_key(generator_112r2, point2)
self.assertNotEqual(pub_key1, pub_key2)
def test_inequality_public_key_not_implemented(self):
gen = generator_192
x = 0xC58D61F88D905293BCD4CD0080BCB1B7F811F2FFA41979F6
y = 0x8804DC7A7C4C7F8B5D437F5156F3312CA7D6DE8A0E11867F
point = ellipticcurve.Point(gen.curve(), x, y)
pub_key = Public_key(gen, point)
self.assertNotEqual(pub_key, None)
def test_public_key_with_generator_without_order(self):
gen = ellipticcurve.PointJacobi(
generator_192.curve(), generator_192.x(), generator_192.y(), 1
)
x = 0xC58D61F88D905293BCD4CD0080BCB1B7F811F2FFA41979F6
y = 0x8804DC7A7C4C7F8B5D437F5156F3312CA7D6DE8A0E11867F
point = ellipticcurve.Point(gen.curve(), x, y)
with self.assertRaises(InvalidPointError) as e:
Public_key(gen, point)
self.assertIn("Generator point must have order", str(e.exception))
def test_public_point_on_curve_not_scalar_multiple_of_base_point(self):
x = 2
y = 0xBE6AA4938EF7CFE6FE29595B6B00
# we need a curve with cofactor != 1
point = ellipticcurve.PointJacobi(curve_112r2, x, y, 1)
self.assertTrue(curve_112r2.contains_point(x, y))
with self.assertRaises(InvalidPointError) as e:
Public_key(generator_112r2, point)
self.assertIn("Generator point order", str(e.exception))
def test_point_is_valid_with_not_scalar_multiple_of_base_point(self):
x = 2
y = 0xBE6AA4938EF7CFE6FE29595B6B00
self.assertFalse(point_is_valid(generator_112r2, x, y))
# the tests to verify the extensiveness of tests in ecdsa.ecdsa
# if PointJacobi gets modified to calculate the x and y mod p the tests
# below will need to use a fake/mock object
def test_invalid_point_x_negative(self):
pt = ellipticcurve.PointJacobi(curve_192, -1, 0, 1)
with self.assertRaises(InvalidPointError) as e:
Public_key(generator_192, pt)
self.assertIn("The public point has x or y", str(e.exception))
def test_invalid_point_x_equal_p(self):
pt = ellipticcurve.PointJacobi(curve_192, curve_192.p(), 0, 1)
with self.assertRaises(InvalidPointError) as e:
Public_key(generator_192, pt)
self.assertIn("The public point has x or y", str(e.exception))
def test_invalid_point_y_negative(self):
pt = ellipticcurve.PointJacobi(curve_192, 0, -1, 1)
with self.assertRaises(InvalidPointError) as e:
Public_key(generator_192, pt)
self.assertIn("The public point has x or y", str(e.exception))
def test_invalid_point_y_equal_p(self):
pt = ellipticcurve.PointJacobi(curve_192, 0, curve_192.p(), 1)
with self.assertRaises(InvalidPointError) as e:
Public_key(generator_192, pt)
self.assertIn("The public point has x or y", str(e.exception))
class TestPublicKeyVerifies(unittest.TestCase):
# test all the different ways that a signature can be publicly invalid
@classmethod
def setUpClass(cls):
gen = generator_192
x = 0xC58D61F88D905293BCD4CD0080BCB1B7F811F2FFA41979F6
y = 0x8804DC7A7C4C7F8B5D437F5156F3312CA7D6DE8A0E11867F
point = ellipticcurve.Point(gen.curve(), x, y)
cls.pub_key = Public_key(gen, point)
def test_sig_with_r_zero(self):
sig = Signature(0, 1)
self.assertFalse(self.pub_key.verifies(1, sig))
def test_sig_with_r_order(self):
sig = Signature(generator_192.order(), 1)
self.assertFalse(self.pub_key.verifies(1, sig))
def test_sig_with_s_zero(self):
sig = Signature(1, 0)
self.assertFalse(self.pub_key.verifies(1, sig))
def test_sig_with_s_order(self):
sig = Signature(1, generator_192.order())
self.assertFalse(self.pub_key.verifies(1, sig))
class TestPrivateKey(unittest.TestCase):
@classmethod
def setUpClass(cls):
gen = generator_192
x = 0xC58D61F88D905293BCD4CD0080BCB1B7F811F2FFA41979F6
y = 0x8804DC7A7C4C7F8B5D437F5156F3312CA7D6DE8A0E11867F
point = ellipticcurve.Point(gen.curve(), x, y)
cls.pub_key = Public_key(gen, point)
def test_equality_private_keys(self):
pr_key1 = Private_key(self.pub_key, 100)
pr_key2 = Private_key(self.pub_key, 100)
self.assertEqual(pr_key1, pr_key2)
def test_inequality_private_keys(self):
pr_key1 = Private_key(self.pub_key, 100)
pr_key2 = Private_key(self.pub_key, 200)
self.assertNotEqual(pr_key1, pr_key2)
def test_inequality_private_keys_not_implemented(self):
pr_key = Private_key(self.pub_key, 100)
self.assertNotEqual(pr_key, None)
# Testing point validity, as per ECDSAVS.pdf B.2.2:
P192_POINTS = [
(
generator_192,
0xCD6D0F029A023E9AACA429615B8F577ABEE685D8257CC83A,
0x00019C410987680E9FB6C0B6ECC01D9A2647C8BAE27721BACDFC,
False,
),
(
generator_192,
0x00017F2FCE203639E9EAF9FB50B81FC32776B30E3B02AF16C73B,
0x95DA95C5E72DD48E229D4748D4EEE658A9A54111B23B2ADB,
False,
),
(
generator_192,
0x4F77F8BC7FCCBADD5760F4938746D5F253EE2168C1CF2792,
0x000147156FF824D131629739817EDB197717C41AAB5C2A70F0F6,
False,
),
(
generator_192,
0xC58D61F88D905293BCD4CD0080BCB1B7F811F2FFA41979F6,
0x8804DC7A7C4C7F8B5D437F5156F3312CA7D6DE8A0E11867F,
True,
),
(
generator_192,
0xCDF56C1AA3D8AFC53C521ADF3FFB96734A6A630A4A5B5A70,
0x97C1C44A5FB229007B5EC5D25F7413D170068FFD023CAA4E,
True,
),
(
generator_192,
0x89009C0DC361C81E99280C8E91DF578DF88CDF4B0CDEDCED,
0x27BE44A529B7513E727251F128B34262A0FD4D8EC82377B9,
True,
),
(
generator_192,
0x6A223D00BD22C52833409A163E057E5B5DA1DEF2A197DD15,
0x7B482604199367F1F303F9EF627F922F97023E90EAE08ABF,
True,
),
(
generator_192,
0x6DCCBDE75C0948C98DAB32EA0BC59FE125CF0FB1A3798EDA,
0x0001171A3E0FA60CF3096F4E116B556198DE430E1FBD330C8835,
False,
),
(
generator_192,
0xD266B39E1F491FC4ACBBBC7D098430931CFA66D55015AF12,
0x193782EB909E391A3148B7764E6B234AA94E48D30A16DBB2,
False,
),
(
generator_192,
0x9D6DDBCD439BAA0C6B80A654091680E462A7D1D3F1FFEB43,
0x6AD8EFC4D133CCF167C44EB4691C80ABFFB9F82B932B8CAA,
False,
),
(
generator_192,
0x146479D944E6BDA87E5B35818AA666A4C998A71F4E95EDBC,
0xA86D6FE62BC8FBD88139693F842635F687F132255858E7F6,
False,
),
(
generator_192,
0xE594D4A598046F3598243F50FD2C7BD7D380EDB055802253,
0x509014C0C4D6B536E3CA750EC09066AF39B4C8616A53A923,
False,
),
]
@pytest.mark.parametrize("generator,x,y,expected", P192_POINTS)
def test_point_validity(generator, x, y, expected):
"""
`generator` defines the curve; is `(x, y)` a point on
this curve? `expected` is True if the right answer is Yes.
"""
assert point_is_valid(generator, x, y) == expected
# Trying signature-verification tests from ECDSAVS.pdf B.2.4:
CURVE_192_KATS = [
(
generator_192,
int(
"0x84ce72aa8699df436059f052ac51b6398d2511e49631bcb7e71f89c499b9ee"
"425dfbc13a5f6d408471b054f2655617cbbaf7937b7c80cd8865cf02c8487d30"
"d2b0fbd8b2c4e102e16d828374bbc47b93852f212d5043c3ea720f086178ff79"
"8cc4f63f787b9c2e419efa033e7644ea7936f54462dc21a6c4580725f7f0e7d1"
"58",
16,
),
0xD9DBFB332AA8E5FF091E8CE535857C37C73F6250FFB2E7AC,
0x282102E364FEDED3AD15DDF968F88D8321AA268DD483EBC4,
0x64DCA58A20787C488D11D6DD96313F1B766F2D8EFE122916,
0x1ECBA28141E84AB4ECAD92F56720E2CC83EB3D22DEC72479,
True,
),
(
generator_192,
int(
"0x94bb5bacd5f8ea765810024db87f4224ad71362a3c28284b2b9f39fab86db1"
"2e8beb94aae899768229be8fdb6c4f12f28912bb604703a79ccff769c1607f5a"
"91450f30ba0460d359d9126cbd6296be6d9c4bb96c0ee74cbb44197c207f6db3"
"26ab6f5a659113a9034e54be7b041ced9dcf6458d7fb9cbfb2744d999f7dfd63"
"f4",
16,
),
0x3E53EF8D3112AF3285C0E74842090712CD324832D4277AE7,
0xCC75F8952D30AEC2CBB719FC6AA9934590B5D0FF5A83ADB7,
0x8285261607283BA18F335026130BAB31840DCFD9C3E555AF,
0x356D89E1B04541AFC9704A45E9C535CE4A50929E33D7E06C,
True,
),
(
generator_192,
int(
"0xf6227a8eeb34afed1621dcc89a91d72ea212cb2f476839d9b4243c66877911"
"b37b4ad6f4448792a7bbba76c63bdd63414b6facab7dc71c3396a73bd7ee14cd"
"d41a659c61c99b779cecf07bc51ab391aa3252386242b9853ea7da67fd768d30"
"3f1b9b513d401565b6f1eb722dfdb96b519fe4f9bd5de67ae131e64b40e78c42"
"dd",
16,
),
0x16335DBE95F8E8254A4E04575D736BEFB258B8657F773CB7,
0x421B13379C59BC9DCE38A1099CA79BBD06D647C7F6242336,
0x4141BD5D64EA36C5B0BD21EF28C02DA216ED9D04522B1E91,
0x159A6AA852BCC579E821B7BB0994C0861FB08280C38DAA09,
False,
),
(
generator_192,
int(
"0x16b5f93afd0d02246f662761ed8e0dd9504681ed02a253006eb36736b56309"
"7ba39f81c8e1bce7a16c1339e345efabbc6baa3efb0612948ae51103382a8ee8"
"bc448e3ef71e9f6f7a9676694831d7f5dd0db5446f179bcb737d4a526367a447"
"bfe2c857521c7f40b6d7d7e01a180d92431fb0bbd29c04a0c420a57b3ed26ccd"
"8a",
16,
),
0xFD14CDF1607F5EFB7B1793037B15BDF4BAA6F7C16341AB0B,
0x83FA0795CC6C4795B9016DAC928FD6BAC32F3229A96312C4,
0x8DFDB832951E0167C5D762A473C0416C5C15BC1195667DC1,
0x1720288A2DC13FA1EC78F763F8FE2FF7354A7E6FDDE44520,
False,
),
(
generator_192,
int(
"0x08a2024b61b79d260e3bb43ef15659aec89e5b560199bc82cf7c65c77d3919"
"2e03b9a895d766655105edd9188242b91fbde4167f7862d4ddd61e5d4ab55196"
"683d4f13ceb90d87aea6e07eb50a874e33086c4a7cb0273a8e1c4408f4b846bc"
"eae1ebaac1b2b2ea851a9b09de322efe34cebe601653efd6ddc876ce8c2f2072"
"fb",
16,
),
0x674F941DC1A1F8B763C9334D726172D527B90CA324DB8828,
0x65ADFA32E8B236CB33A3E84CF59BFB9417AE7E8EDE57A7FF,
0x9508B9FDD7DAF0D8126F9E2BC5A35E4C6D800B5B804D7796,
0x36F2BF6B21B987C77B53BB801B3435A577E3D493744BFAB0,
False,
),
(
generator_192,
int(
"0x1843aba74b0789d4ac6b0b8923848023a644a7b70afa23b1191829bbe4397c"
"e15b629bf21a8838298653ed0c19222b95fa4f7390d1b4c844d96e645537e0aa"
"e98afb5c0ac3bd0e4c37f8daaff25556c64e98c319c52687c904c4de7240a1cc"
"55cd9756b7edaef184e6e23b385726e9ffcba8001b8f574987c1a3fedaaa83ca"
"6d",
16,
),
0x10ECCA1AAD7220B56A62008B35170BFD5E35885C4014A19F,
0x04EB61984C6C12ADE3BC47F3C629ECE7AA0A033B9948D686,
0x82BFA4E82C0DFE9274169B86694E76CE993FD83B5C60F325,
0xA97685676C59A65DBDE002FE9D613431FB183E8006D05633,
False,
),
(
generator_192,
int(
"0x5a478f4084ddd1a7fea038aa9732a822106385797d02311aeef4d0264f824f"
"698df7a48cfb6b578cf3da416bc0799425bb491be5b5ecc37995b85b03420a98"
"f2c4dc5c31a69a379e9e322fbe706bbcaf0f77175e05cbb4fa162e0da82010a2"
"78461e3e974d137bc746d1880d6eb02aa95216014b37480d84b87f717bb13f76"
"e1",
16,
),
0x6636653CB5B894CA65C448277B29DA3AD101C4C2300F7C04,
0xFDF1CBB3FC3FD6A4F890B59E554544175FA77DBDBEB656C1,
0xEAC2DDECDDFB79931A9C3D49C08DE0645C783A24CB365E1C,
0x3549FEE3CFA7E5F93BC47D92D8BA100E881A2A93C22F8D50,
False,
),
(
generator_192,
int(
"0xc598774259a058fa65212ac57eaa4f52240e629ef4c310722088292d1d4af6"
"c39b49ce06ba77e4247b20637174d0bd67c9723feb57b5ead232b47ea452d5d7"
"a089f17c00b8b6767e434a5e16c231ba0efa718a340bf41d67ea2d295812ff1b"
"9277daacb8bc27b50ea5e6443bcf95ef4e9f5468fe78485236313d53d1c68f6b"
"a2",
16,
),
0xA82BD718D01D354001148CD5F69B9EBF38FF6F21898F8AAA,
0xE67CEEDE07FC2EBFAFD62462A51E4B6C6B3D5B537B7CAF3E,
0x4D292486C620C3DE20856E57D3BB72FCDE4A73AD26376955,
0xA85289591A6081D5728825520E62FF1C64F94235C04C7F95,
False,
),
(
generator_192,
int(
"0xca98ed9db081a07b7557f24ced6c7b9891269a95d2026747add9e9eb80638a"
"961cf9c71a1b9f2c29744180bd4c3d3db60f2243c5c0b7cc8a8d40a3f9a7fc91"
"0250f2187136ee6413ffc67f1a25e1c4c204fa9635312252ac0e0481d89b6d53"
"808f0c496ba87631803f6c572c1f61fa049737fdacce4adff757afed4f05beb6"
"58",
16,
),
0x7D3B016B57758B160C4FCA73D48DF07AE3B6B30225126C2F,
0x4AF3790D9775742BDE46F8DA876711BE1B65244B2B39E7EC,
0x95F778F5F656511A5AB49A5D69DDD0929563C29CBC3A9E62,
0x75C87FC358C251B4C83D2DD979FAAD496B539F9F2EE7A289,
False,
),
(
generator_192,
int(
"0x31dd9a54c8338bea06b87eca813d555ad1850fac9742ef0bbe40dad400e102"
"88acc9c11ea7dac79eb16378ebea9490e09536099f1b993e2653cd50240014c9"
"0a9c987f64545abc6a536b9bd2435eb5e911fdfde2f13be96ea36ad38df4ae9e"
"a387b29cced599af777338af2794820c9cce43b51d2112380a35802ab7e396c9"
"7a",
16,
),
0x9362F28C4EF96453D8A2F849F21E881CD7566887DA8BEB4A,
0xE64D26D8D74C48A024AE85D982EE74CD16046F4EE5333905,
0xF3923476A296C88287E8DE914B0B324AD5A963319A4FE73B,
0xF0BAEED7624ED00D15244D8BA2AEDE085517DBDEC8AC65F5,
True,
),
(
generator_192,
int(
"0xb2b94e4432267c92f9fdb9dc6040c95ffa477652761290d3c7de312283f645"
"0d89cc4aabe748554dfb6056b2d8e99c7aeaad9cdddebdee9dbc099839562d90"
"64e68e7bb5f3a6bba0749ca9a538181fc785553a4000785d73cc207922f63e8c"
"e1112768cb1de7b673aed83a1e4a74592f1268d8e2a4e9e63d414b5d442bd045"
"6d",
16,
),
0xCC6FC032A846AAAC25533EB033522824F94E670FA997ECEF,
0xE25463EF77A029ECCDA8B294FD63DD694E38D223D30862F1,
0x066B1D07F3A40E679B620EDA7F550842A35C18B80C5EBE06,
0xA0B0FB201E8F2DF65E2C4508EF303BDC90D934016F16B2DC,
False,
),
(
generator_192,
int(
"0x4366fcadf10d30d086911de30143da6f579527036937007b337f7282460eae"
"5678b15cccda853193ea5fc4bc0a6b9d7a31128f27e1214988592827520b214e"
"ed5052f7775b750b0c6b15f145453ba3fee24a085d65287e10509eb5d5f602c4"
"40341376b95c24e5c4727d4b859bfe1483d20538acdd92c7997fa9c614f0f839"
"d7",
16,
),
0x955C908FE900A996F7E2089BEE2F6376830F76A19135E753,
0xBA0C42A91D3847DE4A592A46DC3FDAF45A7CC709B90DE520,
0x1F58AD77FC04C782815A1405B0925E72095D906CBF52A668,
0xF2E93758B3AF75EDF784F05A6761C9B9A6043C66B845B599,
False,
),
(
generator_192,
int(
"0x543f8af57d750e33aa8565e0cae92bfa7a1ff78833093421c2942cadf99866"
"70a5ff3244c02a8225e790fbf30ea84c74720abf99cfd10d02d34377c3d3b412"
"69bea763384f372bb786b5846f58932defa68023136cd571863b304886e95e52"
"e7877f445b9364b3f06f3c28da12707673fecb4b8071de06b6e0a3c87da160ce"
"f3",
16,
),
0x31F7FA05576D78A949B24812D4383107A9A45BB5FCCDD835,
0x8DC0EB65994A90F02B5E19BD18B32D61150746C09107E76B,
0xBE26D59E4E883DDE7C286614A767B31E49AD88789D3A78FF,
0x8762CA831C1CE42DF77893C9B03119428E7A9B819B619068,
False,
),
(
generator_192,
int(
"0xd2e8454143ce281e609a9d748014dcebb9d0bc53adb02443a6aac2ffe6cb009f"
"387c346ecb051791404f79e902ee333ad65e5c8cb38dc0d1d39a8dc90add502357"
"2720e5b94b190d43dd0d7873397504c0c7aef2727e628eb6a74411f2e400c65670"
"716cb4a815dc91cbbfeb7cfe8c929e93184c938af2c078584da045e8f8d1",
16,
),
0x66AA8EDBBDB5CF8E28CEB51B5BDA891CAE2DF84819FE25C0,
0x0C6BC2F69030A7CE58D4A00E3B3349844784A13B8936F8DA,
0xA4661E69B1734F4A71B788410A464B71E7FFE42334484F23,
0x738421CF5E049159D69C57A915143E226CAC8355E149AFE9,
False,
),
(
generator_192,
int(
"0x6660717144040f3e2f95a4e25b08a7079c702a8b29babad5a19a87654bc5c5af"
"a261512a11b998a4fb36b5d8fe8bd942792ff0324b108120de86d63f65855e5461"
"184fc96a0a8ffd2ce6d5dfb0230cbbdd98f8543e361b3205f5da3d500fdc8bac6d"
"b377d75ebef3cb8f4d1ff738071ad0938917889250b41dd1d98896ca06fb",
16,
),
0xBCFACF45139B6F5F690A4C35A5FFFA498794136A2353FC77,
0x6F4A6C906316A6AFC6D98FE1F0399D056F128FE0270B0F22,
0x9DB679A3DAFE48F7CCAD122933ACFE9DA0970B71C94C21C1,
0x984C2DB99827576C0A41A5DA41E07D8CC768BC82F18C9DA9,
False,
),
]
@pytest.mark.parametrize("gen,msg,qx,qy,r,s,expected", CURVE_192_KATS)
def test_signature_validity(gen, msg, qx, qy, r, s, expected):
"""
`msg` = message, `qx` and `qy` represent the base point on
elliptic curve of `gen`, `r` and `s` are the signature, and
`expected` is True iff the signature is expected to be valid."""
pubk = Public_key(gen, ellipticcurve.Point(gen.curve(), qx, qy))
assert expected == pubk.verifies(digest_integer(msg), Signature(r, s))
@pytest.mark.parametrize(
"gen,msg,qx,qy,r,s,expected", [x for x in CURVE_192_KATS if x[6]]
)
def test_pk_recovery(gen, msg, r, s, qx, qy, expected):
del expected
sign = Signature(r, s)
pks = sign.recover_public_keys(digest_integer(msg), gen)
assert pks
# Test if the signature is valid for all found public keys
for pk in pks:
q = pk.point
test_signature_validity(gen, msg, q.x(), q.y(), r, s, True)
# Test if the original public key is in the set of found keys
original_q = ellipticcurve.Point(gen.curve(), qx, qy)
points = [pk.point for pk in pks]
assert original_q in points
@st.composite
def st_random_gen_key_msg_nonce(draw):
"""Hypothesis strategy for test_sig_verify()."""
name_gen = {
"generator_192": generator_192,
"generator_224": generator_224,
"generator_256": generator_256,
"generator_secp256k1": generator_secp256k1,
"generator_384": generator_384,
"generator_521": generator_521,
}
name = draw(st.sampled_from(sorted(name_gen.keys())))
note("Generator used: {0}".format(name))
generator = name_gen[name]
order = int(generator.order())
key = draw(st.integers(min_value=1, max_value=order))
msg = draw(st.integers(min_value=1, max_value=order))
nonce = draw(
st.integers(min_value=1, max_value=order + 1)
| st.integers(min_value=order >> 1, max_value=order)
)
return generator, key, msg, nonce
SIG_VER_SETTINGS = dict(HYP_SETTINGS)
SIG_VER_SETTINGS["max_examples"] = 10
@settings(**SIG_VER_SETTINGS)
@example((generator_224, 4, 1, 1))
@given(st_random_gen_key_msg_nonce())
def test_sig_verify(args):
"""
Check if signing and verification works for arbitrary messages and
that signatures for other messages are rejected.
"""
generator, sec_mult, msg, nonce = args
pubkey = Public_key(generator, generator * sec_mult)
privkey = Private_key(pubkey, sec_mult)
signature = privkey.sign(msg, nonce)
assert pubkey.verifies(msg, signature)
assert not pubkey.verifies(msg - 1, signature)
def test_int_to_string_with_zero():
assert int_to_string(0) == b"\x00"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,199 @@
import pytest
try:
import unittest2 as unittest
except ImportError:
import unittest
from hypothesis import given, settings
import hypothesis.strategies as st
try:
from hypothesis import HealthCheck
HC_PRESENT = True
except ImportError: # pragma: no cover
HC_PRESENT = False
from .numbertheory import inverse_mod
from .ellipticcurve import CurveFp, INFINITY, Point
HYP_SETTINGS = {}
if HC_PRESENT: # pragma: no branch
HYP_SETTINGS["suppress_health_check"] = [HealthCheck.too_slow]
HYP_SETTINGS["deadline"] = 5000
# NIST Curve P-192:
p = 6277101735386680763835789423207666416083908700390324961279
r = 6277101735386680763835789423176059013767194773182842284081
# s = 0x3045ae6fc8422f64ed579528d38120eae12196d5
# c = 0x3099d2bbbfcb2538542dcd5fb078b6ef5f3d6fe2c745de65
b = 0x64210519E59C80E70FA7E9AB72243049FEB8DEECC146B9B1
Gx = 0x188DA80EB03090F67CBF20EB43A18800F4FF0AFD82FF1012
Gy = 0x07192B95FFC8DA78631011ED6B24CDD573F977A11E794811
c192 = CurveFp(p, -3, b)
p192 = Point(c192, Gx, Gy, r)
c_23 = CurveFp(23, 1, 1)
g_23 = Point(c_23, 13, 7, 7)
HYP_SLOW_SETTINGS = dict(HYP_SETTINGS)
HYP_SLOW_SETTINGS["max_examples"] = 10
@settings(**HYP_SLOW_SETTINGS)
@given(st.integers(min_value=1, max_value=r + 1))
def test_p192_mult_tests(multiple):
inv_m = inverse_mod(multiple, r)
p1 = p192 * multiple
assert p1 * inv_m == p192
def add_n_times(point, n):
ret = INFINITY
i = 0
while i <= n:
yield ret
ret = ret + point
i += 1
# From X9.62 I.1 (p. 96):
@pytest.mark.parametrize(
"p, m, check",
[(g_23, n, exp) for n, exp in enumerate(add_n_times(g_23, 8))],
ids=["g_23 test with mult {0}".format(i) for i in range(9)],
)
def test_add_and_mult_equivalence(p, m, check):
assert p * m == check
class TestCurve(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.c_23 = CurveFp(23, 1, 1)
def test_equality_curves(self):
self.assertEqual(self.c_23, CurveFp(23, 1, 1))
def test_inequality_curves(self):
c192 = CurveFp(p, -3, b)
self.assertNotEqual(self.c_23, c192)
def test_usability_in_a_hashed_collection_curves(self):
{self.c_23: None}
def test_hashability_curves(self):
hash(self.c_23)
def test_conflation_curves(self):
ne1, ne2, ne3 = CurveFp(24, 1, 1), CurveFp(23, 2, 1), CurveFp(23, 1, 2)
eq1, eq2, eq3 = CurveFp(23, 1, 1), CurveFp(23, 1, 1), self.c_23
self.assertEqual(len(set((c_23, eq1, eq2, eq3))), 1)
self.assertEqual(len(set((c_23, ne1, ne2, ne3))), 4)
self.assertDictEqual({c_23: None}, {eq1: None})
self.assertIn(eq2, {eq3: None})
class TestPoint(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.c_23 = CurveFp(23, 1, 1)
cls.g_23 = Point(cls.c_23, 13, 7, 7)
p = 6277101735386680763835789423207666416083908700390324961279
r = 6277101735386680763835789423176059013767194773182842284081
# s = 0x3045ae6fc8422f64ed579528d38120eae12196d5
# c = 0x3099d2bbbfcb2538542dcd5fb078b6ef5f3d6fe2c745de65
b = 0x64210519E59C80E70FA7E9AB72243049FEB8DEECC146B9B1
Gx = 0x188DA80EB03090F67CBF20EB43A18800F4FF0AFD82FF1012
Gy = 0x07192B95FFC8DA78631011ED6B24CDD573F977A11E794811
cls.c192 = CurveFp(p, -3, b)
cls.p192 = Point(cls.c192, Gx, Gy, r)
def test_p192(self):
# Checking against some sample computations presented
# in X9.62:
d = 651056770906015076056810763456358567190100156695615665659
Q = d * self.p192
self.assertEqual(
Q.x(), 0x62B12D60690CDCF330BABAB6E69763B471F994DD702D16A5
)
k = 6140507067065001063065065565667405560006161556565665656654
R = k * self.p192
self.assertEqual(
R.x(), 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD
)
self.assertEqual(
R.y(), 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835
)
u1 = 2563697409189434185194736134579731015366492496392189760599
u2 = 6266643813348617967186477710235785849136406323338782220568
temp = u1 * self.p192 + u2 * Q
self.assertEqual(
temp.x(), 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD
)
self.assertEqual(
temp.y(), 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835
)
def test_double_infinity(self):
p1 = INFINITY
p3 = p1.double()
self.assertEqual(p1, p3)
self.assertEqual(p3.x(), p1.x())
self.assertEqual(p3.y(), p3.y())
def test_double(self):
x1, y1, x3, y3 = (3, 10, 7, 12)
p1 = Point(self.c_23, x1, y1)
p3 = p1.double()
self.assertEqual(p3.x(), x3)
self.assertEqual(p3.y(), y3)
def test_multiply(self):
x1, y1, m, x3, y3 = (3, 10, 2, 7, 12)
p1 = Point(self.c_23, x1, y1)
p3 = p1 * m
self.assertEqual(p3.x(), x3)
self.assertEqual(p3.y(), y3)
# Trivial tests from X9.62 B.3:
def test_add(self):
"""We expect that on curve c, (x1,y1) + (x2, y2 ) = (x3, y3)."""
x1, y1, x2, y2, x3, y3 = (3, 10, 9, 7, 17, 20)
p1 = Point(self.c_23, x1, y1)
p2 = Point(self.c_23, x2, y2)
p3 = p1 + p2
self.assertEqual(p3.x(), x3)
self.assertEqual(p3.y(), y3)
def test_add_as_double(self):
"""We expect that on curve c, (x1,y1) + (x2, y2 ) = (x3, y3)."""
x1, y1, x2, y2, x3, y3 = (3, 10, 3, 10, 7, 12)
p1 = Point(self.c_23, x1, y1)
p2 = Point(self.c_23, x2, y2)
p3 = p1 + p2
self.assertEqual(p3.x(), x3)
self.assertEqual(p3.y(), y3)
def test_equality_points(self):
self.assertEqual(self.g_23, Point(self.c_23, 13, 7, 7))
def test_inequality_points(self):
c = CurveFp(100, -3, 100)
p = Point(c, 100, 100, 100)
self.assertNotEqual(self.g_23, p)
def test_inequality_points_diff_types(self):
c = CurveFp(100, -3, 100)
self.assertNotEqual(self.g_23, c)

View file

@ -0,0 +1,657 @@
import pickle
try:
import unittest2 as unittest
except ImportError:
import unittest
import os
import sys
import signal
import pytest
import threading
import platform
import hypothesis.strategies as st
from hypothesis import given, assume, settings, example
from .ellipticcurve import CurveFp, PointJacobi, INFINITY
from .ecdsa import (
generator_256,
curve_256,
generator_224,
generator_brainpoolp160r1,
curve_brainpoolp160r1,
generator_112r2,
)
from .numbertheory import inverse_mod
from .util import randrange
NO_OLD_SETTINGS = {}
if sys.version_info > (2, 7): # pragma: no branch
NO_OLD_SETTINGS["deadline"] = 5000
class TestJacobi(unittest.TestCase):
def test___init__(self):
curve = object()
x = 2
y = 3
z = 1
order = 4
pj = PointJacobi(curve, x, y, z, order)
self.assertEqual(pj.order(), order)
self.assertIs(pj.curve(), curve)
self.assertEqual(pj.x(), x)
self.assertEqual(pj.y(), y)
def test_add_with_different_curves(self):
p_a = PointJacobi.from_affine(generator_256)
p_b = PointJacobi.from_affine(generator_224)
with self.assertRaises(ValueError):
p_a + p_b
def test_compare_different_curves(self):
self.assertNotEqual(generator_256, generator_224)
def test_equality_with_non_point(self):
pj = PointJacobi.from_affine(generator_256)
self.assertNotEqual(pj, "value")
def test_conversion(self):
pj = PointJacobi.from_affine(generator_256)
pw = pj.to_affine()
self.assertEqual(generator_256, pw)
def test_single_double(self):
pj = PointJacobi.from_affine(generator_256)
pw = generator_256.double()
pj = pj.double()
self.assertEqual(pj.x(), pw.x())
self.assertEqual(pj.y(), pw.y())
def test_double_with_zero_point(self):
pj = PointJacobi(curve_256, 0, 0, 1)
pj = pj.double()
self.assertIs(pj, INFINITY)
def test_double_with_zero_equivalent_point(self):
pj = PointJacobi(curve_256, 0, curve_256.p(), 1)
pj = pj.double()
self.assertIs(pj, INFINITY)
def test_double_with_zero_equivalent_point_non_1_z(self):
pj = PointJacobi(curve_256, 0, curve_256.p(), 2)
pj = pj.double()
self.assertIs(pj, INFINITY)
def test_compare_with_affine_point(self):
pj = PointJacobi.from_affine(generator_256)
pa = pj.to_affine()
self.assertEqual(pj, pa)
self.assertEqual(pa, pj)
def test_to_affine_with_zero_point(self):
pj = PointJacobi(curve_256, 0, 0, 1)
pa = pj.to_affine()
self.assertIs(pa, INFINITY)
def test_add_with_affine_point(self):
pj = PointJacobi.from_affine(generator_256)
pa = pj.to_affine()
s = pj + pa
self.assertEqual(s, pj.double())
def test_radd_with_affine_point(self):
pj = PointJacobi.from_affine(generator_256)
pa = pj.to_affine()
s = pa + pj
self.assertEqual(s, pj.double())
def test_add_with_infinity(self):
pj = PointJacobi.from_affine(generator_256)
s = pj + INFINITY
self.assertEqual(s, pj)
def test_add_zero_point_to_affine(self):
pa = PointJacobi.from_affine(generator_256).to_affine()
pj = PointJacobi(curve_256, 0, 0, 1)
s = pj + pa
self.assertIs(s, pa)
def test_multiply_by_zero(self):
pj = PointJacobi.from_affine(generator_256)
pj = pj * 0
self.assertIs(pj, INFINITY)
def test_zero_point_multiply_by_one(self):
pj = PointJacobi(curve_256, 0, 0, 1)
pj = pj * 1
self.assertIs(pj, INFINITY)
def test_multiply_by_one(self):
pj = PointJacobi.from_affine(generator_256)
pw = generator_256 * 1
pj = pj * 1
self.assertEqual(pj.x(), pw.x())
self.assertEqual(pj.y(), pw.y())
def test_multiply_by_two(self):
pj = PointJacobi.from_affine(generator_256)
pw = generator_256 * 2
pj = pj * 2
self.assertEqual(pj.x(), pw.x())
self.assertEqual(pj.y(), pw.y())
def test_rmul_by_two(self):
pj = PointJacobi.from_affine(generator_256)
pw = generator_256 * 2
pj = 2 * pj
self.assertEqual(pj, pw)
def test_compare_non_zero_with_infinity(self):
pj = PointJacobi.from_affine(generator_256)
self.assertNotEqual(pj, INFINITY)
def test_compare_zero_point_with_infinity(self):
pj = PointJacobi(curve_256, 0, 0, 1)
self.assertEqual(pj, INFINITY)
def test_compare_double_with_multiply(self):
pj = PointJacobi.from_affine(generator_256)
dbl = pj.double()
mlpl = pj * 2
self.assertEqual(dbl, mlpl)
@settings(max_examples=10)
@given(
st.integers(
min_value=0, max_value=int(generator_brainpoolp160r1.order())
)
)
def test_multiplications(self, mul):
pj = PointJacobi.from_affine(generator_brainpoolp160r1)
pw = pj.to_affine() * mul
pj = pj * mul
self.assertEqual((pj.x(), pj.y()), (pw.x(), pw.y()))
self.assertEqual(pj, pw)
@settings(max_examples=10)
@given(
st.integers(
min_value=0, max_value=int(generator_brainpoolp160r1.order())
)
)
@example(0)
@example(int(generator_brainpoolp160r1.order()))
def test_precompute(self, mul):
precomp = generator_brainpoolp160r1
self.assertTrue(precomp._PointJacobi__precompute)
pj = PointJacobi.from_affine(generator_brainpoolp160r1)
a = precomp * mul
b = pj * mul
self.assertEqual(a, b)
@settings(max_examples=10)
@given(
st.integers(
min_value=1, max_value=int(generator_brainpoolp160r1.order())
),
st.integers(
min_value=1, max_value=int(generator_brainpoolp160r1.order())
),
)
@example(3, 3)
def test_add_scaled_points(self, a_mul, b_mul):
j_g = PointJacobi.from_affine(generator_brainpoolp160r1)
a = PointJacobi.from_affine(j_g * a_mul)
b = PointJacobi.from_affine(j_g * b_mul)
c = a + b
self.assertEqual(c, j_g * (a_mul + b_mul))
@settings(max_examples=10)
@given(
st.integers(
min_value=1, max_value=int(generator_brainpoolp160r1.order())
),
st.integers(
min_value=1, max_value=int(generator_brainpoolp160r1.order())
),
st.integers(min_value=1, max_value=int(curve_brainpoolp160r1.p() - 1)),
)
def test_add_one_scaled_point(self, a_mul, b_mul, new_z):
j_g = PointJacobi.from_affine(generator_brainpoolp160r1)
a = PointJacobi.from_affine(j_g * a_mul)
b = PointJacobi.from_affine(j_g * b_mul)
p = curve_brainpoolp160r1.p()
assume(inverse_mod(new_z, p))
new_zz = new_z * new_z % p
b = PointJacobi(
curve_brainpoolp160r1,
b.x() * new_zz % p,
b.y() * new_zz * new_z % p,
new_z,
)
c = a + b
self.assertEqual(c, j_g * (a_mul + b_mul))
@settings(max_examples=10)
@given(
st.integers(
min_value=1, max_value=int(generator_brainpoolp160r1.order())
),
st.integers(
min_value=1, max_value=int(generator_brainpoolp160r1.order())
),
st.integers(min_value=1, max_value=int(curve_brainpoolp160r1.p() - 1)),
)
@example(1, 1, 1)
@example(3, 3, 3)
@example(2, int(generator_brainpoolp160r1.order() - 2), 1)
@example(2, int(generator_brainpoolp160r1.order() - 2), 3)
def test_add_same_scale_points(self, a_mul, b_mul, new_z):
j_g = PointJacobi.from_affine(generator_brainpoolp160r1)
a = PointJacobi.from_affine(j_g * a_mul)
b = PointJacobi.from_affine(j_g * b_mul)
p = curve_brainpoolp160r1.p()
assume(inverse_mod(new_z, p))
new_zz = new_z * new_z % p
a = PointJacobi(
curve_brainpoolp160r1,
a.x() * new_zz % p,
a.y() * new_zz * new_z % p,
new_z,
)
b = PointJacobi(
curve_brainpoolp160r1,
b.x() * new_zz % p,
b.y() * new_zz * new_z % p,
new_z,
)
c = a + b
self.assertEqual(c, j_g * (a_mul + b_mul))
def test_add_same_scale_points_static(self):
j_g = generator_brainpoolp160r1
p = curve_brainpoolp160r1.p()
a = j_g * 11
a.scale()
z1 = 13
x = PointJacobi(
curve_brainpoolp160r1,
a.x() * z1**2 % p,
a.y() * z1**3 % p,
z1,
)
y = PointJacobi(
curve_brainpoolp160r1,
a.x() * z1**2 % p,
a.y() * z1**3 % p,
z1,
)
c = a + a
self.assertEqual(c, x + y)
@settings(max_examples=14)
@given(
st.integers(
min_value=1, max_value=int(generator_brainpoolp160r1.order())
),
st.integers(
min_value=1, max_value=int(generator_brainpoolp160r1.order())
),
st.lists(
st.integers(
min_value=1, max_value=int(curve_brainpoolp160r1.p() - 1)
),
min_size=2,
max_size=2,
unique=True,
),
)
@example(2, 2, [2, 1])
@example(2, 2, [2, 3])
@example(2, int(generator_brainpoolp160r1.order() - 2), [2, 3])
@example(2, int(generator_brainpoolp160r1.order() - 2), [2, 1])
def test_add_different_scale_points(self, a_mul, b_mul, new_z):
j_g = PointJacobi.from_affine(generator_brainpoolp160r1)
a = PointJacobi.from_affine(j_g * a_mul)
b = PointJacobi.from_affine(j_g * b_mul)
p = curve_brainpoolp160r1.p()
assume(inverse_mod(new_z[0], p))
assume(inverse_mod(new_z[1], p))
new_zz0 = new_z[0] * new_z[0] % p
new_zz1 = new_z[1] * new_z[1] % p
a = PointJacobi(
curve_brainpoolp160r1,
a.x() * new_zz0 % p,
a.y() * new_zz0 * new_z[0] % p,
new_z[0],
)
b = PointJacobi(
curve_brainpoolp160r1,
b.x() * new_zz1 % p,
b.y() * new_zz1 * new_z[1] % p,
new_z[1],
)
c = a + b
self.assertEqual(c, j_g * (a_mul + b_mul))
def test_add_different_scale_points_static(self):
j_g = generator_brainpoolp160r1
p = curve_brainpoolp160r1.p()
a = j_g * 11
a.scale()
z1 = 13
x = PointJacobi(
curve_brainpoolp160r1,
a.x() * z1**2 % p,
a.y() * z1**3 % p,
z1,
)
z2 = 29
y = PointJacobi(
curve_brainpoolp160r1,
a.x() * z2**2 % p,
a.y() * z2**3 % p,
z2,
)
c = a + a
self.assertEqual(c, x + y)
def test_add_point_3_times(self):
j_g = PointJacobi.from_affine(generator_256)
self.assertEqual(j_g * 3, j_g + j_g + j_g)
def test_mul_without_order(self):
j_g = PointJacobi(curve_256, generator_256.x(), generator_256.y(), 1)
self.assertEqual(j_g * generator_256.order(), INFINITY)
def test_mul_add_inf(self):
j_g = PointJacobi.from_affine(generator_256)
self.assertEqual(j_g, j_g.mul_add(1, INFINITY, 1))
def test_mul_add_same(self):
j_g = PointJacobi.from_affine(generator_256)
self.assertEqual(j_g * 2, j_g.mul_add(1, j_g, 1))
def test_mul_add_precompute(self):
j_g = PointJacobi.from_affine(generator_brainpoolp160r1, True)
b = PointJacobi.from_affine(j_g * 255, True)
self.assertEqual(j_g * 256, j_g + b)
self.assertEqual(j_g * (5 + 255 * 7), j_g * 5 + b * 7)
self.assertEqual(j_g * (5 + 255 * 7), j_g.mul_add(5, b, 7))
def test_mul_add_precompute_large(self):
j_g = PointJacobi.from_affine(generator_brainpoolp160r1, True)
b = PointJacobi.from_affine(j_g * 255, True)
self.assertEqual(j_g * 256, j_g + b)
self.assertEqual(
j_g * (0xFF00 + 255 * 0xF0F0), j_g * 0xFF00 + b * 0xF0F0
)
self.assertEqual(
j_g * (0xFF00 + 255 * 0xF0F0), j_g.mul_add(0xFF00, b, 0xF0F0)
)
def test_mul_add_to_mul(self):
j_g = PointJacobi.from_affine(generator_256)
a = j_g * 3
b = j_g.mul_add(2, j_g, 1)
self.assertEqual(a, b)
def test_mul_add_differnt(self):
j_g = PointJacobi.from_affine(generator_256)
w_a = j_g * 2
self.assertEqual(j_g.mul_add(1, w_a, 1), j_g * 3)
def test_mul_add_slightly_different(self):
j_g = PointJacobi.from_affine(generator_256)
w_a = j_g * 2
w_b = j_g * 3
self.assertEqual(w_a.mul_add(1, w_b, 3), w_a * 1 + w_b * 3)
def test_mul_add(self):
j_g = PointJacobi.from_affine(generator_256)
w_a = generator_256 * 255
w_b = generator_256 * (0xA8 * 0xF0)
j_b = j_g * 0xA8
ret = j_g.mul_add(255, j_b, 0xF0)
self.assertEqual(ret.to_affine(), w_a + w_b)
def test_mul_add_large(self):
j_g = PointJacobi.from_affine(generator_256)
b = PointJacobi.from_affine(j_g * 255)
self.assertEqual(j_g * 256, j_g + b)
self.assertEqual(
j_g * (0xFF00 + 255 * 0xF0F0), j_g * 0xFF00 + b * 0xF0F0
)
self.assertEqual(
j_g * (0xFF00 + 255 * 0xF0F0), j_g.mul_add(0xFF00, b, 0xF0F0)
)
def test_mul_add_with_infinity_as_result(self):
j_g = PointJacobi.from_affine(generator_256)
order = generator_256.order()
b = PointJacobi.from_affine(generator_256 * 256)
self.assertEqual(j_g.mul_add(order % 256, b, order // 256), INFINITY)
def test_mul_add_without_order(self):
j_g = PointJacobi(curve_256, generator_256.x(), generator_256.y(), 1)
order = generator_256.order()
w_b = generator_256 * 34
w_b.scale()
b = PointJacobi(curve_256, w_b.x(), w_b.y(), 1)
self.assertEqual(j_g.mul_add(order % 34, b, order // 34), INFINITY)
def test_mul_add_with_doubled_negation_of_itself(self):
j_g = PointJacobi.from_affine(generator_256 * 17)
dbl_neg = 2 * (-j_g)
self.assertEqual(j_g.mul_add(4, dbl_neg, 2), INFINITY)
def test_equality(self):
pj1 = PointJacobi(curve=CurveFp(23, 1, 1, 1), x=2, y=3, z=1, order=1)
pj2 = PointJacobi(curve=CurveFp(23, 1, 1, 1), x=2, y=3, z=1, order=1)
self.assertEqual(pj1, pj2)
def test_equality_with_invalid_object(self):
j_g = PointJacobi.from_affine(generator_256)
self.assertNotEqual(j_g, 12)
def test_equality_with_wrong_curves(self):
p_a = PointJacobi.from_affine(generator_256)
p_b = PointJacobi.from_affine(generator_224)
self.assertNotEqual(p_a, p_b)
def test_pickle(self):
pj = PointJacobi(curve=CurveFp(23, 1, 1, 1), x=2, y=3, z=1, order=1)
self.assertEqual(pickle.loads(pickle.dumps(pj)), pj)
@settings(**NO_OLD_SETTINGS)
@given(st.integers(min_value=1, max_value=10))
def test_multithreading(self, thread_num):
# ensure that generator's precomputation table is filled
generator_112r2 * 2
# create a fresh point that doesn't have a filled precomputation table
gen = generator_112r2
gen = PointJacobi(gen.curve(), gen.x(), gen.y(), 1, gen.order(), True)
self.assertEqual(gen._PointJacobi__precompute, [])
def runner(generator):
order = generator.order()
for _ in range(10):
generator * randrange(order)
threads = []
for _ in range(thread_num):
threads.append(threading.Thread(target=runner, args=(gen,)))
for t in threads:
t.start()
runner(gen)
for t in threads:
t.join()
self.assertEqual(
gen._PointJacobi__precompute,
generator_112r2._PointJacobi__precompute,
)
@pytest.mark.skipif(
platform.system() == "Windows",
reason="there are no signals on Windows",
)
def test_multithreading_with_interrupts(self):
thread_num = 10
# ensure that generator's precomputation table is filled
generator_112r2 * 2
# create a fresh point that doesn't have a filled precomputation table
gen = generator_112r2
gen = PointJacobi(gen.curve(), gen.x(), gen.y(), 1, gen.order(), True)
self.assertEqual(gen._PointJacobi__precompute, [])
def runner(generator):
order = generator.order()
for _ in range(50):
generator * randrange(order)
def interrupter(barrier_start, barrier_end, lock_exit):
# wait until MainThread can handle KeyboardInterrupt
barrier_start.release()
barrier_end.acquire()
os.kill(os.getpid(), signal.SIGINT)
lock_exit.release()
threads = []
for _ in range(thread_num):
threads.append(threading.Thread(target=runner, args=(gen,)))
barrier_start = threading.Lock()
barrier_start.acquire()
barrier_end = threading.Lock()
barrier_end.acquire()
lock_exit = threading.Lock()
lock_exit.acquire()
threads.append(
threading.Thread(
target=interrupter,
args=(barrier_start, barrier_end, lock_exit),
)
)
for t in threads:
t.start()
with self.assertRaises(KeyboardInterrupt):
# signal to interrupter that we can now handle the signal
barrier_start.acquire()
barrier_end.release()
runner(gen)
# use the lock to ensure we never go past the scope of
# assertRaises before the os.kill is called
lock_exit.acquire()
for t in threads:
t.join()
self.assertEqual(
gen._PointJacobi__precompute,
generator_112r2._PointJacobi__precompute,
)

View file

@ -0,0 +1,959 @@
try:
import unittest2 as unittest
except ImportError:
import unittest
try:
buffer
except NameError:
buffer = memoryview
import os
import array
import pytest
import hashlib
from .keys import VerifyingKey, SigningKey, MalformedPointError
from .der import (
unpem,
UnexpectedDER,
encode_sequence,
encode_oid,
encode_bitstring,
)
from .util import (
sigencode_string,
sigencode_der,
sigencode_strings,
sigdecode_string,
sigdecode_der,
sigdecode_strings,
)
from .curves import NIST256p, Curve, BRAINPOOLP160r1, Ed25519, Ed448
from .ellipticcurve import Point, PointJacobi, CurveFp, INFINITY
from .ecdsa import generator_brainpoolp160r1
class TestVerifyingKeyFromString(unittest.TestCase):
"""
Verify that ecdsa.keys.VerifyingKey.from_string() can be used with
bytes-like objects
"""
@classmethod
def setUpClass(cls):
cls.key_bytes = (
b"\x04L\xa2\x95\xdb\xc7Z\xd7\x1f\x93\nz\xcf\x97\xcf"
b"\xd7\xc2\xd9o\xfe8}X!\xae\xd4\xfah\xfa^\rpI\xba\xd1"
b"Y\xfb\x92xa\xebo+\x9cG\xfav\xca"
)
cls.vk = VerifyingKey.from_string(cls.key_bytes)
def test_bytes(self):
self.assertIsNotNone(self.vk)
self.assertIsInstance(self.vk, VerifyingKey)
self.assertEqual(
self.vk.pubkey.point.x(),
105419898848891948935835657980914000059957975659675736097,
)
self.assertEqual(
self.vk.pubkey.point.y(),
4286866841217412202667522375431381222214611213481632495306,
)
def test_bytes_memoryview(self):
vk = VerifyingKey.from_string(buffer(self.key_bytes))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytearray(self):
vk = VerifyingKey.from_string(bytearray(self.key_bytes))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytesarray_memoryview(self):
vk = VerifyingKey.from_string(buffer(bytearray(self.key_bytes)))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_array_array_of_bytes(self):
arr = array.array("B", self.key_bytes)
vk = VerifyingKey.from_string(arr)
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_array_array_of_bytes_memoryview(self):
arr = array.array("B", self.key_bytes)
vk = VerifyingKey.from_string(buffer(arr))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_array_array_of_ints(self):
arr = array.array("I", self.key_bytes)
vk = VerifyingKey.from_string(arr)
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_array_array_of_ints_memoryview(self):
arr = array.array("I", self.key_bytes)
vk = VerifyingKey.from_string(buffer(arr))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytes_uncompressed(self):
vk = VerifyingKey.from_string(b"\x04" + self.key_bytes)
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytearray_uncompressed(self):
vk = VerifyingKey.from_string(bytearray(b"\x04" + self.key_bytes))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytes_compressed(self):
vk = VerifyingKey.from_string(b"\x02" + self.key_bytes[:24])
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytearray_compressed(self):
vk = VerifyingKey.from_string(bytearray(b"\x02" + self.key_bytes[:24]))
self.assertEqual(self.vk.to_string(), vk.to_string())
class TestVerifyingKeyFromDer(unittest.TestCase):
"""
Verify that ecdsa.keys.VerifyingKey.from_der() can be used with
bytes-like objects.
"""
@classmethod
def setUpClass(cls):
prv_key_str = (
"-----BEGIN EC PRIVATE KEY-----\n"
"MF8CAQEEGF7IQgvW75JSqULpiQQ8op9WH6Uldw6xxaAKBggqhkjOPQMBAaE0AzIA\n"
"BLiBd9CE7xf15FY5QIAoNg+fWbSk1yZOYtoGUdzkejWkxbRc9RWTQjqLVXucIJnz\n"
"bA==\n"
"-----END EC PRIVATE KEY-----\n"
)
key_str = (
"-----BEGIN PUBLIC KEY-----\n"
"MEkwEwYHKoZIzj0CAQYIKoZIzj0DAQEDMgAEuIF30ITvF/XkVjlAgCg2D59ZtKTX\n"
"Jk5i2gZR3OR6NaTFtFz1FZNCOotVe5wgmfNs\n"
"-----END PUBLIC KEY-----\n"
)
cls.key_pem = key_str
cls.key_bytes = unpem(key_str)
assert isinstance(cls.key_bytes, bytes)
cls.vk = VerifyingKey.from_pem(key_str)
cls.sk = SigningKey.from_pem(prv_key_str)
key_str = (
"-----BEGIN PUBLIC KEY-----\n"
"MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4H3iRbG4TSrsSRb/gusPQB/4YcN8\n"
"Poqzgjau4kfxBPyZimeRfuY/9g/wMmPuhGl4BUve51DsnKJFRr8psk0ieA==\n"
"-----END PUBLIC KEY-----\n"
)
cls.vk2 = VerifyingKey.from_pem(key_str)
cls.sk2 = SigningKey.generate(vk.curve)
def test_load_key_with_explicit_parameters(self):
pub_key_str = (
"-----BEGIN PUBLIC KEY-----\n"
"MIIBSzCCAQMGByqGSM49AgEwgfcCAQEwLAYHKoZIzj0BAQIhAP////8AAAABAAAA\n"
"AAAAAAAAAAAA////////////////MFsEIP////8AAAABAAAAAAAAAAAAAAAA////\n"
"///////////8BCBaxjXYqjqT57PrvVV2mIa8ZR0GsMxTsPY7zjw+J9JgSwMVAMSd\n"
"NgiG5wSTamZ44ROdJreBn36QBEEEaxfR8uEsQkf4vOblY6RA8ncDfYEt6zOg9KE5\n"
"RdiYwpZP40Li/hp/m47n60p8D54WK84zV2sxXs7LtkBoN79R9QIhAP////8AAAAA\n"
"//////////+85vqtpxeehPO5ysL8YyVRAgEBA0IABIr1UkgYs5jmbFc7it1/YI2X\n"
"T//IlaEjMNZft1owjqpBYH2ErJHk4U5Pp4WvWq1xmHwIZlsH7Ig4KmefCfR6SmU=\n"
"-----END PUBLIC KEY-----"
)
pk = VerifyingKey.from_pem(pub_key_str)
pk_exp = VerifyingKey.from_string(
b"\x04\x8a\xf5\x52\x48\x18\xb3\x98\xe6\x6c\x57\x3b\x8a\xdd\x7f"
b"\x60\x8d\x97\x4f\xff\xc8\x95\xa1\x23\x30\xd6\x5f\xb7\x5a\x30"
b"\x8e\xaa\x41\x60\x7d\x84\xac\x91\xe4\xe1\x4e\x4f\xa7\x85\xaf"
b"\x5a\xad\x71\x98\x7c\x08\x66\x5b\x07\xec\x88\x38\x2a\x67\x9f"
b"\x09\xf4\x7a\x4a\x65",
curve=NIST256p,
)
self.assertEqual(pk, pk_exp)
def test_load_key_with_explicit_with_explicit_disabled(self):
pub_key_str = (
"-----BEGIN PUBLIC KEY-----\n"
"MIIBSzCCAQMGByqGSM49AgEwgfcCAQEwLAYHKoZIzj0BAQIhAP////8AAAABAAAA\n"
"AAAAAAAAAAAA////////////////MFsEIP////8AAAABAAAAAAAAAAAAAAAA////\n"
"///////////8BCBaxjXYqjqT57PrvVV2mIa8ZR0GsMxTsPY7zjw+J9JgSwMVAMSd\n"
"NgiG5wSTamZ44ROdJreBn36QBEEEaxfR8uEsQkf4vOblY6RA8ncDfYEt6zOg9KE5\n"
"RdiYwpZP40Li/hp/m47n60p8D54WK84zV2sxXs7LtkBoN79R9QIhAP////8AAAAA\n"
"//////////+85vqtpxeehPO5ysL8YyVRAgEBA0IABIr1UkgYs5jmbFc7it1/YI2X\n"
"T//IlaEjMNZft1owjqpBYH2ErJHk4U5Pp4WvWq1xmHwIZlsH7Ig4KmefCfR6SmU=\n"
"-----END PUBLIC KEY-----"
)
with self.assertRaises(UnexpectedDER):
VerifyingKey.from_pem(
pub_key_str, valid_curve_encodings=["named_curve"]
)
def test_load_key_with_disabled_format(self):
with self.assertRaises(MalformedPointError) as e:
VerifyingKey.from_der(self.key_bytes, valid_encodings=["raw"])
self.assertIn("enabled (raw) encodings", str(e.exception))
def test_custom_hashfunc(self):
vk = VerifyingKey.from_der(self.key_bytes, hashlib.sha256)
self.assertIs(vk.default_hashfunc, hashlib.sha256)
def test_from_pem_with_custom_hashfunc(self):
vk = VerifyingKey.from_pem(self.key_pem, hashlib.sha256)
self.assertIs(vk.default_hashfunc, hashlib.sha256)
def test_bytes(self):
vk = VerifyingKey.from_der(self.key_bytes)
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytes_memoryview(self):
vk = VerifyingKey.from_der(buffer(self.key_bytes))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytearray(self):
vk = VerifyingKey.from_der(bytearray(self.key_bytes))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_bytesarray_memoryview(self):
vk = VerifyingKey.from_der(buffer(bytearray(self.key_bytes)))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_array_array_of_bytes(self):
arr = array.array("B", self.key_bytes)
vk = VerifyingKey.from_der(arr)
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_array_array_of_bytes_memoryview(self):
arr = array.array("B", self.key_bytes)
vk = VerifyingKey.from_der(buffer(arr))
self.assertEqual(self.vk.to_string(), vk.to_string())
def test_equality_on_verifying_keys(self):
self.assertEqual(self.vk, self.sk.get_verifying_key())
def test_inequality_on_verifying_keys(self):
self.assertNotEqual(self.vk, self.vk2)
def test_inequality_on_verifying_keys_not_implemented(self):
self.assertNotEqual(self.vk, None)
def test_VerifyingKey_inequality_on_same_curve(self):
self.assertNotEqual(self.vk, self.sk2.verifying_key)
def test_SigningKey_inequality_on_same_curve(self):
self.assertNotEqual(self.sk, self.sk2)
def test_inequality_on_wrong_types(self):
self.assertNotEqual(self.vk, self.sk)
def test_from_public_point_old(self):
pj = self.vk.pubkey.point
point = Point(pj.curve(), pj.x(), pj.y())
vk = VerifyingKey.from_public_point(point, self.vk.curve)
self.assertEqual(vk, self.vk)
def test_ed25519_VerifyingKey_repr__(self):
sk = SigningKey.from_string(Ed25519.generator.to_bytes(), Ed25519)
string = repr(sk.verifying_key)
self.assertEqual(
"VerifyingKey.from_string("
"bytearray(b'K\\x0c\\xfbZH\\x8e\\x8c\\x8c\\x07\\xee\\xda\\xfb"
"\\xe1\\x97\\xcd\\x90\\x18\\x02\\x15h]\\xfe\\xbe\\xcbB\\xba\\xe6r"
"\\x10\\xae\\xf1P'), Ed25519, None)",
string,
)
def test_edwards_from_public_point(self):
point = Ed25519.generator
with self.assertRaises(ValueError) as e:
VerifyingKey.from_public_point(point, Ed25519)
self.assertIn("incompatible with Edwards", str(e.exception))
def test_edwards_precompute_no_side_effect(self):
sk = SigningKey.from_string(Ed25519.generator.to_bytes(), Ed25519)
vk = sk.verifying_key
vk2 = VerifyingKey.from_string(vk.to_string(), Ed25519)
vk.precompute()
self.assertEqual(vk, vk2)
def test_parse_malfomed_eddsa_der_pubkey(self):
der_str = encode_sequence(
encode_sequence(encode_oid(*Ed25519.oid)),
encode_bitstring(bytes(Ed25519.generator.to_bytes()), 0),
encode_bitstring(b"\x00", 0),
)
with self.assertRaises(UnexpectedDER) as e:
VerifyingKey.from_der(der_str)
self.assertIn("trailing junk after public key", str(e.exception))
def test_edwards_from_public_key_recovery(self):
with self.assertRaises(ValueError) as e:
VerifyingKey.from_public_key_recovery(b"", b"", Ed25519)
self.assertIn("unsupported for Edwards", str(e.exception))
def test_edwards_from_public_key_recovery_with_digest(self):
with self.assertRaises(ValueError) as e:
VerifyingKey.from_public_key_recovery_with_digest(
b"", b"", Ed25519
)
self.assertIn("unsupported for Edwards", str(e.exception))
def test_load_ed25519_from_pem(self):
vk_pem = (
"-----BEGIN PUBLIC KEY-----\n"
"MCowBQYDK2VwAyEAIwBQ0NZkIiiO41WJfm5BV42u3kQm7lYnvIXmCy8qy2U=\n"
"-----END PUBLIC KEY-----\n"
)
vk = VerifyingKey.from_pem(vk_pem)
self.assertIsInstance(vk.curve, Curve)
self.assertIs(vk.curve, Ed25519)
vk_str = (
b"\x23\x00\x50\xd0\xd6\x64\x22\x28\x8e\xe3\x55\x89\x7e\x6e\x41\x57"
b"\x8d\xae\xde\x44\x26\xee\x56\x27\xbc\x85\xe6\x0b\x2f\x2a\xcb\x65"
)
vk_2 = VerifyingKey.from_string(vk_str, Ed25519)
self.assertEqual(vk, vk_2)
def test_export_ed255_to_pem(self):
vk_str = (
b"\x23\x00\x50\xd0\xd6\x64\x22\x28\x8e\xe3\x55\x89\x7e\x6e\x41\x57"
b"\x8d\xae\xde\x44\x26\xee\x56\x27\xbc\x85\xe6\x0b\x2f\x2a\xcb\x65"
)
vk = VerifyingKey.from_string(vk_str, Ed25519)
vk_pem = (
b"-----BEGIN PUBLIC KEY-----\n"
b"MCowBQYDK2VwAyEAIwBQ0NZkIiiO41WJfm5BV42u3kQm7lYnvIXmCy8qy2U=\n"
b"-----END PUBLIC KEY-----\n"
)
self.assertEqual(vk_pem, vk.to_pem())
def test_ed25519_export_import(self):
sk = SigningKey.generate(Ed25519)
vk = sk.verifying_key
vk2 = VerifyingKey.from_pem(vk.to_pem())
self.assertEqual(vk, vk2)
def test_ed25519_sig_verify(self):
vk_pem = (
"-----BEGIN PUBLIC KEY-----\n"
"MCowBQYDK2VwAyEAIwBQ0NZkIiiO41WJfm5BV42u3kQm7lYnvIXmCy8qy2U=\n"
"-----END PUBLIC KEY-----\n"
)
vk = VerifyingKey.from_pem(vk_pem)
data = b"data\n"
# signature created by OpenSSL 3.0.0 beta1
sig = (
b"\x64\x47\xab\x6a\x33\xcd\x79\x45\xad\x98\x11\x6c\xb9\xf2\x20\xeb"
b"\x90\xd6\x50\xe3\xc7\x8f\x9f\x60\x10\xec\x75\xe0\x2f\x27\xd3\x96"
b"\xda\xe8\x58\x7f\xe0\xfe\x46\x5c\x81\xef\x50\xec\x29\x9f\xae\xd5"
b"\xad\x46\x3c\x91\x68\x83\x4d\xea\x8d\xa8\x19\x04\x04\x79\x03\x0b"
)
self.assertTrue(vk.verify(sig, data))
def test_ed448_from_pem(self):
pem_str = (
"-----BEGIN PUBLIC KEY-----\n"
"MEMwBQYDK2VxAzoAeQtetSu7CMEzE+XWB10Bg47LCA0giNikOxHzdp+tZ/eK/En0\n"
"dTdYD2ll94g58MhSnBiBQB9A1MMA\n"
"-----END PUBLIC KEY-----\n"
)
vk = VerifyingKey.from_pem(pem_str)
self.assertIsInstance(vk.curve, Curve)
self.assertIs(vk.curve, Ed448)
vk_str = (
b"\x79\x0b\x5e\xb5\x2b\xbb\x08\xc1\x33\x13\xe5\xd6\x07\x5d\x01\x83"
b"\x8e\xcb\x08\x0d\x20\x88\xd8\xa4\x3b\x11\xf3\x76\x9f\xad\x67\xf7"
b"\x8a\xfc\x49\xf4\x75\x37\x58\x0f\x69\x65\xf7\x88\x39\xf0\xc8\x52"
b"\x9c\x18\x81\x40\x1f\x40\xd4\xc3\x00"
)
vk2 = VerifyingKey.from_string(vk_str, Ed448)
self.assertEqual(vk, vk2)
def test_ed448_to_pem(self):
vk_str = (
b"\x79\x0b\x5e\xb5\x2b\xbb\x08\xc1\x33\x13\xe5\xd6\x07\x5d\x01\x83"
b"\x8e\xcb\x08\x0d\x20\x88\xd8\xa4\x3b\x11\xf3\x76\x9f\xad\x67\xf7"
b"\x8a\xfc\x49\xf4\x75\x37\x58\x0f\x69\x65\xf7\x88\x39\xf0\xc8\x52"
b"\x9c\x18\x81\x40\x1f\x40\xd4\xc3\x00"
)
vk = VerifyingKey.from_string(vk_str, Ed448)
vk_pem = (
b"-----BEGIN PUBLIC KEY-----\n"
b"MEMwBQYDK2VxAzoAeQtetSu7CMEzE+XWB10Bg47LCA0giNikOxHzdp+tZ/eK/En0\n"
b"dTdYD2ll94g58MhSnBiBQB9A1MMA\n"
b"-----END PUBLIC KEY-----\n"
)
self.assertEqual(vk_pem, vk.to_pem())
def test_ed448_export_import(self):
sk = SigningKey.generate(Ed448)
vk = sk.verifying_key
vk2 = VerifyingKey.from_pem(vk.to_pem())
self.assertEqual(vk, vk2)
def test_ed448_sig_verify(self):
pem_str = (
"-----BEGIN PUBLIC KEY-----\n"
"MEMwBQYDK2VxAzoAeQtetSu7CMEzE+XWB10Bg47LCA0giNikOxHzdp+tZ/eK/En0\n"
"dTdYD2ll94g58MhSnBiBQB9A1MMA\n"
"-----END PUBLIC KEY-----\n"
)
vk = VerifyingKey.from_pem(pem_str)
data = b"data\n"
# signature created by OpenSSL 3.0.0 beta1
sig = (
b"\x68\xed\x2c\x70\x35\x22\xca\x1c\x35\x03\xf3\xaa\x51\x33\x3d\x00"
b"\xc0\xae\xb0\x54\xc5\xdc\x7f\x6f\x30\x57\xb4\x1d\xcb\xe9\xec\xfa"
b"\xc8\x45\x3e\x51\xc1\xcb\x60\x02\x6a\xd0\x43\x11\x0b\x5f\x9b\xfa"
b"\x32\x88\xb2\x38\x6b\xed\xac\x09\x00\x78\xb1\x7b\x5d\x7e\xf8\x16"
b"\x31\xdd\x1b\x3f\x98\xa0\xce\x19\xe7\xd8\x1c\x9f\x30\xac\x2f\xd4"
b"\x1e\x55\xbf\x21\x98\xf6\x4c\x8c\xbe\x81\xa5\x2d\x80\x4c\x62\x53"
b"\x91\xd5\xee\x03\x30\xc6\x17\x66\x4b\x9e\x0c\x8d\x40\xd0\xad\xae"
b"\x0a\x00"
)
self.assertTrue(vk.verify(sig, data))
class TestSigningKey(unittest.TestCase):
"""
Verify that ecdsa.keys.SigningKey.from_der() can be used with
bytes-like objects.
"""
@classmethod
def setUpClass(cls):
prv_key_str = (
"-----BEGIN EC PRIVATE KEY-----\n"
"MF8CAQEEGF7IQgvW75JSqULpiQQ8op9WH6Uldw6xxaAKBggqhkjOPQMBAaE0AzIA\n"
"BLiBd9CE7xf15FY5QIAoNg+fWbSk1yZOYtoGUdzkejWkxbRc9RWTQjqLVXucIJnz\n"
"bA==\n"
"-----END EC PRIVATE KEY-----\n"
)
cls.sk1 = SigningKey.from_pem(prv_key_str)
prv_key_str = (
"-----BEGIN PRIVATE KEY-----\n"
"MG8CAQAwEwYHKoZIzj0CAQYIKoZIzj0DAQEEVTBTAgEBBBheyEIL1u+SUqlC6YkE\n"
"PKKfVh+lJXcOscWhNAMyAAS4gXfQhO8X9eRWOUCAKDYPn1m0pNcmTmLaBlHc5Ho1\n"
"pMW0XPUVk0I6i1V7nCCZ82w=\n"
"-----END PRIVATE KEY-----\n"
)
cls.sk1_pkcs8 = SigningKey.from_pem(prv_key_str)
prv_key_str = (
"-----BEGIN EC PRIVATE KEY-----\n"
"MHcCAQEEIKlL2EAm5NPPZuXwxRf4nXMk0A80y6UUbiQ17be/qFhRoAoGCCqGSM49\n"
"AwEHoUQDQgAE4H3iRbG4TSrsSRb/gusPQB/4YcN8Poqzgjau4kfxBPyZimeRfuY/\n"
"9g/wMmPuhGl4BUve51DsnKJFRr8psk0ieA==\n"
"-----END EC PRIVATE KEY-----\n"
)
cls.sk2 = SigningKey.from_pem(prv_key_str)
def test_decoding_explicit_curve_parameters(self):
prv_key_str = (
"-----BEGIN PRIVATE KEY-----\n"
"MIIBeQIBADCCAQMGByqGSM49AgEwgfcCAQEwLAYHKoZIzj0BAQIhAP////8AAAAB\n"
"AAAAAAAAAAAAAAAA////////////////MFsEIP////8AAAABAAAAAAAAAAAAAAAA\n"
"///////////////8BCBaxjXYqjqT57PrvVV2mIa8ZR0GsMxTsPY7zjw+J9JgSwMV\n"
"AMSdNgiG5wSTamZ44ROdJreBn36QBEEEaxfR8uEsQkf4vOblY6RA8ncDfYEt6zOg\n"
"9KE5RdiYwpZP40Li/hp/m47n60p8D54WK84zV2sxXs7LtkBoN79R9QIhAP////8A\n"
"AAAA//////////+85vqtpxeehPO5ysL8YyVRAgEBBG0wawIBAQQgIXtREfUmR16r\n"
"ZbmvDGD2lAEFPZa2DLPyz0czSja58yChRANCAASK9VJIGLOY5mxXO4rdf2CNl0//\n"
"yJWhIzDWX7daMI6qQWB9hKyR5OFOT6eFr1qtcZh8CGZbB+yIOCpnnwn0ekpl\n"
"-----END PRIVATE KEY-----\n"
)
sk = SigningKey.from_pem(prv_key_str)
sk2 = SigningKey.from_string(
b"\x21\x7b\x51\x11\xf5\x26\x47\x5e\xab\x65\xb9\xaf\x0c\x60\xf6"
b"\x94\x01\x05\x3d\x96\xb6\x0c\xb3\xf2\xcf\x47\x33\x4a\x36\xb9"
b"\xf3\x20",
curve=NIST256p,
)
self.assertEqual(sk, sk2)
def test_decoding_explicit_curve_parameters_with_explicit_disabled(self):
prv_key_str = (
"-----BEGIN PRIVATE KEY-----\n"
"MIIBeQIBADCCAQMGByqGSM49AgEwgfcCAQEwLAYHKoZIzj0BAQIhAP////8AAAAB\n"
"AAAAAAAAAAAAAAAA////////////////MFsEIP////8AAAABAAAAAAAAAAAAAAAA\n"
"///////////////8BCBaxjXYqjqT57PrvVV2mIa8ZR0GsMxTsPY7zjw+J9JgSwMV\n"
"AMSdNgiG5wSTamZ44ROdJreBn36QBEEEaxfR8uEsQkf4vOblY6RA8ncDfYEt6zOg\n"
"9KE5RdiYwpZP40Li/hp/m47n60p8D54WK84zV2sxXs7LtkBoN79R9QIhAP////8A\n"
"AAAA//////////+85vqtpxeehPO5ysL8YyVRAgEBBG0wawIBAQQgIXtREfUmR16r\n"
"ZbmvDGD2lAEFPZa2DLPyz0czSja58yChRANCAASK9VJIGLOY5mxXO4rdf2CNl0//\n"
"yJWhIzDWX7daMI6qQWB9hKyR5OFOT6eFr1qtcZh8CGZbB+yIOCpnnwn0ekpl\n"
"-----END PRIVATE KEY-----\n"
)
with self.assertRaises(UnexpectedDER):
SigningKey.from_pem(
prv_key_str, valid_curve_encodings=["named_curve"]
)
def test_equality_on_signing_keys(self):
sk = SigningKey.from_secret_exponent(
self.sk1.privkey.secret_multiplier, self.sk1.curve
)
self.assertEqual(self.sk1, sk)
self.assertEqual(self.sk1_pkcs8, sk)
def test_verify_with_empty_message(self):
sig = self.sk1.sign(b"")
self.assertTrue(sig)
vk = self.sk1.verifying_key
self.assertTrue(vk.verify(sig, b""))
def test_verify_with_precompute(self):
sig = self.sk1.sign(b"message")
vk = self.sk1.verifying_key
vk.precompute()
self.assertTrue(vk.verify(sig, b"message"))
def test_compare_verifying_key_with_precompute(self):
vk1 = self.sk1.verifying_key
vk1.precompute()
vk2 = self.sk1_pkcs8.verifying_key
self.assertEqual(vk1, vk2)
def test_verify_with_lazy_precompute(self):
sig = self.sk2.sign(b"other message")
vk = self.sk2.verifying_key
vk.precompute(lazy=True)
self.assertTrue(vk.verify(sig, b"other message"))
def test_inequality_on_signing_keys(self):
self.assertNotEqual(self.sk1, self.sk2)
def test_inequality_on_signing_keys_not_implemented(self):
self.assertNotEqual(self.sk1, None)
def test_ed25519_from_pem(self):
pem_str = (
"-----BEGIN PRIVATE KEY-----\n"
"MC4CAQAwBQYDK2VwBCIEIDS6x9FO1PG8T4xIPg8Zd0z8uL6sVGZFEZrX17gHC/XU\n"
"-----END PRIVATE KEY-----\n"
)
sk = SigningKey.from_pem(pem_str)
sk_str = SigningKey.from_string(
b"\x34\xBA\xC7\xD1\x4E\xD4\xF1\xBC\x4F\x8C\x48\x3E\x0F\x19\x77\x4C"
b"\xFC\xB8\xBE\xAC\x54\x66\x45\x11\x9A\xD7\xD7\xB8\x07\x0B\xF5\xD4",
Ed25519,
)
self.assertEqual(sk, sk_str)
def test_ed25519_to_pem(self):
sk = SigningKey.from_string(
b"\x34\xBA\xC7\xD1\x4E\xD4\xF1\xBC\x4F\x8C\x48\x3E\x0F\x19\x77\x4C"
b"\xFC\xB8\xBE\xAC\x54\x66\x45\x11\x9A\xD7\xD7\xB8\x07\x0B\xF5\xD4",
Ed25519,
)
pem_str = (
b"-----BEGIN PRIVATE KEY-----\n"
b"MC4CAQAwBQYDK2VwBCIEIDS6x9FO1PG8T4xIPg8Zd0z8uL6sVGZFEZrX17gHC/XU\n"
b"-----END PRIVATE KEY-----\n"
)
self.assertEqual(sk.to_pem(format="pkcs8"), pem_str)
def test_ed25519_to_and_from_pem(self):
sk = SigningKey.generate(Ed25519)
decoded = SigningKey.from_pem(sk.to_pem(format="pkcs8"))
self.assertEqual(sk, decoded)
def test_ed448_from_pem(self):
pem_str = (
"-----BEGIN PRIVATE KEY-----\n"
"MEcCAQAwBQYDK2VxBDsEOTyFuXqFLXgJlV8uDqcOw9nG4IqzLiZ/i5NfBDoHPzmP\n"
"OP0JMYaLGlTzwovmvCDJ2zLaezu9NLz9aQ==\n"
"-----END PRIVATE KEY-----\n"
)
sk = SigningKey.from_pem(pem_str)
sk_str = SigningKey.from_string(
b"\x3C\x85\xB9\x7A\x85\x2D\x78\x09\x95\x5F\x2E\x0E\xA7\x0E\xC3\xD9"
b"\xC6\xE0\x8A\xB3\x2E\x26\x7F\x8B\x93\x5F\x04\x3A\x07\x3F\x39\x8F"
b"\x38\xFD\x09\x31\x86\x8B\x1A\x54\xF3\xC2\x8B\xE6\xBC\x20\xC9\xDB"
b"\x32\xDA\x7B\x3B\xBD\x34\xBC\xFD\x69",
Ed448,
)
self.assertEqual(sk, sk_str)
def test_ed448_to_pem(self):
sk = SigningKey.from_string(
b"\x3C\x85\xB9\x7A\x85\x2D\x78\x09\x95\x5F\x2E\x0E\xA7\x0E\xC3\xD9"
b"\xC6\xE0\x8A\xB3\x2E\x26\x7F\x8B\x93\x5F\x04\x3A\x07\x3F\x39\x8F"
b"\x38\xFD\x09\x31\x86\x8B\x1A\x54\xF3\xC2\x8B\xE6\xBC\x20\xC9\xDB"
b"\x32\xDA\x7B\x3B\xBD\x34\xBC\xFD\x69",
Ed448,
)
pem_str = (
b"-----BEGIN PRIVATE KEY-----\n"
b"MEcCAQAwBQYDK2VxBDsEOTyFuXqFLXgJlV8uDqcOw9nG4IqzLiZ/i5NfBDoHPzmP\n"
b"OP0JMYaLGlTzwovmvCDJ2zLaezu9NLz9aQ==\n"
b"-----END PRIVATE KEY-----\n"
)
self.assertEqual(sk.to_pem(format="pkcs8"), pem_str)
def test_ed448_encode_decode(self):
sk = SigningKey.generate(Ed448)
decoded = SigningKey.from_pem(sk.to_pem(format="pkcs8"))
self.assertEqual(decoded, sk)
class TestTrivialCurve(unittest.TestCase):
@classmethod
def setUpClass(cls):
# To test what happens with r or s in signing happens to be zero we
# need to find a scalar that creates one of the points on a curve that
# has x coordinate equal to zero.
# Even for secp112r2 curve that's non trivial so use this toy
# curve, for which we can iterate over all points quickly
curve = CurveFp(163, 84, 58)
gen = PointJacobi(curve, 2, 87, 1, 167, generator=True)
cls.toy_curve = Curve("toy_p8", curve, gen, (1, 2, 0))
cls.sk = SigningKey.from_secret_exponent(
140,
cls.toy_curve,
hashfunc=hashlib.sha1,
)
def test_generator_sanity(self):
gen = self.toy_curve.generator
self.assertEqual(gen * gen.order(), INFINITY)
def test_public_key_sanity(self):
self.assertEqual(self.sk.verifying_key.to_string(), b"\x98\x1e")
def test_deterministic_sign(self):
sig = self.sk.sign_deterministic(b"message")
self.assertEqual(sig, b"-.")
self.assertTrue(self.sk.verifying_key.verify(sig, b"message"))
def test_deterministic_sign_random_message(self):
msg = os.urandom(32)
sig = self.sk.sign_deterministic(msg)
self.assertEqual(len(sig), 2)
self.assertTrue(self.sk.verifying_key.verify(sig, msg))
def test_deterministic_sign_that_rises_R_zero_error(self):
# the raised RSZeroError is caught and handled internally by
# sign_deterministic methods
msg = b"\x00\x4f"
sig = self.sk.sign_deterministic(msg)
self.assertEqual(sig, b"\x36\x9e")
self.assertTrue(self.sk.verifying_key.verify(sig, msg))
def test_deterministic_sign_that_rises_S_zero_error(self):
msg = b"\x01\x6d"
sig = self.sk.sign_deterministic(msg)
self.assertEqual(sig, b"\x49\x6c")
self.assertTrue(self.sk.verifying_key.verify(sig, msg))
# test VerifyingKey.verify()
prv_key_str = (
"-----BEGIN EC PRIVATE KEY-----\n"
"MF8CAQEEGF7IQgvW75JSqULpiQQ8op9WH6Uldw6xxaAKBggqhkjOPQMBAaE0AzIA\n"
"BLiBd9CE7xf15FY5QIAoNg+fWbSk1yZOYtoGUdzkejWkxbRc9RWTQjqLVXucIJnz\n"
"bA==\n"
"-----END EC PRIVATE KEY-----\n"
)
key_bytes = unpem(prv_key_str)
assert isinstance(key_bytes, bytes)
sk = SigningKey.from_der(key_bytes)
vk = sk.verifying_key
data = (
b"some string for signing"
b"contents don't really matter"
b"but do include also some crazy values: "
b"\x00\x01\t\r\n\x00\x00\x00\xff\xf0"
)
assert len(data) % 4 == 0
sha1 = hashlib.sha1()
sha1.update(data)
data_hash = sha1.digest()
assert isinstance(data_hash, bytes)
sig_raw = sk.sign(data, sigencode=sigencode_string)
assert isinstance(sig_raw, bytes)
sig_der = sk.sign(data, sigencode=sigencode_der)
assert isinstance(sig_der, bytes)
sig_strings = sk.sign(data, sigencode=sigencode_strings)
assert isinstance(sig_strings[0], bytes)
verifiers = []
for modifier, fun in [
("bytes", lambda x: x),
("bytes memoryview", lambda x: buffer(x)),
("bytearray", lambda x: bytearray(x)),
("bytearray memoryview", lambda x: buffer(bytearray(x))),
("array.array of bytes", lambda x: array.array("B", x)),
("array.array of bytes memoryview", lambda x: buffer(array.array("B", x))),
("array.array of ints", lambda x: array.array("I", x)),
("array.array of ints memoryview", lambda x: buffer(array.array("I", x))),
]:
if "ints" in modifier:
conv = lambda x: x
else:
conv = fun
for sig_format, signature, decoder, mod_apply in [
("raw", sig_raw, sigdecode_string, lambda x: conv(x)),
("der", sig_der, sigdecode_der, lambda x: conv(x)),
(
"strings",
sig_strings,
sigdecode_strings,
lambda x: tuple(conv(i) for i in x),
),
]:
for method_name, vrf_mthd, vrf_data in [
("verify", vk.verify, data),
("verify_digest", vk.verify_digest, data_hash),
]:
verifiers.append(
pytest.param(
signature,
decoder,
mod_apply,
fun,
vrf_mthd,
vrf_data,
id="{2}-{0}-{1}".format(modifier, sig_format, method_name),
)
)
@pytest.mark.parametrize(
"signature,decoder,mod_apply,fun,vrf_mthd,vrf_data", verifiers
)
def test_VerifyingKey_verify(
signature, decoder, mod_apply, fun, vrf_mthd, vrf_data
):
sig = mod_apply(signature)
assert vrf_mthd(sig, fun(vrf_data), sigdecode=decoder)
# test SigningKey.from_string()
prv_key_bytes = (
b"^\xc8B\x0b\xd6\xef\x92R\xa9B\xe9\x89\x04<\xa2"
b"\x9fV\x1f\xa5%w\x0e\xb1\xc5"
)
assert len(prv_key_bytes) == 24
converters = []
for modifier, convert in [
("bytes", lambda x: x),
("bytes memoryview", buffer),
("bytearray", bytearray),
("bytearray memoryview", lambda x: buffer(bytearray(x))),
("array.array of bytes", lambda x: array.array("B", x)),
("array.array of bytes memoryview", lambda x: buffer(array.array("B", x))),
("array.array of ints", lambda x: array.array("I", x)),
("array.array of ints memoryview", lambda x: buffer(array.array("I", x))),
]:
converters.append(pytest.param(convert, id=modifier))
@pytest.mark.parametrize("convert", converters)
def test_SigningKey_from_string(convert):
key = convert(prv_key_bytes)
sk = SigningKey.from_string(key)
assert sk.to_string() == prv_key_bytes
# test SigningKey.from_der()
prv_key_str = (
"-----BEGIN EC PRIVATE KEY-----\n"
"MF8CAQEEGF7IQgvW75JSqULpiQQ8op9WH6Uldw6xxaAKBggqhkjOPQMBAaE0AzIA\n"
"BLiBd9CE7xf15FY5QIAoNg+fWbSk1yZOYtoGUdzkejWkxbRc9RWTQjqLVXucIJnz\n"
"bA==\n"
"-----END EC PRIVATE KEY-----\n"
)
key_bytes = unpem(prv_key_str)
assert isinstance(key_bytes, bytes)
# last two converters are for array.array of ints, those require input
# that's multiple of 4, which no curve we support produces
@pytest.mark.parametrize("convert", converters[:-2])
def test_SigningKey_from_der(convert):
key = convert(key_bytes)
sk = SigningKey.from_der(key)
assert sk.to_string() == prv_key_bytes
# test SigningKey.sign_deterministic()
extra_entropy = b"\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11"
@pytest.mark.parametrize("convert", converters)
def test_SigningKey_sign_deterministic(convert):
sig = sk.sign_deterministic(
convert(data), extra_entropy=convert(extra_entropy)
)
vk.verify(sig, data)
# test SigningKey.sign_digest_deterministic()
@pytest.mark.parametrize("convert", converters)
def test_SigningKey_sign_digest_deterministic(convert):
sig = sk.sign_digest_deterministic(
convert(data_hash), extra_entropy=convert(extra_entropy)
)
vk.verify(sig, data)
@pytest.mark.parametrize("convert", converters)
def test_SigningKey_sign(convert):
sig = sk.sign(convert(data))
vk.verify(sig, data)
@pytest.mark.parametrize("convert", converters)
def test_SigningKey_sign_digest(convert):
sig = sk.sign_digest(convert(data_hash))
vk.verify(sig, data)
def test_SigningKey_with_unlikely_value():
sk = SigningKey.from_secret_exponent(NIST256p.order - 1, curve=NIST256p)
vk = sk.verifying_key
sig = sk.sign(b"hello")
assert vk.verify(sig, b"hello")
def test_SigningKey_with_custom_curve_old_point():
generator = generator_brainpoolp160r1
generator = Point(
generator.curve(),
generator.x(),
generator.y(),
generator.order(),
)
curve = Curve(
"BRAINPOOLP160r1",
generator.curve(),
generator,
(1, 3, 36, 3, 3, 2, 8, 1, 1, 1),
)
sk = SigningKey.from_secret_exponent(12, curve)
sk2 = SigningKey.from_secret_exponent(12, BRAINPOOLP160r1)
assert sk.privkey == sk2.privkey
def test_VerifyingKey_inequality_with_different_curves():
sk1 = SigningKey.from_secret_exponent(2, BRAINPOOLP160r1)
sk2 = SigningKey.from_secret_exponent(2, NIST256p)
assert sk1.verifying_key != sk2.verifying_key
def test_VerifyingKey_inequality_with_different_secret_points():
sk1 = SigningKey.from_secret_exponent(2, BRAINPOOLP160r1)
sk2 = SigningKey.from_secret_exponent(3, BRAINPOOLP160r1)
assert sk1.verifying_key != sk2.verifying_key
def test_SigningKey_from_pem_pkcs8v2_EdDSA():
pem = """-----BEGIN PRIVATE KEY-----
MFMCAQEwBQYDK2VwBCIEICc2F2ag1n1QP0jY+g9qWx5sDkx0s/HdNi3cSRHw+zsI
oSMDIQA+HQ2xCif8a/LMWR2m5HaCm5I2pKe/cc8OiRANMHxjKQ==
-----END PRIVATE KEY-----"""
sk = SigningKey.from_pem(pem)
assert sk.curve == Ed25519

View file

@ -0,0 +1,370 @@
from __future__ import with_statement, division
import hashlib
try:
from hashlib import algorithms_available
except ImportError: # pragma: no cover
algorithms_available = [
"md5",
"sha1",
"sha224",
"sha256",
"sha384",
"sha512",
]
# skip algorithms broken by change to OpenSSL 3.0 and early versions
# of hashlib that list algorithms that require the legacy provider to work
# https://bugs.python.org/issue38820
algorithms_available = [
i
for i in algorithms_available
if i not in ("mdc2", "md2", "md4", "whirlpool", "ripemd160")
]
from functools import partial
import pytest
import sys
import hypothesis.strategies as st
from hypothesis import note, assume, given, settings, example
from .keys import SigningKey
from .keys import BadSignatureError
from .util import sigencode_der, sigencode_string
from .util import sigdecode_der, sigdecode_string
from .curves import curves
from .der import (
encode_integer,
encode_bitstring,
encode_octet_string,
encode_oid,
encode_sequence,
encode_constructed,
)
from .ellipticcurve import CurveEdTw
example_data = b"some data to sign"
"""Since the data is hashed for processing, really any string will do."""
hash_and_size = [
(name, hashlib.new(name).digest_size) for name in algorithms_available
]
"""Pairs of hash names and their output sizes.
Needed for pairing with curves as we don't support hashes
bigger than order sizes of curves."""
keys_and_sigs = []
"""Name of the curve+hash combination, VerifyingKey and DER signature."""
# for hypothesis strategy shrinking we want smallest curves and hashes first
for curve in sorted(curves, key=lambda x: x.baselen):
for hash_alg in [
name
for name, size in sorted(hash_and_size, key=lambda x: x[1])
if 0 < size <= curve.baselen
]:
sk = SigningKey.generate(
curve, hashfunc=partial(hashlib.new, hash_alg)
)
keys_and_sigs.append(
(
"{0} {1}".format(curve, hash_alg),
sk.verifying_key,
sk.sign(example_data, sigencode=sigencode_der),
)
)
# first make sure that the signatures can be verified
@pytest.mark.parametrize(
"verifying_key,signature",
[pytest.param(vk, sig, id=name) for name, vk, sig in keys_and_sigs],
)
def test_signatures(verifying_key, signature):
assert verifying_key.verify(
signature, example_data, sigdecode=sigdecode_der
)
@st.composite
def st_fuzzed_sig(draw, keys_and_sigs):
"""
Hypothesis strategy that generates pairs of VerifyingKey and malformed
signatures created by fuzzing of a valid signature.
"""
name, verifying_key, old_sig = draw(st.sampled_from(keys_and_sigs))
note("Configuration: {0}".format(name))
sig = bytearray(old_sig)
# decide which bytes should be removed
to_remove = draw(
st.lists(st.integers(min_value=0, max_value=len(sig) - 1), unique=True)
)
to_remove.sort()
for i in reversed(to_remove):
del sig[i]
note("Remove bytes: {0}".format(to_remove))
# decide which bytes of the original signature should be changed
if sig: # pragma: no branch
xors = draw(
st.dictionaries(
st.integers(min_value=0, max_value=len(sig) - 1),
st.integers(min_value=1, max_value=255),
)
)
for i, val in xors.items():
sig[i] ^= val
note("xors: {0}".format(xors))
# decide where new data should be inserted
insert_pos = draw(st.integers(min_value=0, max_value=len(sig)))
# NIST521p signature is about 140 bytes long, test slightly longer
insert_data = draw(st.binary(max_size=256))
sig = sig[:insert_pos] + insert_data + sig[insert_pos:]
note(
"Inserted at position {0} bytes: {1!r}".format(insert_pos, insert_data)
)
sig = bytes(sig)
# make sure that there was performed at least one mutation on the data
assume(to_remove or xors or insert_data)
# and that the mutations didn't cancel each-other out
assume(sig != old_sig)
return verifying_key, sig
params = {}
# not supported in hypothesis 2.0.0
if sys.version_info >= (2, 7): # pragma: no branch
from hypothesis import HealthCheck
# deadline=5s because NIST521p are slow to verify
params["deadline"] = 5000
params["suppress_health_check"] = [
HealthCheck.data_too_large,
HealthCheck.filter_too_much,
HealthCheck.too_slow,
]
slow_params = dict(params)
slow_params["max_examples"] = 10
@settings(**params)
@given(st_fuzzed_sig(keys_and_sigs))
def test_fuzzed_der_signatures(args):
verifying_key, sig = args
with pytest.raises(BadSignatureError):
verifying_key.verify(sig, example_data, sigdecode=sigdecode_der)
@st.composite
def st_random_der_ecdsa_sig_value(draw):
"""
Hypothesis strategy for selecting random values and encoding them
to ECDSA-Sig-Value object::
ECDSA-Sig-Value ::= SEQUENCE {
r INTEGER,
s INTEGER
}
"""
name, verifying_key, _ = draw(st.sampled_from(keys_and_sigs))
note("Configuration: {0}".format(name))
order = int(verifying_key.curve.order)
# the encode_integer doesn't support negative numbers, would be nice
# to generate them too, but we have coverage for remove_integer()
# verifying that it doesn't accept them, so meh.
# Test all numbers around the ones that can show up (around order)
# way smaller and slightly bigger
r = draw(
st.integers(min_value=0, max_value=order << 4)
| st.integers(min_value=order >> 2, max_value=order + 1)
)
s = draw(
st.integers(min_value=0, max_value=order << 4)
| st.integers(min_value=order >> 2, max_value=order + 1)
)
sig = encode_sequence(encode_integer(r), encode_integer(s))
return verifying_key, sig
@settings(**slow_params)
@given(st_random_der_ecdsa_sig_value())
def test_random_der_ecdsa_sig_value(params):
"""
Check if random values encoded in ECDSA-Sig-Value structure are rejected
as signature.
"""
verifying_key, sig = params
with pytest.raises(BadSignatureError):
verifying_key.verify(sig, example_data, sigdecode=sigdecode_der)
def st_der_integer(*args, **kwargs):
"""
Hypothesis strategy that returns a random positive integer as DER
INTEGER.
Parameters are passed to hypothesis.strategy.integer.
"""
if "min_value" not in kwargs: # pragma: no branch
kwargs["min_value"] = 0
return st.builds(encode_integer, st.integers(*args, **kwargs))
@st.composite
def st_der_bit_string(draw, *args, **kwargs):
"""
Hypothesis strategy that returns a random DER BIT STRING.
Parameters are passed to hypothesis.strategy.binary.
"""
data = draw(st.binary(*args, **kwargs))
if data:
unused = draw(st.integers(min_value=0, max_value=7))
data = bytearray(data)
data[-1] &= -(2**unused)
data = bytes(data)
else:
unused = 0
return encode_bitstring(data, unused)
def st_der_octet_string(*args, **kwargs):
"""
Hypothesis strategy that returns a random DER OCTET STRING object.
Parameters are passed to hypothesis.strategy.binary
"""
return st.builds(encode_octet_string, st.binary(*args, **kwargs))
def st_der_null():
"""
Hypothesis strategy that returns DER NULL object.
"""
return st.just(b"\x05\x00")
@st.composite
def st_der_oid(draw):
"""
Hypothesis strategy that returns DER OBJECT IDENTIFIER objects.
"""
first = draw(st.integers(min_value=0, max_value=2))
if first < 2:
second = draw(st.integers(min_value=0, max_value=39))
else:
second = draw(st.integers(min_value=0, max_value=2**512))
rest = draw(
st.lists(st.integers(min_value=0, max_value=2**512), max_size=50)
)
return encode_oid(first, second, *rest)
def st_der():
"""
Hypothesis strategy that returns random DER structures.
A valid DER structure is any primitive object, an octet encoding
of a valid DER structure, sequence of valid DER objects or a constructed
encoding of any of the above.
"""
return st.recursive(
st.just(b"")
| st_der_integer(max_value=2**4096)
| st_der_bit_string(max_size=1024**2)
| st_der_octet_string(max_size=1024**2)
| st_der_null()
| st_der_oid(),
lambda children: st.builds(
lambda x: encode_octet_string(x), st.one_of(children)
)
| st.builds(lambda x: encode_bitstring(x, 0), st.one_of(children))
| st.builds(
lambda x: encode_sequence(*x), st.lists(children, max_size=200)
)
| st.builds(
lambda tag, x: encode_constructed(tag, x),
st.integers(min_value=0, max_value=0x3F),
st.one_of(children),
),
max_leaves=40,
)
@settings(**params)
@given(st.sampled_from(keys_and_sigs), st_der())
def test_random_der_as_signature(params, der):
"""Check if random DER structures are rejected as signature"""
name, verifying_key, _ = params
with pytest.raises(BadSignatureError):
verifying_key.verify(der, example_data, sigdecode=sigdecode_der)
@settings(**params)
@given(st.sampled_from(keys_and_sigs), st.binary(max_size=1024**2))
@example(
keys_and_sigs[0], encode_sequence(encode_integer(0), encode_integer(0))
)
@example(
keys_and_sigs[0],
encode_sequence(encode_integer(1), encode_integer(1)) + b"\x00",
)
@example(keys_and_sigs[0], encode_sequence(*[encode_integer(1)] * 3))
def test_random_bytes_as_signature(params, der):
"""Check if random bytes are rejected as signature"""
name, verifying_key, _ = params
with pytest.raises(BadSignatureError):
verifying_key.verify(der, example_data, sigdecode=sigdecode_der)
keys_and_string_sigs = [
(
name,
verifying_key,
sigencode_string(
*sigdecode_der(sig, verifying_key.curve.order),
order=verifying_key.curve.order
),
)
for name, verifying_key, sig in keys_and_sigs
if not isinstance(verifying_key.curve.curve, CurveEdTw)
]
"""
Name of the curve+hash combination, VerifyingKey and signature as a
byte string.
"""
keys_and_string_sigs += [
(
name,
verifying_key,
sig,
)
for name, verifying_key, sig in keys_and_sigs
if isinstance(verifying_key.curve.curve, CurveEdTw)
]
@settings(**params)
@given(st_fuzzed_sig(keys_and_string_sigs))
def test_fuzzed_string_signatures(params):
verifying_key, sig = params
with pytest.raises(BadSignatureError):
verifying_key.verify(sig, example_data, sigdecode=sigdecode_string)

View file

@ -0,0 +1,433 @@
import operator
from functools import reduce
try:
import unittest2 as unittest
except ImportError:
import unittest
import hypothesis.strategies as st
import pytest
from hypothesis import given, settings, example
try:
from hypothesis import HealthCheck
HC_PRESENT = True
except ImportError: # pragma: no cover
HC_PRESENT = False
from .numbertheory import (
SquareRootError,
JacobiError,
factorization,
gcd,
lcm,
jacobi,
inverse_mod,
is_prime,
next_prime,
smallprimes,
square_root_mod_prime,
)
BIGPRIMES = (
999671,
999683,
999721,
999727,
999749,
999763,
999769,
999773,
999809,
999853,
999863,
999883,
999907,
999917,
999931,
999953,
999959,
999961,
999979,
999983,
)
@pytest.mark.parametrize(
"prime, next_p", [(p, q) for p, q in zip(BIGPRIMES[:-1], BIGPRIMES[1:])]
)
def test_next_prime(prime, next_p):
assert next_prime(prime) == next_p
@pytest.mark.parametrize("val", [-1, 0, 1])
def test_next_prime_with_nums_less_2(val):
assert next_prime(val) == 2
@pytest.mark.parametrize("prime", smallprimes)
def test_square_root_mod_prime_for_small_primes(prime):
squares = set()
for num in range(0, 1 + prime // 2):
sq = num * num % prime
squares.add(sq)
root = square_root_mod_prime(sq, prime)
# tested for real with TestNumbertheory.test_square_root_mod_prime
assert root * root % prime == sq
for nonsquare in range(0, prime):
if nonsquare in squares:
continue
with pytest.raises(SquareRootError):
square_root_mod_prime(nonsquare, prime)
def test_square_root_mod_prime_for_2():
a = square_root_mod_prime(1, 2)
assert a == 1
def test_square_root_mod_prime_for_small_prime():
root = square_root_mod_prime(98**2 % 101, 101)
assert root * root % 101 == 9
def test_square_root_mod_prime_for_p_congruent_5():
p = 13
assert p % 8 == 5
root = square_root_mod_prime(3, p)
assert root * root % p == 3
def test_square_root_mod_prime_for_p_congruent_5_large_d():
p = 29
assert p % 8 == 5
root = square_root_mod_prime(4, p)
assert root * root % p == 4
class TestSquareRootModPrime(unittest.TestCase):
def test_power_of_2_p(self):
with self.assertRaises(JacobiError):
square_root_mod_prime(12, 32)
def test_no_square(self):
with self.assertRaises(SquareRootError) as e:
square_root_mod_prime(12, 31)
self.assertIn("no square root", str(e.exception))
def test_non_prime(self):
with self.assertRaises(SquareRootError) as e:
square_root_mod_prime(12, 33)
self.assertIn("p is not prime", str(e.exception))
def test_non_prime_with_negative(self):
with self.assertRaises(SquareRootError) as e:
square_root_mod_prime(697 - 1, 697)
self.assertIn("p is not prime", str(e.exception))
@st.composite
def st_two_nums_rel_prime(draw):
# 521-bit is the biggest curve we operate on, use 1024 for a bit
# of breathing space
mod = draw(st.integers(min_value=2, max_value=2**1024))
num = draw(
st.integers(min_value=1, max_value=mod - 1).filter(
lambda x: gcd(x, mod) == 1
)
)
return num, mod
@st.composite
def st_primes(draw, *args, **kwargs):
if "min_value" not in kwargs: # pragma: no branch
kwargs["min_value"] = 1
prime = draw(
st.sampled_from(smallprimes)
| st.integers(*args, **kwargs).filter(is_prime)
)
return prime
@st.composite
def st_num_square_prime(draw):
prime = draw(st_primes(max_value=2**1024))
num = draw(st.integers(min_value=0, max_value=1 + prime // 2))
sq = num * num % prime
return sq, prime
@st.composite
def st_comp_with_com_fac(draw):
"""
Strategy that returns lists of numbers, all having a common factor.
"""
primes = draw(
st.lists(st_primes(max_value=2**512), min_size=1, max_size=10)
)
# select random prime(s) that will make the common factor of composites
com_fac_primes = draw(
st.lists(st.sampled_from(primes), min_size=1, max_size=20)
)
com_fac = reduce(operator.mul, com_fac_primes, 1)
# select at most 20 lists (returned numbers),
# each having at most 30 primes (factors) including none (then the number
# will be 1)
comp_primes = draw(
st.integers(min_value=1, max_value=20).flatmap(
lambda n: st.lists(
st.lists(st.sampled_from(primes), max_size=30),
min_size=1,
max_size=n,
)
)
)
return [reduce(operator.mul, nums, 1) * com_fac for nums in comp_primes]
@st.composite
def st_comp_no_com_fac(draw):
"""
Strategy that returns lists of numbers that don't have a common factor.
"""
primes = draw(
st.lists(
st_primes(max_value=2**512), min_size=2, max_size=10, unique=True
)
)
# first select the primes that will create the uncommon factor
# between returned numbers
uncom_fac_primes = draw(
st.lists(
st.sampled_from(primes),
min_size=1,
max_size=len(primes) - 1,
unique=True,
)
)
uncom_fac = reduce(operator.mul, uncom_fac_primes, 1)
# then build composites from leftover primes
leftover_primes = [i for i in primes if i not in uncom_fac_primes]
assert leftover_primes
assert uncom_fac_primes
# select at most 20 lists, each having at most 30 primes
# selected from the leftover_primes list
number_primes = draw(
st.integers(min_value=1, max_value=20).flatmap(
lambda n: st.lists(
st.lists(st.sampled_from(leftover_primes), max_size=30),
min_size=1,
max_size=n,
)
)
)
numbers = [reduce(operator.mul, nums, 1) for nums in number_primes]
insert_at = draw(st.integers(min_value=0, max_value=len(numbers)))
numbers.insert(insert_at, uncom_fac)
return numbers
HYP_SETTINGS = {}
if HC_PRESENT: # pragma: no branch
HYP_SETTINGS["suppress_health_check"] = [
HealthCheck.filter_too_much,
HealthCheck.too_slow,
]
# the factorization() sometimes takes a long time to finish
HYP_SETTINGS["deadline"] = 5000
HYP_SLOW_SETTINGS = dict(HYP_SETTINGS)
HYP_SLOW_SETTINGS["max_examples"] = 10
class TestIsPrime(unittest.TestCase):
def test_very_small_prime(self):
assert is_prime(23)
def test_very_small_composite(self):
assert not is_prime(22)
def test_small_prime(self):
assert is_prime(123456791)
def test_special_composite(self):
assert not is_prime(10261)
def test_medium_prime_1(self):
# nextPrime[2^256]
assert is_prime(2**256 + 0x129)
def test_medium_prime_2(self):
# nextPrime(2^256+0x129)
assert is_prime(2**256 + 0x12D)
def test_medium_trivial_composite(self):
assert not is_prime(2**256 + 0x130)
def test_medium_non_trivial_composite(self):
assert not is_prime(2**256 + 0x12F)
def test_large_prime(self):
# nextPrime[2^2048]
assert is_prime(2**2048 + 0x3D5)
class TestNumbertheory(unittest.TestCase):
def test_gcd(self):
assert gcd(3 * 5 * 7, 3 * 5 * 11, 3 * 5 * 13) == 3 * 5
assert gcd([3 * 5 * 7, 3 * 5 * 11, 3 * 5 * 13]) == 3 * 5
assert gcd(3) == 3
@unittest.skipUnless(
HC_PRESENT,
"Hypothesis 2.0.0 can't be made tolerant of hard to "
"meet requirements (like `is_prime()`), the test "
"case times-out on it",
)
@settings(**HYP_SLOW_SETTINGS)
@given(st_comp_with_com_fac())
def test_gcd_with_com_factor(self, numbers):
n = gcd(numbers)
assert 1 in numbers or n != 1
for i in numbers:
assert i % n == 0
@unittest.skipUnless(
HC_PRESENT,
"Hypothesis 2.0.0 can't be made tolerant of hard to "
"meet requirements (like `is_prime()`), the test "
"case times-out on it",
)
@settings(**HYP_SLOW_SETTINGS)
@given(st_comp_no_com_fac())
def test_gcd_with_uncom_factor(self, numbers):
n = gcd(numbers)
assert n == 1
@given(
st.lists(
st.integers(min_value=1, max_value=2**8192),
min_size=1,
max_size=20,
)
)
def test_gcd_with_random_numbers(self, numbers):
n = gcd(numbers)
for i in numbers:
# check that at least it's a divider
assert i % n == 0
def test_lcm(self):
assert lcm(3, 5 * 3, 7 * 3) == 3 * 5 * 7
assert lcm([3, 5 * 3, 7 * 3]) == 3 * 5 * 7
assert lcm(3) == 3
@given(
st.lists(
st.integers(min_value=1, max_value=2**8192),
min_size=1,
max_size=20,
)
)
def test_lcm_with_random_numbers(self, numbers):
n = lcm(numbers)
for i in numbers:
assert n % i == 0
@unittest.skipUnless(
HC_PRESENT,
"Hypothesis 2.0.0 can't be made tolerant of hard to "
"meet requirements (like `is_prime()`), the test "
"case times-out on it",
)
@settings(**HYP_SETTINGS)
@given(st_num_square_prime())
def test_square_root_mod_prime(self, vals):
square, prime = vals
calc = square_root_mod_prime(square, prime)
assert calc * calc % prime == square
@settings(**HYP_SETTINGS)
@given(st.integers(min_value=1, max_value=10**12))
@example(265399 * 1526929)
@example(373297**2 * 553991)
def test_factorization(self, num):
factors = factorization(num)
mult = 1
for i in factors:
mult *= i[0] ** i[1]
assert mult == num
def test_factorisation_smallprimes(self):
exp = 101 * 103
assert 101 in smallprimes
assert 103 in smallprimes
factors = factorization(exp)
mult = 1
for i in factors:
mult *= i[0] ** i[1]
assert mult == exp
def test_factorisation_not_smallprimes(self):
exp = 1231 * 1237
assert 1231 not in smallprimes
assert 1237 not in smallprimes
factors = factorization(exp)
mult = 1
for i in factors:
mult *= i[0] ** i[1]
assert mult == exp
def test_jacobi_with_zero(self):
assert jacobi(0, 3) == 0
def test_jacobi_with_one(self):
assert jacobi(1, 3) == 1
@settings(**HYP_SETTINGS)
@given(st.integers(min_value=3, max_value=1000).filter(lambda x: x % 2))
def test_jacobi(self, mod):
if is_prime(mod):
squares = set()
for root in range(1, mod):
assert jacobi(root * root, mod) == 1
squares.add(root * root % mod)
for i in range(1, mod):
if i not in squares:
assert jacobi(i, mod) == -1
else:
factors = factorization(mod)
for a in range(1, mod):
c = 1
for i in factors:
c *= jacobi(a, i[0]) ** i[1]
assert c == jacobi(a, mod)
@given(st_two_nums_rel_prime())
def test_inverse_mod(self, nums):
num, mod = nums
inv = inverse_mod(num, mod)
assert 0 < inv < mod
assert num * inv % mod == 1
def test_inverse_mod_with_zero(self):
assert 0 == inverse_mod(0, 11)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,180 @@
# Copyright Mateusz Kobos, (c) 2011
# https://code.activestate.com/recipes/577803-reader-writer-lock-with-priority-for-writers/
# released under the MIT licence
try:
import unittest2 as unittest
except ImportError:
import unittest
import threading
import time
import copy
from ._rwlock import RWLock
class Writer(threading.Thread):
def __init__(
self, buffer_, rw_lock, init_sleep_time, sleep_time, to_write
):
"""
@param buffer_: common buffer_ shared by the readers and writers
@type buffer_: list
@type rw_lock: L{RWLock}
@param init_sleep_time: sleep time before doing any action
@type init_sleep_time: C{float}
@param sleep_time: sleep time while in critical section
@type sleep_time: C{float}
@param to_write: data that will be appended to the buffer
"""
threading.Thread.__init__(self)
self.__buffer = buffer_
self.__rw_lock = rw_lock
self.__init_sleep_time = init_sleep_time
self.__sleep_time = sleep_time
self.__to_write = to_write
self.entry_time = None
"""Time of entry to the critical section"""
self.exit_time = None
"""Time of exit from the critical section"""
def run(self):
time.sleep(self.__init_sleep_time)
self.__rw_lock.writer_acquire()
self.entry_time = time.time()
time.sleep(self.__sleep_time)
self.__buffer.append(self.__to_write)
self.exit_time = time.time()
self.__rw_lock.writer_release()
class Reader(threading.Thread):
def __init__(self, buffer_, rw_lock, init_sleep_time, sleep_time):
"""
@param buffer_: common buffer shared by the readers and writers
@type buffer_: list
@type rw_lock: L{RWLock}
@param init_sleep_time: sleep time before doing any action
@type init_sleep_time: C{float}
@param sleep_time: sleep time while in critical section
@type sleep_time: C{float}
"""
threading.Thread.__init__(self)
self.__buffer = buffer_
self.__rw_lock = rw_lock
self.__init_sleep_time = init_sleep_time
self.__sleep_time = sleep_time
self.buffer_read = None
"""a copy of a the buffer read while in critical section"""
self.entry_time = None
"""Time of entry to the critical section"""
self.exit_time = None
"""Time of exit from the critical section"""
def run(self):
time.sleep(self.__init_sleep_time)
self.__rw_lock.reader_acquire()
self.entry_time = time.time()
time.sleep(self.__sleep_time)
self.buffer_read = copy.deepcopy(self.__buffer)
self.exit_time = time.time()
self.__rw_lock.reader_release()
class RWLockTestCase(unittest.TestCase):
def test_readers_nonexclusive_access(self):
(buffer_, rw_lock, threads) = self.__init_variables()
threads.append(Reader(buffer_, rw_lock, 0, 0))
threads.append(Writer(buffer_, rw_lock, 0.2, 0.4, 1))
threads.append(Reader(buffer_, rw_lock, 0.3, 0.3))
threads.append(Reader(buffer_, rw_lock, 0.5, 0))
self.__start_and_join_threads(threads)
## The third reader should enter after the second one but it should
## exit before the second one exits
## (i.e. the readers should be in the critical section
## at the same time)
self.assertEqual([], threads[0].buffer_read)
self.assertEqual([1], threads[2].buffer_read)
self.assertEqual([1], threads[3].buffer_read)
self.assertTrue(threads[1].exit_time <= threads[2].entry_time)
self.assertTrue(threads[2].entry_time <= threads[3].entry_time)
self.assertTrue(threads[3].exit_time < threads[2].exit_time)
def test_writers_exclusive_access(self):
(buffer_, rw_lock, threads) = self.__init_variables()
threads.append(Writer(buffer_, rw_lock, 0, 0.4, 1))
threads.append(Writer(buffer_, rw_lock, 0.1, 0, 2))
threads.append(Reader(buffer_, rw_lock, 0.2, 0))
self.__start_and_join_threads(threads)
## The second writer should wait for the first one to exit
self.assertEqual([1, 2], threads[2].buffer_read)
self.assertTrue(threads[0].exit_time <= threads[1].entry_time)
self.assertTrue(threads[1].exit_time <= threads[2].exit_time)
def test_writer_priority(self):
(buffer_, rw_lock, threads) = self.__init_variables()
threads.append(Writer(buffer_, rw_lock, 0, 0, 1))
threads.append(Reader(buffer_, rw_lock, 0.1, 0.4))
threads.append(Writer(buffer_, rw_lock, 0.2, 0, 2))
threads.append(Reader(buffer_, rw_lock, 0.3, 0))
threads.append(Reader(buffer_, rw_lock, 0.3, 0))
self.__start_and_join_threads(threads)
## The second writer should go before the second and the third reader
self.assertEqual([1], threads[1].buffer_read)
self.assertEqual([1, 2], threads[3].buffer_read)
self.assertEqual([1, 2], threads[4].buffer_read)
self.assertTrue(threads[0].exit_time < threads[1].entry_time)
self.assertTrue(threads[1].exit_time <= threads[2].entry_time)
self.assertTrue(threads[2].exit_time <= threads[3].entry_time)
self.assertTrue(threads[2].exit_time <= threads[4].entry_time)
def test_many_writers_priority(self):
(buffer_, rw_lock, threads) = self.__init_variables()
threads.append(Writer(buffer_, rw_lock, 0, 0, 1))
threads.append(Reader(buffer_, rw_lock, 0.1, 0.6))
threads.append(Writer(buffer_, rw_lock, 0.2, 0.1, 2))
threads.append(Reader(buffer_, rw_lock, 0.3, 0))
threads.append(Reader(buffer_, rw_lock, 0.4, 0))
threads.append(Writer(buffer_, rw_lock, 0.5, 0.1, 3))
self.__start_and_join_threads(threads)
## The two last writers should go first -- after the first reader and
## before the second and the third reader
self.assertEqual([1], threads[1].buffer_read)
self.assertEqual([1, 2, 3], threads[3].buffer_read)
self.assertEqual([1, 2, 3], threads[4].buffer_read)
self.assertTrue(threads[0].exit_time < threads[1].entry_time)
self.assertTrue(threads[1].exit_time <= threads[2].entry_time)
self.assertTrue(threads[1].exit_time <= threads[5].entry_time)
self.assertTrue(threads[2].exit_time <= threads[3].entry_time)
self.assertTrue(threads[2].exit_time <= threads[4].entry_time)
self.assertTrue(threads[5].exit_time <= threads[3].entry_time)
self.assertTrue(threads[5].exit_time <= threads[4].entry_time)
@staticmethod
def __init_variables():
buffer_ = []
rw_lock = RWLock()
threads = []
return (buffer_, rw_lock, threads)
@staticmethod
def __start_and_join_threads(threads):
for t in threads:
t.start()
for t in threads:
t.join()

View file

@ -0,0 +1,111 @@
try:
import unittest2 as unittest
except ImportError:
import unittest
import pytest
try:
from gmpy2 import mpz
GMPY = True
except ImportError:
try:
from gmpy import mpz
GMPY = True
except ImportError:
GMPY = False
from ._sha3 import shake_256
from ._compat import bytes_to_int, int_to_bytes
B2I_VECTORS = [
(b"\x00\x01", "big", 1),
(b"\x00\x01", "little", 0x0100),
(b"", "big", 0),
(b"\x00", "little", 0),
]
@pytest.mark.parametrize("bytes_in,endian,int_out", B2I_VECTORS)
def test_bytes_to_int(bytes_in, endian, int_out):
out = bytes_to_int(bytes_in, endian)
assert out == int_out
class TestBytesToInt(unittest.TestCase):
def test_bytes_to_int_wrong_endian(self):
with self.assertRaises(ValueError):
bytes_to_int(b"\x00", "middle")
def test_int_to_bytes_wrong_endian(self):
with self.assertRaises(ValueError):
int_to_bytes(0, byteorder="middle")
@pytest.mark.skipif(GMPY == False, reason="requites gmpy or gmpy2")
def test_int_to_bytes_with_gmpy():
assert int_to_bytes(mpz(1)) == b"\x01"
I2B_VECTORS = [
(0, None, "big", b""),
(0, 1, "big", b"\x00"),
(1, None, "big", b"\x01"),
(0x0100, None, "little", b"\x00\x01"),
(0x0100, 4, "little", b"\x00\x01\x00\x00"),
(1, 4, "big", b"\x00\x00\x00\x01"),
]
@pytest.mark.parametrize("int_in,length,endian,bytes_out", I2B_VECTORS)
def test_int_to_bytes(int_in, length, endian, bytes_out):
out = int_to_bytes(int_in, length, endian)
assert out == bytes_out
SHAKE_256_VECTORS = [
(
b"Message.",
32,
b"\x78\xa1\x37\xbb\x33\xae\xe2\x72\xb1\x02\x4f\x39\x43\xe5\xcf\x0c"
b"\x4e\x9c\x72\x76\x2e\x34\x4c\xf8\xf9\xc3\x25\x9d\x4f\x91\x2c\x3a",
),
(
b"",
32,
b"\x46\xb9\xdd\x2b\x0b\xa8\x8d\x13\x23\x3b\x3f\xeb\x74\x3e\xeb\x24"
b"\x3f\xcd\x52\xea\x62\xb8\x1b\x82\xb5\x0c\x27\x64\x6e\xd5\x76\x2f",
),
(
b"message",
32,
b"\x86\x16\xe1\xe4\xcf\xd8\xb5\xf7\xd9\x2d\x43\xd8\x6e\x1b\x14\x51"
b"\xa2\xa6\x5a\xf8\x64\xfc\xb1\x26\xc2\x66\x0a\xb3\x46\x51\xb1\x75",
),
(
b"message",
16,
b"\x86\x16\xe1\xe4\xcf\xd8\xb5\xf7\xd9\x2d\x43\xd8\x6e\x1b\x14\x51",
),
(
b"message",
64,
b"\x86\x16\xe1\xe4\xcf\xd8\xb5\xf7\xd9\x2d\x43\xd8\x6e\x1b\x14\x51"
b"\xa2\xa6\x5a\xf8\x64\xfc\xb1\x26\xc2\x66\x0a\xb3\x46\x51\xb1\x75"
b"\x30\xd6\xba\x2a\x46\x65\xf1\x9d\xf0\x62\x25\xb1\x26\xd1\x3e\xed"
b"\x91\xd5\x0d\xe7\xb9\xcb\x65\xf3\x3a\x46\xae\xd3\x6c\x7d\xc5\xe8",
),
(
b"A" * 1024,
32,
b"\xa5\xef\x7e\x30\x8b\xe8\x33\x64\xe5\x9c\xf3\xb5\xf3\xba\x20\xa3"
b"\x5a\xe7\x30\xfd\xbc\x33\x11\xbf\x83\x89\x50\x82\xb4\x41\xe9\xb3",
),
]
@pytest.mark.parametrize("msg,olen,ohash", SHAKE_256_VECTORS)
def test_shake_256(msg, olen, ohash):
out = shake_256(msg, olen)
assert out == bytearray(ohash)

View file

@ -0,0 +1,433 @@
from __future__ import division
import os
import math
import binascii
import sys
from hashlib import sha256
from six import PY2, int2byte, b, next
from . import der
from ._compat import normalise_bytes
# RFC5480:
# The "unrestricted" algorithm identifier is:
# id-ecPublicKey OBJECT IDENTIFIER ::= {
# iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 }
oid_ecPublicKey = (1, 2, 840, 10045, 2, 1)
encoded_oid_ecPublicKey = der.encode_oid(*oid_ecPublicKey)
# RFC5480:
# The ECDH algorithm uses the following object identifier:
# id-ecDH OBJECT IDENTIFIER ::= {
# iso(1) identified-organization(3) certicom(132) schemes(1)
# ecdh(12) }
oid_ecDH = (1, 3, 132, 1, 12)
# RFC5480:
# The ECMQV algorithm uses the following object identifier:
# id-ecMQV OBJECT IDENTIFIER ::= {
# iso(1) identified-organization(3) certicom(132) schemes(1)
# ecmqv(13) }
oid_ecMQV = (1, 3, 132, 1, 13)
if sys.version_info >= (3,): # pragma: no branch
def entropy_to_bits(ent_256):
"""Convert a bytestring to string of 0's and 1's"""
return bin(int.from_bytes(ent_256, "big"))[2:].zfill(len(ent_256) * 8)
else:
def entropy_to_bits(ent_256):
"""Convert a bytestring to string of 0's and 1's"""
return "".join(bin(ord(x))[2:].zfill(8) for x in ent_256)
if sys.version_info < (2, 7): # pragma: no branch
# Can't add a method to a built-in type so we are stuck with this
def bit_length(x):
return len(bin(x)) - 2
else:
def bit_length(x):
return x.bit_length() or 1
def orderlen(order):
return (1 + len("%x" % order)) // 2 # bytes
def randrange(order, entropy=None):
"""Return a random integer k such that 1 <= k < order, uniformly
distributed across that range. Worst case should be a mean of 2 loops at
(2**k)+2.
Note that this function is not declared to be forwards-compatible: we may
change the behavior in future releases. The entropy= argument (which
should get a callable that behaves like os.urandom) can be used to
achieve stability within a given release (for repeatable unit tests), but
should not be used as a long-term-compatible key generation algorithm.
"""
assert order > 1
if entropy is None:
entropy = os.urandom
upper_2 = bit_length(order - 2)
upper_256 = upper_2 // 8 + 1
while True: # I don't think this needs a counter with bit-wise randrange
ent_256 = entropy(upper_256)
ent_2 = entropy_to_bits(ent_256)
rand_num = int(ent_2[:upper_2], base=2) + 1
if 0 < rand_num < order:
return rand_num
class PRNG:
# this returns a callable which, when invoked with an integer N, will
# return N pseudorandom bytes. Note: this is a short-term PRNG, meant
# primarily for the needs of randrange_from_seed__trytryagain(), which
# only needs to run it a few times per seed. It does not provide
# protection against state compromise (forward security).
def __init__(self, seed):
self.generator = self.block_generator(seed)
def __call__(self, numbytes):
a = [next(self.generator) for i in range(numbytes)]
if PY2: # pragma: no branch
return "".join(a)
else:
return bytes(a)
def block_generator(self, seed):
counter = 0
while True:
for byte in sha256(
("prng-%d-%s" % (counter, seed)).encode()
).digest():
yield byte
counter += 1
def randrange_from_seed__overshoot_modulo(seed, order):
# hash the data, then turn the digest into a number in [1,order).
#
# We use David-Sarah Hopwood's suggestion: turn it into a number that's
# sufficiently larger than the group order, then modulo it down to fit.
# This should give adequate (but not perfect) uniformity, and simple
# code. There are other choices: try-try-again is the main one.
base = PRNG(seed)(2 * orderlen(order))
number = (int(binascii.hexlify(base), 16) % (order - 1)) + 1
assert 1 <= number < order, (1, number, order)
return number
def lsb_of_ones(numbits):
return (1 << numbits) - 1
def bits_and_bytes(order):
bits = int(math.log(order - 1, 2) + 1)
bytes = bits // 8
extrabits = bits % 8
return bits, bytes, extrabits
# the following randrange_from_seed__METHOD() functions take an
# arbitrarily-sized secret seed and turn it into a number that obeys the same
# range limits as randrange() above. They are meant for deriving consistent
# signing keys from a secret rather than generating them randomly, for
# example a protocol in which three signing keys are derived from a master
# secret. You should use a uniformly-distributed unguessable seed with about
# curve.baselen bytes of entropy. To use one, do this:
# seed = os.urandom(curve.baselen) # or other starting point
# secexp = ecdsa.util.randrange_from_seed__trytryagain(sed, curve.order)
# sk = SigningKey.from_secret_exponent(secexp, curve)
def randrange_from_seed__truncate_bytes(seed, order, hashmod=sha256):
# hash the seed, then turn the digest into a number in [1,order), but
# don't worry about trying to uniformly fill the range. This will lose,
# on average, four bits of entropy.
bits, _bytes, extrabits = bits_and_bytes(order)
if extrabits:
_bytes += 1
base = hashmod(seed).digest()[:_bytes]
base = "\x00" * (_bytes - len(base)) + base
number = 1 + int(binascii.hexlify(base), 16)
assert 1 <= number < order
return number
def randrange_from_seed__truncate_bits(seed, order, hashmod=sha256):
# like string_to_randrange_truncate_bytes, but only lose an average of
# half a bit
bits = int(math.log(order - 1, 2) + 1)
maxbytes = (bits + 7) // 8
base = hashmod(seed).digest()[:maxbytes]
base = "\x00" * (maxbytes - len(base)) + base
topbits = 8 * maxbytes - bits
if topbits:
base = int2byte(ord(base[0]) & lsb_of_ones(topbits)) + base[1:]
number = 1 + int(binascii.hexlify(base), 16)
assert 1 <= number < order
return number
def randrange_from_seed__trytryagain(seed, order):
# figure out exactly how many bits we need (rounded up to the nearest
# bit), so we can reduce the chance of looping to less than 0.5 . This is
# specified to feed from a byte-oriented PRNG, and discards the
# high-order bits of the first byte as necessary to get the right number
# of bits. The average number of loops will range from 1.0 (when
# order=2**k-1) to 2.0 (when order=2**k+1).
assert order > 1
bits, bytes, extrabits = bits_and_bytes(order)
generate = PRNG(seed)
while True:
extrabyte = b("")
if extrabits:
extrabyte = int2byte(ord(generate(1)) & lsb_of_ones(extrabits))
guess = string_to_number(extrabyte + generate(bytes)) + 1
if 1 <= guess < order:
return guess
def number_to_string(num, order):
l = orderlen(order)
fmt_str = "%0" + str(2 * l) + "x"
string = binascii.unhexlify((fmt_str % num).encode())
assert len(string) == l, (len(string), l)
return string
def number_to_string_crop(num, order):
l = orderlen(order)
fmt_str = "%0" + str(2 * l) + "x"
string = binascii.unhexlify((fmt_str % num).encode())
return string[:l]
def string_to_number(string):
return int(binascii.hexlify(string), 16)
def string_to_number_fixedlen(string, order):
l = orderlen(order)
assert len(string) == l, (len(string), l)
return int(binascii.hexlify(string), 16)
# these methods are useful for the sigencode= argument to SK.sign() and the
# sigdecode= argument to VK.verify(), and control how the signature is packed
# or unpacked.
def sigencode_strings(r, s, order):
r_str = number_to_string(r, order)
s_str = number_to_string(s, order)
return (r_str, s_str)
def sigencode_string(r, s, order):
"""
Encode the signature to raw format (:term:`raw encoding`)
It's expected that this function will be used as a `sigencode=` parameter
in :func:`ecdsa.keys.SigningKey.sign` method.
:param int r: first parameter of the signature
:param int s: second parameter of the signature
:param int order: the order of the curve over which the signature was
computed
:return: raw encoding of ECDSA signature
:rtype: bytes
"""
# for any given curve, the size of the signature numbers is
# fixed, so just use simple concatenation
r_str, s_str = sigencode_strings(r, s, order)
return r_str + s_str
def sigencode_der(r, s, order):
"""
Encode the signature into the ECDSA-Sig-Value structure using :term:`DER`.
Encodes the signature to the following :term:`ASN.1` structure::
Ecdsa-Sig-Value ::= SEQUENCE {
r INTEGER,
s INTEGER
}
It's expected that this function will be used as a `sigencode=` parameter
in :func:`ecdsa.keys.SigningKey.sign` method.
:param int r: first parameter of the signature
:param int s: second parameter of the signature
:param int order: the order of the curve over which the signature was
computed
:return: DER encoding of ECDSA signature
:rtype: bytes
"""
return der.encode_sequence(der.encode_integer(r), der.encode_integer(s))
# canonical versions of sigencode methods
# these enforce low S values, by negating the value (modulo the order) if
# above order/2 see CECKey::Sign()
# https://github.com/bitcoin/bitcoin/blob/master/src/key.cpp#L214
def sigencode_strings_canonize(r, s, order):
if s > order / 2:
s = order - s
return sigencode_strings(r, s, order)
def sigencode_string_canonize(r, s, order):
if s > order / 2:
s = order - s
return sigencode_string(r, s, order)
def sigencode_der_canonize(r, s, order):
if s > order / 2:
s = order - s
return sigencode_der(r, s, order)
class MalformedSignature(Exception):
"""
Raised by decoding functions when the signature is malformed.
Malformed in this context means that the relevant strings or integers
do not match what a signature over provided curve would create. Either
because the byte strings have incorrect lengths or because the encoded
values are too large.
"""
pass
def sigdecode_string(signature, order):
"""
Decoder for :term:`raw encoding` of ECDSA signatures.
raw encoding is a simple concatenation of the two integers that comprise
the signature, with each encoded using the same amount of bytes depending
on curve size/order.
It's expected that this function will be used as the `sigdecode=`
parameter to the :func:`ecdsa.keys.VerifyingKey.verify` method.
:param signature: encoded signature
:type signature: bytes like object
:param order: order of the curve over which the signature was computed
:type order: int
:raises MalformedSignature: when the encoding of the signature is invalid
:return: tuple with decoded 'r' and 's' values of signature
:rtype: tuple of ints
"""
signature = normalise_bytes(signature)
l = orderlen(order)
if not len(signature) == 2 * l:
raise MalformedSignature(
"Invalid length of signature, expected {0} bytes long, "
"provided string is {1} bytes long".format(2 * l, len(signature))
)
r = string_to_number_fixedlen(signature[:l], order)
s = string_to_number_fixedlen(signature[l:], order)
return r, s
def sigdecode_strings(rs_strings, order):
"""
Decode the signature from two strings.
First string needs to be a big endian encoding of 'r', second needs to
be a big endian encoding of the 's' parameter of an ECDSA signature.
It's expected that this function will be used as the `sigdecode=`
parameter to the :func:`ecdsa.keys.VerifyingKey.verify` method.
:param list rs_strings: list of two bytes-like objects, each encoding one
parameter of signature
:param int order: order of the curve over which the signature was computed
:raises MalformedSignature: when the encoding of the signature is invalid
:return: tuple with decoded 'r' and 's' values of signature
:rtype: tuple of ints
"""
if not len(rs_strings) == 2:
raise MalformedSignature(
"Invalid number of strings provided: {0}, expected 2".format(
len(rs_strings)
)
)
(r_str, s_str) = rs_strings
r_str = normalise_bytes(r_str)
s_str = normalise_bytes(s_str)
l = orderlen(order)
if not len(r_str) == l:
raise MalformedSignature(
"Invalid length of first string ('r' parameter), "
"expected {0} bytes long, provided string is {1} "
"bytes long".format(l, len(r_str))
)
if not len(s_str) == l:
raise MalformedSignature(
"Invalid length of second string ('s' parameter), "
"expected {0} bytes long, provided string is {1} "
"bytes long".format(l, len(s_str))
)
r = string_to_number_fixedlen(r_str, order)
s = string_to_number_fixedlen(s_str, order)
return r, s
def sigdecode_der(sig_der, order):
"""
Decoder for DER format of ECDSA signatures.
DER format of signature is one that uses the :term:`ASN.1` :term:`DER`
rules to encode it as a sequence of two integers::
Ecdsa-Sig-Value ::= SEQUENCE {
r INTEGER,
s INTEGER
}
It's expected that this function will be used as as the `sigdecode=`
parameter to the :func:`ecdsa.keys.VerifyingKey.verify` method.
:param sig_der: encoded signature
:type sig_der: bytes like object
:param order: order of the curve over which the signature was computed
:type order: int
:raises UnexpectedDER: when the encoding of signature is invalid
:return: tuple with decoded 'r' and 's' values of signature
:rtype: tuple of ints
"""
sig_der = normalise_bytes(sig_der)
# return der.encode_sequence(der.encode_integer(r), der.encode_integer(s))
rs_strings, empty = der.remove_sequence(sig_der)
if empty != b"":
raise der.UnexpectedDER(
"trailing junk after DER sig: %s" % binascii.hexlify(empty)
)
r, rest = der.remove_integer(rs_strings)
s, empty = der.remove_integer(rest)
if empty != b"":
raise der.UnexpectedDER(
"trailing junk after DER numbers: %s" % binascii.hexlify(empty)
)
return r, s

View file

@ -0,0 +1,491 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import base64
from copy import deepcopy
import json
import os
import time
try:
from urllib.parse import urlparse
except ImportError: # pragma nocover
from urlparse import urlparse
import six
import http_ece
import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from py_vapid import Vapid, Vapid01
class WebPushException(Exception):
"""Web Push failure.
This may contain the requests.Response
"""
def __init__(self, message, response=None):
self.message = message
self.response = response
def __str__(self):
extra = ""
if self.response:
try:
extra = ", Response {}".format(
self.response.text,
)
except AttributeError:
extra = ", Response {}".format(self.response)
return "WebPushException: {}{}".format(self.message, extra)
class CaseInsensitiveDict(dict):
"""A dictionary that has case-insensitive keys"""
def __init__(self, data={}, **kwargs):
for key in data:
dict.__setitem__(self, key.lower(), data[key])
self.update(kwargs)
def __contains__(self, key):
return dict.__contains__(self, key.lower())
def __setitem__(self, key, value):
dict.__setitem__(self, key.lower(), value)
def __getitem__(self, key):
return dict.__getitem__(self, key.lower())
def __delitem__(self, key):
dict.__delitem__(self, key.lower())
def get(self, key, default=None):
try:
return self.__getitem__(key)
except KeyError:
return default
def update(self, data):
for key in data:
self.__setitem__(key, data[key])
class WebPusher:
"""WebPusher encrypts a data block using HTTP Encrypted Content Encoding
for WebPush.
See https://tools.ietf.org/html/draft-ietf-webpush-protocol-04
for the current specification, and
https://developer.mozilla.org/en-US/docs/Web/API/Push_API for an
overview of Web Push.
Example of use:
The javascript promise handler for PushManager.subscribe()
receives a subscription_info object. subscription_info.getJSON()
will return a JSON representation.
(e.g.
.. code-block:: javascript
subscription_info.getJSON() ==
{"endpoint": "https://push.server.com/...",
"keys":{"auth": "...", "p256dh": "..."}
}
)
This subscription_info block can be stored.
To send a subscription update:
.. code-block:: python
# Optional
# headers = py_vapid.sign({"aud": "https://push.server.com/",
"sub": "mailto:your_admin@your.site.com"})
data = "Mary had a little lamb, with a nice mint jelly"
WebPusher(subscription_info).send(data, headers)
"""
subscription_info = {}
valid_encodings = [
# "aesgcm128", # this is draft-0, but DO NOT USE.
"aesgcm", # draft-httpbis-encryption-encoding-01
"aes128gcm" # RFC8188 Standard encoding
]
verbose = False
def __init__(self, subscription_info, requests_session=None,
verbose=False):
"""Initialize using the info provided by the client PushSubscription
object (See
https://developer.mozilla.org/en-US/docs/Web/API/PushManager/subscribe)
:param subscription_info: a dict containing the subscription_info from
the client.
:type subscription_info: dict
:param requests_session: a requests.Session object to optimize requests
to the same client.
:type requests_session: requests.Session
:param verbose: provide verbose feedback
:type verbose: bool
"""
self.verbose = verbose
if requests_session is None:
self.requests_method = requests
else:
self.requests_method = requests_session
if 'endpoint' not in subscription_info:
raise WebPushException("subscription_info missing endpoint URL")
self.subscription_info = deepcopy(subscription_info)
self.auth_key = self.receiver_key = None
if 'keys' in subscription_info:
keys = self.subscription_info['keys']
for k in ['p256dh', 'auth']:
if keys.get(k) is None:
raise WebPushException("Missing keys value: {}".format(k))
if isinstance(keys[k], six.text_type):
keys[k] = bytes(keys[k].encode('utf8'))
receiver_raw = base64.urlsafe_b64decode(
self._repad(keys['p256dh']))
if len(receiver_raw) != 65 and receiver_raw[0] != "\x04":
raise WebPushException("Invalid p256dh key specified")
self.receiver_key = receiver_raw
self.auth_key = base64.urlsafe_b64decode(
self._repad(keys['auth']))
def verb(self, msg, *args, **kwargs):
if self.verbose:
print(msg.format(*args, **kwargs))
def _repad(self, data):
"""Add base64 padding to the end of a string, if required"""
return data + b"===="[:len(data) % 4]
def encode(self, data, content_encoding="aes128gcm"):
"""Encrypt the data.
:param data: A serialized block of byte data (String, JSON, bit array,
etc.) Make sure that whatever you send, your client knows how
to understand it.
:type data: str
:param content_encoding: The content_encoding type to use to encrypt
the data. Defaults to RFC8188 "aes128gcm". The previous draft-01 is
"aesgcm", however this format is now deprecated.
:type content_encoding: enum("aesgcm", "aes128gcm")
"""
# Salt is a random 16 byte array.
if not data:
self.verb("No data found...")
return
if not self.auth_key or not self.receiver_key:
raise WebPushException("No keys specified in subscription info")
self.verb("Encoding data...")
salt = None
if content_encoding not in self.valid_encodings:
raise WebPushException("Invalid content encoding specified. "
"Select from " +
json.dumps(self.valid_encodings))
if content_encoding == "aesgcm":
self.verb("Generating salt for aesgcm...")
salt = os.urandom(16)
# The server key is an ephemeral ECDH key used only for this
# transaction
server_key = ec.generate_private_key(ec.SECP256R1, default_backend())
crypto_key = server_key.public_key().public_bytes(
encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint
)
if isinstance(data, six.text_type):
data = bytes(data.encode('utf8'))
if content_encoding == "aes128gcm":
self.verb("Encrypting to aes128gcm...")
encrypted = http_ece.encrypt(
data,
salt=salt,
private_key=server_key,
dh=self.receiver_key,
auth_secret=self.auth_key,
version=content_encoding)
reply = CaseInsensitiveDict({
'body': encrypted
})
else:
self.verb("Encrypting to aesgcm...")
crypto_key = base64.urlsafe_b64encode(crypto_key).strip(b'=')
encrypted = http_ece.encrypt(
data,
salt=salt,
private_key=server_key,
keyid=crypto_key.decode(),
dh=self.receiver_key,
auth_secret=self.auth_key,
version=content_encoding)
reply = CaseInsensitiveDict({
'crypto_key': crypto_key,
'body': encrypted,
})
if salt:
reply['salt'] = base64.urlsafe_b64encode(salt).strip(b'=')
return reply
def as_curl(self, endpoint, encoded_data, headers):
"""Return the send as a curl command.
Useful for debugging. This will write out the encoded data to a local
file named `encrypted.data`
:param endpoint: Push service endpoint URL
:type endpoint: basestring
:param encoded_data: byte array of encoded data
:type encoded_data: bytearray
:param headers: Additional headers for the send
:type headers: dict
:returns string
"""
header_list = [
'-H "{}: {}" \\ \n'.format(
key.lower(), val) for key, val in headers.items()
]
data = ""
if encoded_data:
with open("encrypted.data", "wb") as f:
f.write(encoded_data)
data = "--data-binary @encrypted.data"
if 'content-length' not in headers:
self.verb("Generating content-length header...")
header_list.append(
'-H "content-length: {}" \\ \n'.format(len(encoded_data)))
return ("""curl -vX POST {url} \\\n{headers}{data}""".format(
url=endpoint, headers="".join(header_list), data=data))
def send(self, data=None, headers=None, ttl=0, gcm_key=None, reg_id=None,
content_encoding="aes128gcm", curl=False, timeout=None):
"""Encode and send the data to the Push Service.
:param data: A serialized block of data (see encode() ).
:type data: str
:param headers: A dictionary containing any additional HTTP headers.
:type headers: dict
:param ttl: The Time To Live in seconds for this message if the
recipient is not online. (Defaults to "0", which discards the
message immediately if the recipient is unavailable.)
:type ttl: int
:param gcm_key: API key obtained from the Google Developer Console.
Needed if endpoint is https://android.googleapis.com/gcm/send
:type gcm_key: string
:param reg_id: registration id of the recipient. If not provided,
it will be extracted from the endpoint.
:type reg_id: str
:param content_encoding: ECE content encoding (defaults to "aes128gcm")
:type content_encoding: str
:param curl: Display output as `curl` command instead of sending
:type curl: bool
:param timeout: POST requests timeout
:type timeout: float or tuple
"""
# Encode the data.
if headers is None:
headers = dict()
encoded = {}
headers = CaseInsensitiveDict(headers)
if data:
encoded = self.encode(data, content_encoding)
if "crypto_key" in encoded:
# Append the p256dh to the end of any existing crypto-key
crypto_key = headers.get("crypto-key", "")
if crypto_key:
# due to some confusion by a push service provider, we
# should use ';' instead of ',' to append the headers.
# see
# https://github.com/webpush-wg/webpush-encryption/issues/6
crypto_key += ';'
crypto_key += (
"dh=" + encoded["crypto_key"].decode('utf8'))
headers.update({
'crypto-key': crypto_key
})
if "salt" in encoded:
headers.update({
'encryption': "salt=" + encoded['salt'].decode('utf8')
})
headers.update({
'content-encoding': content_encoding,
})
if gcm_key:
# guess if it is a legacy GCM project key or actual FCM key
# gcm keys are all about 40 chars (use 100 for confidence),
# fcm keys are 153-175 chars
if len(gcm_key) < 100:
self.verb("Guessing this is legacy GCM...")
endpoint = 'https://android.googleapis.com/gcm/send'
else:
self.verb("Guessing this is FCM...")
endpoint = 'https://fcm.googleapis.com/fcm/send'
reg_ids = []
if not reg_id:
reg_id = self.subscription_info['endpoint'].rsplit('/', 1)[-1]
self.verb("Fetching out registration id: {}", reg_id)
reg_ids.append(reg_id)
gcm_data = dict()
gcm_data['registration_ids'] = reg_ids
if data:
gcm_data['raw_data'] = base64.b64encode(
encoded.get('body')).decode('utf8')
gcm_data['time_to_live'] = int(
headers['ttl'] if 'ttl' in headers else ttl)
encoded_data = json.dumps(gcm_data)
headers.update({
'Authorization': 'key='+gcm_key,
'Content-Type': 'application/json',
})
else:
encoded_data = encoded.get('body')
endpoint = self.subscription_info['endpoint']
if 'ttl' not in headers or ttl:
self.verb("Generating TTL of 0...")
headers['ttl'] = str(ttl or 0)
# Additionally useful headers:
# Authorization / Crypto-Key (VAPID headers)
if curl:
return self.as_curl(endpoint, encoded_data, headers)
self.verb("\nSending request to"
"\n\thost: {}\n\theaders: {}\n\tdata: {}",
endpoint, headers, encoded_data)
resp = self.requests_method.post(endpoint,
data=encoded_data,
headers=headers,
timeout=timeout)
self.verb("\nResponse:\n\tcode: {}\n\tbody: {}\n",
resp.status_code, resp.text or "Empty")
return resp
def webpush(subscription_info,
data=None,
vapid_private_key=None,
vapid_claims=None,
content_encoding="aes128gcm",
curl=False,
timeout=None,
ttl=0,
verbose=False,
headers=None,
requests_session=None):
"""
One call solution to endcode and send `data` to the endpoint
contained in `subscription_info` using optional VAPID auth headers.
in example:
.. code-block:: python
from pywebpush import python
webpush(
subscription_info={
"endpoint": "https://push.example.com/v1/abcd",
"keys": {"p256dh": "0123abcd...",
"auth": "001122..."}
},
data="Mary had a little lamb, with a nice mint jelly",
vapid_private_key="path/to/key.pem",
vapid_claims={"sub": "YourNameHere@example.com"}
)
No additional method call is required. Any non-success will throw a
`WebPushException`.
:param subscription_info: Provided by the client call
:type subscription_info: dict
:param data: Serialized data to send
:type data: str
:param vapid_private_key: Vapid instance or path to vapid private key PEM \
or encoded str
:type vapid_private_key: Union[Vapid, str]
:param vapid_claims: Dictionary of claims ('sub' required)
:type vapid_claims: dict
:param content_encoding: Optional content type string
:type content_encoding: str
:param curl: Return as "curl" string instead of sending
:type curl: bool
:param timeout: POST requests timeout
:type timeout: float or tuple
:param ttl: Time To Live
:type ttl: int
:param verbose: Provide verbose feedback
:type verbose: bool
:return requests.Response or string
:param headers: Dictionary of extra HTTP headers to include
:type headers: dict
"""
if headers is None:
headers = dict()
else:
# Ensure we don't leak VAPID headers by mutating the passed in dict.
headers = headers.copy()
vapid_headers = None
if vapid_claims:
if verbose:
print("Generating VAPID headers...")
if not vapid_claims.get('aud'):
url = urlparse(subscription_info.get('endpoint'))
aud = "{}://{}".format(url.scheme, url.netloc)
vapid_claims['aud'] = aud
# Remember, passed structures are mutable in python.
# It's possible that a previously set `exp` field is no longer valid.
if (not vapid_claims.get('exp')
or vapid_claims.get('exp') < int(time.time())):
# encryption lives for 12 hours
vapid_claims['exp'] = int(time.time()) + (12 * 60 * 60)
if verbose:
print("Setting VAPID expry to {}...".format(
vapid_claims['exp']))
if not vapid_private_key:
raise WebPushException("VAPID dict missing 'private_key'")
if isinstance(vapid_private_key, Vapid01):
vv = vapid_private_key
elif os.path.isfile(vapid_private_key):
# Presume that key from file is handled correctly by
# py_vapid.
vv = Vapid.from_file(
private_key_file=vapid_private_key) # pragma no cover
else:
vv = Vapid.from_string(private_key=vapid_private_key)
if verbose:
print("\t claims: {}".format(vapid_claims))
vapid_headers = vv.sign(vapid_claims)
if verbose:
print("\t headers: {}".format(vapid_headers))
headers.update(vapid_headers)
response = WebPusher(
subscription_info, requests_session=requests_session, verbose=verbose
).send(
data,
headers,
ttl=ttl,
content_encoding=content_encoding,
curl=curl,
timeout=timeout,
)
if not curl and response.status_code > 202:
raise WebPushException("Push failed: {} {}\nResponse body:{}".format(
response.status_code, response.reason, response.text),
response=response)
return response

View file

@ -0,0 +1,67 @@
import argparse
import os
import json
from pywebpush import webpush
def get_config():
parser = argparse.ArgumentParser(description="WebPush tool")
parser.add_argument("--data", '-d', help="Data file")
parser.add_argument("--info", "-i", help="Subscription Info JSON file")
parser.add_argument("--head", help="Header Info JSON file")
parser.add_argument("--claims", help="Vapid claim file")
parser.add_argument("--key", help="Vapid private key file path")
parser.add_argument("--curl", help="Don't send, display as curl command",
default=False, action="store_true")
parser.add_argument("--encoding", default="aes128gcm")
parser.add_argument("--verbose", "-v", help="Provide verbose feedback",
default=False, action="store_true")
args = parser.parse_args()
if not args.info:
raise Exception("Subscription Info argument missing.")
if not os.path.exists(args.info):
raise Exception("Subscription Info file missing.")
try:
with open(args.info) as r:
args.sub_info = json.loads(r.read())
if args.data:
with open(args.data) as r:
args.data = r.read()
if args.head:
with open(args.head) as r:
args.head = json.loads(r.read())
if args.claims:
if not args.key:
raise Exception("No private --key specified for claims")
with open(args.claims) as r:
args.claims = json.loads(r.read())
except Exception as ex:
print("Couldn't read input {}.".format(ex))
raise ex
return args
def main():
""" Send data """
try:
args = get_config()
result = webpush(
args.sub_info,
data=args.data,
vapid_private_key=args.key,
vapid_claims=args.claims,
curl=args.curl,
content_encoding=args.encoding,
verbose=args.verbose,
headers=args.head)
print(result)
except Exception as ex:
print("ERROR: {}".format(ex))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,390 @@
import base64
import json
import os
import unittest
import time
from mock import patch, Mock
import http_ece
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
import py_vapid
from pywebpush import WebPusher, WebPushException, CaseInsensitiveDict, webpush
class WebpushTestCase(unittest.TestCase):
# This is a exported DER formatted string of an ECDH public key
# This was lifted from the py_vapid tests.
vapid_key = (
"MHcCAQEEIPeN1iAipHbt8+/KZ2NIF8NeN24jqAmnMLFZEMocY8RboAoGCCqGSM49"
"AwEHoUQDQgAEEJwJZq/GN8jJbo1GGpyU70hmP2hbWAUpQFKDByKB81yldJ9GTklB"
"M5xqEwuPM7VuQcyiLDhvovthPIXx+gsQRQ=="
)
def _gen_subscription_info(self,
recv_key=None,
endpoint="https://example.com/"):
if not recv_key:
recv_key = ec.generate_private_key(ec.SECP256R1, default_backend())
return {
"endpoint": endpoint,
"keys": {
'auth': base64.urlsafe_b64encode(os.urandom(16)).strip(b'='),
'p256dh': self._get_pubkey_str(recv_key),
}
}
def _get_pubkey_str(self, priv_key):
return base64.urlsafe_b64encode(
priv_key.public_key().public_bytes(
encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint
)).strip(b'=')
def test_init(self):
# use static values so we know what to look for in the reply
subscription_info = {
u"endpoint": u"https://example.com/",
u"keys": {
u"p256dh": (u"BOrnIslXrUow2VAzKCUAE4sIbK00daEZCswOcf8m3T"
"F8V82B-OpOg5JbmYLg44kRcvQC1E2gMJshsUYA-_zMPR8"),
u"auth": u"k8JV6sjdbhAi1n3_LDBLvA"
}
}
rk_decode = (b'\x04\xea\xe7"\xc9W\xadJ0\xd9P3(%\x00\x13\x8b'
b'\x08l\xad4u\xa1\x19\n\xcc\x0eq\xff&\xdd1'
b'|W\xcd\x81\xf8\xeaN\x83\x92[\x99\x82\xe0\xe3'
b'\x89\x11r\xf4\x02\xd4M\xa00\x9b!\xb1F\x00'
b'\xfb\xfc\xcc=\x1f')
self.assertRaises(
WebPushException,
WebPusher,
{"keys": {'p256dh': 'AAA=', 'auth': 'AAA='}})
self.assertRaises(
WebPushException,
WebPusher,
{"endpoint": "https://example.com", "keys": {'p256dh': 'AAA='}})
self.assertRaises(
WebPushException,
WebPusher,
{"endpoint": "https://example.com", "keys": {'auth': 'AAA='}})
self.assertRaises(
WebPushException,
WebPusher,
{"endpoint": "https://example.com",
"keys": {'p256dh': 'AAA=', 'auth': 'AAA='}})
push = WebPusher(subscription_info)
assert push.subscription_info != subscription_info
assert push.subscription_info['keys'] != subscription_info['keys']
assert push.subscription_info['endpoint'] == subscription_info['endpoint']
assert push.receiver_key == rk_decode
assert push.auth_key == b'\x93\xc2U\xea\xc8\xddn\x10"\xd6}\xff,0K\xbc'
def test_encode(self):
for content_encoding in ["aesgcm", "aes128gcm"]:
recv_key = ec.generate_private_key(
ec.SECP256R1, default_backend())
subscription_info = self._gen_subscription_info(recv_key)
data = "Mary had a little lamb, with some nice mint jelly"
push = WebPusher(subscription_info)
encoded = push.encode(data, content_encoding=content_encoding)
"""
crypto_key = base64.urlsafe_b64encode(
self._get_pubkey_str(recv_key)
).strip(b'=')
"""
# Convert these b64 strings into their raw, binary form.
raw_salt = None
if 'salt' in encoded:
raw_salt = base64.urlsafe_b64decode(
push._repad(encoded['salt']))
raw_dh = None
if content_encoding != "aes128gcm":
raw_dh = base64.urlsafe_b64decode(
push._repad(encoded['crypto_key']))
raw_auth = base64.urlsafe_b64decode(
push._repad(subscription_info['keys']['auth']))
decoded = http_ece.decrypt(
encoded['body'],
salt=raw_salt,
dh=raw_dh,
private_key=recv_key,
auth_secret=raw_auth,
version=content_encoding
)
assert decoded.decode('utf8') == data
def test_bad_content_encoding(self):
subscription_info = self._gen_subscription_info()
data = "Mary had a little lamb, with some nice mint jelly"
push = WebPusher(subscription_info)
self.assertRaises(WebPushException,
push.encode,
data,
content_encoding="aesgcm128")
@patch("requests.post")
def test_send(self, mock_post):
subscription_info = self._gen_subscription_info()
headers = {"Crypto-Key": "pre-existing",
"Authentication": "bearer vapid"}
data = "Mary had a little lamb"
WebPusher(subscription_info).send(data, headers)
assert subscription_info.get('endpoint') == mock_post.call_args[0][0]
pheaders = mock_post.call_args[1].get('headers')
assert pheaders.get('ttl') == '0'
assert pheaders.get('AUTHENTICATION') == headers.get('Authentication')
ckey = pheaders.get('crypto-key')
assert 'pre-existing' in ckey
assert pheaders.get('content-encoding') == 'aes128gcm'
@patch("requests.post")
def test_send_vapid(self, mock_post):
mock_post.return_value.status_code = 200
subscription_info = self._gen_subscription_info()
data = "Mary had a little lamb"
webpush(
subscription_info=subscription_info,
data=data,
vapid_private_key=self.vapid_key,
vapid_claims={"sub": "mailto:ops@example.com"},
content_encoding="aesgcm",
headers={"Test-Header": "test-value"}
)
assert subscription_info.get('endpoint') == mock_post.call_args[0][0]
pheaders = mock_post.call_args[1].get('headers')
assert pheaders.get('ttl') == '0'
def repad(str):
return str + "===="[:len(str) % 4]
auth = json.loads(
base64.urlsafe_b64decode(
repad(pheaders['authorization'].split('.')[1])
).decode('utf8')
)
assert subscription_info.get('endpoint').startswith(auth['aud'])
assert 'vapid' in pheaders.get('authorization')
ckey = pheaders.get('crypto-key')
assert 'dh=' in ckey
assert pheaders.get('content-encoding') == 'aesgcm'
assert pheaders.get('test-header') == 'test-value'
@patch.object(WebPusher, "send")
@patch.object(py_vapid.Vapid, "sign")
def test_webpush_vapid_instance(self, vapid_sign, pusher_send):
pusher_send.return_value.status_code = 200
subscription_info = self._gen_subscription_info()
data = "Mary had a little lamb"
vapid_key = py_vapid.Vapid.from_string(self.vapid_key)
claims = dict(sub="mailto:ops@example.com", aud="https://example.com")
webpush(
subscription_info=subscription_info,
data=data,
vapid_private_key=vapid_key,
vapid_claims=claims,
)
vapid_sign.assert_called_once_with(claims)
pusher_send.assert_called_once()
@patch.object(WebPusher, "send")
@patch.object(py_vapid.Vapid, "sign")
def test_webpush_vapid_exp(self, vapid_sign, pusher_send):
pusher_send.return_value.status_code = 200
subscription_info = self._gen_subscription_info()
data = "Mary had a little lamb"
vapid_key = py_vapid.Vapid.from_string(self.vapid_key)
claims = dict(sub="mailto:ops@example.com",
aud="https://example.com",
exp=int(time.time() - 48600))
webpush(
subscription_info=subscription_info,
data=data,
vapid_private_key=vapid_key,
vapid_claims=claims,
)
vapid_sign.assert_called_once_with(claims)
pusher_send.assert_called_once()
assert claims['exp'] > int(time.time())
@patch("requests.post")
def test_send_bad_vapid_no_key(self, mock_post):
mock_post.return_value.status_code = 200
subscription_info = self._gen_subscription_info()
data = "Mary had a little lamb"
self.assertRaises(
WebPushException,
webpush,
subscription_info=subscription_info,
data=data,
vapid_claims={
"aud": "https://example.com",
"sub": "mailto:ops@example.com"
})
@patch("requests.post")
def test_send_bad_vapid_bad_return(self, mock_post):
mock_post.return_value.status_code = 410
subscription_info = self._gen_subscription_info()
data = "Mary had a little lamb"
self.assertRaises(
WebPushException,
webpush,
subscription_info=subscription_info,
data=data,
vapid_claims={
"aud": "https://example.com",
"sub": "mailto:ops@example.com"
},
vapid_private_key=self.vapid_key)
@patch("requests.post")
def test_send_empty(self, mock_post):
subscription_info = self._gen_subscription_info()
headers = {"Crypto-Key": "pre-existing",
"Authentication": "bearer vapid"}
WebPusher(subscription_info).send('', headers)
assert subscription_info.get('endpoint') == mock_post.call_args[0][0]
pheaders = mock_post.call_args[1].get('headers')
assert pheaders.get('ttl') == '0'
assert 'encryption' not in pheaders
assert pheaders.get('AUTHENTICATION') == headers.get('Authentication')
ckey = pheaders.get('crypto-key')
assert 'pre-existing' in ckey
def test_encode_empty(self):
subscription_info = self._gen_subscription_info()
headers = {"Crypto-Key": "pre-existing",
"Authentication": "bearer vapid"}
encoded = WebPusher(subscription_info).encode('', headers)
assert encoded is None
def test_encode_no_crypto(self):
subscription_info = self._gen_subscription_info()
del(subscription_info['keys'])
headers = {"Crypto-Key": "pre-existing",
"Authentication": "bearer vapid"}
data = 'Something'
pusher = WebPusher(subscription_info)
self.assertRaises(
WebPushException,
pusher.encode,
data,
headers)
@patch("requests.post")
def test_send_no_headers(self, mock_post):
subscription_info = self._gen_subscription_info()
data = "Mary had a little lamb"
WebPusher(subscription_info).send(data)
assert subscription_info.get('endpoint') == mock_post.call_args[0][0]
pheaders = mock_post.call_args[1].get('headers')
assert pheaders.get('ttl') == '0'
assert pheaders.get('content-encoding') == 'aes128gcm'
@patch("pywebpush.open")
def test_as_curl(self, opener):
subscription_info = self._gen_subscription_info()
result = webpush(
subscription_info,
data="Mary had a little lamb",
vapid_claims={
"aud": "https://example.com",
"sub": "mailto:ops@example.com"
},
vapid_private_key=self.vapid_key,
curl=True
)
for s in [
"curl -vX POST https://example.com",
"-H \"content-encoding: aes128gcm\"",
"-H \"authorization: vapid ",
"-H \"ttl: 0\"",
"-H \"content-length:"
]:
assert s in result, "missing: {}".format(s)
def test_ci_dict(self):
ci = CaseInsensitiveDict({"Foo": "apple", "bar": "banana"})
assert 'apple' == ci["foo"]
assert 'apple' == ci.get("FOO")
assert 'apple' == ci.get("Foo")
del (ci['FOO'])
assert ci.get('Foo') is None
@patch("requests.post")
def test_gcm(self, mock_post):
subscription_info = self._gen_subscription_info(
None,
endpoint="https://android.googleapis.com/gcm/send/regid123")
headers = {"Crypto-Key": "pre-existing",
"Authentication": "bearer vapid"}
data = "Mary had a little lamb"
wp = WebPusher(subscription_info)
wp.send(data, headers, gcm_key="gcm_key_value")
pdata = json.loads(mock_post.call_args[1].get('data'))
pheaders = mock_post.call_args[1].get('headers')
assert pdata["registration_ids"][0] == "regid123"
assert pheaders.get("authorization") == "key=gcm_key_value"
assert pheaders.get("content-type") == "application/json"
@patch("requests.post")
def test_timeout(self, mock_post):
mock_post.return_value.status_code = 200
subscription_info = self._gen_subscription_info()
WebPusher(subscription_info).send(timeout=5.2)
assert mock_post.call_args[1].get('timeout') == 5.2
webpush(subscription_info, timeout=10.001)
assert mock_post.call_args[1].get('timeout') == 10.001
@patch("requests.Session")
def test_send_using_requests_session(self, mock_session):
subscription_info = self._gen_subscription_info()
headers = {"Crypto-Key": "pre-existing",
"Authentication": "bearer vapid"}
data = "Mary had a little lamb"
WebPusher(subscription_info,
requests_session=mock_session).send(data, headers)
assert subscription_info.get(
'endpoint') == mock_session.post.call_args[0][0]
pheaders = mock_session.post.call_args[1].get('headers')
assert pheaders.get('ttl') == '0'
assert pheaders.get('AUTHENTICATION') == headers.get('Authentication')
ckey = pheaders.get('crypto-key')
assert 'pre-existing' in ckey
assert pheaders.get('content-encoding') == 'aes128gcm'
class WebpushExceptionTestCase(unittest.TestCase):
def test_exception(self):
from requests import Response
exp = WebPushException("foo")
assert ("{}".format(exp) == "WebPushException: foo")
# Really should try to load the response to verify, but this mock
# covers what we need.
response = Mock(spec=Response)
response.text = (
'{"code": 401, "errno": 109, "error": '
'"Unauthorized", "more_info": "http://'
'autopush.readthedocs.io/en/latest/htt'
'p.html#error-codes", "message": "Requ'
'est did not validate missing authoriz'
'ation header"}')
response.json.return_value = json.loads(response.text)
response.status_code = 401
response.reason = "Unauthorized"
exp = WebPushException("foo", response)
assert "{}".format(exp) == "WebPushException: foo, Response {}".format(
response.text)
assert '{}'.format(exp.response), '<Response [401]>'
assert exp.response.json().get('errno') == 109
exp = WebPushException("foo", [1, 2, 3])
assert '{}'.format(exp) == "WebPushException: foo, Response [1, 2, 3]"