Factor out most crypto bits into their own functions and modules, where they can sensibly be tested.

This commit is contained in:
William Grant
2011-03-13 18:55:30 +11:00
3 changed files with 383 additions and 212 deletions
+34 -212
View File
@@ -1,11 +1,11 @@
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the author be held liable for any damages
# arising from the use of this software.
#
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
@@ -13,8 +13,9 @@
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
#
#
# Copyright (c) 2008 Greg Hewgill http://hewgill.com
# Copyright (c) 2011 William Grant <me@williamgrant.id.au>
import base64
import hashlib
@@ -23,6 +24,14 @@ import time
import dns.resolver
from dkim.crypto import (
DigestTooLargeError,
parse_private_key,
parse_public_key,
RSASSA_PKCS1_v1_5_sign,
RSASSA_PKCS1_v1_5_verify,
UnparsableKeyError,
)
from dkim.util import (
InvalidTagValueList,
parse_tag_value,
@@ -99,29 +108,6 @@ def _remove(s, t):
assert i >= 0
return s[:i] + s[i+len(t):]
def EMSA_PKCS1_v1_5_encode(digest, modlen, hashid):
"""Encode a digest with EMSA-PKCS1-v1_5.
Defined in RFC3447 section 9.2.
@param digest: A digest value to encode.
@param modlen: The desired message length.
@param hashid: The ID of the hash used to generate the digest.
"""
dinfo = asn1_build(
(SEQUENCE, [
(SEQUENCE, [
(OBJECT_IDENTIFIER, hashid),
(NULL, None),
]),
(OCTET_STRING, digest),
]))
if len(dinfo)+3 > modlen:
raise ParameterError("Hash too large for modulus")
return "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo
def hash_headers(hasher, canonicalize_headers, headers, include_headers,
sigheaders, sig):
"""Sign message header fields."""
@@ -146,22 +132,6 @@ def hash_headers(hasher, canonicalize_headers, headers, include_headers,
hasher.update(x[1])
def parse_public_key(data):
"""Parse an RSA public key.
@param data: A DER-encoded X.509 subjectPublicKeyInfo
containing an RFC3447 RSAPublicKey.
"""
x = asn1_parse(ASN1_Object, data)
# Not sure why the [1:] is necessary to skip a byte.
pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
pk = {
'modulus': pkd[0][0],
'publicExponent': pkd[0][1],
}
return pk
def validate_signature_fields(sig, debuglog=None):
"""Validate DKIM-Signature fields.
@@ -217,148 +187,10 @@ def validate_signature_fields(sig, debuglog=None):
return False
return True
INTEGER = 0x02
BIT_STRING = 0x03
OCTET_STRING = 0x04
NULL = 0x05
OBJECT_IDENTIFIER = 0x06
SEQUENCE = 0x30
ASN1_Object = [
(SEQUENCE, [
(SEQUENCE, [
(OBJECT_IDENTIFIER,),
(NULL,),
]),
(BIT_STRING,),
])
]
ASN1_RSAPublicKey = [
(SEQUENCE, [
(INTEGER,),
(INTEGER,),
])
]
ASN1_RSAPrivateKey = [
(SEQUENCE, [
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
])
]
def asn1_parse(template, data):
"""Parse a data structure according to ASN.1 template.
@param template: A list of tuples comprising the ASN.1 template.
@param data: A list of bytes to parse.
"""
r = []
i = 0
for t in template:
tag = ord(data[i])
i += 1
if tag == t[0]:
length = ord(data[i])
i += 1
if length & 0x80:
n = length & 0x7f
length = 0
for j in range(n):
length = (length << 8) | ord(data[i])
i += 1
if tag == INTEGER:
n = 0
for j in range(length):
n = (n << 8) | ord(data[i])
i += 1
r.append(n)
elif tag == BIT_STRING:
r.append(data[i:i+length])
i += length
elif tag == NULL:
assert length == 0
r.append(None)
elif tag == OBJECT_IDENTIFIER:
r.append(data[i:i+length])
i += length
elif tag == SEQUENCE:
r.append(asn1_parse(t[1], data[i:i+length]))
i += length
else:
raise KeyFormatError("Unexpected tag in template: %02x" % tag)
else:
raise KeyFormatError("Unexpected tag (got %02x, expecting %02x)" % (tag, t[0]))
return r
def asn1_length(n):
"""Return a string representing a field length in ASN.1 format."""
assert n >= 0
if n < 0x7f:
return chr(n)
r = ""
while n > 0:
r = chr(n & 0xff) + r
n >>= 8
return r
def asn1_build(node):
"""Build an ASN.1 data structure based on pairs of (type, data)."""
if node[0] == OCTET_STRING:
return chr(OCTET_STRING) + asn1_length(len(node[1])) + node[1]
if node[0] == NULL:
assert node[1] is None
return chr(NULL) + asn1_length(0)
elif node[0] == OBJECT_IDENTIFIER:
return chr(OBJECT_IDENTIFIER) + asn1_length(len(node[1])) + node[1]
elif node[0] == SEQUENCE:
r = ""
for x in node[1]:
r += asn1_build(x)
return chr(SEQUENCE) + asn1_length(len(r)) + r
else:
raise InternalError("Unexpected tag in template: %02x" % node[0])
# These values come from RFC 3447, section 9.2 Notes, page 43.
HASHID_SHA1 = "\x2b\x0e\x03\x02\x1a"
HASHID_SHA256 = "\x60\x86\x48\x01\x65\x03\x04\x02\x01"
def str2int(s):
"""Convert an octet string to an integer. Octet string assumed to represent a positive integer."""
r = 0
for c in s:
r = (r << 8) | ord(c)
return r
def int2str(n, length = -1):
"""Convert an integer to an octet string. Number must be positive.
@param n: Number to convert.
@param length: Minimum length, or -1 to return the smallest number of bytes that represent the integer.
"""
assert n >= 0
r = []
while length < 0 or len(r) < length:
r.append(chr(n & 0xff))
n >>= 8
if length < 0 and n == 0: break
r.reverse()
assert length < 0 or len(r) == length
return r
def rfc822_parse(message):
"""Parse a message in RFC822 format.
@@ -443,18 +275,10 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple
raise KeyFormatError(str(e))
if debuglog is not None:
print >>debuglog, " ".join("%02x" % ord(x) for x in pkdata)
pka = asn1_parse(ASN1_RSAPrivateKey, pkdata)
pk = {
'version': pka[0][0],
'modulus': pka[0][1],
'publicExponent': pka[0][2],
'privateExponent': pka[0][3],
'prime1': pka[0][4],
'prime2': pka[0][5],
'exponent1': pka[0][6],
'exponent2': pka[0][7],
'coefficient': pka[0][8],
}
try:
pk = parse_private_key(pkdata)
except UnparsableKeyError, e:
raise KeyFormatError(str(e))
if identity is not None and not identity.endswith(domain):
raise ParameterError("identity must end with domain")
@@ -503,10 +327,12 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple
if debuglog is not None:
print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d)
modlen = len(int2str(pk['modulus']))
encoded = EMSA_PKCS1_v1_5_encode(d, modlen, HASHID_SHA256)
sig2 = int2str(pow(str2int(encoded), pk['privateExponent'], pk['modulus']), modlen)
sig += base64.b64encode(''.join(sig2))
try:
sig2 = RSASSA_PKCS1_v1_5_sign(
d, HASHID_SHA256, pk['privateExponent'], pk['modulus'])
except DigestTooLargeError:
raise ParameterError("digest too large for modulus")
sig += base64.b64encode(sig2)
return sig + "\r\n"
@@ -597,10 +423,12 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
pub = parse_tag_value(s)
except InvalidTagValueList:
return False
pk = parse_public_key(base64.b64decode(pub['p']))
modlen = len(int2str(pk['modulus']))
if debuglog is not None:
print >>debuglog, "modlen:", modlen
try:
pk = parse_public_key(base64.b64decode(pub['p']))
except UnparsableKeyError, e:
if debuglog is not None:
print >>debuglog, "could not parse public key: %s" % e
return False
include_headers = re.split(r"\s*:\s*", sig['h'])
h = hasher()
@@ -609,17 +437,11 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
d = h.digest()
if debuglog is not None:
print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d)
signature = base64.b64decode(re.sub(r"\s+", "", sig['b']))
try:
sig2 = EMSA_PKCS1_v1_5_encode(d, modlen, hashid)
except ParameterError:
return RSASSA_PKCS1_v1_5_verify(
d, hashid, signature, pk['publicExponent'], pk['modulus'])
except DigestTooLargeError:
if debuglog is not None:
print >>debuglog, "digest too large for modulus"
return False
if debuglog is not None:
print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2)
print >>debuglog, sig['b']
print >>debuglog, re.sub(r"\s+", "", sig['b'])
v = int2str(pow(str2int(base64.b64decode(re.sub(r"\s+", "", sig['b']))), pk['publicExponent'], pk['modulus']), modlen)
if debuglog is not None:
print >>debuglog, "v:", " ".join("%02x" % ord(x) for x in v)
assert len(v) == len(sig2)
# Byte-by-byte compare of signatures
return not [1 for x in zip(v, sig2) if x[0] != x[1]]
+128
View File
@@ -0,0 +1,128 @@
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the author be held liable for any damages
# arising from the use of this software.
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
# appreciated but is not required.
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
#
# Copyright (c) 2008 Greg Hewgill http://hewgill.com
# Copyright (c) 2011 William Grant <me@williamgrant.id.au>
__all__ = [
'asn1_build',
'asn1_parse',
'ASN1FormatError',
'BIT_STRING',
'INTEGER',
'SEQUENCE',
'OBJECT_IDENTIFIER',
'OCTET_STRING',
'NULL',
]
INTEGER = 0x02
BIT_STRING = 0x03
OCTET_STRING = 0x04
NULL = 0x05
OBJECT_IDENTIFIER = 0x06
SEQUENCE = 0x30
class ASN1FormatError(Exception):
pass
def asn1_parse(template, data):
"""Parse a data structure according to an ASN.1 template.
@param template: tuples comprising the ASN.1 template
@param data: byte string data to parse
@return: decoded structure
"""
r = []
i = 0
for t in template:
tag = ord(data[i])
i += 1
if tag == t[0]:
length = ord(data[i])
i += 1
if length & 0x80:
n = length & 0x7f
length = 0
for j in range(n):
length = (length << 8) | ord(data[i])
i += 1
if tag == INTEGER:
n = 0
for j in range(length):
n = (n << 8) | ord(data[i])
i += 1
r.append(n)
elif tag == BIT_STRING:
r.append(data[i:i+length])
i += length
elif tag == NULL:
assert length == 0
r.append(None)
elif tag == OBJECT_IDENTIFIER:
r.append(data[i:i+length])
i += length
elif tag == SEQUENCE:
r.append(asn1_parse(t[1], data[i:i+length]))
i += length
else:
raise ASN1FormatError(
"Unexpected tag in template: %02x" % tag)
else:
raise ASN1FormatError(
"Unexpected tag (got %02x, expecting %02x)" % (tag, t[0]))
return r
def asn1_length(n):
"""Return a string representing a field length in ASN.1 format.
@param n: integer field length
@return: ASN.1 field length
"""
assert n >= 0
if n < 0x7f:
return chr(n)
r = ""
while n > 0:
r = chr(n & 0xff) + r
n >>= 8
return r
def asn1_build(node):
"""Build a DER-encoded ASN.1 data structure.
@param node: (type, data) tuples comprising the ASN.1 structure
@return: DER-encoded ASN.1 byte string
"""
if node[0] == OCTET_STRING:
return chr(OCTET_STRING) + asn1_length(len(node[1])) + node[1]
if node[0] == NULL:
assert node[1] is None
return chr(NULL) + asn1_length(0)
elif node[0] == OBJECT_IDENTIFIER:
return chr(OBJECT_IDENTIFIER) + asn1_length(len(node[1])) + node[1]
elif node[0] == SEQUENCE:
r = ""
for x in node[1]:
r += asn1_build(x)
return chr(SEQUENCE) + asn1_length(len(r)) + r
else:
raise ASN1FormatError("Unexpected tag in template: %02x" % node[0])
+221
View File
@@ -0,0 +1,221 @@
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the author be held liable for any damages
# arising from the use of this software.
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
# appreciated but is not required.
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
#
# Copyright (c) 2008 Greg Hewgill http://hewgill.com
# Copyright (c) 2011 William Grant <me@williamgrant.id.au>
__all__ = [
'DigestTooLargeError',
'parse_private_key',
'parse_public_key',
'RSASSA_PKCS1_v1_5_sign',
'RSASSA_PKCS1_v1_5_verify',
'UnparsableKeyError',
]
from dkim.asn1 import (
ASN1FormatError,
asn1_build,
asn1_parse,
BIT_STRING,
INTEGER,
SEQUENCE,
OBJECT_IDENTIFIER,
OCTET_STRING,
NULL,
)
ASN1_Object = [
(SEQUENCE, [
(SEQUENCE, [
(OBJECT_IDENTIFIER,),
(NULL,),
]),
(BIT_STRING,),
])
]
ASN1_RSAPublicKey = [
(SEQUENCE, [
(INTEGER,),
(INTEGER,),
])
]
ASN1_RSAPrivateKey = [
(SEQUENCE, [
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
(INTEGER,),
])
]
class DigestTooLargeError(Exception):
"""The digest is too large to fit within the requested length."""
pass
class UnparsableKeyError(Exception):
"""The data could not be parsed as a key."""
pass
def parse_public_key(data):
"""Parse an RSA public key.
@param data: DER-encoded X.509 subjectPublicKeyInfo
containing an RFC3447 RSAPublicKey.
@return: RSA public key
"""
try:
# Not sure why the [1:] is necessary to skip a byte.
x = asn1_parse(ASN1_Object, data)
pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
except ASN1FormatError, e:
raise UnparsableKeyError(str(e))
pk = {
'modulus': pkd[0][0],
'publicExponent': pkd[0][1],
}
return pk
def parse_private_key(data):
"""Parse an RSA private key.
@param data: DER-encoded RFC3447 RSAPrivateKey.
@return: RSA private key
"""
try:
pka = asn1_parse(ASN1_RSAPrivateKey, data)
except ASN1FormatError, e:
raise UnparsableKeyError(str(e))
pk = {
'version': pka[0][0],
'modulus': pka[0][1],
'publicExponent': pka[0][2],
'privateExponent': pka[0][3],
'prime1': pka[0][4],
'prime2': pka[0][5],
'exponent1': pka[0][6],
'exponent2': pka[0][7],
'coefficient': pka[0][8],
}
return pk
def EMSA_PKCS1_v1_5_encode(digest, mlen, hashid):
"""Encode a digest with RFC3447 EMSA-PKCS1-v1_5.
@param digest: digest byte string to encode
@param mlen: desired message length
@param hashid: ID of the hash used to generate the digest
@return: encoded digest byte string
"""
dinfo = asn1_build(
(SEQUENCE, [
(SEQUENCE, [
(OBJECT_IDENTIFIER, hashid),
(NULL, None),
]),
(OCTET_STRING, digest),
]))
if len(dinfo)+3 > mlen:
raise DigestTooLargeError()
return "\x00\x01"+"\xff"*(mlen-len(dinfo)-3)+"\x00"+dinfo
def str2int(s):
"""Convert a byte string to an integer.
@param s: byte string representing a positive integer to convert
@return: converted integer
"""
r = 0
for c in s:
r = (r << 8) | ord(c)
return r
def int2str(n, length=-1):
"""Convert an integer to a byte string.
@param n: positive integer to convert
@param length: minimum length
@return: converted bytestring, of at least the minimum length if it was
specified
"""
assert n >= 0
r = []
while length < 0 or len(r) < length:
r.append(chr(n & 0xff))
n >>= 8
if length < 0 and n == 0:
break
r.reverse()
assert length < 0 or len(r) == length
return ''.join(r)
def perform_rsa(message, exponent, modulus, mlen):
"""Perform RSA signing or verification.
@param message: byte string to operate on
@param exponent: public or private key exponent
@param modulus: key modulus
@param mlen: desired output length
@return: byte string result of the operation
"""
return int2str(pow(str2int(message), exponent, modulus), mlen)
def RSASSA_PKCS1_v1_5_sign(digest, hashid, private_exponent, modulus):
"""Sign a digest with RFC3447 RSASSA-PKCS1-v1_5.
@param digest: digest byte string to sign
@param hashid: ID of the hash used to generate the digest
@param private_exponent: private key exponent
@param modulus: key modulus
@return: signed digest byte string
"""
modlen = len(int2str(modulus))
encoded_digest = EMSA_PKCS1_v1_5_encode(digest, modlen, hashid)
return perform_rsa(encoded_digest, private_exponent, modulus, modlen)
def RSASSA_PKCS1_v1_5_verify(digest, hashid, signature, public_exponent,
modulus):
"""Verify a digest signed with RFC3447 RSASSA-PKCS1-v1_5.
@param digest: digest byte string to check
@param hashid: ID of the hash used to generate the digest
@param signature: signed digest byte string
@param public_exponent: public key exponent
@param modulus: key modulus
@return: True if the signature is valid, False otherwise
"""
modlen = len(int2str(modulus))
encoded_digest = EMSA_PKCS1_v1_5_encode(digest, modlen, hashid)
signed_digest = perform_rsa(signature, public_exponent, modulus, modlen)
return encoded_digest == signed_digest