diff --git a/dkim/__init__.py b/dkim/__init__.py index ab2347d..41829bb 100644 --- a/dkim/__init__.py +++ b/dkim/__init__.py @@ -21,6 +21,7 @@ import base64 import hashlib +import logging import re import time @@ -35,6 +36,7 @@ from dkim.crypto import ( UnparsableKeyError, ) from dkim.util import ( + get_default_logger, InvalidTagValueList, parse_tag_value, ) @@ -50,6 +52,7 @@ __all__ = [ "verify", ] + class Simple: """Class that represents the "simple" canonicalization algorithm.""" @@ -105,6 +108,10 @@ class ParameterError(DKIMException): """Input parameter error.""" pass +class ValidationError(DKIMException): + """Validation error.""" + pass + def _remove(s, t): i = s.find(t) assert i >= 0 @@ -134,60 +141,48 @@ def hash_headers(hasher, canonicalize_headers, headers, include_headers, hasher.update(x[1]) -def validate_signature_fields(sig, debuglog=None): +def validate_signature_fields(sig): """Validate DKIM-Signature fields. Basic checks for presence and correct formatting of mandatory fields. + Raises a ValidationError if checks fail, otherwise returns None. @param sig: A dict mapping field keys to values. - @param debuglog: A file-like object to which details will be written - on error. """ mandatory_fields = ('v', 'a', 'b', 'bh', 'd', 'h', 's') for field in mandatory_fields: if field not in sig: - if debuglog is not None: - print >>debuglog, "signature missing %s=" % field - return False + raise ValidationError("signature missing %s=" % field) if sig['v'] != "1": - if debuglog is not None: - print >>debuglog, "v= value is not 1 (%s)" % sig['v'] - return False + raise ValidationError("v= value is not 1 (%s)" % sig['v']) if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None: - if debuglog is not None: - print >>debuglog, "b= value is not valid base64 (%s)" % sig['b'] - return False + raise ValidationError("b= value is not valid base64 (%s)" % sig['b']) if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None: - if debuglog is not None: - print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh'] - return False - if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."): - if debuglog is not None: - print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d']) - return False + raise ValidationError( + "bh= value is not valid base64 (%s)" % sig['bh']) + if 'i' in sig and ( + not sig['i'].endswith(sig['d']) or + sig['i'][-len(sig['d'])-1] not in "@."): + raise ValidationError( + "i= domain is not a subdomain of d= (i=%s d=%d)" % + (sig['i'], sig['d'])) if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None: - if debuglog is not None: - print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l'] - return False + raise ValidationError( + "l= value is not a decimal integer (%s)" % sig['l']) if 'q' in sig and sig['q'] != "dns/txt": - if debuglog is not None: - print >>debuglog, "q= value is not dns/txt (%s)" % sig['q'] - return False + raise ValidationError("q= value is not dns/txt (%s)" % sig['q']) if 't' in sig and re.match(r"\d+$", sig['t']) is None: - if debuglog is not None: - print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t'] - return False + raise ValidationError( + "t= value is not a decimal integer (%s)" % sig['t']) if 'x' in sig: if re.match(r"\d+$", sig['x']) is None: - if debuglog is not None: - print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x'] - return False + raise ValidationError( + "x= value is not a decimal integer (%s)" % sig['x']) if int(sig['x']) < int(sig['t']): - if debuglog is not None: - print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t']) - return False - return True + raise ValidationError( + "x= value is less than t= value (x=%s t=%s)" % + (sig['x'], sig['t'])) def rfc822_parse(message): """Parse a message in RFC822 format. @@ -247,7 +242,10 @@ def fold(header): header = header[j:] return pre + header -def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple, Simple), include_headers=None, length=False, debuglog=None): + +def sign(message, selector, domain, privkey, identity=None, + canonicalize=(Simple, Simple), include_headers=None, length=False, + logger=None): """Sign an RFC822 message and return the DKIM-Signature header line. @param message: an RFC822 formatted message (with either \\n or \\r\\n line endings) @@ -258,9 +256,10 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple @param canonicalize: the canonicalization algorithms to use (default (Simple, Simple)) @param include_headers: a list of strings indicating which headers are to be signed (default all headers) @param length: true if the l= tag should be included to indicate body length (default False) - @param debuglog: a file-like object to which debug info will be written (default None) - + @param logger: a logger to which debug info will be written (default None) """ + if logger is None: + logger = get_default_logger() (headers, body) = rfc822_parse(message) @@ -310,8 +309,7 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple dkim_header = (dkim_header[0], dkim_header[1][:-2]) sign_headers.append(dkim_header) - if debuglog is not None: - print >>debuglog, "sign headers:", sign_headers + logger.debug("sign headers: %r" % sign_headers) h = hashlib.sha256() for x in sign_headers: h.update(x[0]) @@ -327,13 +325,16 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple return 'DKIM-Signature: ' + sig_value + "\r\n" -def verify(message, debuglog=None, dnsfunc=dnstxt): + +def verify(message, logger=None, dnsfunc=dnstxt): """Verify a DKIM signature on an RFC822 formatted message. @param message: an RFC822 formatted message (with either \\n or \\r\\n line endings) - @param debuglog: a file-like object to which debug info will be written (default None) + @param logger: a logger to which debug info will be written (default None) """ + if logger is None: + logger = get_default_logger() (headers, body) = rfc822_parse(message) @@ -346,16 +347,18 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): sig = parse_tag_value(sigheaders[0][1]) except InvalidTagValueList: return False - if debuglog is not None: - print >>debuglog, "sig:", sig + logger.debug("sig: %r" % sig) - if not validate_signature_fields(sig, debuglog): + try: + validate_signature_fields(sig) + except ValidationError, e: + logger.error("signature fields failed to validate: %s" % e) return False m = re.match("(\w+)(?:/(\w+))?$", sig['c']) if m is None: - if debuglog is not None: - print >>debuglog, "c= value is not in format method/method (%s)" % sig['c'] + logger.error( + "c= value is not in format method/method (%s)" % sig['c']) return False can_headers = m.group(1) if m.group(2) is not None: @@ -368,8 +371,7 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): elif can_headers == "relaxed": canonicalize_headers = Relaxed else: - if debuglog is not None: - print >>debuglog, "Unknown header canonicalization (%s)" % can_headers + logger.error("unknown header canonicalization (%s)" % can_headers) return False headers = canonicalize_headers.canonicalize_headers(headers) @@ -379,8 +381,7 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): elif can_body == "relaxed": body = Relaxed.canonicalize_body(body) else: - if debuglog is not None: - print >>debuglog, "Unknown body canonicalization (%s)" % can_body + logger.error("unknown body canonicalization (%s)" % can_body) return False if sig['a'] == "rsa-sha1": @@ -388,8 +389,7 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): elif sig['a'] == "rsa-sha256": hasher = hashlib.sha256 else: - if debuglog is not None: - print >>debuglog, "Unknown signature algorithm (%s)" % sig['a'] + logger.error("unknown signature algorithm (%s)" % sig['a']) return False if 'l' in sig: @@ -398,11 +398,11 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): h = hasher() h.update(body) bodyhash = h.digest() - if debuglog is not None: - print >>debuglog, "bh:", base64.b64encode(bodyhash) + logger.debug("bh: %s" % base64.b64encode(bodyhash)) if bodyhash != base64.b64decode(re.sub(r"\s+", "", sig['bh'])): - if debuglog is not None: - print >>debuglog, "body hash mismatch (got %s, expected %s)" % (base64.b64encode(bodyhash), sig['bh']) + logger.error( + "body hash mismatch (got %s, expected %s)" % + (base64.b64encode(bodyhash), sig['bh'])) return False s = dnsfunc(sig['s']+"._domainkey."+sig['d']+".") @@ -415,8 +415,7 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): try: pk = parse_public_key(base64.b64decode(pub['p'])) except UnparsableKeyError, e: - if debuglog is not None: - print >>debuglog, "could not parse public key: %s" % e + logger.error("could not parse public key: %s" % e) return False include_headers = re.split(r"\s*:\s*", sig['h']) @@ -428,6 +427,5 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): return RSASSA_PKCS1_v1_5_verify( h, signature, pk['publicExponent'], pk['modulus']) except DigestTooLargeError: - if debuglog is not None: - print >>debuglog, "digest too large for modulus" + logger.error("digest too large for modulus") return False diff --git a/dkim/util.py b/dkim/util.py index 8511ca2..0cdda93 100644 --- a/dkim/util.py +++ b/dkim/util.py @@ -16,8 +16,18 @@ # # Copyright (c) 2011 William Grant +import logging +try: + from logging import NullHandler +except ImportError: + class NullHandler(logging.Handler): + def emit(self, record): + pass + + __all__ = [ 'DuplicateTag', + 'get_default_logger', 'InvalidTagSpec', 'InvalidTagValueList', 'parse_tag_value', @@ -58,3 +68,11 @@ def parse_tag_value(tag_list): raise DuplicateTag(key.strip()) tags[key.strip()] = value.strip() return tags + + +def get_default_logger(): + """Get the default pydkim logger.""" + logger = logging.getLogger('pydkim') + if not logger.handlers: + logger.addHandler(NullHandler()) + return logger