Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
prefect / orion / database / query_components.py
Size: Mime:
import datetime
from abc import ABC, abstractmethod, abstractproperty
from typing import TYPE_CHECKING, Dict, Hashable, List, Optional, Tuple
from uuid import UUID

import pendulum
import sqlalchemy as sa
from cachetools import TTLCache
from jinja2 import Environment, PackageLoader, select_autoescape
from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy.ext.asyncio import AsyncSession

from prefect.orion import schemas
from prefect.orion.utilities.database import UUID as UUIDTypeDecorator
from prefect.orion.utilities.database import Timestamp, json_has_any_key

if TYPE_CHECKING:
    from prefect.orion.database.interface import OrionDBInterface

ONE_HOUR = 60 * 60


jinja_env = Environment(
    loader=PackageLoader("prefect.orion.database", package_path="sql"),
    autoescape=select_autoescape(),
    trim_blocks=True,
)


class BaseQueryComponents(ABC):
    """
    Abstract base class used to inject dialect-specific SQL operations into Orion.
    """

    CONFIGURATION_CACHE = TTLCache(maxsize=100, ttl=ONE_HOUR)

    def _unique_key(self) -> Tuple[Hashable, ...]:
        """
        Returns a key used to determine whether to instantiate a new DB interface.
        """
        return (self.__class__,)

    # --- dialect-specific SqlAlchemy bindings

    @abstractmethod
    def insert(self, obj):
        """dialect-specific insert statement"""

    @abstractmethod
    def greatest(self, *values):
        """dialect-specific SqlAlchemy binding"""

    @abstractmethod
    def least(self, *values):
        """dialect-specific SqlAlchemy binding"""

    # --- dialect-specific JSON handling

    @abstractproperty
    def uses_json_strings(self) -> bool:
        """specifies whether the configured dialect returns JSON as strings"""

    @abstractmethod
    def cast_to_json(self, json_obj):
        """casts to JSON object if necessary"""

    @abstractmethod
    def build_json_object(self, *args):
        """builds a JSON object from sequential key-value pairs"""

    @abstractmethod
    def json_arr_agg(self, json_array):
        """aggregates a JSON array"""

    # --- dialect-optimized subqueries

    @abstractmethod
    def make_timestamp_intervals(
        self,
        start_time: datetime.datetime,
        end_time: datetime.datetime,
        interval: datetime.timedelta,
    ):
        ...

    @abstractmethod
    def set_state_id_on_inserted_flow_runs_statement(
        self,
        fr_model,
        frs_model,
        inserted_flow_run_ids,
        insert_flow_run_states,
    ):
        ...

    @abstractmethod
    async def get_flow_run_notifications_from_queue(
        self, session: AsyncSession, db: "OrionDBInterface", limit: int
    ):
        """Database-specific implementation of reading notifications from the queue and deleting them"""

    async def queue_flow_run_notifications(
        self,
        session: sa.orm.session,
        flow_run: schemas.core.FlowRun,
        db: "OrionDBInterface",
    ):
        """Database-specific implementation of queueing notifications for a flow run"""
        # insert a <policy, state> pair into the notification queue
        stmt = (await db.insert(db.FlowRunNotificationQueue)).from_select(
            [
                db.FlowRunNotificationQueue.flow_run_notification_policy_id,
                db.FlowRunNotificationQueue.flow_run_state_id,
            ],
            # ... by selecting from any notification policy that matches the criteria
            sa.select(
                db.FlowRunNotificationPolicy.id,
                sa.cast(sa.literal(str(flow_run.state_id)), UUIDTypeDecorator),
            )
            .select_from(db.FlowRunNotificationPolicy)
            .where(
                sa.and_(
                    # the policy is active
                    db.FlowRunNotificationPolicy.is_active.is_(True),
                    # the policy state names aren't set or match the current state name
                    sa.or_(
                        db.FlowRunNotificationPolicy.state_names == [],
                        json_has_any_key(
                            db.FlowRunNotificationPolicy.state_names,
                            [flow_run.state_name],
                        ),
                    ),
                    # the policy tags aren't set, or the tags match the flow run tags
                    sa.or_(
                        db.FlowRunNotificationPolicy.tags == [],
                        json_has_any_key(
                            db.FlowRunNotificationPolicy.tags, flow_run.tags
                        ),
                    ),
                )
            ),
            # don't send python defaults as part of the insert statement, because they are
            # evaluated once per statement and create unique constraint violations on each row
            include_defaults=False,
        )
        await session.execute(stmt)

    def get_scheduled_flow_runs_from_work_queues(
        self,
        db: "OrionDBInterface",
        limit_per_queue: Optional[int] = None,
        work_queue_ids: Optional[List[UUID]] = None,
        scheduled_before: Optional[datetime.datetime] = None,
    ):
        """
        Returns all scheduled runs in work queues, subject to provided parameters.

        This query returns a `(db.FlowRun, db.WorkQueue.id)` pair; calling
        `result.all()` will return both; calling `result.scalars().unique().all()`
        will return only the flow run because it grabs the first result.
        """

        # get any work queues that have a concurrency limit, and compute available
        # slots as their limit less the number of running flows
        concurrency_queues = (
            sa.select(
                db.WorkQueue.id,
                self.greatest(
                    0, db.WorkQueue.concurrency_limit - sa.func.count(db.FlowRun.id)
                ).label("available_slots"),
            )
            .select_from(db.WorkQueue)
            .join(
                db.FlowRun,
                sa.and_(
                    self._flow_run_work_queue_join_clause(db.FlowRun, db.WorkQueue),
                    db.FlowRun.state_type.in_(["RUNNING", "PENDING", "CANCELLING"]),
                ),
                isouter=True,
            )
            .where(db.WorkQueue.concurrency_limit.is_not(None))
            .group_by(db.WorkQueue.id)
            .cte("concurrency_queues")
        )

        # use the available slots information to generate a join
        # for all scheduled runs
        scheduled_flow_runs, join_criteria = self._get_scheduled_flow_runs_join(
            db=db,
            work_queue_query=concurrency_queues,
            limit_per_queue=limit_per_queue,
            scheduled_before=scheduled_before,
        )

        # starting with the work queue table, join the limited queues to get the
        # concurrency information and the scheduled flow runs to load all applicable
        # runs. this will return all the scheduled runs allowed by the parameters
        query = (
            # return a flow run and work queue id
            sa.select(
                sa.orm.aliased(db.FlowRun, scheduled_flow_runs),
                db.WorkQueue.id.label("wq_id"),
            )
            .select_from(db.WorkQueue)
            .join(
                concurrency_queues,
                db.WorkQueue.id == concurrency_queues.c.id,
                isouter=True,
            )
            .join(scheduled_flow_runs, join_criteria)
            .where(
                db.WorkQueue.is_paused.is_(False),
                db.WorkQueue.id.in_(work_queue_ids) if work_queue_ids else True,
            )
            .order_by(
                scheduled_flow_runs.c.next_scheduled_start_time,
                scheduled_flow_runs.c.id,
            )
        )

        return query

    def _get_scheduled_flow_runs_join(
        self,
        db: "OrionDBInterface",
        work_queue_query,
        limit_per_queue: Optional[int],
        scheduled_before: Optional[datetime.datetime],
    ):
        """Used by self.get_scheduled_flow_runs_from_work_queue, allowing just
        this function to be changed on a per-dialect basis"""

        # precompute for readability
        scheduled_before_clause = (
            db.FlowRun.next_scheduled_start_time <= scheduled_before
            if scheduled_before is not None
            else True
        )

        # get scheduled flow runs with lateral join where the limit is the
        # available slots per queue
        scheduled_flow_runs = (
            sa.select(db.FlowRun)
            .where(
                self._flow_run_work_queue_join_clause(db.FlowRun, db.WorkQueue),
                db.FlowRun.state_type == "SCHEDULED",
                scheduled_before_clause,
            )
            .with_for_update(skip_locked=True)
            # priority given to runs with earlier next_scheduled_start_time
            .order_by(db.FlowRun.next_scheduled_start_time)
            # if null, no limit will be applied
            .limit(sa.func.least(limit_per_queue, work_queue_query.c.available_slots))
            .lateral("scheduled_flow_runs")
        )

        join_criteria = True

        return scheduled_flow_runs, join_criteria

    def _flow_run_work_queue_join_clause(self, flow_run, work_queue):
        """
        On clause for for joining flow runs to work queues

        Used by self.get_scheduled_flow_runs_from_work_queue, allowing just
        this function to be changed on a per-dialect basis
        """
        return sa.and_(flow_run.work_queue_name == work_queue.name)

    # -------------------------------------------------------
    # Workers
    # -------------------------------------------------------

    @abstractproperty
    def _get_scheduled_flow_runs_from_work_pool_template_path(self):
        """
        Template for the query to get scheduled flow runs from a work pool
        """

    async def get_scheduled_flow_runs_from_work_pool(
        self,
        session,
        db: "OrionDBInterface",
        limit: Optional[int] = None,
        worker_limit: Optional[int] = None,
        queue_limit: Optional[int] = None,
        work_pool_ids: Optional[List[UUID]] = None,
        work_queue_ids: Optional[List[UUID]] = None,
        scheduled_before: Optional[datetime.datetime] = None,
        scheduled_after: Optional[datetime.datetime] = None,
        respect_queue_priorities: bool = False,
    ) -> List[schemas.responses.WorkerFlowRunResponse]:

        template = jinja_env.get_template(
            self._get_scheduled_flow_runs_from_work_pool_template_path
        )

        raw_query = sa.text(
            template.render(
                work_pool_ids=work_pool_ids,
                work_queue_ids=work_queue_ids,
                respect_queue_priorities=respect_queue_priorities,
                scheduled_before=scheduled_before,
                scheduled_after=scheduled_after,
            )
        )

        bindparams = []

        if scheduled_before:
            bindparams.append(
                sa.bindparam("scheduled_before", scheduled_before, type_=Timestamp)
            )

        if scheduled_after:
            bindparams.append(
                sa.bindparam("scheduled_after", scheduled_after, type_=Timestamp)
            )

        # if work pool IDs were provided, bind them
        if work_pool_ids:
            assert all(isinstance(i, UUID) for i in work_pool_ids)
            bindparams.append(
                sa.bindparam(
                    "work_pool_ids",
                    work_pool_ids,
                    expanding=True,
                    type_=UUIDTypeDecorator,
                )
            )

        # if work queue IDs were provided, bind them
        if work_queue_ids:
            assert all(isinstance(i, UUID) for i in work_queue_ids)
            bindparams.append(
                sa.bindparam(
                    "work_queue_ids",
                    work_queue_ids,
                    expanding=True,
                    type_=UUIDTypeDecorator,
                )
            )

        query = raw_query.bindparams(
            *bindparams,
            limit=1000 if limit is None else limit,
            worker_limit=1000 if worker_limit is None else worker_limit,
            queue_limit=1000 if queue_limit is None else queue_limit,
        )

        orm_query = (
            sa.select(
                sa.column("run_work_pool_id"),
                sa.column("run_work_queue_id"),
                db.FlowRun,
            ).from_statement(query)
            # indicate that the state relationship isn't being loaded
            .options(sa.orm.noload(db.FlowRun.state))
        )

        result = await session.execute(orm_query)

        return [
            schemas.responses.WorkerFlowRunResponse(
                work_pool_id=r.run_work_pool_id,
                work_queue_id=r.run_work_queue_id,
                flow_run=schemas.core.FlowRun.from_orm(r.FlowRun),
            )
            for r in result
        ]

    async def read_block_documents(
        self,
        session: sa.orm.Session,
        db: "OrionDBInterface",
        block_document_filter: Optional[schemas.filters.BlockDocumentFilter] = None,
        block_type_filter: Optional[schemas.filters.BlockTypeFilter] = None,
        block_schema_filter: Optional[schemas.filters.BlockSchemaFilter] = None,
        include_secrets: bool = False,
        offset: Optional[int] = None,
        limit: Optional[int] = None,
    ):

        # if no filter is provided, one is created that excludes anonymous blocks
        if block_document_filter is None:
            block_document_filter = schemas.filters.BlockDocumentFilter(
                is_anonymous=schemas.filters.BlockDocumentFilterIsAnonymous(eq_=False)
            )

        # --- Query for Parent Block Documents
        # begin by building a query for only those block documents that are selected
        # by the provided filters
        filtered_block_documents_query = sa.select(db.BlockDocument.id).where(
            block_document_filter.as_sql_filter(db)
        )

        if block_type_filter is not None:
            block_type_exists_clause = sa.select(db.BlockType).where(
                db.BlockType.id == db.BlockDocument.block_type_id,
                block_type_filter.as_sql_filter(db),
            )
            filtered_block_documents_query = filtered_block_documents_query.where(
                block_type_exists_clause.exists()
            )

        if block_schema_filter is not None:
            block_schema_exists_clause = sa.select(db.BlockSchema).where(
                db.BlockSchema.id == db.BlockDocument.block_schema_id,
                block_schema_filter.as_sql_filter(db),
            )
            filtered_block_documents_query = filtered_block_documents_query.where(
                block_schema_exists_clause.exists()
            )

        if offset is not None:
            filtered_block_documents_query = filtered_block_documents_query.offset(
                offset
            )

        if limit is not None:
            filtered_block_documents_query = filtered_block_documents_query.limit(limit)

        filtered_block_documents_query = filtered_block_documents_query.cte(
            "filtered_block_documents"
        )

        # --- Query for Referenced Block Documents
        # next build a recursive query for (potentially nested) block documents
        # that reference the filtered block documents
        block_document_references_query = (
            sa.select(db.BlockDocumentReference)
            .filter(
                db.BlockDocumentReference.parent_block_document_id.in_(
                    sa.select(filtered_block_documents_query.c.id)
                )
            )
            .cte("block_document_references", recursive=True)
        )
        block_document_references_join = sa.select(db.BlockDocumentReference).join(
            block_document_references_query,
            db.BlockDocumentReference.parent_block_document_id
            == block_document_references_query.c.reference_block_document_id,
        )
        recursive_block_document_references_cte = (
            block_document_references_query.union_all(block_document_references_join)
        )

        # --- Final Query for All Block Documents
        # build a query that unions:
        # - the filtered block documents
        # - with any block documents that are discovered as (potentially nested) references
        all_block_documents_query = sa.union_all(
            # first select the parent block
            sa.select(
                [
                    db.BlockDocument,
                    sa.null().label("reference_name"),
                    sa.null().label("reference_parent_block_document_id"),
                ]
            )
            .select_from(db.BlockDocument)
            .where(
                db.BlockDocument.id.in_(sa.select(filtered_block_documents_query.c.id))
            ),
            #
            # then select any referenced blocks
            sa.select(
                [
                    db.BlockDocument,
                    recursive_block_document_references_cte.c.name,
                    recursive_block_document_references_cte.c.parent_block_document_id,
                ]
            )
            .select_from(db.BlockDocument)
            .join(
                recursive_block_document_references_cte,
                db.BlockDocument.id
                == recursive_block_document_references_cte.c.reference_block_document_id,
            ),
        ).cte("all_block_documents_query")

        # the final union query needs to be `aliased` for proper ORM unpacking
        # and also be sorted
        return (
            sa.select(
                sa.orm.aliased(db.BlockDocument, all_block_documents_query),
                all_block_documents_query.c.reference_name,
                all_block_documents_query.c.reference_parent_block_document_id,
            )
            .select_from(all_block_documents_query)
            .order_by(all_block_documents_query.c.name)
        )

    async def read_configuration_value(
        self, db: "OrionDBInterface", session: sa.orm.Session, key: str
    ) -> Optional[Dict]:
        """
        Read a configuration value by key.

        Configuration values should not be changed at run time, so retrieved
        values are cached in memory.

        The main use of confiugrations is encrypting blocks, this speeds up nested
        block document queries.
        """
        try:
            return self.CONFIGURATION_CACHE[key]
        except KeyError:
            query = sa.select(db.Configuration).where(db.Configuration.key == key)
            result = await session.execute(query)
            configuration = result.scalar()
            if configuration is not None:
                self.CONFIGURATION_CACHE[key] = configuration.value
                return configuration.value
            return configuration

    def clear_configuration_value_cache_for_key(self, key: str):
        """Removes a configuration key from the cache."""
        self.CONFIGURATION_CACHE.pop(key, None)


class AsyncPostgresQueryComponents(BaseQueryComponents):
    # --- Postgres-specific SqlAlchemy bindings

    def insert(self, obj):
        return postgresql.insert(obj)

    def greatest(self, *values):
        return sa.func.greatest(*values)

    def least(self, *values):
        return sa.func.least(*values)

    # --- Postgres-specific JSON handling

    @property
    def uses_json_strings(self):
        return False

    def cast_to_json(self, json_obj):
        return json_obj

    def build_json_object(self, *args):
        return sa.func.jsonb_build_object(*args)

    def json_arr_agg(self, json_array):
        return sa.func.jsonb_agg(json_array)

    # --- Postgres-optimized subqueries

    def make_timestamp_intervals(
        self,
        start_time: datetime.datetime,
        end_time: datetime.datetime,
        interval: datetime.timedelta,
    ):
        # validate inputs
        start_time = pendulum.instance(start_time)
        end_time = pendulum.instance(end_time)
        assert isinstance(interval, datetime.timedelta)
        return (
            sa.select(
                sa.literal_column("dt").label("interval_start"),
                (sa.literal_column("dt") + interval).label("interval_end"),
            )
            .select_from(
                sa.func.generate_series(start_time, end_time, interval).alias("dt")
            )
            .where(sa.literal_column("dt") < end_time)
            # grab at most 500 intervals
            .limit(500)
        )

    def set_state_id_on_inserted_flow_runs_statement(
        self,
        fr_model,
        frs_model,
        inserted_flow_run_ids,
        insert_flow_run_states,
    ):
        """Given a list of flow run ids and associated states, set the state_id
        to the appropriate state for all flow runs"""
        # postgres supports `UPDATE ... FROM` syntax
        stmt = (
            sa.update(fr_model)
            .where(
                fr_model.id.in_(inserted_flow_run_ids),
                frs_model.flow_run_id == fr_model.id,
                frs_model.id.in_([r["id"] for r in insert_flow_run_states]),
            )
            .values(state_id=frs_model.id)
            # no need to synchronize as these flow runs are entirely new
            .execution_options(synchronize_session=False)
        )
        return stmt

    async def get_flow_run_notifications_from_queue(
        self, session: AsyncSession, db: "OrionDBInterface", limit: int
    ) -> List:

        # including this as a subquery in the where clause of the
        # `queued_notifications` statement below, leads to errors where the limit
        # is not respected if it is 1. pulling this out into a CTE statement
        # prevents this. see link for more details:
        # https://www.postgresql.org/message-id/16497.1553640836%40sss.pgh.pa.us
        queued_notifications_ids = (
            sa.select(db.FlowRunNotificationQueue.id)
            .select_from(db.FlowRunNotificationQueue)
            .order_by(db.FlowRunNotificationQueue.updated)
            .limit(limit)
            .with_for_update(skip_locked=True)
        ).cte("queued_notifications_ids")

        queued_notifications = (
            sa.delete(db.FlowRunNotificationQueue)
            .returning(
                db.FlowRunNotificationQueue.id,
                db.FlowRunNotificationQueue.flow_run_notification_policy_id,
                db.FlowRunNotificationQueue.flow_run_state_id,
            )
            .where(
                db.FlowRunNotificationQueue.id.in_(sa.select(queued_notifications_ids))
            )
            .cte("queued_notifications")
        )

        notification_details_stmt = (
            sa.select(
                queued_notifications.c.id.label("queue_id"),
                db.FlowRunNotificationPolicy.id.label(
                    "flow_run_notification_policy_id"
                ),
                db.FlowRunNotificationPolicy.message_template.label(
                    "flow_run_notification_policy_message_template"
                ),
                db.FlowRunNotificationPolicy.block_document_id,
                db.Flow.id.label("flow_id"),
                db.Flow.name.label("flow_name"),
                db.FlowRun.id.label("flow_run_id"),
                db.FlowRun.name.label("flow_run_name"),
                db.FlowRun.parameters.label("flow_run_parameters"),
                db.FlowRunState.type.label("flow_run_state_type"),
                db.FlowRunState.name.label("flow_run_state_name"),
                db.FlowRunState.timestamp.label("flow_run_state_timestamp"),
                db.FlowRunState.message.label("flow_run_state_message"),
            )
            .select_from(queued_notifications)
            .join(
                db.FlowRunNotificationPolicy,
                queued_notifications.c.flow_run_notification_policy_id
                == db.FlowRunNotificationPolicy.id,
            )
            .join(
                db.FlowRunState,
                queued_notifications.c.flow_run_state_id == db.FlowRunState.id,
            )
            .join(
                db.FlowRun,
                db.FlowRunState.flow_run_id == db.FlowRun.id,
            )
            .join(
                db.Flow,
                db.FlowRun.flow_id == db.Flow.id,
            )
        )

        result = await session.execute(notification_details_stmt)
        return result.fetchall()

    @property
    def _get_scheduled_flow_runs_from_work_pool_template_path(self):
        """
        Template for the query to get scheduled flow runs from a work pool
        """
        return "postgres/get-runs-from-worker-queues.sql.jinja"


class AioSqliteQueryComponents(BaseQueryComponents):
    # --- Sqlite-specific SqlAlchemy bindings

    def insert(self, obj):
        return sqlite.insert(obj)

    def greatest(self, *values):
        return sa.func.max(*values)

    def least(self, *values):
        return sa.func.min(*values)

    # --- Sqlite-specific JSON handling

    @property
    def uses_json_strings(self):
        return True

    def cast_to_json(self, json_obj):
        return sa.func.json(json_obj)

    def build_json_object(self, *args):
        return sa.func.json_object(*args)

    def json_arr_agg(self, json_array):
        return sa.func.json_group_array(json_array)

    # --- Sqlite-optimized subqueries

    def make_timestamp_intervals(
        self,
        start_time: datetime.datetime,
        end_time: datetime.datetime,
        interval: datetime.timedelta,
    ):
        from prefect.orion.utilities.database import Timestamp

        # validate inputs
        start_time = pendulum.instance(start_time)
        end_time = pendulum.instance(end_time)
        assert isinstance(interval, datetime.timedelta)

        return (
            sa.text(
                r"""
                -- recursive CTE to mimic the behavior of `generate_series`,
                -- which is only available as a compiled extension
                WITH RECURSIVE intervals(interval_start, interval_end, counter) AS (
                    VALUES(
                        strftime('%Y-%m-%d %H:%M:%f000', :start_time),
                        strftime('%Y-%m-%d %H:%M:%f000', :start_time, :interval),
                        1
                        )

                    UNION ALL

                    SELECT interval_end, strftime('%Y-%m-%d %H:%M:%f000', interval_end, :interval), counter + 1
                    FROM intervals
                    -- subtract interval because recursive where clauses are effectively evaluated on a t-1 lag
                    WHERE
                        interval_start < strftime('%Y-%m-%d %H:%M:%f000', :end_time, :negative_interval)
                        -- don't compute more than 500 intervals
                        AND counter < 500
                )
                SELECT * FROM intervals
                """
            )
            .bindparams(
                start_time=str(start_time),
                end_time=str(end_time),
                interval=f"+{interval.total_seconds()} seconds",
                negative_interval=f"-{interval.total_seconds()} seconds",
            )
            .columns(interval_start=Timestamp(), interval_end=Timestamp())
        )

    def set_state_id_on_inserted_flow_runs_statement(
        self,
        fr_model,
        frs_model,
        inserted_flow_run_ids,
        insert_flow_run_states,
    ):
        """Given a list of flow run ids and associated states, set the state_id
        to the appropriate state for all flow runs"""
        # sqlite requires a correlated subquery to update from another table
        subquery = (
            sa.select(frs_model.id)
            .where(
                frs_model.flow_run_id == fr_model.id,
                frs_model.id.in_([r["id"] for r in insert_flow_run_states]),
            )
            .limit(1)
            .scalar_subquery()
        )
        stmt = (
            sa.update(fr_model)
            .where(
                fr_model.id.in_(inserted_flow_run_ids),
            )
            .values(state_id=subquery)
            # no need to synchronize as these flow runs are entirely new
            .execution_options(synchronize_session=False)
        )
        return stmt

    async def get_flow_run_notifications_from_queue(
        self, session: AsyncSession, db: "OrionDBInterface", limit: int
    ) -> List:
        """
        Sqlalchemy has no support for DELETE RETURNING in sqlite (as of May 2022)
        so instead we issue two queries; one to get queued notifications and a second to delete
        them. This *could* introduce race conditions if multiple queue workers are
        running.
        """

        notification_details_stmt = (
            sa.select(
                db.FlowRunNotificationQueue.id.label("queue_id"),
                db.FlowRunNotificationPolicy.id.label(
                    "flow_run_notification_policy_id"
                ),
                db.FlowRunNotificationPolicy.message_template.label(
                    "flow_run_notification_policy_message_template"
                ),
                db.FlowRunNotificationPolicy.block_document_id,
                db.Flow.id.label("flow_id"),
                db.Flow.name.label("flow_name"),
                db.FlowRun.id.label("flow_run_id"),
                db.FlowRun.name.label("flow_run_name"),
                db.FlowRun.parameters.label("flow_run_parameters"),
                db.FlowRunState.type.label("flow_run_state_type"),
                db.FlowRunState.name.label("flow_run_state_name"),
                db.FlowRunState.timestamp.label("flow_run_state_timestamp"),
                db.FlowRunState.message.label("flow_run_state_message"),
            )
            .select_from(db.FlowRunNotificationQueue)
            .join(
                db.FlowRunNotificationPolicy,
                db.FlowRunNotificationQueue.flow_run_notification_policy_id
                == db.FlowRunNotificationPolicy.id,
            )
            .join(
                db.FlowRunState,
                db.FlowRunNotificationQueue.flow_run_state_id == db.FlowRunState.id,
            )
            .join(
                db.FlowRun,
                db.FlowRunState.flow_run_id == db.FlowRun.id,
            )
            .join(
                db.Flow,
                db.FlowRun.flow_id == db.Flow.id,
            )
            .order_by(db.FlowRunNotificationQueue.updated)
            .limit(limit)
        )

        result = await session.execute(notification_details_stmt)
        notifications = result.fetchall()

        # delete the notifications
        delete_stmt = (
            sa.delete(db.FlowRunNotificationQueue)
            .where(
                db.FlowRunNotificationQueue.id.in_([n.queue_id for n in notifications])
            )
            .execution_options(synchronize_session="fetch")
        )

        await session.execute(delete_stmt)

        return notifications

    async def _handle_filtered_block_document_ids(
        self, session, filtered_block_documents_query
    ):
        """
        On SQLite, including the filtered block document parameters confuses the
        compiler and it passes positional parameters in the wrong order (it is
        unclear why; SQLalchemy manual compilation works great. Switching to
        `named` paramstyle also works but fails elsewhere in the codebase). To
        resolve this, we materialize the filtered id query into a literal set of
        IDs rather than leaving it as a SQL select.
        """
        result = await session.execute(filtered_block_documents_query)
        return result.scalars().all()

    def _get_scheduled_flow_runs_join(
        self,
        db: "OrionDBInterface",
        work_queue_query,
        limit_per_queue: Optional[int],
        scheduled_before: Optional[datetime.datetime],
    ):

        # precompute for readability
        scheduled_before_clause = (
            db.FlowRun.next_scheduled_start_time <= scheduled_before
            if scheduled_before is not None
            else True
        )

        # select scheduled flow runs, ordered by scheduled start time per queue
        scheduled_flow_runs = (
            sa.select(
                (
                    sa.func.row_number()
                    .over(
                        partition_by=[db.FlowRun.work_queue_name],
                        order_by=db.FlowRun.next_scheduled_start_time,
                    )
                    .label("rank")
                ),
                db.FlowRun,
            )
            .where(
                db.FlowRun.state_type == "SCHEDULED",
                scheduled_before_clause,
            )
            .subquery("scheduled_flow_runs")
        )

        # sqlite short-circuits the `min` comparison on nulls, so we use `999999`
        # as an "unlimited" limit.
        limit = 999999 if limit_per_queue is None else limit_per_queue

        # in the join, only keep flow runs whose rank is less than or equal to the
        # available slots for each queue
        join_criteria = sa.and_(
            self._flow_run_work_queue_join_clause(scheduled_flow_runs.c, db.WorkQueue),
            scheduled_flow_runs.c.rank
            <= sa.func.min(
                sa.func.coalesce(work_queue_query.c.available_slots, limit), limit
            ),
        )
        return scheduled_flow_runs, join_criteria

    # -------------------------------------------------------
    # Workers
    # -------------------------------------------------------

    @property
    def _get_scheduled_flow_runs_from_work_pool_template_path(self):
        """
        Template for the query to get scheduled flow runs from a work pool
        """
        return "sqlite/get-runs-from-worker-queues.sql.jinja"