Factor out a few functions. Not yet tested.
This commit is contained in:
+128
-120
@@ -94,6 +94,125 @@ def _remove(s, t):
|
|||||||
assert i >= 0
|
assert i >= 0
|
||||||
return s[:i] + s[i+len(t):]
|
return s[:i] + s[i+len(t):]
|
||||||
|
|
||||||
|
|
||||||
|
def EMSA_PKCS1_v1_5_encode(digest, modlen, hashid):
|
||||||
|
"""Encode a digest with EMSA-PKCS1-v1_5.
|
||||||
|
|
||||||
|
Defined in RFC3447 section 9.2.
|
||||||
|
|
||||||
|
@param digest: A digest value to encode.
|
||||||
|
@param modlen: The desired message length.
|
||||||
|
@param hashid: The ID of the hash used to generate the digest.
|
||||||
|
"""
|
||||||
|
dinfo = asn1_build(
|
||||||
|
(SEQUENCE, [
|
||||||
|
(SEQUENCE, [
|
||||||
|
(OBJECT_IDENTIFIER, hashid),
|
||||||
|
(NULL, None),
|
||||||
|
]),
|
||||||
|
(OCTET_STRING, digest),
|
||||||
|
]))
|
||||||
|
if len(dinfo)+3 > modlen:
|
||||||
|
raise ParameterError("Hash too large for modulus")
|
||||||
|
return "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo
|
||||||
|
|
||||||
|
|
||||||
|
def hash_headers(hasher, canonicalize_headers, headers, include_headers,
|
||||||
|
sigheaders, sig):
|
||||||
|
"""Sign message header fields."""
|
||||||
|
sign_headers = []
|
||||||
|
lastindex = {}
|
||||||
|
for h in include_headers:
|
||||||
|
i = lastindex.get(h, len(headers))
|
||||||
|
while i > 0:
|
||||||
|
i -= 1
|
||||||
|
if h.lower() == headers[i][0].lower():
|
||||||
|
sign_headers.append(headers[i])
|
||||||
|
break
|
||||||
|
lastindex[h] = i
|
||||||
|
# The call to _remove() assumes that the signature b= only appears
|
||||||
|
# once in the signature header
|
||||||
|
cheaders = canonicalize_headers.canonicalize_headers(
|
||||||
|
[(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])
|
||||||
|
sign_headers += [(x[0], x[1].rstrip()) for x in cheaders]
|
||||||
|
for x in sign_headers:
|
||||||
|
hasher.update(x[0])
|
||||||
|
hasher.update(":")
|
||||||
|
hasher.update(x[1])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_public_key(data):
|
||||||
|
"""Parse an RSA public key.
|
||||||
|
|
||||||
|
@param data: A DER-encoded X.509 subjectPublicKeyInfo
|
||||||
|
containing an RFC3447 RSAPublicKey.
|
||||||
|
"""
|
||||||
|
x = asn1_parse(ASN1_Object, data)
|
||||||
|
# Not sure why the [1:] is necessary to skip a byte.
|
||||||
|
pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
|
||||||
|
pk = {
|
||||||
|
'modulus': pkd[0][0],
|
||||||
|
'publicExponent': pkd[0][1],
|
||||||
|
}
|
||||||
|
return pk
|
||||||
|
|
||||||
|
|
||||||
|
def validate_signature_fields(sig, debuglog=None):
|
||||||
|
"""Validate DKIM-Signature fields.
|
||||||
|
|
||||||
|
Basic checks for presence and correct formatting of mandatory fields.
|
||||||
|
|
||||||
|
@param sig: A dict mapping field keys to values.
|
||||||
|
@param debuglog: A file-like object to which details will be written
|
||||||
|
on error.
|
||||||
|
"""
|
||||||
|
mandatory_fields = ('v', 'a', 'b', 'bh', 'd', 'h', 's')
|
||||||
|
for field in mandatory_fields:
|
||||||
|
if field not in sig:
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "signature missing %s=" % field
|
||||||
|
return False
|
||||||
|
|
||||||
|
if sig['v'] != "1":
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "v= value is not 1 (%s)" % sig['v']
|
||||||
|
return False
|
||||||
|
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None:
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "b= value is not valid base64 (%s)" % sig['b']
|
||||||
|
return False
|
||||||
|
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None:
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh']
|
||||||
|
return False
|
||||||
|
if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."):
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d'])
|
||||||
|
return False
|
||||||
|
if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None:
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l']
|
||||||
|
return False
|
||||||
|
if 'q' in sig and sig['q'] != "dns/txt":
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "q= value is not dns/txt (%s)" % sig['q']
|
||||||
|
return False
|
||||||
|
if 't' in sig and re.match(r"\d+$", sig['t']) is None:
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t']
|
||||||
|
return False
|
||||||
|
if 'x' in sig:
|
||||||
|
if re.match(r"\d+$", sig['x']) is None:
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x']
|
||||||
|
return False
|
||||||
|
if int(sig['x']) < int(sig['t']):
|
||||||
|
if debuglog is not None:
|
||||||
|
print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t'])
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
INTEGER = 0x02
|
INTEGER = 0x02
|
||||||
BIT_STRING = 0x03
|
BIT_STRING = 0x03
|
||||||
OCTET_STRING = 0x04
|
OCTET_STRING = 0x04
|
||||||
@@ -379,19 +498,9 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple
|
|||||||
if debuglog is not None:
|
if debuglog is not None:
|
||||||
print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d)
|
print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d)
|
||||||
|
|
||||||
dinfo = asn1_build(
|
|
||||||
(SEQUENCE, [
|
|
||||||
(SEQUENCE, [
|
|
||||||
(OBJECT_IDENTIFIER, HASHID_SHA256),
|
|
||||||
(NULL, None),
|
|
||||||
]),
|
|
||||||
(OCTET_STRING, d),
|
|
||||||
])
|
|
||||||
)
|
|
||||||
modlen = len(int2str(pk['modulus']))
|
modlen = len(int2str(pk['modulus']))
|
||||||
if len(dinfo)+3 > modlen:
|
encoded = EMSA_PKCS1_v1_5_encode(d, modlen, HASHID_SHA256)
|
||||||
raise ParameterError("Hash too large for modulus")
|
sig2 = int2str(pow(str2int(encoded), pk['privateExponent'], pk['modulus']), modlen)
|
||||||
sig2 = int2str(pow(str2int("\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo), pk['privateExponent'], pk['modulus']), modlen)
|
|
||||||
sig += base64.b64encode(''.join(sig2))
|
sig += base64.b64encode(''.join(sig2))
|
||||||
|
|
||||||
return sig + "\r\n"
|
return sig + "\r\n"
|
||||||
@@ -427,71 +536,8 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
|
|||||||
if debuglog is not None:
|
if debuglog is not None:
|
||||||
print >>debuglog, "sig:", sig
|
print >>debuglog, "sig:", sig
|
||||||
|
|
||||||
if 'v' not in sig:
|
if not validate_signature_fields(sig, debuglog):
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "signature missing v="
|
|
||||||
return False
|
return False
|
||||||
if sig['v'] != "1":
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "v= value is not 1 (%s)" % sig['v']
|
|
||||||
return False
|
|
||||||
if 'a' not in sig:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "signature missing a="
|
|
||||||
return False
|
|
||||||
if 'b' not in sig:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "signature missing b="
|
|
||||||
return False
|
|
||||||
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "b= value is not valid base64 (%s)" % sig['b']
|
|
||||||
return False
|
|
||||||
if 'bh' not in sig:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "signature missing bh="
|
|
||||||
return False
|
|
||||||
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh']
|
|
||||||
return False
|
|
||||||
if 'd' not in sig:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "signature missing d="
|
|
||||||
return False
|
|
||||||
if 'h' not in sig:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "signature missing h="
|
|
||||||
return False
|
|
||||||
if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."):
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d'])
|
|
||||||
return False
|
|
||||||
if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l']
|
|
||||||
return False
|
|
||||||
if 'q' in sig and sig['q'] != "dns/txt":
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "q= value is not dns/txt (%s)" % sig['q']
|
|
||||||
return False
|
|
||||||
if 's' not in sig:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "signature missing s="
|
|
||||||
return False
|
|
||||||
if 't' in sig and re.match(r"\d+$", sig['t']) is None:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t']
|
|
||||||
return False
|
|
||||||
if 'x' in sig:
|
|
||||||
if re.match(r"\d+$", sig['x']) is None:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x']
|
|
||||||
return False
|
|
||||||
if int(sig['x']) < int(sig['t']):
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t'])
|
|
||||||
return False
|
|
||||||
|
|
||||||
m = re.match("(\w+)(?:/(\w+))?$", sig['c'])
|
m = re.match("(\w+)(?:/(\w+))?$", sig['c'])
|
||||||
if m is None:
|
if m is None:
|
||||||
@@ -565,60 +611,22 @@ def verify(message, debuglog=None, dnsfunc=dnstxt):
|
|||||||
if debuglog is not None:
|
if debuglog is not None:
|
||||||
print >>debuglog, "invalid format in _domainkey txt record"
|
print >>debuglog, "invalid format in _domainkey txt record"
|
||||||
return False
|
return False
|
||||||
x = asn1_parse(ASN1_Object, base64.b64decode(pub['p']))
|
pk = parse_public_key(base64.b64decode(pub['p']))
|
||||||
# Not sure why the [1:] is necessary to skip a byte.
|
|
||||||
pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
|
|
||||||
pk = {
|
|
||||||
'modulus': pkd[0][0],
|
|
||||||
'publicExponent': pkd[0][1],
|
|
||||||
}
|
|
||||||
modlen = len(int2str(pk['modulus']))
|
modlen = len(int2str(pk['modulus']))
|
||||||
if debuglog is not None:
|
if debuglog is not None:
|
||||||
print >>debuglog, "modlen:", modlen
|
print >>debuglog, "modlen:", modlen
|
||||||
|
|
||||||
include_headers = re.split(r"\s*:\s*", sig['h'])
|
include_headers = re.split(r"\s*:\s*", sig['h'])
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "include_headers:", include_headers
|
|
||||||
sign_headers = []
|
|
||||||
lastindex = {}
|
|
||||||
for h in include_headers:
|
|
||||||
i = lastindex.get(h, len(headers))
|
|
||||||
while i > 0:
|
|
||||||
i -= 1
|
|
||||||
if h.lower() == headers[i][0].lower():
|
|
||||||
sign_headers.append(headers[i])
|
|
||||||
break
|
|
||||||
lastindex[h] = i
|
|
||||||
# The call to _remove() assumes that the signature b= only appears once in the signature header
|
|
||||||
sign_headers += [(x[0], x[1].rstrip()) for x in canonicalize_headers.canonicalize_headers([(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])]
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "verify headers:", sign_headers
|
|
||||||
|
|
||||||
h = hasher()
|
h = hasher()
|
||||||
for x in sign_headers:
|
hash_headers(
|
||||||
h.update(x[0])
|
h, canonicalize_headers, headers, include_headers, sigheaders, sig)
|
||||||
h.update(":")
|
|
||||||
h.update(x[1])
|
|
||||||
d = h.digest()
|
d = h.digest()
|
||||||
if debuglog is not None:
|
if debuglog is not None:
|
||||||
print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d)
|
print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d)
|
||||||
|
try:
|
||||||
dinfo = asn1_build(
|
sig2 = EMSA_PKCS1_v1_5_encode(d, modlen, hashid)
|
||||||
(SEQUENCE, [
|
except ParameterError:
|
||||||
(SEQUENCE, [
|
|
||||||
(OBJECT_IDENTIFIER, hashid),
|
|
||||||
(NULL, None),
|
|
||||||
]),
|
|
||||||
(OCTET_STRING, d),
|
|
||||||
])
|
|
||||||
)
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "dinfo:", " ".join("%02x" % ord(x) for x in dinfo)
|
|
||||||
if len(dinfo)+3 > modlen:
|
|
||||||
if debuglog is not None:
|
|
||||||
print >>debuglog, "Hash too large for modulus"
|
|
||||||
return False
|
return False
|
||||||
sig2 = "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo
|
|
||||||
if debuglog is not None:
|
if debuglog is not None:
|
||||||
print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2)
|
print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2)
|
||||||
print >>debuglog, sig['b']
|
print >>debuglog, sig['b']
|
||||||
|
|||||||
Reference in New Issue
Block a user