Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / psycopg2   python

Repository URL to install this package:

/ tests / test_connection.py

#!/usr/bin/env python

# test_connection.py - unit test for connection attributes
#
# Copyright (C) 2008-2011 James Henstridge  <james@jamesh.id.au>
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
# License for more details.

import re
import os
import sys
import time
import threading
import subprocess as sp
from operator import attrgetter

import psycopg2
import psycopg2.errorcodes
from psycopg2 import extensions as ext

from .testutils import (
    script_to_py3, unittest, decorate_all_tests, skip_if_no_superuser,
    skip_before_postgres, skip_after_postgres, skip_before_libpq,
    ConnectingTestCase, skip_if_tpc_disabled, skip_if_windows, slow)

from .testconfig import dsn, dbname


class ConnectionTests(ConnectingTestCase):
    def test_closed_attribute(self):
        conn = self.conn
        self.assertEqual(conn.closed, False)
        conn.close()
        self.assertEqual(conn.closed, True)

    def test_close_idempotent(self):
        conn = self.conn
        conn.close()
        conn.close()
        self.assertTrue(conn.closed)

    def test_cursor_closed_attribute(self):
        conn = self.conn
        curs = conn.cursor()
        self.assertEqual(curs.closed, False)
        curs.close()
        self.assertEqual(curs.closed, True)

        # Closing the connection closes the cursor:
        curs = conn.cursor()
        conn.close()
        self.assertEqual(curs.closed, True)

    @skip_before_postgres(8, 4)
    @skip_if_no_superuser
    @skip_if_windows
    def test_cleanup_on_badconn_close(self):
        # ticket #148
        conn = self.conn
        cur = conn.cursor()
        self.assertRaises(psycopg2.OperationalError,
            cur.execute, "select pg_terminate_backend(pg_backend_pid())")

        self.assertEqual(conn.closed, 2)
        conn.close()
        self.assertEqual(conn.closed, 1)

    def test_reset(self):
        conn = self.conn
        # switch session characteristics
        conn.autocommit = True
        conn.isolation_level = 'serializable'
        conn.readonly = True
        if self.conn.server_version >= 90100:
            conn.deferrable = False

        self.assertTrue(conn.autocommit)
        self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
        self.assertTrue(conn.readonly is True)
        if self.conn.server_version >= 90100:
            self.assertTrue(conn.deferrable is False)

        conn.reset()
        # now the session characteristics should be reverted
        self.assertTrue(not conn.autocommit)
        self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT)
        self.assertTrue(conn.readonly is None)
        if self.conn.server_version >= 90100:
            self.assertTrue(conn.deferrable is None)

    def test_notices(self):
        conn = self.conn
        cur = conn.cursor()
        if self.conn.server_version >= 90300:
            cur.execute("set client_min_messages=debug1")
        cur.execute("create temp table chatty (id serial primary key);")
        self.assertEqual("CREATE TABLE", cur.statusmessage)
        self.assertTrue(conn.notices)

    def test_notices_consistent_order(self):
        conn = self.conn
        cur = conn.cursor()
        if self.conn.server_version >= 90300:
            cur.execute("set client_min_messages=debug1")
        cur.execute("""
            create temp table table1 (id serial);
            create temp table table2 (id serial);
            """)
        cur.execute("""
            create temp table table3 (id serial);
            create temp table table4 (id serial);
            """)
        self.assertEqual(4, len(conn.notices))
        self.assertTrue('table1' in conn.notices[0])
        self.assertTrue('table2' in conn.notices[1])
        self.assertTrue('table3' in conn.notices[2])
        self.assertTrue('table4' in conn.notices[3])

    @slow
    def test_notices_limited(self):
        conn = self.conn
        cur = conn.cursor()
        if self.conn.server_version >= 90300:
            cur.execute("set client_min_messages=debug1")
        for i in range(0, 100, 10):
            sql = " ".join(["create temp table table%d (id serial);" % j
                            for j in range(i, i + 10)])
            cur.execute(sql)

        self.assertEqual(50, len(conn.notices))
        self.assertTrue('table99' in conn.notices[-1], conn.notices[-1])

    @slow
    def test_notices_deque(self):
        from collections import deque

        conn = self.conn
        self.conn.notices = deque()
        cur = conn.cursor()
        if self.conn.server_version >= 90300:
            cur.execute("set client_min_messages=debug1")

        cur.execute("""
            create temp table table1 (id serial);
            create temp table table2 (id serial);
            """)
        cur.execute("""
            create temp table table3 (id serial);
            create temp table table4 (id serial);""")
        self.assertEqual(len(conn.notices), 4)
        self.assertTrue('table1' in conn.notices.popleft())
        self.assertTrue('table2' in conn.notices.popleft())
        self.assertTrue('table3' in conn.notices.popleft())
        self.assertTrue('table4' in conn.notices.popleft())
        self.assertEqual(len(conn.notices), 0)

        # not limited, but no error
        for i in range(0, 100, 10):
            sql = " ".join(["create temp table table2_%d (id serial);" % j
                            for j in range(i, i + 10)])
            cur.execute(sql)

        self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]),
            100)

    def test_notices_noappend(self):
        conn = self.conn
        self.conn.notices = None    # will make an error swallowes ok
        cur = conn.cursor()
        if self.conn.server_version >= 90300:
            cur.execute("set client_min_messages=debug1")

        cur.execute("create temp table table1 (id serial);")

        self.assertEqual(self.conn.notices, None)

    def test_server_version(self):
        self.assertTrue(self.conn.server_version)

    def test_protocol_version(self):
        self.assertTrue(self.conn.protocol_version in (2, 3),
            self.conn.protocol_version)

    def test_tpc_unsupported(self):
        cnn = self.conn
        if cnn.server_version >= 80100:
            return self.skipTest("tpc is supported")

        self.assertRaises(psycopg2.NotSupportedError,
            cnn.xid, 42, "foo", "bar")

    @slow
    @skip_before_postgres(8, 2)
    def test_concurrent_execution(self):
        def slave():
            cnn = self.connect()
            cur = cnn.cursor()
            cur.execute("select pg_sleep(4)")
            cur.close()
            cnn.close()

        t1 = threading.Thread(target=slave)
        t2 = threading.Thread(target=slave)
        t0 = time.time()
        t1.start()
        t2.start()
        t1.join()
        t2.join()
        self.assertTrue(time.time() - t0 < 7,
            "something broken in concurrency")

    def test_encoding_name(self):
        self.conn.set_client_encoding("EUC_JP")
        # conn.encoding is 'EUCJP' now.
        cur = self.conn.cursor()
        ext.register_type(ext.UNICODE, cur)
        cur.execute("select 'foo'::text;")
        self.assertEqual(cur.fetchone()[0], 'foo')

    def test_connect_nonnormal_envvar(self):
        # We must perform encoding normalization at connection time
        self.conn.close()
        oldenc = os.environ.get('PGCLIENTENCODING')
        os.environ['PGCLIENTENCODING'] = 'utf-8'    # malformed spelling
        try:
            self.conn = self.connect()
        finally:
            if oldenc is not None:
                os.environ['PGCLIENTENCODING'] = oldenc
            else:
                del os.environ['PGCLIENTENCODING']

    def test_weakref(self):
        from weakref import ref
        import gc
        conn = psycopg2.connect(dsn)
        w = ref(conn)
        conn.close()
        del conn
        gc.collect()
        self.assertTrue(w() is None)

    @slow
    def test_commit_concurrency(self):
        # The problem is the one reported in ticket #103. Because of bad
        # status check, we commit even when a commit is already on its way.
        # We can detect this condition by the warnings.
        conn = self.conn
        notices = []
        stop = []

        def committer():
            while not stop:
                conn.commit()
                while conn.notices:
                    notices.append((2, conn.notices.pop()))

        cur = conn.cursor()
        t1 = threading.Thread(target=committer)
        t1.start()
        i = 1
        for i in range(1000):
            cur.execute("select %s;", (i,))
            conn.commit()
            while conn.notices:
                notices.append((1, conn.notices.pop()))

        # Stop the committer thread
        stop.append(True)

        self.assertTrue(not notices, "%d notices raised" % len(notices))

    def test_connect_cursor_factory(self):
        import psycopg2.extras
        conn = self.connect(cursor_factory=psycopg2.extras.DictCursor)
        cur = conn.cursor()
        cur.execute("select 1 as a")
        self.assertEqual(cur.fetchone()['a'], 1)

    def test_cursor_factory(self):
        self.assertEqual(self.conn.cursor_factory, None)
        cur = self.conn.cursor()
        cur.execute("select 1 as a")
        self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone())

        self.conn.cursor_factory = psycopg2.extras.DictCursor
        self.assertEqual(self.conn.cursor_factory, psycopg2.extras.DictCursor)
        cur = self.conn.cursor()
        cur.execute("select 1 as a")
        self.assertEqual(cur.fetchone()['a'], 1)

        self.conn.cursor_factory = None
        self.assertEqual(self.conn.cursor_factory, None)
        cur = self.conn.cursor()
        cur.execute("select 1 as a")
        self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone())

    def test_cursor_factory_none(self):
        # issue #210
        conn = self.connect()
        cur = conn.cursor(cursor_factory=None)
        self.assertEqual(type(cur), ext.cursor)

        conn = self.connect(cursor_factory=psycopg2.extras.DictCursor)
        cur = conn.cursor(cursor_factory=None)
        self.assertEqual(type(cur), psycopg2.extras.DictCursor)

    def test_failed_init_status(self):
        class SubConnection(ext.connection):
            def __init__(self, dsn):
                try:
                    super(SubConnection, self).__init__(dsn)
                except Exception:
                    pass

        c = SubConnection("dbname=thereisnosuchdatabasemate password=foobar")
        self.assertTrue(c.closed, "connection failed so it must be closed")
        self.assertTrue('foobar' not in c.dsn, "password was not obscured")


class ParseDsnTestCase(ConnectingTestCase):
    def test_parse_dsn(self):
        from psycopg2 import ProgrammingError

        self.assertEqual(
            ext.parse_dsn('dbname=test user=tester password=secret'),
            dict(user='tester', password='secret', dbname='test'),
            "simple DSN parsed")
Loading ...