Provide a class API so that selector, domain can be recovered on verify.
This commit is contained in:
+116
-94
@@ -205,28 +205,22 @@ def fold(header):
|
|||||||
header = header[j:]
|
header = header[j:]
|
||||||
return pre + header
|
return pre + header
|
||||||
|
|
||||||
|
class DKIM(object):
|
||||||
|
|
||||||
def sign(message, selector, domain, privkey, identity=None,
|
def __init__(self,message,logger=None,signature_algorithm=b'rsa-sha256'):
|
||||||
canonicalize=(b'simple', b'simple'),
|
(self.headers, self.body) = rfc822_parse(message)
|
||||||
signature_algorithm=b'rsa-sha256',
|
self.domain = None
|
||||||
include_headers=None, length=False, logger=None):
|
self.selector = 'default'
|
||||||
"""Sign an RFC822 message and return the DKIM-Signature header line.
|
|
||||||
|
|
||||||
@param message: an RFC822 formatted message (with either \\n or \\r\\n line endings)
|
|
||||||
@param selector: the DKIM selector value for the signature
|
|
||||||
@param domain: the DKIM domain value for the signature
|
|
||||||
@param privkey: a PKCS#1 private key in base64-encoded text form
|
|
||||||
@param identity: the DKIM identity value for the signature (default "@"+domain)
|
|
||||||
@param canonicalize: the canonicalization algorithms to use (default (Simple, Simple))
|
|
||||||
@param include_headers: a list of strings indicating which headers are to be signed (default all headers)
|
|
||||||
@param length: true if the l= tag should be included to indicate body length (default False)
|
|
||||||
@param logger: a logger to which debug info will be written (default None)
|
|
||||||
"""
|
|
||||||
if logger is None:
|
if logger is None:
|
||||||
logger = get_default_logger()
|
logger = get_default_logger()
|
||||||
|
self.logger = logger
|
||||||
|
if signature_algorithm not in HASH_ALGORITHMS:
|
||||||
|
raise ParameterError(
|
||||||
|
"Unsupported signature algorithm: "+signature_algorithm)
|
||||||
|
self.signature_algorithm = signature_algorithm
|
||||||
|
|
||||||
(headers, body) = rfc822_parse(message)
|
def sign(self, selector, domain, privkey, identity=None,
|
||||||
|
canonicalize=(b'simple',b'simple'), include_headers=None, length=False):
|
||||||
try:
|
try:
|
||||||
pk = parse_pem_private_key(privkey)
|
pk = parse_pem_private_key(privkey)
|
||||||
except UnparsableKeyError as e:
|
except UnparsableKeyError as e:
|
||||||
@@ -237,7 +231,7 @@ def sign(message, selector, domain, privkey, identity=None,
|
|||||||
|
|
||||||
canon_policy = CanonicalizationPolicy.from_c_value(
|
canon_policy = CanonicalizationPolicy.from_c_value(
|
||||||
b'/'.join(canonicalize))
|
b'/'.join(canonicalize))
|
||||||
headers = canon_policy.canonicalize_headers(headers)
|
headers = canon_policy.canonicalize_headers(self.headers)
|
||||||
|
|
||||||
if include_headers is None:
|
if include_headers is None:
|
||||||
include_headers = [x[0].lower() for x in headers]
|
include_headers = [x[0].lower() for x in headers]
|
||||||
@@ -245,15 +239,16 @@ def sign(message, selector, domain, privkey, identity=None,
|
|||||||
include_headers = [x.lower() for x in include_headers]
|
include_headers = [x.lower() for x in include_headers]
|
||||||
sign_headers = [x for x in headers if x[0].lower() in include_headers]
|
sign_headers = [x for x in headers if x[0].lower() in include_headers]
|
||||||
|
|
||||||
body = canon_policy.canonicalize_body(body)
|
body = canon_policy.canonicalize_body(self.body)
|
||||||
|
|
||||||
h = hashlib.sha256()
|
hasher = HASH_ALGORITHMS[self.signature_algorithm]
|
||||||
|
h = hasher()
|
||||||
h.update(body)
|
h.update(body)
|
||||||
bodyhash = base64.b64encode(h.digest())
|
bodyhash = base64.b64encode(h.digest())
|
||||||
|
|
||||||
sigfields = [x for x in [
|
sigfields = [x for x in [
|
||||||
(b'v', b"1"),
|
(b'v', b"1"),
|
||||||
(b'a', signature_algorithm),
|
(b'a', self.signature_algorithm),
|
||||||
(b'c', canon_policy.to_c_value()),
|
(b'c', canon_policy.to_c_value()),
|
||||||
(b'd', domain),
|
(b'd', domain),
|
||||||
(b'i', identity or b"@"+domain),
|
(b'i', identity or b"@"+domain),
|
||||||
@@ -275,7 +270,7 @@ def sign(message, selector, domain, privkey, identity=None,
|
|||||||
dkim_header = (dkim_header[0], dkim_header[1][:-2])
|
dkim_header = (dkim_header[0], dkim_header[1][:-2])
|
||||||
sign_headers.append(dkim_header)
|
sign_headers.append(dkim_header)
|
||||||
|
|
||||||
logger.debug("sign headers: %r" % sign_headers)
|
self.logger.debug("sign headers: %r" % sign_headers)
|
||||||
h = hashlib.sha256()
|
h = hashlib.sha256()
|
||||||
for x in sign_headers:
|
for x in sign_headers:
|
||||||
h.update(x[0])
|
h.update(x[0])
|
||||||
@@ -289,9 +284,102 @@ def sign(message, selector, domain, privkey, identity=None,
|
|||||||
raise ParameterError("digest too large for modulus")
|
raise ParameterError("digest too large for modulus")
|
||||||
sig_value += base64.b64encode(bytes(sig2))
|
sig_value += base64.b64encode(bytes(sig2))
|
||||||
|
|
||||||
|
self.domain = domain
|
||||||
|
self.selector = selector
|
||||||
return b'DKIM-Signature: ' + sig_value + b"\r\n"
|
return b'DKIM-Signature: ' + sig_value + b"\r\n"
|
||||||
|
|
||||||
|
|
||||||
|
def verify(self,dnsfunc=get_txt):
|
||||||
|
|
||||||
|
sigheaders = [x for x in self.headers if x[0].lower() == b"dkim-signature"]
|
||||||
|
if len(sigheaders) < 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Currently, we only validate the first DKIM-Signature line found.
|
||||||
|
try:
|
||||||
|
sig = parse_tag_value(sigheaders[0][1])
|
||||||
|
except InvalidTagValueList,e:
|
||||||
|
raise MessageFormatError(e)
|
||||||
|
|
||||||
|
sig = parse_tag_value(sigheaders[0][1])
|
||||||
|
logger = self.logger
|
||||||
|
logger.debug("sig: %r" % sig)
|
||||||
|
|
||||||
|
validate_signature_fields(sig)
|
||||||
|
|
||||||
|
try:
|
||||||
|
canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'))
|
||||||
|
except InvalidCanonicalizationPolicyError as e:
|
||||||
|
raise MessageFormatError("invalid c= value: %s" % e.args[0])
|
||||||
|
headers = canon_policy.canonicalize_headers(self.headers)
|
||||||
|
body = canon_policy.canonicalize_body(self.body)
|
||||||
|
|
||||||
|
try:
|
||||||
|
hasher = HASH_ALGORITHMS[sig[b'a']]
|
||||||
|
except KeyError as e:
|
||||||
|
raise MessageFormatError("unknown signature algorithm: %s" % e.args[0])
|
||||||
|
|
||||||
|
self.domain = sig[b'd']
|
||||||
|
self.selector = sig[b's']
|
||||||
|
if b'l' in sig:
|
||||||
|
body = body[:int(sig[b'l'])]
|
||||||
|
|
||||||
|
h = hasher()
|
||||||
|
h.update(body)
|
||||||
|
bodyhash = h.digest()
|
||||||
|
logger.debug("bh: %s" % base64.b64encode(bodyhash))
|
||||||
|
try:
|
||||||
|
bh = base64.b64decode(re.sub(br"\s+", b"", sig[b'bh']))
|
||||||
|
except TypeError,e:
|
||||||
|
raise MessageFormatError(str(e))
|
||||||
|
if bodyhash != bh:
|
||||||
|
raise ValidationError(
|
||||||
|
"body hash mismatch (got %s, expected %s)" %
|
||||||
|
(base64.b64encode(bodyhash), sig[b'bh']))
|
||||||
|
|
||||||
|
name = sig[b's'] + b"._domainkey." + sig[b'd'] + b"."
|
||||||
|
s = dnsfunc(name)
|
||||||
|
if not s:
|
||||||
|
raise KeyFormatError("missing public key: %s"%name)
|
||||||
|
try:
|
||||||
|
pub = parse_tag_value(s)
|
||||||
|
except InvalidTagValueList:
|
||||||
|
raise KeyFormatError(e)
|
||||||
|
try:
|
||||||
|
pk = parse_public_key(base64.b64decode(pub[b'p']))
|
||||||
|
except (TypeError,UnparsableKeyError) as e:
|
||||||
|
raise KeyFormatError("could not parse public key: %s" % e)
|
||||||
|
|
||||||
|
include_headers = re.split(br"\s*:\s*", sig[b'h'])
|
||||||
|
h = hasher()
|
||||||
|
hash_headers(h, canon_policy, headers, include_headers, sigheaders, sig)
|
||||||
|
try:
|
||||||
|
signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b']))
|
||||||
|
return RSASSA_PKCS1_v1_5_verify(
|
||||||
|
h, signature, pk['publicExponent'], pk['modulus'])
|
||||||
|
except (TypeError,DigestTooLargeError) as e:
|
||||||
|
raise KeyFormatError("digest too large for modulus: %s"%e)
|
||||||
|
|
||||||
|
def sign(message, selector, domain, privkey, identity=None,
|
||||||
|
canonicalize=(b'simple', b'simple'),
|
||||||
|
signature_algorithm=b'rsa-sha256',
|
||||||
|
include_headers=None, length=False, logger=None):
|
||||||
|
"""Sign an RFC822 message and return the DKIM-Signature header line.
|
||||||
|
|
||||||
|
@param message: an RFC822 formatted message (with either \\n or \\r\\n line endings)
|
||||||
|
@param selector: the DKIM selector value for the signature
|
||||||
|
@param domain: the DKIM domain value for the signature
|
||||||
|
@param privkey: a PKCS#1 private key in base64-encoded text form
|
||||||
|
@param identity: the DKIM identity value for the signature (default "@"+domain)
|
||||||
|
@param canonicalize: the canonicalization algorithms to use (default (Simple, Simple))
|
||||||
|
@param include_headers: a list of strings indicating which headers are to be signed (default all headers)
|
||||||
|
@param length: true if the l= tag should be included to indicate body length (default False)
|
||||||
|
@param logger: a logger to which debug info will be written (default None)
|
||||||
|
"""
|
||||||
|
|
||||||
|
d = DKIM(message,logger=logger)
|
||||||
|
return d.sign(selector, domain, privkey, identity=identity, canonicalize=canonicalize, include_headers=include_headers, length=length)
|
||||||
|
|
||||||
def verify(message, logger=None, dnsfunc=get_txt):
|
def verify(message, logger=None, dnsfunc=get_txt):
|
||||||
"""Verify a DKIM signature on an RFC822 formatted message.
|
"""Verify a DKIM signature on an RFC822 formatted message.
|
||||||
|
|
||||||
@@ -299,76 +387,10 @@ def verify(message, logger=None, dnsfunc=get_txt):
|
|||||||
@param logger: a logger to which debug info will be written (default None)
|
@param logger: a logger to which debug info will be written (default None)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if logger is None:
|
d = DKIM(message,logger=logger)
|
||||||
logger = get_default_logger()
|
|
||||||
|
|
||||||
(headers, body) = rfc822_parse(message)
|
|
||||||
|
|
||||||
sigheaders = [x for x in headers if x[0].lower() == b"dkim-signature"]
|
|
||||||
if len(sigheaders) < 1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Currently, we only validate the first DKIM-Signature line found.
|
|
||||||
try:
|
try:
|
||||||
sig = parse_tag_value(sigheaders[0][1])
|
return d.verify(dnsfunc=dnsfunc)
|
||||||
except InvalidTagValueList:
|
except DKIMException,x:
|
||||||
return False
|
if logger is not None:
|
||||||
logger.debug("sig: %r" % sig)
|
logger.error("%s" % e)
|
||||||
|
|
||||||
try:
|
|
||||||
validate_signature_fields(sig)
|
|
||||||
except ValidationError as e:
|
|
||||||
logger.error("signature fields failed to validate: %s" % e)
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'))
|
|
||||||
except InvalidCanonicalizationPolicyError as e:
|
|
||||||
logger.error("invalid c= value: %s" % e.args[0])
|
|
||||||
return False
|
|
||||||
headers = canon_policy.canonicalize_headers(headers)
|
|
||||||
body = canon_policy.canonicalize_body(body)
|
|
||||||
|
|
||||||
try:
|
|
||||||
hasher = HASH_ALGORITHMS[sig[b'a']]
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error("unknown signature algorithm: %s" % e.args[0])
|
|
||||||
return False
|
|
||||||
|
|
||||||
if b'l' in sig:
|
|
||||||
body = body[:int(sig[b'l'])]
|
|
||||||
|
|
||||||
h = hasher()
|
|
||||||
h.update(body)
|
|
||||||
bodyhash = h.digest()
|
|
||||||
logger.debug("bh: %s" % base64.b64encode(bodyhash))
|
|
||||||
if bodyhash != base64.b64decode(re.sub(br"\s+", b"", sig[b'bh'])):
|
|
||||||
logger.error(
|
|
||||||
"body hash mismatch (got %s, expected %s)" %
|
|
||||||
(base64.b64encode(bodyhash), sig[b'bh']))
|
|
||||||
return False
|
|
||||||
|
|
||||||
name = sig[b's'] + b"._domainkey." + sig[b'd'] + b"."
|
|
||||||
s = dnsfunc(name)
|
|
||||||
if not s:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
pub = parse_tag_value(s)
|
|
||||||
except InvalidTagValueList:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
pk = parse_public_key(base64.b64decode(pub[b'p']))
|
|
||||||
except UnparsableKeyError as e:
|
|
||||||
logger.error("could not parse public key: %s" % e)
|
|
||||||
return False
|
|
||||||
|
|
||||||
include_headers = re.split(br"\s*:\s*", sig[b'h'])
|
|
||||||
h = hasher()
|
|
||||||
hash_headers(h, canon_policy, headers, include_headers, sigheaders, sig)
|
|
||||||
signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b']))
|
|
||||||
try:
|
|
||||||
return RSASSA_PKCS1_v1_5_verify(
|
|
||||||
h, signature, pk['publicExponent'], pk['modulus'])
|
|
||||||
except DigestTooLargeError:
|
|
||||||
logger.error("digest too large for modulus")
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
Reference in New Issue
Block a user