From ac6d9a6bb394a65c09d81019860944a5bd90d333 Mon Sep 17 00:00:00 2001 From: Gene Shuman Date: Tue, 17 Jan 2017 13:20:20 -0800 Subject: [PATCH] refactoring/integrating ARC code --- dkim/__init__.py | 885 ++++++++++++++++------------------------- dkim/tests/test_arc.py | 2 +- dkim/util.py | 12 +- test.py | 3 + tests/__init__.py | 34 -- tests/test_arc.py | 148 ------- 6 files changed, 363 insertions(+), 721 deletions(-) delete mode 100644 tests/__init__.py delete mode 100644 tests/test_arc.py diff --git a/dkim/__init__.py b/dkim/__init__.py index 0d36477..3d07ace 100644 --- a/dkim/__init__.py +++ b/dkim/__init__.py @@ -26,6 +26,10 @@ # This has been modified from the original software. # Copyright (c) 2016 Scott Kitterman # +# This has been modified from the original software. +# Copyright (c) 2017 Valimail Inc +# Contact: Gene Shuman +# import base64 @@ -33,12 +37,13 @@ import hashlib import logging import re import time -import sys from dkim.canonicalization import ( CanonicalizationPolicy, InvalidCanonicalizationPolicyError, ) +from dkim.canonicalization import Relaxed as RelaxedCanonicalization + from dkim.crypto import ( DigestTooLargeError, HASH_ALGORITHMS, @@ -62,27 +67,52 @@ from dkim.util import ( __all__ = [ "DKIMException", "InternalError", - "CV_Pass", - "CV_Fail", - "CV_None", "KeyFormatError", "MessageFormatError", "ParameterError", + "ValidationError", + "CV_Pass", + "CV_Fail", + "CV_None", "Relaxed", "Simple", "DKIM", - "ValidationError", "ARC", "sign", "verify", + "dkim_sign", + "dkim_verify", + "arc_sign", + "arc_verify", ] Relaxed = b'relaxed' # for clients passing dkim.Relaxed Simple = b'simple' # for clients passing dkim.Simple + +# for ARC CV_Pass = b'pass' CV_Fail = b'fail' CV_None = b'none' +class HashThrough(object): + def __init__(self, hasher): + self.data = [] + self.hasher = hasher + self.name = hasher.name + + def update(self, data): + self.data.append(data) + return self.hasher.update(data) + + def digest(self): + return self.hasher.digest() + + def hexdigest(self): + return self.hasher.hexdigest() + + def hashed(self): + return b''.join(self.data) + def bitsize(x): """Return size of long in bits.""" return len(bin(x)) - 2 @@ -111,25 +141,6 @@ class ValidationError(DKIMException): """Validation error.""" pass -class HashThrough(object): - def __init__(self, hasher): - self.data = [] - self.hasher = hasher - self.name = hasher.name - - def update(self, data): - self.data.append(data) - return self.hasher.update(data) - - def digest(self): - return self.hasher.digest() - - def hexdigest(self): - return self.hasher.hexdigest() - - def hashed(self): - return b''.join(self.data) - def select_headers(headers, include_headers): """Select message header fields to be signed/verified. @@ -175,37 +186,40 @@ def hash_headers(hasher, canonicalize_headers, headers, include_headers, hasher.update(y) return sign_headers -def hashed(self): - return ''.join(self.data) - -def validate_signature_fields(sig,arc=False): - """Validate DKIM-Signature fields. +def validate_signature_fields(sig, mandatory_fields=[b'v', b'a', b'b', b'bh', b'd', b'h', b's'], arc=False): + """Validate DKIM or ARC Signature fields. Basic checks for presence and correct formatting of mandatory fields. Raises a ValidationError if checks fail, otherwise returns None. - @param sig: A dict mapping field keys to values. + @param mandatory_fields: A list of non-optional fields + @param arc: flag to differentiate between dkim & arc """ - if not arc: - mandatory_fields = (b'v', b'a', b'b', b'bh', b'd', b'h', b's') - sigtype = 'DKIM' - else: - mandatory_fields = (b'i', b'a', b'b', b'cv', b'd', b's', b't') - sigtype = 'ARC' for field in mandatory_fields: if field not in sig: - raise ValidationError("{0} signature missing {1}=".format(sigtype, field)) - if not arc: - if sig[b'v'] != b"1": - raise ValidationError("v= value is not 1 (%s)" % sig[b'v']) - if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'b']) is None: - raise ValidationError("b= value is not valid base64 (%s)" % sig[b'b']) - if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'bh']) is None: - raise ValidationError( - "bh= value is not valid base64 (%s)" % sig[b'bh']) + raise ValidationError("missing %s=" % field) + + if b'a' in sig and not sig[b'a'] in HASH_ALGORITHMS: + raise ValidationError("unknown signature algorithm: %s" % sig[b'a']) + + if b'b' in sig: + if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'b']) is None: + raise ValidationError("b= value is not valid base64 (%s)" % sig[b'b']) + if len(re.sub(br"\s+", b"", sig[b'b'])) % 4 != 0: + raise ValidationError("b= value is not valid base64 (%s)" % sig[b'b']) + + if b'bh' in sig: + if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'bh']) is None: + raise ValidationError("bh= value is not valid base64 (%s)" % sig[b'bh']) + if len(re.sub(br"\s+", b"", sig[b'bh'])) % 4 != 0: + raise ValidationError("bh= value is not valid base64 (%s)" % sig[b'bh']) + + if b'cv' in sig and sig[b'cv'] not in (CV_Pass, CV_Fail, CV_None): + raise ValidationError("cv= value is not valid (%s)" % sig[b'cv']) + # Nasty hack to support both str and bytes... check for both the # character and integer values. - if b'i' in sig and ( + if not arc and b'i' in sig and ( not sig[b'i'].lower().endswith(sig[b'd'].lower()) or sig[b'i'][-len(sig[b'd'])-1] not in ('@', '.', 64, 46)): raise ValidationError( @@ -216,91 +230,33 @@ def validate_signature_fields(sig,arc=False): "l= value is not a decimal integer (%s)" % sig[b'l']) if b'q' in sig and sig[b'q'] != b"dns/txt": raise ValidationError("q= value is not dns/txt (%s)" % sig[b'q']) - now = int(time.time()) - slop = 36000 # 10H leeway for mailers with inaccurate clocks - t_sign = 0 - if b't' in sig: - if re.match(br"\d+$", sig[b't']) is None: - raise ValidationError( - "t= value is not a decimal integer (%s)" % sig[b't']) - t_sign = int(sig[b't']) - if t_sign > now + slop: - raise ValidationError( - "t= value is in the future (%s)" % sig[b't']) - if b'x' in sig: - if re.match(br"\d+$", sig[b'x']) is None: - raise ValidationError( - "x= value is not a decimal integer (%s)" % sig[b'x']) - x_sign = int(sig[b'x']) - if x_sign < now - slop: - raise ValidationError( - "x= value is past (%s)" % sig[b'x']) - if x_sign < t_sign: - raise ValidationError( - "x= value is less than t= value (x=%s t=%s)" % - (sig[b'x'], sig[b't'])) - if arc: - if sig[b'cv'] not in (CV_Pass, CV_Fail, CV_None): - raise ValidationError("cv= value is not valid (%s)" % sig[b'cv']) -def validate_arc_signature_fields(sig): - """Validate ARC-Message-Signature fields. - - Basic checks for presence and correct formatting of mandatory fields. - Raises a ValidationError if checks fail, otherwise returns None. - - @param sig: A dict mapping field keys to values. - """ - mandatory_fields = (b'i', b'a', b'b', b'bh', b'd', b'h', b's') - for field in mandatory_fields: - if field not in sig: - raise ValidationError("arc-message-signature missing %s=" % field) - - if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'b']) is None: - raise ValidationError("b= value is not valid base64 (%s)" % sig[b'b']) - if re.match(br"[\s0-9A-Za-z/+]+=*$", sig[b'bh']) is None: - raise ValidationError( - "bh= value is not valid base64 (%s)" % sig[b'bh']) - now = int(time.time()) - slop = 36000 # 10H leeway for mailers with inaccurate clocks - t_sign = 0 - if b't' in sig: - if re.match(br"\d+$", sig[b't']) is None: - raise ValidationError( - "t= value is not a decimal integer (%s)" % sig[b't']) - t_sign = int(sig[b't']) - if t_sign > now + slop: - raise ValidationError( - "t= value is in the future (%s)" % sig[b't']) - -def validate_arc_seal_fields(sig): - """Validate ARC-Seal fields. - - Basic checks for presence and correct formatting of mandatory fields. - Raises a ValidationError if checks fail, otherwise returns None. - - @param sig: A dict mapping field keys to values. - """ - mandatory_fields = (b'i', b'a', b'b', b'cv', b'd', b's', b't') - for field in mandatory_fields: - if field not in sig: - raise ValidationError("arc-seal missing %s=" % field) - - if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'b']) is None: - raise ValidationError("b= value is not valid base64 (%s)" % sig[b'b']) - if sig[b'cv'] not in (CV_Pass, CV_Fail, CV_None): - raise ValidationError("cv= value is not valid (%s)" % sig[b'cv']) - now = int(time.time()) - slop = 36000 # 10H leeway for mailers with inaccurate clocks - t_sign = 0 if b't' in sig: if re.match(br"\d+$", sig[b't']) is None: raise ValidationError( "t= value is not a decimal integer (%s)" % sig[b't']) + now = int(time.time()) + slop = 36000 # 10H leeway for mailers with inaccurate clocks t_sign = int(sig[b't']) if t_sign > now + slop: raise ValidationError("t= value is in the future (%s)" % sig[b't']) + if b'v' in sig and sig[b'v'] != b"1": + raise ValidationError("v= value is not 1 (%s)" % sig[b'v']) + + if b'x' in sig: + if re.match(br"\d+$", sig[b'x']) is None: + raise ValidationError( + "x= value is not a decimal integer (%s)" % sig[b'x']) + x_sign = int(sig[b'x']) + if x_sign < now - slop: + raise ValidationError( + "x= value is past (%s)" % sig[b'x']) + if x_sign < t_sign: + raise ValidationError( + "x= value is less than t= value (x=%s t=%s)" % + (sig[b'x'], sig[b't'])) + def rfc822_parse(message): """Parse a message in RFC822 format. @@ -372,24 +328,59 @@ def fold(header): header = header[j:] return pre + header -#: Hold messages and options during DKIM signing and verification. -class DKIM(object): - # NOTE - the first 2 indentation levels are 2 instead of 4 +def load_pk_from_dns(name, dnsfunc=get_txt): + s = dnsfunc(name) + if not s: + raise KeyFormatError("missing public key: %s"%name) + try: + if type(s) is str: + s = s.encode('ascii') + pub = parse_tag_value(s) + except InvalidTagValueList as e: + raise KeyFormatError(e) + try: + pk = parse_public_key(base64.b64decode(pub[b'p'])) + keysize = bitsize(pk['modulus']) + except KeyError: + raise KeyFormatError("incomplete public key: %s" % s) + except (TypeError,UnparsableKeyError) as e: + raise KeyFormatError("could not parse public key (%s): %s" % (pub[b'p'],e)) + return pk, keysize + +#: Abstract base class for holding messages and options during DKIM/ARC signing and verification. +class DomainSigner(object): + # NOTE - the first 2 indentation levels are 2 instead of 4 # to minimize changed lines from the function only version. - #: The U{RFC5322} - #: complete list of singleton headers (which should - #: appear at most once). This can be used for a "paranoid" or - #: "strict" signing mode. - #: Bcc in this list is in the SHOULD NOT sign list, the rest could - #: be in the default FROZEN list, but that could also make signatures - #: more fragile than necessary. - #: @since: 0.5 - RFC5322_SINGLETON = (b'date',b'from',b'sender',b'reply-to',b'to',b'cc',b'bcc', - b'message-id',b'in-reply-to',b'references') + #: @param message: an RFC822 formatted message to be signed or verified + #: (with either \\n or \\r\\n line endings) + #: @param logger: a logger to which debug info will be written (default None) + #: @param signature_algorithm: the signing algorithm to use when signing + def __init__(self,message=None,logger=None,signature_algorithm=b'rsa-sha256', + minkey=1024): + self.set_message(message) + if logger is None: + logger = get_default_logger() + self.logger = logger + if signature_algorithm not in HASH_ALGORITHMS: + raise ParameterError( + "Unsupported signature algorithm: "+signature_algorithm) + self.signature_algorithm = signature_algorithm + #: Header fields which should be signed. Default from RFC4871 + self.should_sign = set(DKIM.SHOULD) + #: Header fields which should not be signed. The default is from RFC4871. + #: Attempting to sign these headers results in an exception. + #: If it is necessary to sign one of these, it must be removed + #: from this list first. + self.should_not_sign = set(DKIM.SHOULD_NOT) + #: Header fields to sign an extra time to prevent additions. + self.frozen_sign = set(DKIM.FROZEN) + #: Minimum public key size. Shorter keys raise KeyFormatError. The + #: default is 1024 + self.minkey = minkey #: Header fields to protect from additions by default. - #: + #: #: The short list below is the result more of instinct than logic. #: @since: 0.5 FROZEN = (b'from',b'date',b'subject') @@ -412,44 +403,17 @@ class DKIM(object): b'dkim-signature' ) - #: Create a DKIM instance to sign and verify rfc5322 messages. - #: - #: @param message: an RFC822 formatted message to be signed or verified - #: (with either \\n or \\r\\n line endings) - #: @param logger: a logger to which debug info will be written (default None) - #: @param signature_algorithm: the signing algorithm to use when signing - #: @param minkey: the minimum key size to accept - - #: Header fields used by ARC - ARC_HEADERS = (b'arc-seal', b'arc-message-signature', b'arc-authentication-results') - - #: Regex to extract i= value from ARC headers - INSTANCE_RE = re.compile(r'[\s;]?i\s*=\s*(\d+)', re.MULTILINE | re.IGNORECASE) - - def __init__(self,message=None,logger=None,signature_algorithm=b'rsa-sha256', - minkey=1024,arc=False): - self.set_message(message) - if logger is None: - logger = get_default_logger() - self.logger = logger - if signature_algorithm not in HASH_ALGORITHMS: - raise ParameterError( - "Unsupported signature algorithm: "+signature_algorithm) - self.signature_algorithm = signature_algorithm - #: Header fields which should be signed. Default from RFC4871 - self.should_sign = set(DKIM.SHOULD) - #: Header fields which should not be signed. The default is from RFC4871. - #: Attempting to sign these headers results in an exception. - #: If it is necessary to sign one of these, it must be removed - #: from this list first. - self.should_not_sign = set(DKIM.SHOULD_NOT) - #: Header fields to sign an extra time to prevent additions. - self.frozen_sign = set(DKIM.FROZEN) - #: Minimum public key size. Shorter keys raise KeyFormatError. The - #: default is 1024 - self.minkey = minkey - # Is this an ARC signature vice regular DKIM - self.arc = arc + # Doesn't seem to be used (GS) + #: The U{RFC5322} + #: complete list of singleton headers (which should + #: appear at most once). This can be used for a "paranoid" or + #: "strict" signing mode. + #: Bcc in this list is in the SHOULD NOT sign list, the rest could + #: be in the default FROZEN list, but that could also make signatures + #: more fragile than necessary. + #: @since: 0.5 + RFC5322_SINGLETON = (b'date',b'from',b'sender',b'reply-to',b'to',b'cc',b'bcc', + b'message-id',b'in-reply-to',b'references') def add_frozen(self,s): """ Add headers not in should_not_sign to frozen_sign. @@ -504,6 +468,118 @@ class DKIM(object): @since: 0.5""" return [x for x,y in self.headers if x.lower() not in self.should_not_sign] + + # Abstract helper method to generate a tag=value header from a list of fields + #: @param fields: A list of key value tuples to be included in the header + #: @param include_headers: A list message headers to include in the b= signature computation + #: @param canon_policy: A canonicialization policy for b= & bh= + #: @param header_name: The name of the generated header + #: @param pk: The private key used for signature generation + #: @param standardize: Flag to enable 'standard' header syntax + def gen_header(self, fields, include_headers, canon_policy, header_name, pk, standardize=False): + if standardize: + lower = [(x,y.lower().replace(b' ', b'')) for (x,y) in fields if x != b'bh'] + reg = [(x,y.replace(b' ', b'')) for (x,y) in fields if x == b'bh'] + fields = lower + reg + fields = sorted(fields, key=(lambda x: x[0])) + + header_value = b"; ".join(b"=".join(x) for x in fields) + if not standardize: + header_value = fold(header_value) + header_value = RE_BTAG.sub(b'\\1',header_value) + header = (header_name, b' ' + header_value) + h = HashThrough(self.hasher()) + sig = dict(fields) + + headers = canon_policy.canonicalize_headers(self.headers) + self.signed_headers = hash_headers( + h, canon_policy, headers, include_headers, header, sig) + self.logger.debug("sign %s headers: %r" % (header_name, h.hashed())) + + try: + sig2 = RSASSA_PKCS1_v1_5_sign(h, pk) + except DigestTooLargeError: + raise ParameterError("digest too large for modulus") + # Folding b= is explicity allowed, but yahoo and live.com are broken + #header_value += base64.b64encode(bytes(sig2)) + # Instead of leaving unfolded (which lets an MTA fold it later and still + # breaks yahoo and live.com), we change the default signing mode to + # relaxed/simple (for broken receivers), and fold now. + idx = [i for i in range(len(fields)) if fields[i][0] == b'b'][0] + fields[idx] = (b'b', base64.b64encode(bytes(sig2))) + header_value = b"; ".join(b"=".join(x) for x in fields) + b"\r\n" + + if not standardize: + header_value = fold(header_value) + + return header_value + + # Abstract helper method to verify a signed header + #: @param sig: List of (key, value) tuples containing tag=values of the header + #: @param include_headers: headers to validate b= signature against + #: @param sig_header: (header_name, header_value) + #: @param dnsfunc: interface to dns + def verify_sig(self, sig, include_headers, sig_header, dnsfunc): + name = sig[b's'] + b"._domainkey." + sig[b'd'] + b"." + try: + pk, self.keysize = load_pk_from_dns(name, dnsfunc) + except KeyFormatError as e: + self.logger.error("%s" % e) + return False + + try: + canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c', b'relaxed/relaxed')) + except InvalidCanonicalizationPolicyError as e: + raise MessageFormatError("invalid c= value: %s" % e.args[0]) + + hasher = HASH_ALGORITHMS[sig[b'a']] + + # validate body if present + if b'bh' in sig: + h = HashThrough(hasher()) + + body = canon_policy.canonicalize_body(self.body) + if b'l' in sig: + body = body[:int(sig[b'l'])] + h.update(body) + self.logger.debug("body hashed: %r" % h.hashed()) + bodyhash = h.digest() + + self.logger.debug("bh: %s" % base64.b64encode(bodyhash)) + try: + bh = base64.b64decode(re.sub(br"\s+", b"", sig[b'bh'])) + except TypeError as e: + raise MessageFormatError(str(e)) + if bodyhash != bh: + raise ValidationError( + "body hash mismatch (got %s, expected %s)" % + (base64.b64encode(bodyhash), sig[b'bh'])) + + # address bug#644046 by including any additional From header + # fields when verifying. Since there should be only one From header, + # this shouldn't break any legitimate messages. This could be + # generalized to check for extras of other singleton headers. + if b'from' in include_headers: + include_headers.append(b'from') + h = HashThrough(hasher()) + + headers = canon_policy.canonicalize_headers(self.headers) + self.signed_headers = hash_headers( + h, canon_policy, headers, include_headers, sig_header, sig) + self.logger.debug("signed for %s: %r" % (sig_header[0], h.hashed())) + + try: + signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b'])) + res = RSASSA_PKCS1_v1_5_verify(h, signature, pk) + self.logger.debug("%s valid: %s" % (sig_header[0], res)) + if res and self.keysize < self.minkey: + raise KeyFormatError("public key too small: %d" % self.keysize) + return res + except (TypeError,DigestTooLargeError) as e: + raise KeyFormatError("digest too large for modulus: %s"%e) + +#: Hold messages and options during DKIM signing and verification. +class DKIM(DomainSigner): #: Sign an RFC822 message and return the DKIM-Signature header line. #: #: The include_headers option gives full control over which header fields @@ -519,8 +595,8 @@ class DKIM(object): #: without breaking the signature. #: #: The default include_headers for this method differs from the backward - #: compatible sign function, which signs all headers not - #: in should_not_sign. The default list for this method can be modified + #: compatible sign function, which signs all headers not + #: in should_not_sign. The default list for this method can be modified #: by tweaking should_sign and frozen_sign (or even should_not_sign). #: It is only necessary to pass an include_headers list when precise control #: is needed. @@ -532,8 +608,6 @@ class DKIM(object): #: (default "@"+domain) #: @param canonicalize: the canonicalization algorithms to use #: (default (Simple, Simple)) - #: @param auth_results: RFC 7601 Authentication-Results header value for the message - #: @param chain_validation_status (ARC only): CV_Pass, CV_Fail, CV_None #: @param include_headers: a list of strings indicating which headers #: are to be signed (default rfc4871 recommended headers) #: @param length: true if the l= tag should be included to indicate @@ -542,8 +616,7 @@ class DKIM(object): #: @raise DKIMException: when the message, include_headers, or key are badly #: formed. def sign(self, selector, domain, privkey, identity=None, - canonicalize=(b'relaxed',b'simple'), include_headers=None, length=False, - auth_results=None, chain_validation_status=None): + canonicalize=(b'relaxed',b'simple'), include_headers=None, length=False): try: pk = parse_pem_private_key(privkey) except UnparsableKeyError as e: @@ -552,30 +625,28 @@ class DKIM(object): if identity is not None and not identity.endswith(domain): raise ParameterError("identity must end with domain") - if self.arc: - canon_policy = CanonicalizationPolicy.from_c_value(b'relaxed/relaxed') - else: - canon_policy = CanonicalizationPolicy.from_c_value( - b'/'.join(canonicalize)) - headers = canon_policy.canonicalize_headers(self.headers) + canon_policy = CanonicalizationPolicy.from_c_value(b'/'.join(canonicalize)) if include_headers is None: include_headers = self.default_sign_headers() + include_headers = tuple([x.lower() for x in include_headers]) + # record what verify should extract + self.include_headers = include_headers + # rfc4871 says FROM is required - if b'from' not in ( x.lower() for x in include_headers ): + if b'from' not in include_headers: raise ParameterError("The From header field MUST be signed") - # raise exception for any SHOULD_NOT headers, call can modify + # raise exception for any SHOULD_NOT headers, call can modify # SHOULD_NOT if really needed. - for x in include_headers: - if x.lower() in self.should_not_sign: - raise ParameterError("The %s header field SHOULD NOT be signed"%x) + for x in set(include_headers).intersection(self.should_not_sign): + raise ParameterError("The %s header field SHOULD NOT be signed"%x) body = canon_policy.canonicalize_body(self.body) - hasher = HASH_ALGORITHMS[self.signature_algorithm] - h = hasher() + self.hasher = HASH_ALGORITHMS[self.signature_algorithm] + h = self.hasher() h.update(body) bodyhash = base64.b64encode(h.digest()) @@ -593,36 +664,16 @@ class DKIM(object): (b'bh', bodyhash), # Force b= to fold onto it's own line so that refolding after # adding sig doesn't change whitespace for previous tags. - (b'b', b'0'*60), + (b'b', b'0'*60), ] if x] - include_headers = [x.lower() for x in include_headers] - # record what verify should extract - self.include_headers = tuple(include_headers) - sig_value = fold(b"; ".join(b"=".join(x) for x in sigfields)) - sig_value = RE_BTAG.sub(b'\\1',sig_value) - dkim_header = (b'DKIM-Signature', b' ' + sig_value) - h = hasher() - sig = dict(sigfields) - self.signed_headers = hash_headers( - h, canon_policy, headers, include_headers, dkim_header,sig) - self.logger.debug("sign headers: %r" % self.signed_headers) - - try: - sig2 = RSASSA_PKCS1_v1_5_sign(h, pk) - except DigestTooLargeError: - raise ParameterError("digest too large for modulus") - # Folding b= is explicity allowed, but yahoo and live.com are broken - #sig_value += base64.b64encode(bytes(sig2)) - # Instead of leaving unfolded (which lets an MTA fold it later and still - # breaks yahoo and live.com), we change the default signing mode to - # relaxed/simple (for broken receivers), and fold now. - sig_value = fold(sig_value + base64.b64encode(bytes(sig2))) + res = self.gen_header(sigfields, include_headers, canon_policy, + b"DKIM-Signature", pk) self.domain = domain self.selector = selector - self.signature_fields = sig - return b'DKIM-Signature: ' + sig_value + b"\r\n" + self.signature_fields = dict(sigfields) + return b'DKIM-Signature: ' + res #: Verify a DKIM signature. #: @type idx: int @@ -633,7 +684,6 @@ class DKIM(object): #: @return: True if signature verifies or False otherwise #: @raise DKIMException: when the message, signature, or key are badly formed def verify(self,idx=0,dnsfunc=get_txt): - sigheaders = [(x,y) for x,y in self.headers if x.lower() == b"dkim-signature"] if len(sigheaders) <= idx: return False @@ -645,155 +695,32 @@ class DKIM(object): except InvalidTagValueList as e: raise MessageFormatError(e) - logger = self.logger - logger.debug("sig: %r" % sig) + self.logger.debug("sig: %r" % sig) - validate_signature_fields(sig, self.arc) + validate_signature_fields(sig) self.domain = sig[b'd'] self.selector = sig[b's'] - try: - if self.arc: - canon_policy = CanonicalizationPolicy.from_c_value(b'relaxed/relaxed') - else: - canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c')) - except InvalidCanonicalizationPolicyError as e: - raise MessageFormatError("invalid c= value: %s" % e.args[0]) - headers = canon_policy.canonicalize_headers(self.headers) - body = canon_policy.canonicalize_body(self.body) - - try: - hasher = HASH_ALGORITHMS[sig[b'a']] - except KeyError as e: - raise MessageFormatError("unknown signature algorithm: %s" % e.args[0]) - - if b'l' in sig: - body = body[:int(sig[b'l'])] - - h = hasher() - h.update(body) - bodyhash = h.digest() - logger.debug("bh: %s" % base64.b64encode(bodyhash)) - try: - bh = base64.b64decode(re.sub(br"\s+", b"", sig[b'bh'])) - except TypeError as e: - raise MessageFormatError(str(e)) - if bodyhash != bh: - raise ValidationError( - "body hash mismatch (got %s, expected %s)" % - (base64.b64encode(bodyhash), sig[b'bh'])) - - name = sig[b's'] + b"._domainkey." + sig[b'd'] + b"." - s = dnsfunc(name) - if not s: - raise KeyFormatError("missing public key: %s"%name) - try: - if type(s) is str: - s = s.encode('ascii') - pub = parse_tag_value(s) - except InvalidTagValueList as e: - raise KeyFormatError(e) - try: - pk = parse_public_key(base64.b64decode(pub[b'p'])) - self.keysize = bitsize(pk['modulus']) - except KeyError: - raise KeyFormatError("incomplete public key: %s" % s) - except (TypeError,UnparsableKeyError) as e: - raise KeyFormatError("could not parse public key (%s): %s" % (pub[b'p'],e)) include_headers = [x.lower() for x in re.split(br"\s*:\s*", sig[b'h'])] self.include_headers = tuple(include_headers) - # address bug#644046 by including any additional From header - # fields when verifying. Since there should be only one From header, - # this shouldn't break any legitimate messages. This could be - # generalized to check for extras of other singleton headers. - if b'from' in include_headers: - include_headers.append(b'from') - h = hasher() - self.signed_headers = hash_headers( - h, canon_policy, headers, include_headers, sigheaders[idx], sig) - try: - signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b'])) - res = RSASSA_PKCS1_v1_5_verify(h, signature, pk) - if res and self.keysize < self.minkey: - raise KeyFormatError("public key too small: %d" % self.keysize) - return res - except (TypeError,DigestTooLargeError) as e: - raise KeyFormatError("digest too large for modulus: %s"%e) + + return self.verify_sig(sig, include_headers, sigheaders[idx], dnsfunc) #: Hold messages and options during ARC signing and verification. -class ARC(DKIM): +class ARC(DomainSigner): #: Header fields used by ARC ARC_HEADERS = (b'arc-seal', b'arc-message-signature', b'arc-authentication-results') #: Regex to extract i= value from ARC headers - INSTANCE_RE = re.compile(r'[\s;]?i\s*=\s*(\d+)', re.MULTILINE | re.IGNORECASE) - - #: Create an ARC instance to sign and verify rfc5322 messages. - #: - #: @param message: an RFC822 formatted message to be signed or verified - #: (with either \\n or \\r\\n line endings) - #: @param logger: a logger to which debug info will be written (default None) - #: @param signature_algorithm: the signing algorithm to use when signing - #: @param minkey: the minimum key size to accept - def __init__(self,message=None,logger=None,signature_algorithm=b'rsa-sha256', - minkey=1024): - self.set_message(message) - if logger is None: - logger = get_default_logger() - self.logger = logger - if signature_algorithm not in HASH_ALGORITHMS: - raise ParameterError( - "Unsupported signature algorithm: "+signature_algorithm) - self.signature_algorithm = signature_algorithm - #: Header fields which should be signed. Default from RFC4871 - self.should_sign = set(DKIM.SHOULD) - #: Header fields which should not be signed. The default is from RFC4871. - #: Attempting to sign these headers results in an exception. - #: If it is necessary to sign one of these, it must be removed - #: from this list first. - self.should_not_sign = set(DKIM.SHOULD_NOT) - #: Header fields to sign an extra time to prevent additions. - self.frozen_sign = set(DKIM.FROZEN) - #: Minimum public key size. Shorter keys raise KeyFormatError. The - #: default is 1024 - self.minkey = minkey - - #: Load a new message to be signed or verified. - #: @param message: an RFC822 formatted message to be signed or verified - #: (with either \\n or \\r\\n line endings) - def set_message(self,message): - if message: - self.headers, self.body = rfc822_parse(message) - else: - self.headers, self.body = [],'' - - # ARC only supports relaxed/relaxed, so canonicalize now. - canon_policy = CanonicalizationPolicy.from_c_value(b'relaxed/relaxed') - self.headers = canon_policy.canonicalize_headers(self.headers) - self.body = canon_policy.canonicalize_body(self.body) - - def default_sign_headers(self): - """Return the default list of headers to sign: those in should_sign or - frozen_sign, with those in frozen_sign signed an extra time to prevent - additions.""" - hset = self.should_sign | self.frozen_sign - include_headers = [ x for x,y in self.headers - if x.lower() in hset ] - return include_headers + [ x for x in include_headers - if x.lower() in self.frozen_sign] - - def all_sign_headers(self): - """Return header list of all existing headers not in should_not_sign. - @since: 0.5""" - return [x for x,y in self.headers if x.lower() not in self.should_not_sign] + INSTANCE_RE = re.compile(br'[\s;]?i\s*=\s*(\d+)', re.MULTILINE | re.IGNORECASE) def sorted_arc_headers(self): headers = [] - for x,y in self.headers: - print(x,y) + # Use relaxed canonicalization to unfold and clean up headers + relaxed_headers = RelaxedCanonicalization.canonicalize_headers(self.headers) + for x,y in relaxed_headers: if x.lower() in ARC.ARC_HEADERS: - m = ARC.INSTANCE_RE.search(str(y)) - print(m) + m = ARC.INSTANCE_RE.search(y) if m is not None: try: i = int(m.group(1)) @@ -806,16 +733,10 @@ class ARC(DKIM): if len(headers) == 0: return 0, [] - def arc_header_sort(a, b): - if a[0] != b[0]: - return cmp(a[0], b[0]) + def arc_header_key(a): + return [a[0], a[1][0].lower(), a[1][1].lower()] - if a[1][0].lower() != b[1][0].lower(): - return cmp(a[1][0].lower(), b[1][0].lower()) - - return cmp(a[1][1].lower(), b[1][1].lower()) - - headers.sort(arc_header_sort) + headers = sorted(headers, key=arc_header_key) headers.reverse() return headers[0][0], headers @@ -848,141 +769,108 @@ class ARC(DKIM): #: @raise DKIMException: when the message, include_headers, or key are badly #: formed. def sign(self, selector, domain, privkey, auth_results, chain_validation_status, - include_headers=None): + include_headers=None, timestamp=None, standardize=False): try: pk = parse_pem_private_key(privkey) except UnparsableKeyError as e: raise KeyFormatError(str(e)) + # Setup headers + if include_headers is None: + include_headers = self.default_sign_headers() + + if b'arc-authentication-results' not in include_headers: + include_headers.append(b'arc-authentication-results') + + include_headers = tuple([x.lower() for x in include_headers]) + + # record what verify should extract + self.include_headers = include_headers + + # rfc4871 says FROM is required + if b'from' not in include_headers: + raise ParameterError("The From header field MUST be signed") + + # raise exception for any SHOULD_NOT headers, call can modify + # SHOULD_NOT if really needed. + for x in set(include_headers).intersection(self.should_not_sign): + raise ParameterError("The %s header field SHOULD NOT be signed"%x) + max_instance, arc_headers_w_instance = self.sorted_arc_headers() instance = 1 if len(arc_headers_w_instance) != 0: instance = max_instance + 1 - arc_headers = [y for x,y in arc_headers_w_instance] if instance == 1 and chain_validation_status != CV_None: raise ParameterError("No existing chain found on message, cv should be none") elif instance != 1 and chain_validation_status == CV_None: raise ParameterError("cv=none not allowed on instance %d" % instance) new_arc_set = [] - if sys.version_info.major == 3: - aar_value = b"i=" + bytes(instance) + b"; " + bytes(auth_results, 'utf-8') - else: - aar_value = "i=%d; %s" % (instance, auth_results) + arc_headers = [y for x,y in arc_headers_w_instance] + + # Compute ARC-Authentication-Results + aar_value = b"i=%d; %s" % (instance, auth_results) if aar_value[-1] != b'\n': aar_value += b'\r\n' + new_arc_set.append(b"ARC-Authentication-Results: " + aar_value) self.headers.insert(0, (b"arc-authentication-results", aar_value)) arc_headers.insert(0, (b"ARC-Authentication-Results", aar_value)) - # Compute ARC-Message-Signature - + # Compute bh= canon_policy = CanonicalizationPolicy.from_c_value(b'relaxed/relaxed') - headers = canon_policy.canonicalize_headers(self.headers) - if include_headers is None: - include_headers = self.default_sign_headers() - - # rfc4871 says FROM is required - if b'from' not in ( x.lower() for x in include_headers ): - raise ParameterError("The From header field MUST be signed") - - if b'arc-authentication-results' not in ( x.lower() for x in include_headers ): - include_headers.append(b'arc-authentication-results') - - # raise exception for any SHOULD_NOT headers, call can modify - # SHOULD_NOT if really needed. - for x in include_headers: - if x.lower() in self.should_not_sign: - raise ParameterError("The %s header field SHOULD NOT be signed"%x) - - hasher = HASH_ALGORITHMS[self.signature_algorithm] - h = HashThrough(hasher()) - h.update(self.body) + self.hasher = HASH_ALGORITHMS[self.signature_algorithm] + h = HashThrough(self.hasher()) + h.update(canon_policy.canonicalize_body(self.body)) + self.logger.debug("sign ams body hashed: %r" % h.hashed()) bodyhash = base64.b64encode(h.digest()) + # Compute ARC-Message-Signature + timestamp = str(timestamp or int(time.time())).encode('ascii') ams_fields = [x for x in [ (b'i', str(instance).encode('ascii')), (b'a', self.signature_algorithm), + (b'c', b'relaxed/relaxed'), (b'd', domain), (b's', selector), - (b't', str(int(time.time())).encode('ascii')), + (b't', timestamp), (b'h', b" : ".join(include_headers)), (b'bh', bodyhash), # Force b= to fold onto it's own line so that refolding after # adding sig doesn't change whitespace for previous tags. (b'b', b'0'*60), ] if x] - include_headers = [x.lower() for x in include_headers] - # record what verify should extract - self.include_headers = tuple(include_headers) - ams_value = fold(b"; ".join(b"=".join(x) for x in ams_fields)) - ams_value = RE_BTAG.sub(b'\\1',ams_value) - ams_header = (b'ARC-Message-Signature', b' ' + ams_value) - h = HashThrough(hasher()) - sig = dict(ams_fields) - self.signed_headers = hash_headers( - h, canon_policy, headers, include_headers, ams_header,sig) - self.logger.debug("ams sign headers: %r" % self.signed_headers) - self.logger.debug("ams hashed: %r" % h.hashed()) + res = self.gen_header(ams_fields, include_headers, canon_policy, + b"ARC-Message-Signature", pk, standardize) - try: - sig2 = RSASSA_PKCS1_v1_5_sign(h, pk) - except DigestTooLargeError: - raise ParameterError("digest too large for modulus") - # Folding b= is explicity allowed, but yahoo and live.com are broken - #ams_value += base64.b64encode(bytes(sig2)) - # Instead of leaving unfolded (which lets an MTA fold it later and still - # breaks yahoo and live.com), we change the default signing mode to - # relaxed/simple (for broken receivers), and fold now. - ams_value = fold(ams_value + base64.b64encode(bytes(sig2))) + b"\r\n" - - new_arc_set.append(b"ARC-Message-Signature: " + ams_value) - self.headers.insert(0, (b"ARC-Message-Signature", ams_value)) - arc_headers.insert(0, (b"ARC-Message-Signature", ams_value)) + new_arc_set.append(b"ARC-Message-Signature: " + res) + self.headers.insert(0, (b"ARC-Message-Signature", res)) + arc_headers.insert(0, (b"ARC-Message-Signature", res)) # Compute ARC-Seal - as_fields = [x for x in [ (b'i', str(instance).encode('ascii')), (b'cv', chain_validation_status), (b'a', self.signature_algorithm), (b'd', domain), (b's', selector), - (b't', str(int(time.time())).encode('ascii')), + (b't', timestamp), # Force b= to fold onto it's own line so that refolding after # adding sig doesn't change whitespace for previous tags. (b'b', b'0'*60), ] if x] + as_include_headers = [x[0].lower() for x in arc_headers] as_include_headers.reverse() - as_headers = canon_policy.canonicalize_headers(arc_headers) - as_value = fold(b"; ".join(b"=".join(x) for x in as_fields)) - as_value = RE_BTAG.sub(b'\\1',as_value) - as_header = (b'ARC-Seal', b' ' + as_value) - h = HashThrough(hasher()) - sig = dict(as_fields) - as_signed_headers = hash_headers( - h, canon_policy, as_headers, as_include_headers, as_header,sig) - self.logger.debug("arc-seal sign headers: %r" % as_signed_headers) - self.logger.debug("arc-seal hashed: %r" % h.hashed()) + res = self.gen_header(as_fields, as_include_headers, canon_policy, + b"ARC-Seal", pk, standardize) - try: - sig2 = RSASSA_PKCS1_v1_5_sign(h, pk) - except DigestTooLargeError: - raise ParameterError("digest too large for modulus") - # Folding b= is explicity allowed, but yahoo and live.com are broken - #as_value += base64.b64encode(bytes(sig2)) - # Instead of leaving unfolded (which lets an MTA fold it later and still - # breaks yahoo and live.com), we change the default signing mode to - # relaxed/simple (for broken receivers), and fold now. - as_value = fold(as_value + base64.b64encode(bytes(sig2))) + b"\r\n" - - new_arc_set.append(b"ARC-Seal: " + as_value) - self.headers.insert(0, (b"ARC-Seal", as_value)) - arc_headers.insert(0, (b"ARC-Seal", as_value)) + new_arc_set.append(b"ARC-Seal: " + res) + self.headers.insert(0, (b"ARC-Seal", res)) + arc_headers.insert(0, (b"ARC-Seal", res)) new_arc_set.reverse() @@ -1025,25 +913,6 @@ class ARC(DKIM): return CV_Fail, result_data, "ARC-Seal[%d] reported invalid status %s" % (result['instance'], result['cv']) return CV_Pass, result_data, "success" - def load_pk_from_dns(self, name, dnsfunc=get_txt): - s = dnsfunc(name) - if not s: - raise KeyFormatError("missing public key: %s"%name) - try: - if type(s) is str: - s = s.encode('ascii') - pub = parse_tag_value(s) - except InvalidTagValueList as e: - raise KeyFormatError(e) - try: - pk = parse_public_key(base64.b64decode(pub[b'p'])) - keysize = bitsize(pk['modulus']) - except KeyError: - raise KeyFormatError("incomplete public key: %s" % s) - except (TypeError,UnparsableKeyError) as e: - raise KeyFormatError("could not parse public key (%s): %s" % (pub[b'p'],e)) - return pk, keysize - #: Verify an ARC set. #: @type arc_headers_w_instance: list #: @param arc_headers_w_instance: list of tuples, (instance, (name, value)) of @@ -1093,56 +962,21 @@ class ARC(DKIM): except InvalidTagValueList as e: raise MessageFormatError(e) - logger = self.logger - logger.debug("ams sig[%d]: %r" % (instance, sig)) + self.logger.debug("ams sig[%d]: %r" % (instance, sig)) - validate_arc_signature_fields(sig) + validate_signature_fields(sig, [b'i', b'a', b'b', b'c', b'bh', b'd', b'h', b's'], True) output['ams-domain'] = sig[b'd'] output['ams-selector'] = sig[b's'] - # TODO(blong): only hash the body once per algorithm - try: - hasher = HASH_ALGORITHMS[sig[b'a']] - except KeyError as e: - raise MessageFormatError("unknown signature algorithm: %s" % e.args[0]) - - h = hasher() - h.update(self.body) - bodyhash = h.digest() - logger.debug("bh: %s" % base64.b64encode(bodyhash)) - try: - bh = base64.b64decode(re.sub(br"\s+", b"", sig[b'bh'])) - except TypeError as e: - raise MessageFormatError(str(e)) - if bodyhash != bh: - raise ValidationError( - "body hash mismatch (got %s, expected %s)" % - (base64.b64encode(bodyhash), sig[b'bh'])) - - name = sig[b's'] + b"._domainkey." + sig[b'd'] + b"." - pk, keysize = self.load_pk_from_dns(name, dnsfunc) - output['ams-keysize'] = keysize include_headers = [x.lower() for x in re.split(br"\s*:\s*", sig[b'h'])] - # address bug#644046 by including any additional From header - # fields when verifying. Since there should be only one From header, - # this shouldn't break any legitimate messages. This could be - # generalized to check for extras of other singleton headers. - if b'from' in include_headers: - include_headers.append(b'from') - h = HashThrough(hasher()) - canon_policy = CanonicalizationPolicy.from_c_value(b'relaxed/relaxed') + if b'arc-seal' in include_headers: + raise ParameterError("The Arc-Message-Signature MUST NOT sign ARC-Seal") + ams_header = (b'ARC-Message-Signature', b' ' + ams_value) - hash_headers(h, canon_policy, self.headers, include_headers, ams_header, sig) - logger.debug("ams hashed: %r" % h.hashed()) - ams_valid = False - try: - signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b'])) - ams_valid = RSASSA_PKCS1_v1_5_verify(h, signature, pk) - if ams_valid and keysize < self.minkey: - raise KeyFormatError("public key too small: %d" % keysize) - except (TypeError,DigestTooLargeError) as e: - raise KeyFormatError("digest too large for modulus: %s"%e) + ams_valid = self.verify_sig(sig, include_headers, ams_header, dnsfunc) + output['ams-valid'] = ams_valid + self.logger.debug("ams valid: %r" % ams_valid) # Validate Arc-Seal try: @@ -1150,39 +984,21 @@ class ARC(DKIM): except InvalidTagValueList as e: raise MessageFormatError(e) - logger.debug("as sig[%d]: %r" % (instance, sig)) + self.logger.debug("as sig[%d]: %r" % (instance, sig)) - validate_arc_seal_fields(sig) + validate_signature_fields(sig, [b'i', b'a', b'b', b'cv', b'd', b's', b't'], True) output['as-domain'] = sig[b'd'] output['as-selector'] = sig[b's'] output['cv'] = sig[b'cv'] - try: - hasher = HASH_ALGORITHMS[sig[b'a']] - except KeyError as e: - raise MessageFormatError("unknown signature algorithm: %s" % e.args[0]) - - name = sig[b's'] + b"._domainkey." + sig[b'd'] + b"." - pk, keysize = self.load_pk_from_dns(name, dnsfunc) - output['as-keysize'] = keysize as_include_headers = [x[0].lower() for x in arc_headers] as_include_headers.reverse() as_header = (b'ARC-Seal', b' ' + as_value) - h = HashThrough(hasher()) - signed_headers = hash_headers( - h, canon_policy, arc_headers, as_include_headers[:-1], as_header, sig) - logger.debug("as hashed: %r" % h.hashed()) - as_valid = False - try: - signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b'])) - as_valid = RSASSA_PKCS1_v1_5_verify(h, signature, pk) - if as_valid and keysize < self.minkey: - raise KeyFormatError("public key too small: %d" % keysize) - except (TypeError,DigestTooLargeError) as e: - raise KeyFormatError("digest too large for modulus: %s"%e) - output['as-valid'] = as_valid - return output + as_valid = self.verify_sig(sig, as_include_headers[:-1], as_header, dnsfunc) + output['as-valid'] = as_valid + self.logger.debug("as valid: %r" % as_valid) + return output def sign(message, selector, domain, privkey, identity=None, canonicalize=(b'relaxed', b'simple'), @@ -1204,8 +1020,6 @@ def sign(message, selector, domain, privkey, identity=None, """ d = DKIM(message,logger=logger,signature_algorithm=signature_algorithm) - if not include_headers: - include_headers = d.default_sign_headers() return d.sign(selector, domain, privkey, identity=identity, canonicalize=canonicalize, include_headers=include_headers, length=length) def verify(message, logger=None, dnsfunc=get_txt, minkey=1024): @@ -1227,9 +1041,10 @@ dkim_sign = sign dkim_verify = verify def arc_sign(message, selector, domain, privkey, - auth_results, chain_validation_status, - signature_algorithm=b'rsa-sha256', - include_headers=None, logger=None): + auth_results, chain_validation_status, + signature_algorithm=b'rsa-sha256', + include_headers=None, timestamp=None, + logger=None, standardize=False): """Sign an RFC822 message and return the ARC set header lines for the next instance @param message: an RFC822 formatted message (with either \\n or \\r\\n line endings) @param selector: the DKIM selector value for the signature @@ -1246,7 +1061,8 @@ def arc_sign(message, selector, domain, privkey, a = ARC(message,logger=logger,signature_algorithm=signature_algorithm) if not include_headers: include_headers = a.default_sign_headers() - return a.sign(selector, domain, privkey, auth_results, chain_validation_status, include_headers=include_headers) + return a.sign(selector, domain, privkey, auth_results, chain_validation_status, + include_headers=include_headers, timestamp=timestamp, standardize=standardize) def arc_verify(message, logger=None, dnsfunc=get_txt, minkey=1024): """Verify the ARC chain on an RFC822 formatted message. @@ -1264,3 +1080,4 @@ def arc_verify(message, logger=None, dnsfunc=get_txt, minkey=1024): if logger is not None: logger.error("%s" % x) return CV_Fail, [], "%s" % x + diff --git a/dkim/tests/test_arc.py b/dkim/tests/test_arc.py index a7fdf42..1aea21f 100644 --- a/dkim/tests/test_arc.py +++ b/dkim/tests/test_arc.py @@ -72,7 +72,7 @@ Y+vtSBczUiKERHv1yRbcaQtZFh5wtiRrN04BLUTD21MycBX5jYchHjPY/wIDAQAB""" # A message verifies after being signed. sig_lines = dkim.arc_sign( self.message, b"test", b"example.com", self.key, - "test.domain: none", dkim.CV_None) + b"test.domain: none", dkim.CV_None) (cv, res, reason) = dkim.arc_verify(b''.join(sig_lines) + self.message, dnsfunc=self.dnsfunc) self.assertEquals(cv, dkim.CV_Pass) diff --git a/dkim/util.py b/dkim/util.py index 5332127..a0545b0 100644 --- a/dkim/util.py +++ b/dkim/util.py @@ -16,6 +16,8 @@ # # Copyright (c) 2011 William Grant +import re + import logging try: from logging import NullHandler @@ -61,12 +63,14 @@ def parse_tag_value(tag_list): tag_specs.pop() for tag_spec in tag_specs: try: - key, value = tag_spec.split(b'=', 1) + key, value = [x.strip() for x in tag_spec.split(b'=', 1)] except ValueError: raise InvalidTagSpec(tag_spec) - if key.strip() in tags: - raise DuplicateTag(key.strip()) - tags[key.strip()] = value.strip() + if re.match(br'^[a-zA-Z](\w)*', key) is None: + raise InvalidTagSpec(tag_spec) + if key in tags: + raise DuplicateTag(key) + tags[key] = value return tags diff --git a/test.py b/test.py index 1f1b190..3dd9baa 100644 --- a/test.py +++ b/test.py @@ -2,6 +2,9 @@ import unittest import doctest import dkim from dkim.tests import test_suite +from dkim.tests.test_arc import test_suite as arc_test_suite +import logging doctest.testmod(dkim) unittest.TextTestRunner().run(test_suite()) +unittest.TextTestRunner().run(arc_test_suite()) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index f10eb39..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# This software is provided 'as-is', without any express or implied -# warranty. In no event will the author be held liable for any damages -# arising from the use of this software. -# -# Permission is granted to anyone to use this software for any purpose, -# including commercial applications, and to alter it and redistribute it -# freely, subject to the following restrictions: -# -# 1. The origin of this software must not be misrepresented; you must not -# claim that you wrote the original software. If you use this software -# in a product, an acknowledgment in the product documentation would be -# appreciated but is not required. -# 2. Altered source versions must be plainly marked as such, and must not be -# misrepresented as being the original software. -# 3. This notice may not be removed or altered from any source distribution. -# -# Copyright (c) 2011 William Grant -# -# This has been modified from the original software. -# Copyright (c) 2016 Google, Inc. -# Contact: Brandon Long - -import unittest - - -def test_suite(): - from arc.tests import ( - test_arc, - ) - modules = [ - test_arc, - ] - suites = [x.test_suite() for x in modules] - return unittest.TestSuite(suites) diff --git a/tests/test_arc.py b/tests/test_arc.py deleted file mode 100644 index b3bc146..0000000 --- a/tests/test_arc.py +++ /dev/null @@ -1,148 +0,0 @@ -# This software is provided 'as-is', without any express or implied -# warranty. In no event will the author be held liable for any damages -# arising from the use of this software. -# -# Permission is granted to anyone to use this software for any purpose, -# including commercial applications, and to alter it and redistribute it -# freely, subject to the following restrictions: -# -# 1. The origin of this software must not be misrepresented; you must not -# claim that you wrote the original software. If you use this software -# in a product, an acknowledgment in the product documentation would be -# appreciated but is not required. -# 2. Altered source versions must be plainly marked as such, and must not be -# misrepresented as being the original software. -# 3. This notice may not be removed or altered from any source distribution. -# -# Copyright (c) 2011 William Grant -# -# This has been modified from the original software. -# Copyright (c) 2016 Google, Inc. -# Contact: Brandon Long - -import os.path -import unittest -import time - -import arc - - -def read_test_data(filename): - """Get the content of the given test data file. - - The files live in dkim/tests/data. - """ - path = os.path.join(os.path.dirname(__file__), '../../dkim/tests/data', filename) - with open(path, 'rb') as f: - return f.read() - - -class TestSignAndVerify(unittest.TestCase): - """End-to-end signature and verification tests.""" - - def setUp(self): - self.message = read_test_data("test.message") - self.key = read_test_data("test.private") - - def dnsfunc(self, domain): - sample_dns = """\ -k=rsa; \ -p=MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANmBe10IgY+u7h3enWTukkqtUD5PR52T\ -b/mPfjC0QJTocVBq6Za/PlzfV+Py92VaCak19F4WrbVTK5Gg5tW220MCAwEAAQ==""" - - _dns_responses = { - 'example._domainkey.canonical.com.': sample_dns, - 'test._domainkey.example.com.': read_test_data("test.txt"), - # dnsfunc returns empty if no txt record - 'missing._domainkey.example.com.': '', - '20120113._domainkey.gmail.com.': """k=rsa; \ -p=MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1Kd87/UeJjenpabgbFwh\ -+eBCsSTrqmwIYYvywlbhbqoo2DymndFkbjOVIPIldNs/m40KF+yzMn1skyoxcTUGCQ\ -s8g3FgD2Ap3ZB5DekAo5wMmk4wimDO+U8QzI3SD07y2+07wlNWwIt8svnxgdxGkVbb\ -hzY8i+RQ9DpSVpPbF7ykQxtKXkv/ahW3KjViiAH+ghvvIhkx4xYSIc9oSwVmAl5Oct\ -MEeWUwg8Istjqz8BZeTWbf41fbNhte7Y+YqZOwq1Sd0DbvYAD9NOZK9vlfuac0598H\ -Y+vtSBczUiKERHv1yRbcaQtZFh5wtiRrN04BLUTD21MycBX5jYchHjPY/wIDAQAB""" - } - try: - domain = domain.decode('ascii') - except UnicodeDecodeError: - return None - self.assertTrue(domain in _dns_responses,domain) - return _dns_responses[domain] - - def test_verifies(self): - # A message verifies after being signed. - sig_lines = arc.sign( - self.message, b"test", b"example.com", self.key, - "test.domain: none", arc.CV_None) - (cv, res, reason) = arc.verify(''.join(sig_lines) + self.message, dnsfunc=self.dnsfunc) - self.assertEquals(cv, arc.CV_Pass) - - def test_multiple_instances_verify(self): - # A message verifies after being signed multiple times. - message = self.message - sig_lines = arc.sign( - message, b"test", b"example.com", self.key, - "test.domain: none", arc.CV_None) - message = ''.join(sig_lines) + message - (cv, res, reason) = arc.verify(message, dnsfunc=self.dnsfunc) - self.assertEquals(cv, arc.CV_Pass) - - for x in range(10): - sig_lines = arc.sign( - message, b"test", b"example.com", self.key, - "test.domain: arc=pass", arc.CV_Pass) - message = ''.join(sig_lines) + message - (cv, res, reason) = arc.verify(message, dnsfunc=self.dnsfunc) - self.assertEquals(cv, arc.CV_Pass) - - def test_multiple_instances_verify_fail(self): - # A message return CV_Fail if signed as failure. - message = self.message - sig_lines = arc.sign( - message, b"test", b"example.com", self.key, - "test.domain: none", arc.CV_None) - message = ''.join(sig_lines) + message - (cv, res, reason) = arc.verify(message, dnsfunc=self.dnsfunc) - self.assertEquals(cv, arc.CV_Pass) - - sig_lines = arc.sign( - message, b"test", b"example.com", self.key, - "test.domain: arc=pass", arc.CV_Fail) - message = ''.join(sig_lines) + message - # A conforming signer wouldn't sign as pass after a fail. - sig_lines = arc.sign( - message, b"test", b"example.com", self.key, - "test.domain: arc=pass", arc.CV_Pass) - message = ''.join(sig_lines) + message - - (cv, res, reason) = arc.verify(message, dnsfunc=self.dnsfunc) - self.assertEquals(cv, arc.CV_Fail) - - def test_altered_body_fails(self): - # An altered body fails verification. - sig_lines = arc.sign( - self.message, b"test", b"example.com", self.key, - "test.domain: none", arc.CV_None) - (cv, res, reason) = arc.verify(''.join(sig_lines) + self.message + b"foo", dnsfunc=self.dnsfunc) - self.assertEquals(cv, arc.CV_Fail) - - def test_dns_pk_mismatch_fails(self): - # DNS public key doesn't match signing private key. - sig_lines = arc.sign( - self.message, b"example", b"canonical.com", self.key, - "test.domain: none", arc.CV_None) - (cv, res, reason) = arc.verify(''.join(sig_lines) + self.message, dnsfunc=self.dnsfunc) - self.assertEquals(cv, arc.CV_Fail) - - def test_dns_missing_fails(self): - # DNS public key missing fails verify - sig_lines = arc.sign( - self.message, b"missing", b"example.com", self.key, - "test.domain: none", arc.CV_None) - (cv, res, reason) = arc.verify(''.join(sig_lines) + self.message, dnsfunc=self.dnsfunc) - self.assertEquals(cv, arc.CV_Fail) - -def test_suite(): - from unittest import TestLoader - return TestLoader().loadTestsFromName(__name__)