Factor out a few functions. Not yet tested.
This commit is contained in:
+128
-120
@@ -94,6 +94,125 @@ def _remove(s, t):
|
||||
assert i >= 0
|
||||
return s[:i] + s[i+len(t):]
|
||||
|
||||
|
||||
def EMSA_PKCS1_v1_5_encode(digest, modlen, hashid):
|
||||
"""Encode a digest with EMSA-PKCS1-v1_5.
|
||||
|
||||
Defined in RFC3447 section 9.2.
|
||||
|
||||
@param digest: A digest value to encode.
|
||||
@param modlen: The desired message length.
|
||||
@param hashid: The ID of the hash used to generate the digest.
|
||||
"""
|
||||
dinfo = asn1_build(
|
||||
(SEQUENCE, [
|
||||
(SEQUENCE, [
|
||||
(OBJECT_IDENTIFIER, hashid),
|
||||
(NULL, None),
|
||||
]),
|
||||
(OCTET_STRING, digest),
|
||||
]))
|
||||
if len(dinfo)+3 > modlen:
|
||||
raise ParameterError("Hash too large for modulus")
|
||||
return "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo
|
||||
|
||||
|
||||
def hash_headers(hasher, canonicalize_headers, headers, include_headers,
|
||||
sigheaders, sig):
|
||||
"""Sign message header fields."""
|
||||
sign_headers = []
|
||||
lastindex = {}
|
||||
for h in include_headers:
|
||||
i = lastindex.get(h, len(headers))
|
||||
while i > 0:
|
||||
i -= 1
|
||||
if h.lower() == headers[i][0].lower():
|
||||
sign_headers.append(headers[i])
|
||||
break
|
||||
lastindex[h] = i
|
||||
# The call to _remove() assumes that the signature b= only appears
|
||||
# once in the signature header
|
||||
cheaders = canonicalize_headers.canonicalize_headers(
|
||||
[(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])
|
||||
sign_headers += [(x[0], x[1].rstrip()) for x in cheaders]
|
||||
for x in sign_headers:
|
||||
hasher.update(x[0])
|
||||
hasher.update(":")
|
||||
hasher.update(x[1])
|
||||
|
||||
|
||||
def parse_public_key(data):
|
||||
"""Parse an RSA public key.
|
||||
|
||||
@param data: A DER-encoded X.509 subjectPublicKeyInfo
|
||||
containing an RFC3447 RSAPublicKey.
|
||||
"""
|
||||
x = asn1_parse(ASN1_Object, data)
|
||||
# Not sure why the [1:] is necessary to skip a byte.
|
||||
pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
|
||||
pk = {
|
||||
'modulus': pkd[0][0],
|
||||
'publicExponent': pkd[0][1],
|
||||
}
|
||||
return pk
|
||||
|
||||
|
||||
def validate_signature_fields(sig, debuglog=None):
|
||||
"""Validate DKIM-Signature fields.
|
||||
|
||||
Basic checks for presence and correct formatting of mandatory fields.
|
||||
|
||||
@param sig: A dict mapping field keys to values.
|
||||
@param debuglog: A file-like object to which details will be written
|
||||
on error.
|
||||
"""
|
||||
mandatory_fields = ('v', 'a', 'b', 'bh', 'd', 'h', 's')
|
||||
for field in mandatory_fields:
|
||||
if field not in sig:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "signature missing %s=" % field
|
||||
return False
|
||||
|
||||
if sig['v'] != "1":
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "v= value is not 1 (%s)" % sig['v']
|
||||
return False
|
||||
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "b= value is not valid base64 (%s)" % sig['b']
|
||||
return False
|
||||
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh']
|
||||
return False
|
||||
if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."):
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d'])
|
||||
return False
|
||||
if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l']
|
||||
return False
|
||||
if 'q' in sig and sig['q'] != "dns/txt":
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "q= value is not dns/txt (%s)" % sig['q']
|
||||
return False
|
||||
if 't' in sig and re.match(r"\d+$", sig['t']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t']
|
||||
return False
|
||||
if 'x' in sig:
|
||||
if re.match(r"\d+$", sig['x']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x']
|
||||
return False
|
||||
if int(sig['x']) < int(sig['t']):
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t'])
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
INTEGER = 0x02
|
||||
BIT_STRING = 0x03
|
||||
OCTET_STRING = 0x04
|
||||
@@ -379,19 +498,9 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d)
|
||||
|
||||
dinfo = asn1_build(
|
||||
(SEQUENCE, [
|
||||
(SEQUENCE, [
|
||||
(OBJECT_IDENTIFIER, HASHID_SHA256),
|
||||
(NULL, None),
|
||||
]),
|
||||
(OCTET_STRING, d),
|
||||
])
|
||||
)
|
||||
modlen = len(int2str(pk['modulus']))
|
||||
if len(dinfo)+3 > modlen:
|
||||
raise ParameterError("Hash too large for modulus")
|
||||
sig2 = int2str(pow(str2int("\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo), pk['privateExponent'], pk['modulus']), modlen)
|
||||
encoded = EMSA_PKCS1_v1_5_encode(d, modlen, HASHID_SHA256)
|
||||
sig2 = int2str(pow(str2int(encoded), pk['privateExponent'], pk['modulus']), modlen)
|
||||
sig += base64.b64encode(''.join(sig2))
|
||||
|
||||
return sig + "\r\n"
|
||||
@@ -427,70 +536,7 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "sig:", sig
|
||||
|
||||
if 'v' not in sig:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "signature missing v="
|
||||
return False
|
||||
if sig['v'] != "1":
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "v= value is not 1 (%s)" % sig['v']
|
||||
return False
|
||||
if 'a' not in sig:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "signature missing a="
|
||||
return False
|
||||
if 'b' not in sig:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "signature missing b="
|
||||
return False
|
||||
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "b= value is not valid base64 (%s)" % sig['b']
|
||||
return False
|
||||
if 'bh' not in sig:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "signature missing bh="
|
||||
return False
|
||||
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh']
|
||||
return False
|
||||
if 'd' not in sig:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "signature missing d="
|
||||
return False
|
||||
if 'h' not in sig:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "signature missing h="
|
||||
return False
|
||||
if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."):
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d'])
|
||||
return False
|
||||
if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l']
|
||||
return False
|
||||
if 'q' in sig and sig['q'] != "dns/txt":
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "q= value is not dns/txt (%s)" % sig['q']
|
||||
return False
|
||||
if 's' not in sig:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "signature missing s="
|
||||
return False
|
||||
if 't' in sig and re.match(r"\d+$", sig['t']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t']
|
||||
return False
|
||||
if 'x' in sig:
|
||||
if re.match(r"\d+$", sig['x']) is None:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x']
|
||||
return False
|
||||
if int(sig['x']) < int(sig['t']):
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t'])
|
||||
if not validate_signature_fields(sig, debuglog):
|
||||
return False
|
||||
|
||||
m = re.match("(\w+)(?:/(\w+))?$", sig['c'])
|
||||
@@ -565,60 +611,22 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "invalid format in _domainkey txt record"
|
||||
return False
|
||||
x = asn1_parse(ASN1_Object, base64.b64decode(pub['p']))
|
||||
# Not sure why the [1:] is necessary to skip a byte.
|
||||
pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
|
||||
pk = {
|
||||
'modulus': pkd[0][0],
|
||||
'publicExponent': pkd[0][1],
|
||||
}
|
||||
pk = parse_public_key(base64.b64decode(pub['p']))
|
||||
modlen = len(int2str(pk['modulus']))
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "modlen:", modlen
|
||||
|
||||
include_headers = re.split(r"\s*:\s*", sig['h'])
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "include_headers:", include_headers
|
||||
sign_headers = []
|
||||
lastindex = {}
|
||||
for h in include_headers:
|
||||
i = lastindex.get(h, len(headers))
|
||||
while i > 0:
|
||||
i -= 1
|
||||
if h.lower() == headers[i][0].lower():
|
||||
sign_headers.append(headers[i])
|
||||
break
|
||||
lastindex[h] = i
|
||||
# The call to _remove() assumes that the signature b= only appears once in the signature header
|
||||
sign_headers += [(x[0], x[1].rstrip()) for x in canonicalize_headers.canonicalize_headers([(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])]
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "verify headers:", sign_headers
|
||||
|
||||
h = hasher()
|
||||
for x in sign_headers:
|
||||
h.update(x[0])
|
||||
h.update(":")
|
||||
h.update(x[1])
|
||||
hash_headers(
|
||||
h, canonicalize_headers, headers, include_headers, sigheaders, sig)
|
||||
d = h.digest()
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d)
|
||||
|
||||
dinfo = asn1_build(
|
||||
(SEQUENCE, [
|
||||
(SEQUENCE, [
|
||||
(OBJECT_IDENTIFIER, hashid),
|
||||
(NULL, None),
|
||||
]),
|
||||
(OCTET_STRING, d),
|
||||
])
|
||||
)
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "dinfo:", " ".join("%02x" % ord(x) for x in dinfo)
|
||||
if len(dinfo)+3 > modlen:
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "Hash too large for modulus"
|
||||
try:
|
||||
sig2 = EMSA_PKCS1_v1_5_encode(d, modlen, hashid)
|
||||
except ParameterError:
|
||||
return False
|
||||
sig2 = "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo
|
||||
if debuglog is not None:
|
||||
print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2)
|
||||
print >>debuglog, sig['b']
|
||||
|
||||
Reference in New Issue
Block a user