Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
SQLAlchemy / testing / suite / test_cte.py
Size: Mime:
from .. import fixtures, config
from ..assertions import eq_

from sqlalchemy import Integer, String, select
from sqlalchemy import ForeignKey
from sqlalchemy import testing

from ..schema import Table, Column


class CTETest(fixtures.TablesTest):
    __backend__ = True
    __requires__ = 'ctes',

    run_inserts = 'each'
    run_deletes = 'each'

    @classmethod
    def define_tables(cls, metadata):
        Table("some_table", metadata,
              Column('id', Integer, primary_key=True),
              Column('data', String(50)),
              Column("parent_id", ForeignKey("some_table.id")))

        Table("some_other_table", metadata,
              Column('id', Integer, primary_key=True),
              Column('data', String(50)),
              Column("parent_id", Integer))

    @classmethod
    def insert_data(cls):
        config.db.execute(
            cls.tables.some_table.insert(),
            [
                {"id": 1, "data": "d1", "parent_id": None},
                {"id": 2, "data": "d2", "parent_id": 1},
                {"id": 3, "data": "d3", "parent_id": 1},
                {"id": 4, "data": "d4", "parent_id": 3},
                {"id": 5, "data": "d5", "parent_id": 3}
            ]
        )

    def test_select_nonrecursive_round_trip(self):
        some_table = self.tables.some_table

        with config.db.connect() as conn:
            cte = select([some_table]).where(
                some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte")
            result = conn.execute(
                select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
            )
            eq_(result.fetchall(), [("d4", )])

    def test_select_recursive_round_trip(self):
        some_table = self.tables.some_table

        with config.db.connect() as conn:
            cte = select([some_table]).where(
                some_table.c.data.in_(["d2", "d3", "d4"])).cte(
                "some_cte", recursive=True)

            cte_alias = cte.alias("c1")
            st1 = some_table.alias()
            # note that SQL Server requires this to be UNION ALL,
            # can't be UNION
            cte = cte.union_all(
                select([st1]).where(st1.c.id == cte_alias.c.parent_id)
            )
            result = conn.execute(
                select([cte.c.data]).where(
                    cte.c.data != "d2").order_by(cte.c.data.desc())
            )
            eq_(
                result.fetchall(),
                [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)]
            )

    def test_insert_from_select_round_trip(self):
        some_table = self.tables.some_table
        some_other_table = self.tables.some_other_table

        with config.db.connect() as conn:
            cte = select([some_table]).where(
                some_table.c.data.in_(["d2", "d3", "d4"])
            ).cte("some_cte")
            conn.execute(
                some_other_table.insert().from_select(
                    ["id", "data", "parent_id"],
                    select([cte])
                )
            )
            eq_(
                conn.execute(
                    select([some_other_table]).order_by(some_other_table.c.id)
                ).fetchall(),
                [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)]
            )

    @testing.requires.ctes_with_update_delete
    @testing.requires.update_from
    def test_update_from_round_trip(self):
        some_table = self.tables.some_table
        some_other_table = self.tables.some_other_table

        with config.db.connect() as conn:
            conn.execute(
                some_other_table.insert().from_select(
                    ['id', 'data', 'parent_id'],
                    select([some_table])
                )
            )

            cte = select([some_table]).where(
                some_table.c.data.in_(["d2", "d3", "d4"])
            ).cte("some_cte")
            conn.execute(
                some_other_table.update().values(parent_id=5).where(
                    some_other_table.c.data == cte.c.data
                )
            )
            eq_(
                conn.execute(
                    select([some_other_table]).order_by(some_other_table.c.id)
                ).fetchall(),
                [
                    (1, "d1", None), (2, "d2", 5),
                    (3, "d3", 5), (4, "d4", 5), (5, "d5", 3)
                ]
            )

    @testing.requires.ctes_with_update_delete
    @testing.requires.delete_from
    def test_delete_from_round_trip(self):
        some_table = self.tables.some_table
        some_other_table = self.tables.some_other_table

        with config.db.connect() as conn:
            conn.execute(
                some_other_table.insert().from_select(
                    ['id', 'data', 'parent_id'],
                    select([some_table])
                )
            )

            cte = select([some_table]).where(
                some_table.c.data.in_(["d2", "d3", "d4"])
            ).cte("some_cte")
            conn.execute(
                some_other_table.delete().where(
                    some_other_table.c.data == cte.c.data
                )
            )
            eq_(
                conn.execute(
                    select([some_other_table]).order_by(some_other_table.c.id)
                ).fetchall(),
                [
                    (1, "d1", None), (5, "d5", 3)
                ]
            )

    @testing.requires.ctes_with_update_delete
    def test_delete_scalar_subq_round_trip(self):

        some_table = self.tables.some_table
        some_other_table = self.tables.some_other_table

        with config.db.connect() as conn:
            conn.execute(
                some_other_table.insert().from_select(
                    ['id', 'data', 'parent_id'],
                    select([some_table])
                )
            )

            cte = select([some_table]).where(
                some_table.c.data.in_(["d2", "d3", "d4"])
            ).cte("some_cte")
            conn.execute(
                some_other_table.delete().where(
                    some_other_table.c.data ==
                    select([cte.c.data]).where(
                        cte.c.id == some_other_table.c.id)
                )
            )
            eq_(
                conn.execute(
                    select([some_other_table]).order_by(some_other_table.c.id)
                ).fetchall(),
                [
                    (1, "d1", None), (5, "d5", 3)
                ]
            )