Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
dnspython / tests / test_query.py
Size: Mime:
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import socket
import sys
import time
import unittest

try:
    import ssl
    have_ssl = True
except Exception:
    have_ssl = False

import dns.exception
import dns.inet
import dns.message
import dns.name
import dns.rdataclass
import dns.rdatatype
import dns.query
import dns.tsigkeyring
import dns.zone

# Some tests require the internet to be available to run, so let's
# skip those if it's not there.
_network_available = True
try:
    socket.gethostbyname('dnspython.org')
except socket.gaierror:
    _network_available = False

# Some tests use a "nano nameserver" for testing.  It requires trio
# and threading, so try to import it and if it doesn't work, skip
# those tests.
try:
    from .nanonameserver import Server
    _nanonameserver_available = True
except ImportError:
    _nanonameserver_available = False
    class Server(object):
        pass

# Probe for IPv4 and IPv6
query_addresses = []
for (af, address) in ((socket.AF_INET, '8.8.8.8'),
                      (socket.AF_INET6, '2001:4860:4860::8888')):
    try:
        with socket.socket(af, socket.SOCK_DGRAM) as s:
            # Connecting a UDP socket is supposed to return ENETUNREACH if
            # no route to the network is present.
            s.connect((address, 53))
        query_addresses.append(address)
    except Exception:
        pass

keyring = dns.tsigkeyring.from_text({'name': 'tDz6cfXXGtNivRpQ98hr6A=='})

@unittest.skipIf(not _network_available, "Internet not reachable")
class QueryTests(unittest.TestCase):

    def testQueryUDP(self):
        for address in query_addresses:
            qname = dns.name.from_text('dns.google.')
            q = dns.message.make_query(qname, dns.rdatatype.A)
            response = dns.query.udp(q, address, timeout=2)
            rrs = response.get_rrset(response.answer, qname,
                                     dns.rdataclass.IN, dns.rdatatype.A)
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue('8.8.8.8' in seen)
            self.assertTrue('8.8.4.4' in seen)

    def testQueryUDPWithSocket(self):
        for address in query_addresses:
            with socket.socket(dns.inet.af_for_address(address),
                               socket.SOCK_DGRAM) as s:
                s.setblocking(0)
                qname = dns.name.from_text('dns.google.')
                q = dns.message.make_query(qname, dns.rdatatype.A)
                response = dns.query.udp(q, address, sock=s, timeout=2)
                rrs = response.get_rrset(response.answer, qname,
                                         dns.rdataclass.IN, dns.rdatatype.A)
                self.assertTrue(rrs is not None)
                seen = set([rdata.address for rdata in rrs])
                self.assertTrue('8.8.8.8' in seen)
                self.assertTrue('8.8.4.4' in seen)

    def testQueryTCP(self):
        for address in query_addresses:
            qname = dns.name.from_text('dns.google.')
            q = dns.message.make_query(qname, dns.rdatatype.A)
            response = dns.query.tcp(q, address, timeout=2)
            rrs = response.get_rrset(response.answer, qname,
                                     dns.rdataclass.IN, dns.rdatatype.A)
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue('8.8.8.8' in seen)
            self.assertTrue('8.8.4.4' in seen)

    def testQueryTCPWithSocket(self):
        for address in query_addresses:
            with socket.socket(dns.inet.af_for_address(address),
                               socket.SOCK_STREAM) as s:
                ll = dns.inet.low_level_address_tuple((address, 53))
                s.settimeout(2)
                s.connect(ll)
                s.setblocking(0)
                qname = dns.name.from_text('dns.google.')
                q = dns.message.make_query(qname, dns.rdatatype.A)
                response = dns.query.tcp(q, None, sock=s, timeout=2)
                rrs = response.get_rrset(response.answer, qname,
                                         dns.rdataclass.IN, dns.rdatatype.A)
                self.assertTrue(rrs is not None)
                seen = set([rdata.address for rdata in rrs])
                self.assertTrue('8.8.8.8' in seen)
                self.assertTrue('8.8.4.4' in seen)

    def testQueryTLS(self):
        for address in query_addresses:
            qname = dns.name.from_text('dns.google.')
            q = dns.message.make_query(qname, dns.rdatatype.A)
            response = dns.query.tls(q, address, timeout=2)
            rrs = response.get_rrset(response.answer, qname,
                                     dns.rdataclass.IN, dns.rdatatype.A)
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue('8.8.8.8' in seen)
            self.assertTrue('8.8.4.4' in seen)

    @unittest.skipUnless(have_ssl, "No SSL support")
    def testQueryTLSWithSocket(self):
        for address in query_addresses:
            with socket.socket(dns.inet.af_for_address(address),
                               socket.SOCK_STREAM) as base_s:
                ll = dns.inet.low_level_address_tuple((address, 853))
                base_s.settimeout(2)
                base_s.connect(ll)
                ctx = ssl.create_default_context()
                with ctx.wrap_socket(base_s, server_hostname='dns.google') as s:
                    s.setblocking(0)
                    qname = dns.name.from_text('dns.google.')
                    q = dns.message.make_query(qname, dns.rdatatype.A)
                    response = dns.query.tls(q, None, sock=s, timeout=2)
                    rrs = response.get_rrset(response.answer, qname,
                                             dns.rdataclass.IN, dns.rdatatype.A)
                    self.assertTrue(rrs is not None)
                    seen = set([rdata.address for rdata in rrs])
                    self.assertTrue('8.8.8.8' in seen)
                    self.assertTrue('8.8.4.4' in seen)

    def testQueryUDPFallback(self):
        for address in query_addresses:
            qname = dns.name.from_text('.')
            q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
            (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2)
            self.assertTrue(tcp)

    def testQueryUDPFallbackWithSocket(self):
        for address in query_addresses:
            af = dns.inet.af_for_address(address)
            with socket.socket(af, socket.SOCK_DGRAM) as udp_s:
                udp_s.setblocking(0)
                with socket.socket(af, socket.SOCK_STREAM) as tcp_s:
                    ll = dns.inet.low_level_address_tuple((address, 53))
                    tcp_s.settimeout(2)
                    tcp_s.connect(ll)
                    tcp_s.setblocking(0)
                    qname = dns.name.from_text('.')
                    q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
                    (_, tcp) = dns.query.udp_with_fallback(q, address,
                                                           udp_sock=udp_s,
                                                           tcp_sock=tcp_s,
                                                           timeout=2)
                    self.assertTrue(tcp)

    def testQueryUDPFallbackNoFallback(self):
        for address in query_addresses:
            qname = dns.name.from_text('dns.google.')
            q = dns.message.make_query(qname, dns.rdatatype.A)
            (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2)
            self.assertFalse(tcp)

    def testUDPReceiveQuery(self):
        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener:
            listener.bind(('127.0.0.1', 0))
            with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender:
                sender.bind(('127.0.0.1', 0))
                q = dns.message.make_query('dns.google', dns.rdatatype.A)
                dns.query.send_udp(sender, q, listener.getsockname())
                expiration = time.time() + 2
                (q, _, addr) = dns.query.receive_udp(listener,
                                                     expiration=expiration)
                self.assertEqual(addr, sender.getsockname())


# for brevity
_d_and_s = dns.query._destination_and_source

class DestinationAndSourceTests(unittest.TestCase):

    def test_af_inferred_from_where(self):
        (af, d, s) = _d_and_s('1.2.3.4', 53, None, 0)
        self.assertEqual(af, socket.AF_INET)

    def test_af_inferred_from_where(self):
        (af, d, s) = _d_and_s('1::2', 53, None, 0)
        self.assertEqual(af, socket.AF_INET6)

    def test_af_inferred_from_source(self):
        (af, d, s) = _d_and_s('https://example/dns-query', 443,
                              '1.2.3.4', 0, False)
        self.assertEqual(af, socket.AF_INET)

    def test_af_mismatch(self):
        def bad():
            (af, d, s) = _d_and_s('1::2', 53, '1.2.3.4', 0)
        self.assertRaises(ValueError, bad)

    def test_source_port_but_no_af_inferred(self):
        def bad():
            (af, d, s) = _d_and_s('https://example/dns-query', 443,
                                  None, 12345, False)
        self.assertRaises(ValueError, bad)

    def test_where_must_be_an_address(self):
        def bad():
            (af, d, s) = _d_and_s('not a valid address', 53, '1.2.3.4', 0)
        self.assertRaises(ValueError, bad)

    def test_destination_is_none_of_where_url(self):
        (af, d, s) = _d_and_s('https://example/dns-query', 443, None, 0, False)
        self.assertEqual(d, None)

    def test_v4_wildcard_source_set(self):
        (af, d, s) = _d_and_s('1.2.3.4', 53, None, 12345)
        self.assertEqual(s, ('0.0.0.0', 12345))

    def test_v6_wildcard_source_set(self):
        (af, d, s) = _d_and_s('1::2', 53, None, 12345)
        self.assertEqual(s, ('::', 12345, 0, 0))


class AddressesEqualTestCase(unittest.TestCase):

    def test_v4(self):
        self.assertTrue(dns.query._addresses_equal(socket.AF_INET,
                                                   ('10.0.0.1', 53),
                                                   ('10.0.0.1', 53)))
        self.assertFalse(dns.query._addresses_equal(socket.AF_INET,
                                                    ('10.0.0.1', 53),
                                                    ('10.0.0.2', 53)))

    def test_v6(self):
        self.assertTrue(dns.query._addresses_equal(socket.AF_INET6,
                                                   ('1::1', 53),
                                                   ('0001:0000::1', 53)))
        self.assertFalse(dns.query._addresses_equal(socket.AF_INET6,
                                                   ('::1', 53),
                                                   ('::2', 53)))

    def test_mixed(self):
        self.assertFalse(dns.query._addresses_equal(socket.AF_INET,
                                                    ('10.0.0.1', 53),
                                                    ('::2', 53)))


axfr_zone = '''
$TTL 300
@ SOA ns1 root 1 7200 900 1209600 86400
@ NS ns1
@ NS ns2
ns1 A 10.0.0.1
ns2 A 10.0.0.1
'''

class AXFRNanoNameserver(Server):

    def handle(self, request):
        self.zone = dns.zone.from_text(axfr_zone, origin=self.origin)
        self.origin = self.zone.origin
        items = []
        soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA)
        response = dns.message.make_response(request.message)
        response.flags |= dns.flags.AA
        response.answer.append(soa)
        items.append(response)
        response = dns.message.make_response(request.message)
        response.question = []
        response.flags |= dns.flags.AA
        for (name, rdataset) in self.zone.iterate_rdatasets():
            if rdataset.rdtype == dns.rdatatype.SOA and \
               name == dns.name.empty:
                continue
            rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype,
                                    rdataset.covers)
            rrset.update(rdataset)
            response.answer.append(rrset)
        items.append(response)
        response = dns.message.make_response(request.message)
        response.question = []
        response.flags |= dns.flags.AA
        response.answer.append(soa)
        items.append(response)
        return items

ixfr_message = '''id 12345
opcode QUERY
rcode NOERROR
flags AA
;QUESTION
example. IN IXFR
;ANSWER
example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
deleted.example. 300 IN A 10.0.0.1
changed.example. 300 IN A 10.0.0.2
example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
changed.example. 300 IN A 10.0.0.4
added.example. 300 IN A 10.0.0.3
example. 300 SOA ns1.example. root.example. 3 7200 900 1209600 86400
example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
added2.example. 300 IN A 10.0.0.5
example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
'''

ixfr_trailing_junk = ixfr_message + 'junk.example. 300 IN A 10.0.0.6'

ixfr_up_to_date_message = '''id 12345
opcode QUERY
rcode NOERROR
flags AA
;QUESTION
example. IN IXFR
;ANSWER
example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
'''

axfr_trailing_junk = '''id 12345
opcode QUERY
rcode NOERROR
flags AA
;QUESTION
example. IN AXFR
;ANSWER
example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
added.example. 300 IN A 10.0.0.3
added2.example. 300 IN A 10.0.0.5
changed.example. 300 IN A 10.0.0.4
example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
junk.example. 300 IN A 10.0.0.6
'''

class IXFRNanoNameserver(Server):

    def __init__(self, response_text):
        super().__init__()
        self.response_text = response_text

    def handle(self, request):
        try:
            r = dns.message.from_text(self.response_text, one_rr_per_rrset=True)
            r.id = request.message.id
            return r
        except Exception:
            pass

@unittest.skipIf(not _nanonameserver_available, "nanonameserver required")
class XfrTests(unittest.TestCase):

    def test_axfr(self):
        expected = dns.zone.from_text(axfr_zone, origin='example')
        with AXFRNanoNameserver(origin='example') as ns:
            xfr = dns.query.xfr(ns.tcp_address[0], 'example',
                                port=ns.tcp_address[1])
            zone = dns.zone.from_xfr(xfr)
            self.assertEqual(zone, expected)

    def test_axfr_tsig(self):
        expected = dns.zone.from_text(axfr_zone, origin='example')
        with AXFRNanoNameserver(origin='example', keyring=keyring) as ns:
            xfr = dns.query.xfr(ns.tcp_address[0], 'example',
                                port=ns.tcp_address[1],
                                keyring=keyring, keyname='name')
            zone = dns.zone.from_xfr(xfr)
            self.assertEqual(zone, expected)

    def test_axfr_root_tsig(self):
        expected = dns.zone.from_text(axfr_zone, origin='.')
        with AXFRNanoNameserver(origin='.', keyring=keyring) as ns:
            xfr = dns.query.xfr(ns.tcp_address[0], '.',
                                port=ns.tcp_address[1],
                                keyring=keyring, keyname='name')
            zone = dns.zone.from_xfr(xfr)
            self.assertEqual(zone, expected)

    def test_axfr_udp(self):
        def bad():
            with AXFRNanoNameserver(origin='example') as ns:
                xfr = dns.query.xfr(ns.udp_address[0], 'example',
                                    port=ns.udp_address[1], use_udp=True)
                l = list(xfr)
        self.assertRaises(ValueError, bad)

    def test_axfr_bad_rcode(self):
        def bad():
            # We just use Server here as by default it will refuse.
            with Server() as ns:
                xfr = dns.query.xfr(ns.tcp_address[0], 'example',
                                    port=ns.tcp_address[1])
                l = list(xfr)
        self.assertRaises(dns.query.TransferError, bad)

    def test_axfr_trailing_junk(self):
        # we use the IXFR server here as it returns messages
        def bad():
            with IXFRNanoNameserver(axfr_trailing_junk) as ns:
                xfr = dns.query.xfr(ns.tcp_address[0], 'example',
                                    dns.rdatatype.AXFR,
                                    port=ns.tcp_address[1])
                l = list(xfr)
        self.assertRaises(dns.exception.FormError, bad)

    def test_ixfr_tcp(self):
        with IXFRNanoNameserver(ixfr_message) as ns:
            xfr = dns.query.xfr(ns.tcp_address[0], 'example',
                                dns.rdatatype.IXFR,
                                port=ns.tcp_address[1],
                                serial=2,
                                relativize=False)
            l = list(xfr)
            self.assertEqual(len(l), 1)
            expected = dns.message.from_text(ixfr_message,
                                             one_rr_per_rrset=True)
            expected.id = l[0].id
            self.assertEqual(l[0], expected)

    def test_ixfr_udp(self):
        with IXFRNanoNameserver(ixfr_message) as ns:
            xfr = dns.query.xfr(ns.udp_address[0], 'example',
                                dns.rdatatype.IXFR,
                                port=ns.udp_address[1],
                                serial=2,
                                relativize=False, use_udp=True)
            l = list(xfr)
            self.assertEqual(len(l), 1)
            expected = dns.message.from_text(ixfr_message,
                                             one_rr_per_rrset=True)
            expected.id = l[0].id
            self.assertEqual(l[0], expected)

    def test_ixfr_up_to_date(self):
        with IXFRNanoNameserver(ixfr_up_to_date_message) as ns:
            xfr = dns.query.xfr(ns.tcp_address[0], 'example',
                                dns.rdatatype.IXFR,
                                port=ns.tcp_address[1],
                                serial=2,
                                relativize=False)
            l = list(xfr)
            self.assertEqual(len(l), 1)
            expected = dns.message.from_text(ixfr_up_to_date_message,
                                             one_rr_per_rrset=True)
            expected.id = l[0].id
            self.assertEqual(l[0], expected)

    def test_ixfr_trailing_junk(self):
        def bad():
            with IXFRNanoNameserver(ixfr_trailing_junk) as ns:
                xfr = dns.query.xfr(ns.tcp_address[0], 'example',
                                    dns.rdatatype.IXFR,
                                    port=ns.tcp_address[1],
                                    serial=2,
                                    relativize=False)
                l = list(xfr)
        self.assertRaises(dns.exception.FormError, bad)

    def test_ixfr_base_serial_mismatch(self):
        def bad():
            with IXFRNanoNameserver(ixfr_message) as ns:
                xfr = dns.query.xfr(ns.tcp_address[0], 'example',
                                    dns.rdatatype.IXFR,
                                    port=ns.tcp_address[1],
                                    serial=1,
                                    relativize=False)
                l = list(xfr)
        self.assertRaises(dns.exception.FormError, bad)

class TSIGNanoNameserver(Server):

    def handle(self, request):
        response = dns.message.make_response(request.message)
        response.set_rcode(dns.rcode.REFUSED)
        response.flags |= dns.flags.RA
        try:
            if request.qtype == dns.rdatatype.A and \
               request.qclass == dns.rdataclass.IN:
                rrs = dns.rrset.from_text(request.qname, 300,
                                          'IN', 'A', '1.2.3.4')
                response.answer.append(rrs)
                response.set_rcode(dns.rcode.NOERROR)
                response.flags |= dns.flags.AA
        except Exception:
            pass
        return response

@unittest.skipIf(not _nanonameserver_available, "nanonameserver required")
class TsigTests(unittest.TestCase):

    def test_tsig(self):
        with TSIGNanoNameserver(keyring=keyring) as ns:
            qname = dns.name.from_text('example.com')
            q = dns.message.make_query(qname, 'A')
            q.use_tsig(keyring=keyring, keyname='name')
            response = dns.query.udp(q, ns.udp_address[0],
                                     port=ns.udp_address[1])
            self.assertTrue(response.had_tsig)
            rrs = response.get_rrset(response.answer, qname,
                                     dns.rdataclass.IN, dns.rdatatype.A)
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue('1.2.3.4' in seen)

@unittest.skipIf(sys.platform == 'win32',
                 'low level tests do not work on win32')
class LowLevelWaitTests(unittest.TestCase):

    def test_wait_for(self):
        try:
            (l, r) = socket.socketpair()
            # already expired
            with self.assertRaises(dns.exception.Timeout):
                dns.query._wait_for(l, True, True, True, 0)
            # simple timeout
            with self.assertRaises(dns.exception.Timeout):
                dns.query._wait_for(l, False, False, False, time.time() + 0.05)
            # writable no timeout (not hanging is passing)
            dns.query._wait_for(l, False, True, False, None)
        finally:
            l.close()
            r.close()


class MiscTests(unittest.TestCase):
    def test_matches_destination(self):
        self.assertTrue(dns.query._matches_destination(socket.AF_INET,
                                                       ('10.0.0.1', 1234),
                                                       ('10.0.0.1', 1234),
                                                       True))
        self.assertTrue(dns.query._matches_destination(socket.AF_INET6,
                                                       ('1::2', 1234),
                                                       ('0001::2', 1234),
                                                       True))
        self.assertTrue(dns.query._matches_destination(socket.AF_INET,
                                                       ('10.0.0.1', 1234),
                                                       None,
                                                       True))
        self.assertFalse(dns.query._matches_destination(socket.AF_INET,
                                                        ('10.0.0.1', 1234),
                                                        ('10.0.0.2', 1234),
                                                        True))
        self.assertFalse(dns.query._matches_destination(socket.AF_INET,
                                                        ('10.0.0.1', 1234),
                                                        ('10.0.0.1', 1235),
                                                        True))
        with self.assertRaises(dns.query.UnexpectedSource):
            dns.query._matches_destination(socket.AF_INET,
                                           ('10.0.0.1', 1234),
                                           ('10.0.0.1', 1235),
                                           False)