Extract c= value manipulation into a new (tested) CanonicalizationPolicy.
This commit is contained in:
+12
-27
@@ -25,7 +25,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
|
||||
from dkim.canonicalization import algorithms
|
||||
from dkim.canonicalization import CanonicalizationPolicy
|
||||
from dkim.crypto import (
|
||||
DigestTooLargeError,
|
||||
HASH_ALGORITHMS,
|
||||
@@ -231,7 +231,9 @@ def sign(message, selector, domain, privkey, identity=None,
|
||||
if identity is not None and not identity.endswith(domain):
|
||||
raise ParameterError("identity must end with domain")
|
||||
|
||||
headers = algorithms[canonicalize[0]].canonicalize_headers(headers)
|
||||
canon_policy = CanonicalizationPolicy.from_c_value(
|
||||
b'/'.join(canonicalize))
|
||||
headers = canon_policy.canonicalize_headers(headers)
|
||||
|
||||
if include_headers is None:
|
||||
include_headers = [x[0].lower() for x in headers]
|
||||
@@ -239,7 +241,7 @@ def sign(message, selector, domain, privkey, identity=None,
|
||||
include_headers = [x.lower() for x in include_headers]
|
||||
sign_headers = [x for x in headers if x[0].lower() in include_headers]
|
||||
|
||||
body = algorithms[canonicalize[1]].canonicalize_body(body)
|
||||
body = canon_policy.canonicalize_body(body)
|
||||
|
||||
h = hashlib.sha256()
|
||||
h.update(body)
|
||||
@@ -248,9 +250,7 @@ def sign(message, selector, domain, privkey, identity=None,
|
||||
sigfields = [x for x in [
|
||||
(b'v', b"1"),
|
||||
(b'a', signature_algorithm),
|
||||
(b'c', b"/".join(
|
||||
(algorithms[canonicalize[0]].name,
|
||||
algorithms[canonicalize[1]].name))),
|
||||
(b'c', canon_policy.to_c_value()),
|
||||
(b'd', domain),
|
||||
(b'i', identity or b"@"+domain),
|
||||
length and (b'l', len(body)),
|
||||
@@ -263,7 +263,7 @@ def sign(message, selector, domain, privkey, identity=None,
|
||||
] if x]
|
||||
|
||||
sig_value = fold(b"; ".join(b"=".join(x) for x in sigfields))
|
||||
dkim_header = algorithms[canonicalize[0]].canonicalize_headers([
|
||||
dkim_header = canon_policy.canonicalize_headers([
|
||||
[b'DKIM-Signature', b' ' + sig_value]])[0]
|
||||
# the dkim sig is hashed with no trailing crlf, even if the
|
||||
# canonicalization algorithm would add one.
|
||||
@@ -317,25 +317,11 @@ def verify(message, logger=None, dnsfunc=get_txt):
|
||||
logger.error("signature fields failed to validate: %s" % e)
|
||||
return False
|
||||
|
||||
m = re.match(b"(\w+)(?:/(\w+))?$", sig[b'c'])
|
||||
if m is None:
|
||||
logger.error(
|
||||
"c= value is not in format method/method (%s)" % sig[b'c'])
|
||||
canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'), logger)
|
||||
if canon_policy is None:
|
||||
return False
|
||||
can_headers = m.group(1)
|
||||
if m.group(2) is not None:
|
||||
can_body = m.group(2)
|
||||
else:
|
||||
can_body = b"simple"
|
||||
|
||||
try:
|
||||
header_algorithm = algorithms[can_headers]
|
||||
body_algorithm = algorithms[can_body]
|
||||
except KeyError as e:
|
||||
logger.error("unknown canonicalization algorithm: %s" % e.message)
|
||||
return False
|
||||
headers = header_algorithm.canonicalize_headers(headers)
|
||||
body = body_algorithm.canonicalize_body(body)
|
||||
headers = canon_policy.canonicalize_headers(headers)
|
||||
body = canon_policy.canonicalize_body(body)
|
||||
|
||||
try:
|
||||
hasher = HASH_ALGORITHMS[sig[b'a']]
|
||||
@@ -372,8 +358,7 @@ def verify(message, logger=None, dnsfunc=get_txt):
|
||||
|
||||
include_headers = re.split(br"\s*:\s*", sig[b'h'])
|
||||
h = hasher()
|
||||
hash_headers(
|
||||
h, header_algorithm, headers, include_headers, sigheaders, sig)
|
||||
hash_headers(h, canon_policy, headers, include_headers, sigheaders, sig)
|
||||
signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b']))
|
||||
try:
|
||||
return RSASSA_PKCS1_v1_5_verify(
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
import re
|
||||
|
||||
__all__ = [
|
||||
'algorithms',
|
||||
'CanonicalizationPolicy',
|
||||
]
|
||||
|
||||
|
||||
@@ -83,4 +83,50 @@ class Relaxed:
|
||||
compress_whitespace(strip_trailing_whitespace(body)))
|
||||
|
||||
|
||||
algorithms = dict((c.name, c) for c in (Simple, Relaxed))
|
||||
class CanonicalizationPolicy:
|
||||
|
||||
def __init__(self, header_algorithm, body_algorithm):
|
||||
self.header_algorithm = header_algorithm
|
||||
self.body_algorithm = body_algorithm
|
||||
|
||||
@classmethod
|
||||
def from_c_value(cls, c, logger=None):
|
||||
"""Construct the canonicalization policy described by a c= value.
|
||||
|
||||
@param c: c= value from a DKIM-Signature header field
|
||||
@return: a C{CanonicalizationPolicy}, or C{None} if the value is
|
||||
invalid
|
||||
"""
|
||||
if c is None:
|
||||
c = b'simple/simple'
|
||||
m = c.split(b'/')
|
||||
if len(m) not in (1, 2):
|
||||
if logger:
|
||||
logger.error(
|
||||
"c= value is not in format method/method: %s" % c)
|
||||
return None
|
||||
if len(m) == 1:
|
||||
m.append(b'simple')
|
||||
can_headers, can_body = m
|
||||
try:
|
||||
header_algorithm = ALGORITHMS[can_headers]
|
||||
body_algorithm = ALGORITHMS[can_body]
|
||||
except KeyError as e:
|
||||
if logger:
|
||||
logger.error(
|
||||
"unknown canonicalization algorithm: %s" % e.message)
|
||||
return None
|
||||
return cls(header_algorithm, body_algorithm)
|
||||
|
||||
def to_c_value(self):
|
||||
return b'/'.join(
|
||||
(self.header_algorithm.name, self.body_algorithm.name))
|
||||
|
||||
def canonicalize_headers(self, headers):
|
||||
return self.header_algorithm.canonicalize_headers(headers)
|
||||
|
||||
def canonicalize_body(self, body):
|
||||
return self.body_algorithm.canonicalize_body(body)
|
||||
|
||||
|
||||
ALGORITHMS = dict((c.name, c) for c in (Simple, Relaxed))
|
||||
|
||||
@@ -18,7 +18,11 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from dkim.canonicalization import Simple, Relaxed
|
||||
from dkim.canonicalization import (
|
||||
CanonicalizationPolicy,
|
||||
Simple,
|
||||
Relaxed,
|
||||
)
|
||||
|
||||
|
||||
class BaseCanonicalizationTest(unittest.TestCase):
|
||||
@@ -94,6 +98,59 @@ class TestRelaxedAlgorithmBody(BaseCanonicalizationTest):
|
||||
b'Foo\r\nbar\r\n\r\n\r\n')
|
||||
|
||||
|
||||
class TestCanonicalizationPolicyFromCValue(unittest.TestCase):
|
||||
|
||||
def assertAlgorithms(self, header_algo, body_algo, c_value):
|
||||
p = CanonicalizationPolicy.from_c_value(c_value)
|
||||
self.assertEqual(
|
||||
(header_algo, body_algo),
|
||||
(p.header_algorithm, p.body_algorithm))
|
||||
|
||||
def assertValueDoesNotParse(self, c_value):
|
||||
self.assertIs(None, CanonicalizationPolicy.from_c_value(c_value))
|
||||
|
||||
def test_both_default_to_simple(self):
|
||||
self.assertAlgorithms(Simple, Simple, None)
|
||||
|
||||
def test_relaxed_headers(self):
|
||||
self.assertAlgorithms(Relaxed, Simple, b'relaxed')
|
||||
|
||||
def test_relaxed_body(self):
|
||||
self.assertAlgorithms(Simple, Relaxed, b'simple/relaxed')
|
||||
|
||||
def test_relaxed_both(self):
|
||||
self.assertAlgorithms(Relaxed, Relaxed, b'relaxed/relaxed')
|
||||
|
||||
def test_explict_simple_both(self):
|
||||
self.assertAlgorithms(Simple, Simple, b'simple/simple')
|
||||
|
||||
def test_corruption_is_ignored(self):
|
||||
self.assertValueDoesNotParse(b'')
|
||||
self.assertValueDoesNotParse(b'simple/simple/simple')
|
||||
self.assertValueDoesNotParse(b'relaxed/stressed')
|
||||
self.assertValueDoesNotParse(b'worried')
|
||||
|
||||
|
||||
class TestCanonicalizationpolicyToCValue(unittest.TestCase):
|
||||
|
||||
def assertCValue(self, c_value, header_algo, body_algo):
|
||||
self.assertEqual(
|
||||
c_value,
|
||||
CanonicalizationPolicy(header_algo, body_algo).to_c_value())
|
||||
|
||||
def test_both_simple(self):
|
||||
self.assertCValue(b'simple/simple', Simple, Simple)
|
||||
|
||||
def test_relaxed_body(self):
|
||||
self.assertCValue(b'simple/relaxed', Simple, Relaxed)
|
||||
|
||||
def test_both_relaxed(self):
|
||||
self.assertCValue(b'relaxed/relaxed', Relaxed, Relaxed)
|
||||
|
||||
def test_relaxed_headers(self):
|
||||
self.assertCValue(b'relaxed/simple', Relaxed, Simple)
|
||||
|
||||
|
||||
def test_suite():
|
||||
from unittest import TestLoader
|
||||
return TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
Reference in New Issue
Block a user