Factor out a few functions. Not yet tested.

This commit is contained in:
William Grant
2011-03-12 12:16:48 +11:00
+128 -120
View File
@@ -94,6 +94,125 @@ def _remove(s, t):
assert i >= 0 assert i >= 0
return s[:i] + s[i+len(t):] return s[:i] + s[i+len(t):]
def EMSA_PKCS1_v1_5_encode(digest, modlen, hashid):
"""Encode a digest with EMSA-PKCS1-v1_5.
Defined in RFC3447 section 9.2.
@param digest: A digest value to encode.
@param modlen: The desired message length.
@param hashid: The ID of the hash used to generate the digest.
"""
dinfo = asn1_build(
(SEQUENCE, [
(SEQUENCE, [
(OBJECT_IDENTIFIER, hashid),
(NULL, None),
]),
(OCTET_STRING, digest),
]))
if len(dinfo)+3 > modlen:
raise ParameterError("Hash too large for modulus")
return "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo
def hash_headers(hasher, canonicalize_headers, headers, include_headers,
sigheaders, sig):
"""Sign message header fields."""
sign_headers = []
lastindex = {}
for h in include_headers:
i = lastindex.get(h, len(headers))
while i > 0:
i -= 1
if h.lower() == headers[i][0].lower():
sign_headers.append(headers[i])
break
lastindex[h] = i
# The call to _remove() assumes that the signature b= only appears
# once in the signature header
cheaders = canonicalize_headers.canonicalize_headers(
[(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])
sign_headers += [(x[0], x[1].rstrip()) for x in cheaders]
for x in sign_headers:
hasher.update(x[0])
hasher.update(":")
hasher.update(x[1])
def parse_public_key(data):
"""Parse an RSA public key.
@param data: A DER-encoded X.509 subjectPublicKeyInfo
containing an RFC3447 RSAPublicKey.
"""
x = asn1_parse(ASN1_Object, data)
# Not sure why the [1:] is necessary to skip a byte.
pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
pk = {
'modulus': pkd[0][0],
'publicExponent': pkd[0][1],
}
return pk
def validate_signature_fields(sig, debuglog=None):
"""Validate DKIM-Signature fields.
Basic checks for presence and correct formatting of mandatory fields.
@param sig: A dict mapping field keys to values.
@param debuglog: A file-like object to which details will be written
on error.
"""
mandatory_fields = ('v', 'a', 'b', 'bh', 'd', 'h', 's')
for field in mandatory_fields:
if field not in sig:
if debuglog is not None:
print >>debuglog, "signature missing %s=" % field
return False
if sig['v'] != "1":
if debuglog is not None:
print >>debuglog, "v= value is not 1 (%s)" % sig['v']
return False
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None:
if debuglog is not None:
print >>debuglog, "b= value is not valid base64 (%s)" % sig['b']
return False
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None:
if debuglog is not None:
print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh']
return False
if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."):
if debuglog is not None:
print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d'])
return False
if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None:
if debuglog is not None:
print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l']
return False
if 'q' in sig and sig['q'] != "dns/txt":
if debuglog is not None:
print >>debuglog, "q= value is not dns/txt (%s)" % sig['q']
return False
if 't' in sig and re.match(r"\d+$", sig['t']) is None:
if debuglog is not None:
print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t']
return False
if 'x' in sig:
if re.match(r"\d+$", sig['x']) is None:
if debuglog is not None:
print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x']
return False
if int(sig['x']) < int(sig['t']):
if debuglog is not None:
print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t'])
return False
return True
INTEGER = 0x02 INTEGER = 0x02
BIT_STRING = 0x03 BIT_STRING = 0x03
OCTET_STRING = 0x04 OCTET_STRING = 0x04
@@ -379,19 +498,9 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple
if debuglog is not None: if debuglog is not None:
print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d) print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d)
dinfo = asn1_build(
(SEQUENCE, [
(SEQUENCE, [
(OBJECT_IDENTIFIER, HASHID_SHA256),
(NULL, None),
]),
(OCTET_STRING, d),
])
)
modlen = len(int2str(pk['modulus'])) modlen = len(int2str(pk['modulus']))
if len(dinfo)+3 > modlen: encoded = EMSA_PKCS1_v1_5_encode(d, modlen, HASHID_SHA256)
raise ParameterError("Hash too large for modulus") sig2 = int2str(pow(str2int(encoded), pk['privateExponent'], pk['modulus']), modlen)
sig2 = int2str(pow(str2int("\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo), pk['privateExponent'], pk['modulus']), modlen)
sig += base64.b64encode(''.join(sig2)) sig += base64.b64encode(''.join(sig2))
return sig + "\r\n" return sig + "\r\n"
@@ -427,71 +536,8 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
if debuglog is not None: if debuglog is not None:
print >>debuglog, "sig:", sig print >>debuglog, "sig:", sig
if 'v' not in sig: if not validate_signature_fields(sig, debuglog):
if debuglog is not None:
print >>debuglog, "signature missing v="
return False return False
if sig['v'] != "1":
if debuglog is not None:
print >>debuglog, "v= value is not 1 (%s)" % sig['v']
return False
if 'a' not in sig:
if debuglog is not None:
print >>debuglog, "signature missing a="
return False
if 'b' not in sig:
if debuglog is not None:
print >>debuglog, "signature missing b="
return False
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None:
if debuglog is not None:
print >>debuglog, "b= value is not valid base64 (%s)" % sig['b']
return False
if 'bh' not in sig:
if debuglog is not None:
print >>debuglog, "signature missing bh="
return False
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None:
if debuglog is not None:
print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh']
return False
if 'd' not in sig:
if debuglog is not None:
print >>debuglog, "signature missing d="
return False
if 'h' not in sig:
if debuglog is not None:
print >>debuglog, "signature missing h="
return False
if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."):
if debuglog is not None:
print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d'])
return False
if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None:
if debuglog is not None:
print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l']
return False
if 'q' in sig and sig['q'] != "dns/txt":
if debuglog is not None:
print >>debuglog, "q= value is not dns/txt (%s)" % sig['q']
return False
if 's' not in sig:
if debuglog is not None:
print >>debuglog, "signature missing s="
return False
if 't' in sig and re.match(r"\d+$", sig['t']) is None:
if debuglog is not None:
print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t']
return False
if 'x' in sig:
if re.match(r"\d+$", sig['x']) is None:
if debuglog is not None:
print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x']
return False
if int(sig['x']) < int(sig['t']):
if debuglog is not None:
print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t'])
return False
m = re.match("(\w+)(?:/(\w+))?$", sig['c']) m = re.match("(\w+)(?:/(\w+))?$", sig['c'])
if m is None: if m is None:
@@ -565,60 +611,22 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
if debuglog is not None: if debuglog is not None:
print >>debuglog, "invalid format in _domainkey txt record" print >>debuglog, "invalid format in _domainkey txt record"
return False return False
x = asn1_parse(ASN1_Object, base64.b64decode(pub['p'])) pk = parse_public_key(base64.b64decode(pub['p']))
# Not sure why the [1:] is necessary to skip a byte.
pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
pk = {
'modulus': pkd[0][0],
'publicExponent': pkd[0][1],
}
modlen = len(int2str(pk['modulus'])) modlen = len(int2str(pk['modulus']))
if debuglog is not None: if debuglog is not None:
print >>debuglog, "modlen:", modlen print >>debuglog, "modlen:", modlen
include_headers = re.split(r"\s*:\s*", sig['h']) include_headers = re.split(r"\s*:\s*", sig['h'])
if debuglog is not None:
print >>debuglog, "include_headers:", include_headers
sign_headers = []
lastindex = {}
for h in include_headers:
i = lastindex.get(h, len(headers))
while i > 0:
i -= 1
if h.lower() == headers[i][0].lower():
sign_headers.append(headers[i])
break
lastindex[h] = i
# The call to _remove() assumes that the signature b= only appears once in the signature header
sign_headers += [(x[0], x[1].rstrip()) for x in canonicalize_headers.canonicalize_headers([(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])]
if debuglog is not None:
print >>debuglog, "verify headers:", sign_headers
h = hasher() h = hasher()
for x in sign_headers: hash_headers(
h.update(x[0]) h, canonicalize_headers, headers, include_headers, sigheaders, sig)
h.update(":")
h.update(x[1])
d = h.digest() d = h.digest()
if debuglog is not None: if debuglog is not None:
print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d) print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d)
try:
dinfo = asn1_build( sig2 = EMSA_PKCS1_v1_5_encode(d, modlen, hashid)
(SEQUENCE, [ except ParameterError:
(SEQUENCE, [
(OBJECT_IDENTIFIER, hashid),
(NULL, None),
]),
(OCTET_STRING, d),
])
)
if debuglog is not None:
print >>debuglog, "dinfo:", " ".join("%02x" % ord(x) for x in dinfo)
if len(dinfo)+3 > modlen:
if debuglog is not None:
print >>debuglog, "Hash too large for modulus"
return False return False
sig2 = "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo
if debuglog is not None: if debuglog is not None:
print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2) print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2)
print >>debuglog, sig['b'] print >>debuglog, sig['b']