File CVE-2023-29483.patch of Package python-dnspython.34852
From 093c593624bcf55766c2a952c207e0b92920214e Mon Sep 17 00:00:00 2001
From: Bob Halley <halley@dnspython.org>
Date: Fri, 9 Feb 2024 10:36:08 -0800
Subject: [PATCH] Address DoS via the Tudoor mechanism (CVE-2023-29483)
---
dns/asyncquery.py | 45 +++++++++++++------
dns/nameserver.py | 2 +
dns/query.py | 110 +++++++++++++++++++++++++++++-----------------
3 files changed, 103 insertions(+), 54 deletions(-)
Index: dnspython-1.12.0/dns/query.py
===================================================================
--- dnspython-1.12.0.orig/dns/query.py
+++ dnspython-1.12.0/dns/query.py
@@ -155,18 +155,38 @@ def _wait_for_writable(s, expiration):
_wait_for(s, False, True, True, expiration)
+def _matches_destination(af, from_address, destination, ignore_unexpected):
+ # Check that from_address is appropriate for a response to a query
+ # sent to destination.
+ if not destination:
+ return True
+ if _addresses_equal(af, from_address, destination) or (
+ dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:]
+ ):
+ return True
+ elif ignore_unexpected:
+ return False
+ raise UnexpectedSource(
+ "got a response from %s instead of %s" % (from_address, destination)
+ )
+
+
def _addresses_equal(af, a1, a2):
# Convert the first value of the tuple, which is a textual format
# address into binary form, so that we are not confused by different
# textual representations of the same address
- n1 = dns.inet.inet_pton(af, a1[0])
- n2 = dns.inet.inet_pton(af, a2[0])
+ try:
+ n1 = dns.inet.inet_pton(af, a1[0])
+ n2 = dns.inet.inet_pton(af, a2[0])
+ except dns.exception.SyntaxError:
+ return False
return n1 == n2 and a1[1:] == a2[1:]
def _destination_and_source(af, where, port, source, source_port):
# Apply defaults and compute destination and source tuples
# suitable for use in connect(), sendto(), or bind().
+ destination = None
if af is None:
try:
af = dns.inet.af_for_address(where)
@@ -187,6 +207,33 @@ def _destination_and_source(af, where, p
return (af, destination, source)
+def _udp_recv(sock, max_size, expiration):
+ """Reads a datagram from the socket.
+ A Timeout exception will be raised if the operation is not completed
+ by the expiration time.
+ """
+ while True:
+ try:
+ return sock.recvfrom(max_size)
+ except BlockingIOError:
+ _wait_for_readable(sock, expiration)
+
+
+def _udp_send(sock, data, destination, expiration):
+ """Sends the specified datagram to destination over the socket.
+ A Timeout exception will be raised if the operation is not completed
+ by the expiration time.
+ """
+ while True:
+ try:
+ if destination:
+ return sock.sendto(data, destination)
+ else:
+ return sock.send(data)
+ except BlockingIOError: # pragma: no cover
+ _wait_for_writable(sock, expiration)
+
+
def send_udp(sock, what, destination, expiration=None):
"""Send a DNS message to the specified UDP socket.
@@ -203,15 +250,15 @@ def send_udp(sock, what, destination, ex
"""
if isinstance(what, dns.message.Message):
what = what.to_wire()
- _wait_for_writable(sock, expiration)
sent_time = time.time()
- n = sock.sendto(what, destination)
+ n = _udp_send(sock, what, destination, expiration)
return (n, sent_time)
def receive_udp(sock, destination, expiration=None, af=None,
ignore_unexpected=False, one_rr_per_rrset=False,
- keyring=None, request_mac=b''):
+ keyring=None, request_mac=b'',
+ ignore_errors=False, query=None):
"""Read a DNS message from a UDP socket.
@param sock: the socket
@@ -229,6 +276,14 @@ def receive_udp(sock, destination, expir
@type keyring: keyring dict
@param request_mac: the MAC of the request (for TSIG)
@type request_mac: bytes
+ @param ignore_errors: If various format errors or response
+ mismatches occur, ignore them and keep listening for a valid response.
+ The default is ``False``.
+ @type ignore_errors: bool
+ @param query: If not ``None`` and *ignore_errors* is ``True``,
+ check that the received message is a response to this query, and
+ if not keep listening for a valid response.
+ @type query: dns.message.Message or None
@rtype: dns.message.Message object
"""
if af is None:
@@ -238,23 +293,32 @@ def receive_udp(sock, destination, expir
af = dns.inet.AF_INET
wire = b''
while 1:
- _wait_for_readable(sock, expiration)
- (wire, from_address) = sock.recvfrom(65535)
- if _addresses_equal(af, from_address, destination) or \
- (dns.inet.is_multicast(destination[0]) and
- from_address[1:] == destination[1:]):
- break
- if not ignore_unexpected:
- raise UnexpectedSource('got a response from '
- '%s instead of %s' % (from_address,
- destination))
- received_time = time.time()
- r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
- one_rr_per_rrset=one_rr_per_rrset)
- return (r, received_time)
+ (wire, from_address) = _udp_recv(sock, 65535, expiration)
+ if not _matches_destination(
+ af, from_address, destination, ignore_unexpected
+ ):
+ continue
+
+ received_time = time.time()
+ try:
+ r = dns.message.from_wire(
+ wire,
+ keyring=keyring,
+ request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ )
+ except Exception:
+ if ignore_errors:
+ continue
+ else:
+ raise
+ if ignore_errors and query is not None and not query.is_response(r):
+ continue
+ return (r, received_time)
def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
- ignore_unexpected=False, one_rr_per_rrset=False):
+ ignore_unexpected=False, one_rr_per_rrset=False,
+ ignore_errors=False, sock=None):
"""Return the response obtained after sending a query via UDP.
@param q: the query
@@ -280,13 +344,20 @@ def udp(q, where, timeout=None, port=53,
@type ignore_unexpected: bool
@param one_rr_per_rrset: Put each RR into its own RRset
@type one_rr_per_rrset: bool
+ @param ignore_errors: If various format errors or response
+ mismatches occur, ignore them and keep listening for a valid
+ response. The default is ``False``.
+ @type ignore_errors: bool
@rtype: dns.message.Message object
"""
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
- s = socket.socket(af, socket.SOCK_DGRAM, 0)
+ if sock:
+ s = sock
+ else:
+ s = socket.socket(af, socket.SOCK_DGRAM, 0)
received_time = None
sent_time = None
try:
@@ -298,10 +369,13 @@ def udp(q, where, timeout=None, port=53,
(_, sent_time) = send_udp(s, wire, destination, expiration)
(r, received_time) = receive_udp(s, destination, expiration, af,
ignore_unexpected, one_rr_per_rrset,
- q.keyring, q.request_mac)
+ q.keyring, q.request_mac,
+ ignore_errors, q)
finally:
s.close()
- if not q.is_response(r):
+ # We don't need to check q.is_response() if we are in ignore_errors mode
+ # as receive_udp() will have checked it.
+ if not (ignore_errors or q.is_response(r)):
raise BadResponse
return r
@@ -520,8 +594,7 @@ def xfr(where, zone, rdtype=dns.rdatatyp
_connect(s, destination)
l = len(wire)
if use_udp:
- _wait_for_writable(s, expiration)
- s.send(wire)
+ _udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", l) + wire
_net_write(s, tcpmsg, expiration)
@@ -542,8 +615,7 @@ def xfr(where, zone, rdtype=dns.rdatatyp
if mexpiration is None or mexpiration > expiration:
mexpiration = expiration
if use_udp:
- _wait_for_readable(s, expiration)
- (wire, from_address) = s.recvfrom(65535)
+ (wire, from_address) = _udp_recv(s, 65535, expiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
Index: dnspython-1.12.0/dns/resolver.py
===================================================================
--- dnspython-1.12.0.orig/dns/resolver.py
+++ dnspython-1.12.0/dns/resolver.py
@@ -850,7 +850,9 @@ class Resolver(object):
response = dns.query.udp(request, nameserver,
timeout, self.port,
source=source,
- source_port=source_port)
+ source_port=source_port,
+ ignore_errors=True,
+ ignore_unexpected=True)
if response.flags & dns.flags.TC:
# Response truncated; retry with TCP.
timeout = self._compute_timeout(start)
Index: dnspython-1.12.0/tests/test_query.py
===================================================================
--- /dev/null
+++ dnspython-1.12.0/tests/test_query.py
@@ -0,0 +1,758 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import contextlib
+import socket
+import sys
+import time
+import unittest
+
+try:
+ import ssl
+
+ have_ssl = True
+except Exception:
+ have_ssl = False
+
+import dns.exception
+import dns.flags
+import dns.inet
+import dns.message
+import dns.name
+import dns.query
+import dns.rcode
+import dns.rdataclass
+import dns.rdatatype
+import dns.tsigkeyring
+import dns.zone
+
+_nanonameserver_available = False
+
+class Server(object):
+ pass
+
+
+query_addresses = []
+
+keyring = dns.tsigkeyring.from_text({"name": b"tDz6cfXXGtNivRpQ98hr6A=="})
+
+
+@unittest.skipIf(True, "Internet not reachable")
+class QueryTests(unittest.TestCase):
+ def testQueryUDP(self):
+ for address in query_addresses:
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.udp(q, address, timeout=2)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+
+ def testQueryUDPWithSocket(self):
+ for address in query_addresses:
+ with socket.socket(
+ dns.inet.af_for_address(address), socket.SOCK_DGRAM
+ ) as s:
+ s.setblocking(0)
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.udp(q, address, sock=s, timeout=2)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+
+ def testQueryTCP(self):
+ for address in query_addresses:
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tcp(q, address, timeout=2)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+
+ def testQueryTCPWithSocket(self):
+ for address in query_addresses:
+ with socket.socket(
+ dns.inet.af_for_address(address), socket.SOCK_STREAM
+ ) as s:
+ ll = dns.inet.low_level_address_tuple((address, 53))
+ s.settimeout(2)
+ s.connect(ll)
+ s.setblocking(0)
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tcp(q, None, sock=s, timeout=2)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+
+ @unittest.skipUnless(have_ssl, "No SSL support")
+ def testQueryTLS(self):
+ for address in query_addresses:
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tls(q, address, timeout=2)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+
+ @unittest.skipUnless(have_ssl, "No SSL support")
+ def testQueryTLSWithContext(self):
+ for address in query_addresses:
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ ssl_context = ssl.create_default_context()
+ ssl_context.check_hostname = False
+ response = dns.query.tls(q, address, timeout=2, ssl_context=ssl_context)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+
+ @unittest.skipUnless(have_ssl, "No SSL support")
+ def testQueryTLSWithSocket(self):
+ for address in query_addresses:
+ with socket.socket(
+ dns.inet.af_for_address(address), socket.SOCK_STREAM
+ ) as base_s:
+ ll = dns.inet.low_level_address_tuple((address, 853))
+ base_s.settimeout(2)
+ base_s.connect(ll)
+ ctx = ssl.create_default_context()
+ ctx.minimum_version = ssl.TLSVersion.TLSv1_2
+ with ctx.wrap_socket(
+ base_s, server_hostname="dns.google"
+ ) as s: # lgtm[py/insecure-protocol]
+ s.setblocking(0)
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tls(q, None, sock=s, timeout=2)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+
+ @unittest.skipUnless(have_ssl, "No SSL support")
+ def testQueryTLSwithPadding(self):
+ for address in query_addresses:
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A, use_edns=0, pad=128)
+ response = dns.query.tls(q, address, timeout=2)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+ # the response should have a padding option
+ self.assertIsNotNone(response.opt)
+ has_pad = False
+ for o in response.opt[0].options:
+ if o.otype == dns.edns.OptionType.PADDING:
+ has_pad = True
+ self.assertTrue(has_pad)
+
+ def testQueryUDPFallback(self):
+ for address in query_addresses:
+ qname = dns.name.from_text(".")
+ q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+ (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=4)
+ self.assertTrue(tcp)
+
+ def testQueryUDPFallbackWithSocket(self):
+ for address in query_addresses:
+ af = dns.inet.af_for_address(address)
+ with socket.socket(af, socket.SOCK_DGRAM) as udp_s:
+ udp_s.setblocking(0)
+ with socket.socket(af, socket.SOCK_STREAM) as tcp_s:
+ ll = dns.inet.low_level_address_tuple((address, 53))
+ tcp_s.settimeout(2)
+ tcp_s.connect(ll)
+ tcp_s.setblocking(0)
+ qname = dns.name.from_text(".")
+ q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+ (_, tcp) = dns.query.udp_with_fallback(
+ q, address, udp_sock=udp_s, tcp_sock=tcp_s, timeout=4
+ )
+ self.assertTrue(tcp)
+
+ def testQueryUDPFallbackNoFallback(self):
+ for address in query_addresses:
+ qname = dns.name.from_text("dns.google.")
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2)
+ self.assertFalse(tcp)
+
+ def testUDPReceiveQuery(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener:
+ listener.bind(("127.0.0.1", 0))
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender:
+ sender.bind(("127.0.0.1", 0))
+ q = dns.message.make_query("dns.google", dns.rdatatype.A)
+ dns.query.send_udp(sender, q, listener.getsockname())
+ expiration = time.time() + 2
+ (q, _, addr) = dns.query.receive_udp(listener, expiration=expiration)
+ self.assertEqual(addr, sender.getsockname())
+
+
+axfr_zone = """
+$TTL 300
+@ SOA ns1 root 1 7200 900 1209600 86400
+@ NS ns1
+@ NS ns2
+ns1 A 10.0.0.1
+ns2 A 10.0.0.1
+"""
+
+
+class AXFRNanoNameserver(Server):
+ def handle(self, request):
+ self.zone = dns.zone.from_text(axfr_zone, origin=self.origin)
+ self.origin = self.zone.origin
+ items = []
+ soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA)
+ response = dns.message.make_response(request.message)
+ response.flags |= dns.flags.AA
+ response.answer.append(soa)
+ items.append(response)
+ response = dns.message.make_response(request.message)
+ response.question = []
+ response.flags |= dns.flags.AA
+ for name, rdataset in self.zone.iterate_rdatasets():
+ if rdataset.rdtype == dns.rdatatype.SOA and name == dns.name.empty:
+ continue
+ rrset = dns.rrset.RRset(
+ name, rdataset.rdclass, rdataset.rdtype, rdataset.covers
+ )
+ rrset.update(rdataset)
+ response.answer.append(rrset)
+ items.append(response)
+ response = dns.message.make_response(request.message)
+ response.question = []
+ response.flags |= dns.flags.AA
+ response.answer.append(soa)
+ items.append(response)
+ return items
+
+
+ixfr_message = """id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN IXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
+deleted.example. 300 IN A 10.0.0.1
+changed.example. 300 IN A 10.0.0.2
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+changed.example. 300 IN A 10.0.0.4
+added.example. 300 IN A 10.0.0.3
+example. 300 SOA ns1.example. root.example. 3 7200 900 1209600 86400
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+added2.example. 300 IN A 10.0.0.5
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+"""
+
+ixfr_trailing_junk = ixfr_message + "junk.example. 300 IN A 10.0.0.6"
+
+ixfr_up_to_date_message = """id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN IXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
+"""
+
+axfr_trailing_junk = """id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN AXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+added.example. 300 IN A 10.0.0.3
+added2.example. 300 IN A 10.0.0.5
+changed.example. 300 IN A 10.0.0.4
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+junk.example. 300 IN A 10.0.0.6
+"""
+
+
+class IXFRNanoNameserver(Server):
+ def __init__(self, response_text):
+ super().__init__()
+ self.response_text = response_text
+
+ def handle(self, request):
+ try:
+ r = dns.message.from_text(self.response_text, one_rr_per_rrset=True)
+ r.id = request.message.id
+ return r
+ except Exception:
+ pass
+
+
+@unittest.skipIf(not _nanonameserver_available, "nanonameserver required")
+class XfrTests(unittest.TestCase):
+ def test_axfr(self):
+ expected = dns.zone.from_text(axfr_zone, origin="example")
+ with AXFRNanoNameserver(origin="example") as ns:
+ xfr = dns.query.xfr(ns.tcp_address[0], "example", port=ns.tcp_address[1])
+ zone = dns.zone.from_xfr(xfr)
+ self.assertEqual(zone, expected)
+
+ def test_axfr_tsig(self):
+ expected = dns.zone.from_text(axfr_zone, origin="example")
+ with AXFRNanoNameserver(origin="example", keyring=keyring) as ns:
+ xfr = dns.query.xfr(
+ ns.tcp_address[0],
+ "example",
+ port=ns.tcp_address[1],
+ keyring=keyring,
+ keyname="name",
+ )
+ zone = dns.zone.from_xfr(xfr)
+ self.assertEqual(zone, expected)
+
+ def test_axfr_root_tsig(self):
+ expected = dns.zone.from_text(axfr_zone, origin=".")
+ with AXFRNanoNameserver(origin=".", keyring=keyring) as ns:
+ xfr = dns.query.xfr(
+ ns.tcp_address[0],
+ ".",
+ port=ns.tcp_address[1],
+ keyring=keyring,
+ keyname="name",
+ )
+ zone = dns.zone.from_xfr(xfr)
+ self.assertEqual(zone, expected)
+
+ def test_axfr_udp(self):
+ def bad():
+ with AXFRNanoNameserver(origin="example") as ns:
+ xfr = dns.query.xfr(
+ ns.udp_address[0], "example", port=ns.udp_address[1], use_udp=True
+ )
+ l = list(xfr)
+
+ self.assertRaises(ValueError, bad)
+
+ def test_axfr_bad_rcode(self):
+ def bad():
+ # We just use Server here as by default it will refuse.
+ with Server() as ns:
+ xfr = dns.query.xfr(
+ ns.tcp_address[0], "example", port=ns.tcp_address[1]
+ )
+ l = list(xfr)
+
+ self.assertRaises(dns.query.TransferError, bad)
+
+ def test_axfr_trailing_junk(self):
+ # we use the IXFR server here as it returns messages
+ def bad():
+ with IXFRNanoNameserver(axfr_trailing_junk) as ns:
+ xfr = dns.query.xfr(
+ ns.tcp_address[0],
+ "example",
+ dns.rdatatype.AXFR,
+ port=ns.tcp_address[1],
+ )
+ l = list(xfr)
+
+ self.assertRaises(dns.exception.FormError, bad)
+
+ def test_ixfr_tcp(self):
+ with IXFRNanoNameserver(ixfr_message) as ns:
+ xfr = dns.query.xfr(
+ ns.tcp_address[0],
+ "example",
+ dns.rdatatype.IXFR,
+ port=ns.tcp_address[1],
+ serial=2,
+ relativize=False,
+ )
+ l = list(xfr)
+ self.assertEqual(len(l), 1)
+ expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True)
+ expected.id = l[0].id
+ self.assertEqual(l[0], expected)
+
+ def test_ixfr_udp(self):
+ with IXFRNanoNameserver(ixfr_message) as ns:
+ xfr = dns.query.xfr(
+ ns.udp_address[0],
+ "example",
+ dns.rdatatype.IXFR,
+ port=ns.udp_address[1],
+ serial=2,
+ relativize=False,
+ use_udp=True,
+ )
+ l = list(xfr)
+ self.assertEqual(len(l), 1)
+ expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True)
+ expected.id = l[0].id
+ self.assertEqual(l[0], expected)
+
+ def test_ixfr_up_to_date(self):
+ with IXFRNanoNameserver(ixfr_up_to_date_message) as ns:
+ xfr = dns.query.xfr(
+ ns.tcp_address[0],
+ "example",
+ dns.rdatatype.IXFR,
+ port=ns.tcp_address[1],
+ serial=2,
+ relativize=False,
+ )
+ l = list(xfr)
+ self.assertEqual(len(l), 1)
+ expected = dns.message.from_text(
+ ixfr_up_to_date_message, one_rr_per_rrset=True
+ )
+ expected.id = l[0].id
+ self.assertEqual(l[0], expected)
+
+ def test_ixfr_trailing_junk(self):
+ def bad():
+ with IXFRNanoNameserver(ixfr_trailing_junk) as ns:
+ xfr = dns.query.xfr(
+ ns.tcp_address[0],
+ "example",
+ dns.rdatatype.IXFR,
+ port=ns.tcp_address[1],
+ serial=2,
+ relativize=False,
+ )
+ l = list(xfr)
+
+ self.assertRaises(dns.exception.FormError, bad)
+
+ def test_ixfr_base_serial_mismatch(self):
+ def bad():
+ with IXFRNanoNameserver(ixfr_message) as ns:
+ xfr = dns.query.xfr(
+ ns.tcp_address[0],
+ "example",
+ dns.rdatatype.IXFR,
+ port=ns.tcp_address[1],
+ serial=1,
+ relativize=False,
+ )
+ l = list(xfr)
+
+ self.assertRaises(dns.exception.FormError, bad)
+
+
+class TSIGNanoNameserver(Server):
+ def handle(self, request):
+ response = dns.message.make_response(request.message)
+ response.set_rcode(dns.rcode.REFUSED)
+ response.flags |= dns.flags.RA
+ try:
+ if request.qtype == dns.rdatatype.A and request.qclass == dns.rdataclass.IN:
+ rrs = dns.rrset.from_text(request.qname, 300, "IN", "A", "1.2.3.4")
+ response.answer.append(rrs)
+ response.set_rcode(dns.rcode.NOERROR)
+ response.flags |= dns.flags.AA
+ except Exception:
+ pass
+ return response
+
+
+@unittest.skipIf(not _nanonameserver_available, "nanonameserver required")
+class TsigTests(unittest.TestCase):
+ def test_tsig(self):
+ with TSIGNanoNameserver(keyring=keyring) as ns:
+ qname = dns.name.from_text("example.com")
+ q = dns.message.make_query(qname, "A")
+ q.use_tsig(keyring=keyring, keyname="name")
+ response = dns.query.udp(q, ns.udp_address[0], port=ns.udp_address[1])
+ self.assertTrue(response.had_tsig)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("1.2.3.4" in seen)
+
+
+@unittest.skipIf(sys.platform == "win32", "low level tests do not work on win32")
+class LowLevelWaitTests(unittest.TestCase):
+ def test_wait_for(self):
+ try:
+ (l, r) = socket.socketpair()
+ # already expired
+ with self.assertRaises(dns.exception.Timeout):
+ dns.query._wait_for(l, True, True, True, 0)
+ # simple timeout
+ with self.assertRaises(dns.exception.Timeout):
+ dns.query._wait_for(l, False, False, False, time.time() + 0.05)
+ # writable no timeout (not hanging is passing)
+ dns.query._wait_for(l, False, True, False, None)
+ finally:
+ l.close()
+ r.close()
+
+
+class MiscTests(unittest.TestCase):
+ def test_matches_destination(self):
+ self.assertTrue(
+ dns.query._matches_destination(
+ socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1234), True
+ )
+ )
+ self.assertTrue(
+ dns.query._matches_destination(
+ socket.AF_INET6, ("1::2", 1234), ("0001::2", 1234), True
+ )
+ )
+ self.assertTrue(
+ dns.query._matches_destination(
+ socket.AF_INET, ("10.0.0.1", 1234), None, True
+ )
+ )
+ self.assertFalse(
+ dns.query._matches_destination(
+ socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.2", 1234), True
+ )
+ )
+ self.assertFalse(
+ dns.query._matches_destination(
+ socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), True
+ )
+ )
+ with self.assertRaises(dns.query.UnexpectedSource):
+ dns.query._matches_destination(
+ socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False
+ )
+
+
+@contextlib.contextmanager
+def mock_udp_recv(wire1, from1, wire2, from2):
+ data = {
+ "saved": dns.query._udp_recv,
+ "first_time": True,
+ }
+
+ def mock(sock, max_size, expiration):
+ if data["first_time"]:
+ data["first_time"] = False
+ return wire1, from1
+ else:
+ return wire2, from2
+
+ try:
+ dns.query._udp_recv = mock
+ yield None
+ finally:
+ dns.query._udp_recv = data["saved"]
+
+
+class MockSock:
+ def __init__(self):
+ self.family = socket.AF_INET
+
+ def sendto(self, data, where):
+ return len(data)
+
+ def close(self):
+ pass
+
+ def setblocking(self, *args):
+ pass
+
+
+class IgnoreErrors(unittest.TestCase):
+ def setUp(self):
+ self.q = dns.message.make_query("example.", "A")
+ self.good_r = dns.message.make_response(self.q)
+ self.good_r.set_rcode(dns.rcode.NXDOMAIN)
+ self.good_r_wire = self.good_r.to_wire()
+
+ def mock_receive(
+ self,
+ wire1,
+ from1,
+ wire2,
+ from2,
+ ignore_unexpected=True,
+ ignore_errors=True,
+ good_r=None,
+ ):
+ if good_r is None:
+ good_r = self.good_r
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ try:
+ with mock_udp_recv(wire1, from1, wire2, from2):
+ (r, when) = dns.query.receive_udp(
+ s,
+ ("127.0.0.1", 53),
+ time.time() + 2,
+ ignore_unexpected=ignore_unexpected,
+ ignore_errors=ignore_errors,
+ query=self.q,
+ )
+ self.assertEqual(r, good_r)
+ finally:
+ s.close()
+
+ def test_good_mock(self):
+ self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)
+
+ def test_bad_address(self):
+ self.mock_receive(
+ self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ def test_bad_address_not_ignored(self):
+ def bad():
+ self.mock_receive(
+ self.good_r_wire,
+ ("127.0.0.2", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ ignore_unexpected=False,
+ )
+
+ self.assertRaises(dns.query.UnexpectedSource, bad)
+
+ def test_bad_id(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+ self.mock_receive(
+ bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ def test_bad_id_not_ignored(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+
+ def bad():
+ (r, wire) = self.mock_receive(
+ bad_r_wire,
+ ("127.0.0.1", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ ignore_errors=False,
+ )
+
+ self.assertRaises(AssertionError, bad)
+
+ def test_not_response_not_ignored_udp_level(self):
+ def bad():
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+ with mock_udp_recv(
+ bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+ ):
+ s = MockSock()
+ dns.query.udp(self.good_r, "127.0.0.1", sock=s)
+
+ self.assertRaises(dns.query.BadResponse, bad)
+
+ def test_bad_wire(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+ self.mock_receive(
+ bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
+ tc_r = dns.message.make_response(self.q)
+ tc_r.flags |= dns.flags.TC
+ tc_r_wire = tc_r.to_wire()
+ self.mock_receive(tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r)
+
+ def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r.flags |= dns.flags.TC
+ bad_r_wire = bad_r.to_wire()
+ self.mock_receive(
+ bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ def test_bad_wire_not_ignored(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r_wire = bad_r.to_wire()
+
+ def bad():
+ self.mock_receive(
+ bad_r_wire[:10],
+ ("127.0.0.1", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ ignore_errors=False,
+ )
+
+ self.assertRaises(dns.message.ShortHeader, bad)
+
+ def test_trailing_wire(self):
+ wire = self.good_r_wire + b"abcd"
+ self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53))
+
+ def test_trailing_wire_not_ignored(self):
+ wire = self.good_r_wire + b"abcd"
+
+ def bad():
+ self.mock_receive(
+ wire,
+ ("127.0.0.1", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ ignore_errors=False,
+ )
+
+ self.assertRaises(dns.message.TrailingJunk, bad)
Index: dnspython-1.12.0/dns/inet.py
===================================================================
--- dnspython-1.12.0.orig/dns/inet.py
+++ dnspython-1.12.0/dns/inet.py
@@ -101,11 +101,11 @@ def is_multicast(text):
@rtype: bool
"""
try:
- first = ord(dns.ipv4.inet_aton(text)[0])
+ first = dns.ipv4.inet_aton(text)[0]
return (first >= 224 and first <= 239)
except:
try:
- first = ord(dns.ipv6.inet_aton(text)[0])
+ first = dns.ipv6.inet_aton(text)[0]
return (first == 255)
except:
raise ValueError