Implement RSASSA-PKCS1-v1_5 in dkim.crypto, and use that in verify() and sign(). Move str2int/int2str into dkim.crypto. Verification no longer does a constant-time string compare; there is no private key involved on which a timing attack could be performed.

This commit is contained in:
William Grant
2011-03-10 00:03:15 +11:00
parent 5898094fe1
commit c82703cea9
2 changed files with 57 additions and 46 deletions
+7 -46
View File
@@ -9,6 +9,8 @@ from dkim.crypto import (
EMSA_PKCS1_v1_5_encode,
parse_private_key,
parse_public_key,
RSASSA_PKCS1_v1_5_sign,
RSASSA_PKCS1_v1_5_verify,
)
__all__ = [
@@ -152,38 +154,10 @@ def validate_signature_fields(sig, debuglog=None):
return False
return True
def perform_rsa(input, exponent, modulus, modlen):
return int2str(pow(str2int(input), exponent, modulus), modlen)
# These values come from RFC 3447, section 9.2 Notes, page 43.
HASHID_SHA1 = "\x2b\x0e\x03\x02\x1a"
HASHID_SHA256 = "\x60\x86\x48\x01\x65\x03\x04\x02\x01"
def str2int(s):
"""Convert an octet string to an integer. Octet string assumed to represent a positive integer."""
r = 0
for c in s:
r = (r << 8) | ord(c)
return r
def int2str(n, length = -1):
"""Convert an integer to an octet string. Number must be positive.
@param n: Number to convert.
@param length: Minimum length, or -1 to return the smallest number of bytes that represent the integer.
"""
assert n >= 0
r = []
while length < 0 or len(r) < length:
r.append(chr(n & 0xff))
n >>= 8
if length < 0 and n == 0: break
r.reverse()
assert length < 0 or len(r) == length
return r
def rfc822_parse(message):
"""Parse a message in RFC822 format.
@@ -316,9 +290,8 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple
if debuglog is not None:
print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d)
modlen = len(int2str(pk['modulus']))
encoded = EMSA_PKCS1_v1_5_encode(d, modlen, HASHID_SHA256)
sig2 = perform_rsa(encoded, pk['privateExponent'], pk['modulus'], modlen)
sig2 = RSASSA_PKCS1_v1_5_sign(
d, HASHID_SHA256, pk['privateExponent'], pk['modulus'])
sig += base64.b64encode(''.join(sig2))
return sig + "\r\n"
@@ -430,9 +403,6 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
print >>debuglog, "invalid format in _domainkey txt record"
return False
pk = parse_public_key(base64.b64decode(pub['p']))
modlen = len(int2str(pk['modulus']))
if debuglog is not None:
print >>debuglog, "modlen:", modlen
include_headers = re.split(r"\s*:\s*", sig['h'])
h = hasher()
@@ -441,18 +411,9 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
d = h.digest()
if debuglog is not None:
print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d)
signature = base64.b64decode(re.sub(r"\s+", "", sig['b']))
try:
sig2 = EMSA_PKCS1_v1_5_encode(d, modlen, hashid)
return RSASSA_PKCS1_v1_5_verify(
d, hashid, signature, pk['publicExponent'], pk['modulus'])
except ParameterError:
return False
if debuglog is not None:
print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2)
print >>debuglog, sig['b']
print >>debuglog, re.sub(r"\s+", "", sig['b'])
signature = base64.b64decode(re.sub(r"\s+", "", sig['b']))
v = perform_rsa(signature, pk['publicExponent'], pk['modulus'], modlen)
if debuglog is not None:
print >>debuglog, "v:", " ".join("%02x" % ord(x) for x in v)
assert len(v) == len(sig2)
# Byte-by-byte compare of signatures
return not [1 for x in zip(v, sig2) if x[0] != x[1]]