File CVE-2023-29483.patch of Package python-dnspython.34852

From 093c593624bcf55766c2a952c207e0b92920214e Mon Sep 17 00:00:00 2001
From: Bob Halley <halley@dnspython.org>
Date: Fri, 9 Feb 2024 10:36:08 -0800
Subject: [PATCH] Address DoS via the Tudoor mechanism (CVE-2023-29483)

---
 dns/asyncquery.py |  45 +++++++++++++------
 dns/nameserver.py |   2 +
 dns/query.py      | 110 +++++++++++++++++++++++++++++-----------------
 3 files changed, 103 insertions(+), 54 deletions(-)

Index: dnspython-1.12.0/dns/query.py
===================================================================
--- dnspython-1.12.0.orig/dns/query.py
+++ dnspython-1.12.0/dns/query.py
@@ -155,18 +155,38 @@ def _wait_for_writable(s, expiration):
     _wait_for(s, False, True, True, expiration)
 
 
+def _matches_destination(af, from_address, destination, ignore_unexpected):
+    # Check that from_address is appropriate for a response to a query
+    # sent to destination.
+    if not destination:
+        return True
+    if _addresses_equal(af, from_address, destination) or (
+        dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:]
+    ):
+        return True
+    elif ignore_unexpected:
+        return False
+    raise UnexpectedSource(
+        "got a response from %s instead of %s" % (from_address, destination)
+    )
+
+
 def _addresses_equal(af, a1, a2):
     # Convert the first value of the tuple, which is a textual format
     # address into binary form, so that we are not confused by different
     # textual representations of the same address
-    n1 = dns.inet.inet_pton(af, a1[0])
-    n2 = dns.inet.inet_pton(af, a2[0])
+    try:
+        n1 = dns.inet.inet_pton(af, a1[0])
+        n2 = dns.inet.inet_pton(af, a2[0])
+    except dns.exception.SyntaxError:
+        return False
     return n1 == n2 and a1[1:] == a2[1:]
 
 
 def _destination_and_source(af, where, port, source, source_port):
     # Apply defaults and compute destination and source tuples
     # suitable for use in connect(), sendto(), or bind().
+    destination = None
     if af is None:
         try:
             af = dns.inet.af_for_address(where)
@@ -187,6 +207,33 @@ def _destination_and_source(af, where, p
     return (af, destination, source)
 
 
+def _udp_recv(sock, max_size, expiration):
+    """Reads a datagram from the socket.
+    A Timeout exception will be raised if the operation is not completed
+    by the expiration time.
+    """
+    while True:
+        try:
+            return sock.recvfrom(max_size)
+        except BlockingIOError:
+            _wait_for_readable(sock, expiration)
+
+
+def _udp_send(sock, data, destination, expiration):
+    """Sends the specified datagram to destination over the socket.
+    A Timeout exception will be raised if the operation is not completed
+    by the expiration time.
+    """
+    while True:
+        try:
+            if destination:
+                return sock.sendto(data, destination)
+            else:
+                return sock.send(data)
+        except BlockingIOError:  # pragma: no cover
+            _wait_for_writable(sock, expiration)
+
+
 def send_udp(sock, what, destination, expiration=None):
     """Send a DNS message to the specified UDP socket.
 
@@ -203,15 +250,15 @@ def send_udp(sock, what, destination, ex
     """
     if isinstance(what, dns.message.Message):
         what = what.to_wire()
-    _wait_for_writable(sock, expiration)
     sent_time = time.time()
-    n = sock.sendto(what, destination)
+    n = _udp_send(sock, what, destination, expiration)
     return (n, sent_time)
 
 
 def receive_udp(sock, destination, expiration=None, af=None,
                 ignore_unexpected=False, one_rr_per_rrset=False,
-                keyring=None, request_mac=b''):
+                keyring=None, request_mac=b'',
+                ignore_errors=False, query=None):
     """Read a DNS message from a UDP socket.
 
     @param sock: the socket
@@ -229,6 +276,14 @@ def receive_udp(sock, destination, expir
     @type keyring: keyring dict
     @param request_mac: the MAC of the request (for TSIG)
     @type request_mac: bytes
+    @param ignore_errors: If various format errors or response
+    mismatches occur, ignore them and keep listening for a valid response.
+    The default is ``False``.
+    @type ignore_errors: bool
+    @param query: If not ``None`` and *ignore_errors* is ``True``,
+    check that the received message is a response to this query, and
+    if not keep listening for a valid response.
+    @type query: dns.message.Message or None
     @rtype: dns.message.Message object
     """
     if af is None:
@@ -238,23 +293,32 @@ def receive_udp(sock, destination, expir
             af = dns.inet.AF_INET
     wire = b''
     while 1:
-        _wait_for_readable(sock, expiration)
-        (wire, from_address) = sock.recvfrom(65535)
-        if _addresses_equal(af, from_address, destination) or \
-           (dns.inet.is_multicast(destination[0]) and
-            from_address[1:] == destination[1:]):
-            break
-        if not ignore_unexpected:
-            raise UnexpectedSource('got a response from '
-                                   '%s instead of %s' % (from_address,
-                                                         destination))
-    received_time = time.time()
-    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
-                              one_rr_per_rrset=one_rr_per_rrset)
-    return (r, received_time)
+        (wire, from_address) = _udp_recv(sock, 65535, expiration)
+        if not _matches_destination(
+             af, from_address, destination, ignore_unexpected
+         ):
+            continue
+
+        received_time = time.time()
+        try:
+            r = dns.message.from_wire(
+                wire,
+                keyring=keyring,
+                request_mac=request_mac,
+                one_rr_per_rrset=one_rr_per_rrset,
+            )
+        except Exception:
+            if ignore_errors:
+                continue
+            else:
+                raise
+        if ignore_errors and query is not None and not query.is_response(r):
+            continue
+        return (r, received_time)
 
 def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
-        ignore_unexpected=False, one_rr_per_rrset=False):
+        ignore_unexpected=False, one_rr_per_rrset=False,
+        ignore_errors=False, sock=None):
     """Return the response obtained after sending a query via UDP.
 
     @param q: the query
@@ -280,13 +344,20 @@ def udp(q, where, timeout=None, port=53,
     @type ignore_unexpected: bool
     @param one_rr_per_rrset: Put each RR into its own RRset
     @type one_rr_per_rrset: bool
+    @param ignore_errors: If various format errors or response
+    mismatches occur, ignore them and keep listening for a valid
+    response. The default is ``False``.
+    @type ignore_errors: bool
     @rtype: dns.message.Message object
     """
 
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
-    s = socket.socket(af, socket.SOCK_DGRAM, 0)
+    if sock:
+        s = sock
+    else:
+        s = socket.socket(af, socket.SOCK_DGRAM, 0)
     received_time = None
     sent_time = None
     try:
@@ -298,10 +369,13 @@ def udp(q, where, timeout=None, port=53,
         (_, sent_time) = send_udp(s, wire, destination, expiration)
         (r, received_time) = receive_udp(s, destination, expiration, af,
                                          ignore_unexpected, one_rr_per_rrset,
-                                         q.keyring, q.request_mac)
+                                         q.keyring, q.request_mac,
+                                         ignore_errors, q)
     finally:
         s.close()
-    if not q.is_response(r):
+    # We don't need to check q.is_response() if we are in ignore_errors mode
+    # as receive_udp() will have checked it.
+    if not (ignore_errors or q.is_response(r)):
         raise BadResponse
     return r
 
@@ -520,8 +594,7 @@ def xfr(where, zone, rdtype=dns.rdatatyp
     _connect(s, destination)
     l = len(wire)
     if use_udp:
-        _wait_for_writable(s, expiration)
-        s.send(wire)
+        _udp_send(s, wire, None, expiration)
     else:
         tcpmsg = struct.pack("!H", l) + wire
         _net_write(s, tcpmsg, expiration)
@@ -542,8 +615,7 @@ def xfr(where, zone, rdtype=dns.rdatatyp
         if mexpiration is None or mexpiration > expiration:
             mexpiration = expiration
         if use_udp:
-            _wait_for_readable(s, expiration)
-            (wire, from_address) = s.recvfrom(65535)
+            (wire, from_address) = _udp_recv(s, 65535, expiration)
         else:
             ldata = _net_read(s, 2, mexpiration)
             (l,) = struct.unpack("!H", ldata)
Index: dnspython-1.12.0/dns/resolver.py
===================================================================
--- dnspython-1.12.0.orig/dns/resolver.py
+++ dnspython-1.12.0/dns/resolver.py
@@ -850,7 +850,9 @@ class Resolver(object):
                             response = dns.query.udp(request, nameserver,
                                                      timeout, self.port,
                                                      source=source,
-                                                     source_port=source_port)
+                                                     source_port=source_port,
+                                                     ignore_errors=True,
+                                                     ignore_unexpected=True)
                             if response.flags & dns.flags.TC:
                                 # Response truncated; retry with TCP.
                                 timeout = self._compute_timeout(start)
Index: dnspython-1.12.0/tests/test_query.py
===================================================================
--- /dev/null
+++ dnspython-1.12.0/tests/test_query.py
@@ -0,0 +1,758 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import contextlib
+import socket
+import sys
+import time
+import unittest
+
+try:
+    import ssl
+
+    have_ssl = True
+except Exception:
+    have_ssl = False
+
+import dns.exception
+import dns.flags
+import dns.inet
+import dns.message
+import dns.name
+import dns.query
+import dns.rcode
+import dns.rdataclass
+import dns.rdatatype
+import dns.tsigkeyring
+import dns.zone
+
+_nanonameserver_available = False
+
+class Server(object):
+    pass
+
+
+query_addresses = []
+
+keyring = dns.tsigkeyring.from_text({"name": b"tDz6cfXXGtNivRpQ98hr6A=="})
+
+
+@unittest.skipIf(True, "Internet not reachable")
+class QueryTests(unittest.TestCase):
+    def testQueryUDP(self):
+        for address in query_addresses:
+            qname = dns.name.from_text("dns.google.")
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            response = dns.query.udp(q, address, timeout=2)
+            rrs = response.get_rrset(
+                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+            )
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue("8.8.8.8" in seen)
+            self.assertTrue("8.8.4.4" in seen)
+
+    def testQueryUDPWithSocket(self):
+        for address in query_addresses:
+            with socket.socket(
+                dns.inet.af_for_address(address), socket.SOCK_DGRAM
+            ) as s:
+                s.setblocking(0)
+                qname = dns.name.from_text("dns.google.")
+                q = dns.message.make_query(qname, dns.rdatatype.A)
+                response = dns.query.udp(q, address, sock=s, timeout=2)
+                rrs = response.get_rrset(
+                    response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+                )
+                self.assertTrue(rrs is not None)
+                seen = set([rdata.address for rdata in rrs])
+                self.assertTrue("8.8.8.8" in seen)
+                self.assertTrue("8.8.4.4" in seen)
+
+    def testQueryTCP(self):
+        for address in query_addresses:
+            qname = dns.name.from_text("dns.google.")
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            response = dns.query.tcp(q, address, timeout=2)
+            rrs = response.get_rrset(
+                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+            )
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue("8.8.8.8" in seen)
+            self.assertTrue("8.8.4.4" in seen)
+
+    def testQueryTCPWithSocket(self):
+        for address in query_addresses:
+            with socket.socket(
+                dns.inet.af_for_address(address), socket.SOCK_STREAM
+            ) as s:
+                ll = dns.inet.low_level_address_tuple((address, 53))
+                s.settimeout(2)
+                s.connect(ll)
+                s.setblocking(0)
+                qname = dns.name.from_text("dns.google.")
+                q = dns.message.make_query(qname, dns.rdatatype.A)
+                response = dns.query.tcp(q, None, sock=s, timeout=2)
+                rrs = response.get_rrset(
+                    response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+                )
+                self.assertTrue(rrs is not None)
+                seen = set([rdata.address for rdata in rrs])
+                self.assertTrue("8.8.8.8" in seen)
+                self.assertTrue("8.8.4.4" in seen)
+
+    @unittest.skipUnless(have_ssl, "No SSL support")
+    def testQueryTLS(self):
+        for address in query_addresses:
+            qname = dns.name.from_text("dns.google.")
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            response = dns.query.tls(q, address, timeout=2)
+            rrs = response.get_rrset(
+                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+            )
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue("8.8.8.8" in seen)
+            self.assertTrue("8.8.4.4" in seen)
+
+    @unittest.skipUnless(have_ssl, "No SSL support")
+    def testQueryTLSWithContext(self):
+        for address in query_addresses:
+            qname = dns.name.from_text("dns.google.")
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            ssl_context = ssl.create_default_context()
+            ssl_context.check_hostname = False
+            response = dns.query.tls(q, address, timeout=2, ssl_context=ssl_context)
+            rrs = response.get_rrset(
+                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+            )
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue("8.8.8.8" in seen)
+            self.assertTrue("8.8.4.4" in seen)
+
+    @unittest.skipUnless(have_ssl, "No SSL support")
+    def testQueryTLSWithSocket(self):
+        for address in query_addresses:
+            with socket.socket(
+                dns.inet.af_for_address(address), socket.SOCK_STREAM
+            ) as base_s:
+                ll = dns.inet.low_level_address_tuple((address, 853))
+                base_s.settimeout(2)
+                base_s.connect(ll)
+                ctx = ssl.create_default_context()
+                ctx.minimum_version = ssl.TLSVersion.TLSv1_2
+                with ctx.wrap_socket(
+                    base_s, server_hostname="dns.google"
+                ) as s:  # lgtm[py/insecure-protocol]
+                    s.setblocking(0)
+                    qname = dns.name.from_text("dns.google.")
+                    q = dns.message.make_query(qname, dns.rdatatype.A)
+                    response = dns.query.tls(q, None, sock=s, timeout=2)
+                    rrs = response.get_rrset(
+                        response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+                    )
+                    self.assertTrue(rrs is not None)
+                    seen = set([rdata.address for rdata in rrs])
+                    self.assertTrue("8.8.8.8" in seen)
+                    self.assertTrue("8.8.4.4" in seen)
+
+    @unittest.skipUnless(have_ssl, "No SSL support")
+    def testQueryTLSwithPadding(self):
+        for address in query_addresses:
+            qname = dns.name.from_text("dns.google.")
+            q = dns.message.make_query(qname, dns.rdatatype.A, use_edns=0, pad=128)
+            response = dns.query.tls(q, address, timeout=2)
+            rrs = response.get_rrset(
+                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+            )
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue("8.8.8.8" in seen)
+            self.assertTrue("8.8.4.4" in seen)
+            # the response should have a padding option
+            self.assertIsNotNone(response.opt)
+            has_pad = False
+            for o in response.opt[0].options:
+                if o.otype == dns.edns.OptionType.PADDING:
+                    has_pad = True
+            self.assertTrue(has_pad)
+
+    def testQueryUDPFallback(self):
+        for address in query_addresses:
+            qname = dns.name.from_text(".")
+            q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+            (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=4)
+            self.assertTrue(tcp)
+
+    def testQueryUDPFallbackWithSocket(self):
+        for address in query_addresses:
+            af = dns.inet.af_for_address(address)
+            with socket.socket(af, socket.SOCK_DGRAM) as udp_s:
+                udp_s.setblocking(0)
+                with socket.socket(af, socket.SOCK_STREAM) as tcp_s:
+                    ll = dns.inet.low_level_address_tuple((address, 53))
+                    tcp_s.settimeout(2)
+                    tcp_s.connect(ll)
+                    tcp_s.setblocking(0)
+                    qname = dns.name.from_text(".")
+                    q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+                    (_, tcp) = dns.query.udp_with_fallback(
+                        q, address, udp_sock=udp_s, tcp_sock=tcp_s, timeout=4
+                    )
+                    self.assertTrue(tcp)
+
+    def testQueryUDPFallbackNoFallback(self):
+        for address in query_addresses:
+            qname = dns.name.from_text("dns.google.")
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2)
+            self.assertFalse(tcp)
+
+    def testUDPReceiveQuery(self):
+        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener:
+            listener.bind(("127.0.0.1", 0))
+            with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender:
+                sender.bind(("127.0.0.1", 0))
+                q = dns.message.make_query("dns.google", dns.rdatatype.A)
+                dns.query.send_udp(sender, q, listener.getsockname())
+                expiration = time.time() + 2
+                (q, _, addr) = dns.query.receive_udp(listener, expiration=expiration)
+                self.assertEqual(addr, sender.getsockname())
+
+
+axfr_zone = """
+$TTL 300
+@ SOA ns1 root 1 7200 900 1209600 86400
+@ NS ns1
+@ NS ns2
+ns1 A 10.0.0.1
+ns2 A 10.0.0.1
+"""
+
+
+class AXFRNanoNameserver(Server):
+    def handle(self, request):
+        self.zone = dns.zone.from_text(axfr_zone, origin=self.origin)
+        self.origin = self.zone.origin
+        items = []
+        soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA)
+        response = dns.message.make_response(request.message)
+        response.flags |= dns.flags.AA
+        response.answer.append(soa)
+        items.append(response)
+        response = dns.message.make_response(request.message)
+        response.question = []
+        response.flags |= dns.flags.AA
+        for name, rdataset in self.zone.iterate_rdatasets():
+            if rdataset.rdtype == dns.rdatatype.SOA and name == dns.name.empty:
+                continue
+            rrset = dns.rrset.RRset(
+                name, rdataset.rdclass, rdataset.rdtype, rdataset.covers
+            )
+            rrset.update(rdataset)
+            response.answer.append(rrset)
+        items.append(response)
+        response = dns.message.make_response(request.message)
+        response.question = []
+        response.flags |= dns.flags.AA
+        response.answer.append(soa)
+        items.append(response)
+        return items
+
+
+ixfr_message = """id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN IXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
+deleted.example. 300 IN A 10.0.0.1
+changed.example. 300 IN A 10.0.0.2
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+changed.example. 300 IN A 10.0.0.4
+added.example. 300 IN A 10.0.0.3
+example. 300 SOA ns1.example. root.example. 3 7200 900 1209600 86400
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+added2.example. 300 IN A 10.0.0.5
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+"""
+
+ixfr_trailing_junk = ixfr_message + "junk.example. 300 IN A 10.0.0.6"
+
+ixfr_up_to_date_message = """id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN IXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
+"""
+
+axfr_trailing_junk = """id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN AXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+added.example. 300 IN A 10.0.0.3
+added2.example. 300 IN A 10.0.0.5
+changed.example. 300 IN A 10.0.0.4
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+junk.example. 300 IN A 10.0.0.6
+"""
+
+
+class IXFRNanoNameserver(Server):
+    def __init__(self, response_text):
+        super().__init__()
+        self.response_text = response_text
+
+    def handle(self, request):
+        try:
+            r = dns.message.from_text(self.response_text, one_rr_per_rrset=True)
+            r.id = request.message.id
+            return r
+        except Exception:
+            pass
+
+
+@unittest.skipIf(not _nanonameserver_available, "nanonameserver required")
+class XfrTests(unittest.TestCase):
+    def test_axfr(self):
+        expected = dns.zone.from_text(axfr_zone, origin="example")
+        with AXFRNanoNameserver(origin="example") as ns:
+            xfr = dns.query.xfr(ns.tcp_address[0], "example", port=ns.tcp_address[1])
+            zone = dns.zone.from_xfr(xfr)
+            self.assertEqual(zone, expected)
+
+    def test_axfr_tsig(self):
+        expected = dns.zone.from_text(axfr_zone, origin="example")
+        with AXFRNanoNameserver(origin="example", keyring=keyring) as ns:
+            xfr = dns.query.xfr(
+                ns.tcp_address[0],
+                "example",
+                port=ns.tcp_address[1],
+                keyring=keyring,
+                keyname="name",
+            )
+            zone = dns.zone.from_xfr(xfr)
+            self.assertEqual(zone, expected)
+
+    def test_axfr_root_tsig(self):
+        expected = dns.zone.from_text(axfr_zone, origin=".")
+        with AXFRNanoNameserver(origin=".", keyring=keyring) as ns:
+            xfr = dns.query.xfr(
+                ns.tcp_address[0],
+                ".",
+                port=ns.tcp_address[1],
+                keyring=keyring,
+                keyname="name",
+            )
+            zone = dns.zone.from_xfr(xfr)
+            self.assertEqual(zone, expected)
+
+    def test_axfr_udp(self):
+        def bad():
+            with AXFRNanoNameserver(origin="example") as ns:
+                xfr = dns.query.xfr(
+                    ns.udp_address[0], "example", port=ns.udp_address[1], use_udp=True
+                )
+                l = list(xfr)
+
+        self.assertRaises(ValueError, bad)
+
+    def test_axfr_bad_rcode(self):
+        def bad():
+            # We just use Server here as by default it will refuse.
+            with Server() as ns:
+                xfr = dns.query.xfr(
+                    ns.tcp_address[0], "example", port=ns.tcp_address[1]
+                )
+                l = list(xfr)
+
+        self.assertRaises(dns.query.TransferError, bad)
+
+    def test_axfr_trailing_junk(self):
+        # we use the IXFR server here as it returns messages
+        def bad():
+            with IXFRNanoNameserver(axfr_trailing_junk) as ns:
+                xfr = dns.query.xfr(
+                    ns.tcp_address[0],
+                    "example",
+                    dns.rdatatype.AXFR,
+                    port=ns.tcp_address[1],
+                )
+                l = list(xfr)
+
+        self.assertRaises(dns.exception.FormError, bad)
+
+    def test_ixfr_tcp(self):
+        with IXFRNanoNameserver(ixfr_message) as ns:
+            xfr = dns.query.xfr(
+                ns.tcp_address[0],
+                "example",
+                dns.rdatatype.IXFR,
+                port=ns.tcp_address[1],
+                serial=2,
+                relativize=False,
+            )
+            l = list(xfr)
+            self.assertEqual(len(l), 1)
+            expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True)
+            expected.id = l[0].id
+            self.assertEqual(l[0], expected)
+
+    def test_ixfr_udp(self):
+        with IXFRNanoNameserver(ixfr_message) as ns:
+            xfr = dns.query.xfr(
+                ns.udp_address[0],
+                "example",
+                dns.rdatatype.IXFR,
+                port=ns.udp_address[1],
+                serial=2,
+                relativize=False,
+                use_udp=True,
+            )
+            l = list(xfr)
+            self.assertEqual(len(l), 1)
+            expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True)
+            expected.id = l[0].id
+            self.assertEqual(l[0], expected)
+
+    def test_ixfr_up_to_date(self):
+        with IXFRNanoNameserver(ixfr_up_to_date_message) as ns:
+            xfr = dns.query.xfr(
+                ns.tcp_address[0],
+                "example",
+                dns.rdatatype.IXFR,
+                port=ns.tcp_address[1],
+                serial=2,
+                relativize=False,
+            )
+            l = list(xfr)
+            self.assertEqual(len(l), 1)
+            expected = dns.message.from_text(
+                ixfr_up_to_date_message, one_rr_per_rrset=True
+            )
+            expected.id = l[0].id
+            self.assertEqual(l[0], expected)
+
+    def test_ixfr_trailing_junk(self):
+        def bad():
+            with IXFRNanoNameserver(ixfr_trailing_junk) as ns:
+                xfr = dns.query.xfr(
+                    ns.tcp_address[0],
+                    "example",
+                    dns.rdatatype.IXFR,
+                    port=ns.tcp_address[1],
+                    serial=2,
+                    relativize=False,
+                )
+                l = list(xfr)
+
+        self.assertRaises(dns.exception.FormError, bad)
+
+    def test_ixfr_base_serial_mismatch(self):
+        def bad():
+            with IXFRNanoNameserver(ixfr_message) as ns:
+                xfr = dns.query.xfr(
+                    ns.tcp_address[0],
+                    "example",
+                    dns.rdatatype.IXFR,
+                    port=ns.tcp_address[1],
+                    serial=1,
+                    relativize=False,
+                )
+                l = list(xfr)
+
+        self.assertRaises(dns.exception.FormError, bad)
+
+
+class TSIGNanoNameserver(Server):
+    def handle(self, request):
+        response = dns.message.make_response(request.message)
+        response.set_rcode(dns.rcode.REFUSED)
+        response.flags |= dns.flags.RA
+        try:
+            if request.qtype == dns.rdatatype.A and request.qclass == dns.rdataclass.IN:
+                rrs = dns.rrset.from_text(request.qname, 300, "IN", "A", "1.2.3.4")
+                response.answer.append(rrs)
+                response.set_rcode(dns.rcode.NOERROR)
+                response.flags |= dns.flags.AA
+        except Exception:
+            pass
+        return response
+
+
+@unittest.skipIf(not _nanonameserver_available, "nanonameserver required")
+class TsigTests(unittest.TestCase):
+    def test_tsig(self):
+        with TSIGNanoNameserver(keyring=keyring) as ns:
+            qname = dns.name.from_text("example.com")
+            q = dns.message.make_query(qname, "A")
+            q.use_tsig(keyring=keyring, keyname="name")
+            response = dns.query.udp(q, ns.udp_address[0], port=ns.udp_address[1])
+            self.assertTrue(response.had_tsig)
+            rrs = response.get_rrset(
+                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+            )
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue("1.2.3.4" in seen)
+
+
+@unittest.skipIf(sys.platform == "win32", "low level tests do not work on win32")
+class LowLevelWaitTests(unittest.TestCase):
+    def test_wait_for(self):
+        try:
+            (l, r) = socket.socketpair()
+            # already expired
+            with self.assertRaises(dns.exception.Timeout):
+                dns.query._wait_for(l, True, True, True, 0)
+            # simple timeout
+            with self.assertRaises(dns.exception.Timeout):
+                dns.query._wait_for(l, False, False, False, time.time() + 0.05)
+            # writable no timeout (not hanging is passing)
+            dns.query._wait_for(l, False, True, False, None)
+        finally:
+            l.close()
+            r.close()
+
+
+class MiscTests(unittest.TestCase):
+    def test_matches_destination(self):
+        self.assertTrue(
+            dns.query._matches_destination(
+                socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1234), True
+            )
+        )
+        self.assertTrue(
+            dns.query._matches_destination(
+                socket.AF_INET6, ("1::2", 1234), ("0001::2", 1234), True
+            )
+        )
+        self.assertTrue(
+            dns.query._matches_destination(
+                socket.AF_INET, ("10.0.0.1", 1234), None, True
+            )
+        )
+        self.assertFalse(
+            dns.query._matches_destination(
+                socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.2", 1234), True
+            )
+        )
+        self.assertFalse(
+            dns.query._matches_destination(
+                socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), True
+            )
+        )
+        with self.assertRaises(dns.query.UnexpectedSource):
+            dns.query._matches_destination(
+                socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False
+            )
+
+
+@contextlib.contextmanager
+def mock_udp_recv(wire1, from1, wire2, from2):
+    data = {
+        "saved": dns.query._udp_recv,
+        "first_time": True,
+    }
+
+    def mock(sock, max_size, expiration):
+        if data["first_time"]:
+            data["first_time"] = False
+            return wire1, from1
+        else:
+            return wire2, from2
+
+    try:
+        dns.query._udp_recv = mock
+        yield None
+    finally:
+        dns.query._udp_recv = data["saved"]
+
+
+class MockSock:
+    def __init__(self):
+        self.family = socket.AF_INET
+
+    def sendto(self, data, where):
+        return len(data)
+
+    def close(self):
+        pass
+
+    def setblocking(self, *args):
+        pass
+
+
+class IgnoreErrors(unittest.TestCase):
+    def setUp(self):
+        self.q = dns.message.make_query("example.", "A")
+        self.good_r = dns.message.make_response(self.q)
+        self.good_r.set_rcode(dns.rcode.NXDOMAIN)
+        self.good_r_wire = self.good_r.to_wire()
+
+    def mock_receive(
+        self,
+        wire1,
+        from1,
+        wire2,
+        from2,
+        ignore_unexpected=True,
+        ignore_errors=True,
+        good_r=None,
+    ):
+        if good_r is None:
+            good_r = self.good_r
+        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        try:
+            with mock_udp_recv(wire1, from1, wire2, from2):
+                (r, when) = dns.query.receive_udp(
+                    s,
+                    ("127.0.0.1", 53),
+                    time.time() + 2,
+                    ignore_unexpected=ignore_unexpected,
+                    ignore_errors=ignore_errors,
+                    query=self.q,
+                )
+                self.assertEqual(r, good_r)
+        finally:
+            s.close()
+
+    def test_good_mock(self):
+        self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)
+
+    def test_bad_address(self):
+        self.mock_receive(
+            self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
+        )
+
+    def test_bad_address_not_ignored(self):
+        def bad():
+            self.mock_receive(
+                self.good_r_wire,
+                ("127.0.0.2", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_unexpected=False,
+            )
+
+        self.assertRaises(dns.query.UnexpectedSource, bad)
+
+    def test_bad_id(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+        self.mock_receive(
+            bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+        )
+
+    def test_bad_id_not_ignored(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+
+        def bad():
+            (r, wire) = self.mock_receive(
+                bad_r_wire,
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        self.assertRaises(AssertionError, bad)
+
+    def test_not_response_not_ignored_udp_level(self):
+        def bad():
+            bad_r = dns.message.make_response(self.q)
+            bad_r.id += 1
+            bad_r_wire = bad_r.to_wire()
+            with mock_udp_recv(
+                bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+            ):
+                s = MockSock()
+                dns.query.udp(self.good_r, "127.0.0.1", sock=s)
+
+        self.assertRaises(dns.query.BadResponse, bad)
+
+    def test_bad_wire(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+        self.mock_receive(
+            bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+        )
+
+    def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
+        tc_r = dns.message.make_response(self.q)
+        tc_r.flags |= dns.flags.TC
+        tc_r_wire = tc_r.to_wire()
+        self.mock_receive(tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r)
+
+    def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r.flags |= dns.flags.TC
+        bad_r_wire = bad_r.to_wire()
+        self.mock_receive(
+            bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+        )
+
+    def test_bad_wire_not_ignored(self):
+        bad_r = dns.message.make_response(self.q)
+        bad_r.id += 1
+        bad_r_wire = bad_r.to_wire()
+
+        def bad():
+            self.mock_receive(
+                bad_r_wire[:10],
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        self.assertRaises(dns.message.ShortHeader, bad)
+
+    def test_trailing_wire(self):
+        wire = self.good_r_wire + b"abcd"
+        self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53))
+
+    def test_trailing_wire_not_ignored(self):
+        wire = self.good_r_wire + b"abcd"
+
+        def bad():
+            self.mock_receive(
+                wire,
+                ("127.0.0.1", 53),
+                self.good_r_wire,
+                ("127.0.0.1", 53),
+                ignore_errors=False,
+            )
+
+        self.assertRaises(dns.message.TrailingJunk, bad)
Index: dnspython-1.12.0/dns/inet.py
===================================================================
--- dnspython-1.12.0.orig/dns/inet.py
+++ dnspython-1.12.0/dns/inet.py
@@ -101,11 +101,11 @@ def is_multicast(text):
     @rtype: bool
     """
     try:
-        first = ord(dns.ipv4.inet_aton(text)[0])
+        first = dns.ipv4.inet_aton(text)[0]
         return (first >= 224 and first <= 239)
     except:
         try:
-            first = ord(dns.ipv6.inet_aton(text)[0])
+            first = dns.ipv6.inet_aton(text)[0]
             return (first == 255)
         except:
             raise ValueError
openSUSE Build Service is sponsored by