Extract c= value manipulation into a new (tested) CanonicalizationPolicy.

This commit is contained in:
William Grant
2011-06-04 14:40:06 +10:00
3 changed files with 118 additions and 30 deletions
+12 -27
View File
@@ -25,7 +25,7 @@ import logging
import re import re
import time import time
from dkim.canonicalization import algorithms from dkim.canonicalization import CanonicalizationPolicy
from dkim.crypto import ( from dkim.crypto import (
DigestTooLargeError, DigestTooLargeError,
HASH_ALGORITHMS, HASH_ALGORITHMS,
@@ -231,7 +231,9 @@ def sign(message, selector, domain, privkey, identity=None,
if identity is not None and not identity.endswith(domain): if identity is not None and not identity.endswith(domain):
raise ParameterError("identity must end with domain") raise ParameterError("identity must end with domain")
headers = algorithms[canonicalize[0]].canonicalize_headers(headers) canon_policy = CanonicalizationPolicy.from_c_value(
b'/'.join(canonicalize))
headers = canon_policy.canonicalize_headers(headers)
if include_headers is None: if include_headers is None:
include_headers = [x[0].lower() for x in headers] include_headers = [x[0].lower() for x in headers]
@@ -239,7 +241,7 @@ def sign(message, selector, domain, privkey, identity=None,
include_headers = [x.lower() for x in include_headers] include_headers = [x.lower() for x in include_headers]
sign_headers = [x for x in headers if x[0].lower() in include_headers] sign_headers = [x for x in headers if x[0].lower() in include_headers]
body = algorithms[canonicalize[1]].canonicalize_body(body) body = canon_policy.canonicalize_body(body)
h = hashlib.sha256() h = hashlib.sha256()
h.update(body) h.update(body)
@@ -248,9 +250,7 @@ def sign(message, selector, domain, privkey, identity=None,
sigfields = [x for x in [ sigfields = [x for x in [
(b'v', b"1"), (b'v', b"1"),
(b'a', signature_algorithm), (b'a', signature_algorithm),
(b'c', b"/".join( (b'c', canon_policy.to_c_value()),
(algorithms[canonicalize[0]].name,
algorithms[canonicalize[1]].name))),
(b'd', domain), (b'd', domain),
(b'i', identity or b"@"+domain), (b'i', identity or b"@"+domain),
length and (b'l', len(body)), length and (b'l', len(body)),
@@ -263,7 +263,7 @@ def sign(message, selector, domain, privkey, identity=None,
] if x] ] if x]
sig_value = fold(b"; ".join(b"=".join(x) for x in sigfields)) sig_value = fold(b"; ".join(b"=".join(x) for x in sigfields))
dkim_header = algorithms[canonicalize[0]].canonicalize_headers([ dkim_header = canon_policy.canonicalize_headers([
[b'DKIM-Signature', b' ' + sig_value]])[0] [b'DKIM-Signature', b' ' + sig_value]])[0]
# the dkim sig is hashed with no trailing crlf, even if the # the dkim sig is hashed with no trailing crlf, even if the
# canonicalization algorithm would add one. # canonicalization algorithm would add one.
@@ -317,25 +317,11 @@ def verify(message, logger=None, dnsfunc=get_txt):
logger.error("signature fields failed to validate: %s" % e) logger.error("signature fields failed to validate: %s" % e)
return False return False
m = re.match(b"(\w+)(?:/(\w+))?$", sig[b'c']) canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'), logger)
if m is None: if canon_policy is None:
logger.error(
"c= value is not in format method/method (%s)" % sig[b'c'])
return False return False
can_headers = m.group(1) headers = canon_policy.canonicalize_headers(headers)
if m.group(2) is not None: body = canon_policy.canonicalize_body(body)
can_body = m.group(2)
else:
can_body = b"simple"
try:
header_algorithm = algorithms[can_headers]
body_algorithm = algorithms[can_body]
except KeyError as e:
logger.error("unknown canonicalization algorithm: %s" % e.message)
return False
headers = header_algorithm.canonicalize_headers(headers)
body = body_algorithm.canonicalize_body(body)
try: try:
hasher = HASH_ALGORITHMS[sig[b'a']] hasher = HASH_ALGORITHMS[sig[b'a']]
@@ -372,8 +358,7 @@ def verify(message, logger=None, dnsfunc=get_txt):
include_headers = re.split(br"\s*:\s*", sig[b'h']) include_headers = re.split(br"\s*:\s*", sig[b'h'])
h = hasher() h = hasher()
hash_headers( hash_headers(h, canon_policy, headers, include_headers, sigheaders, sig)
h, header_algorithm, headers, include_headers, sigheaders, sig)
signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b'])) signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b']))
try: try:
return RSASSA_PKCS1_v1_5_verify( return RSASSA_PKCS1_v1_5_verify(
+48 -2
View File
@@ -22,7 +22,7 @@
import re import re
__all__ = [ __all__ = [
'algorithms', 'CanonicalizationPolicy',
] ]
@@ -83,4 +83,50 @@ class Relaxed:
compress_whitespace(strip_trailing_whitespace(body))) compress_whitespace(strip_trailing_whitespace(body)))
algorithms = dict((c.name, c) for c in (Simple, Relaxed)) class CanonicalizationPolicy:
def __init__(self, header_algorithm, body_algorithm):
self.header_algorithm = header_algorithm
self.body_algorithm = body_algorithm
@classmethod
def from_c_value(cls, c, logger=None):
"""Construct the canonicalization policy described by a c= value.
@param c: c= value from a DKIM-Signature header field
@return: a C{CanonicalizationPolicy}, or C{None} if the value is
invalid
"""
if c is None:
c = b'simple/simple'
m = c.split(b'/')
if len(m) not in (1, 2):
if logger:
logger.error(
"c= value is not in format method/method: %s" % c)
return None
if len(m) == 1:
m.append(b'simple')
can_headers, can_body = m
try:
header_algorithm = ALGORITHMS[can_headers]
body_algorithm = ALGORITHMS[can_body]
except KeyError as e:
if logger:
logger.error(
"unknown canonicalization algorithm: %s" % e.message)
return None
return cls(header_algorithm, body_algorithm)
def to_c_value(self):
return b'/'.join(
(self.header_algorithm.name, self.body_algorithm.name))
def canonicalize_headers(self, headers):
return self.header_algorithm.canonicalize_headers(headers)
def canonicalize_body(self, body):
return self.body_algorithm.canonicalize_body(body)
ALGORITHMS = dict((c.name, c) for c in (Simple, Relaxed))
+58 -1
View File
@@ -18,7 +18,11 @@
import unittest import unittest
from dkim.canonicalization import Simple, Relaxed from dkim.canonicalization import (
CanonicalizationPolicy,
Simple,
Relaxed,
)
class BaseCanonicalizationTest(unittest.TestCase): class BaseCanonicalizationTest(unittest.TestCase):
@@ -94,6 +98,59 @@ class TestRelaxedAlgorithmBody(BaseCanonicalizationTest):
b'Foo\r\nbar\r\n\r\n\r\n') b'Foo\r\nbar\r\n\r\n\r\n')
class TestCanonicalizationPolicyFromCValue(unittest.TestCase):
def assertAlgorithms(self, header_algo, body_algo, c_value):
p = CanonicalizationPolicy.from_c_value(c_value)
self.assertEqual(
(header_algo, body_algo),
(p.header_algorithm, p.body_algorithm))
def assertValueDoesNotParse(self, c_value):
self.assertIs(None, CanonicalizationPolicy.from_c_value(c_value))
def test_both_default_to_simple(self):
self.assertAlgorithms(Simple, Simple, None)
def test_relaxed_headers(self):
self.assertAlgorithms(Relaxed, Simple, b'relaxed')
def test_relaxed_body(self):
self.assertAlgorithms(Simple, Relaxed, b'simple/relaxed')
def test_relaxed_both(self):
self.assertAlgorithms(Relaxed, Relaxed, b'relaxed/relaxed')
def test_explict_simple_both(self):
self.assertAlgorithms(Simple, Simple, b'simple/simple')
def test_corruption_is_ignored(self):
self.assertValueDoesNotParse(b'')
self.assertValueDoesNotParse(b'simple/simple/simple')
self.assertValueDoesNotParse(b'relaxed/stressed')
self.assertValueDoesNotParse(b'worried')
class TestCanonicalizationpolicyToCValue(unittest.TestCase):
def assertCValue(self, c_value, header_algo, body_algo):
self.assertEqual(
c_value,
CanonicalizationPolicy(header_algo, body_algo).to_c_value())
def test_both_simple(self):
self.assertCValue(b'simple/simple', Simple, Simple)
def test_relaxed_body(self):
self.assertCValue(b'simple/relaxed', Simple, Relaxed)
def test_both_relaxed(self):
self.assertCValue(b'relaxed/relaxed', Relaxed, Relaxed)
def test_relaxed_headers(self):
self.assertCValue(b'relaxed/simple', Relaxed, Simple)
def test_suite(): def test_suite():
from unittest import TestLoader from unittest import TestLoader
return TestLoader().loadTestsFromName(__name__) return TestLoader().loadTestsFromName(__name__)