Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
prefect / orion / database / orm_models.py
Size: Mime:
import datetime
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Hashable, List, Tuple, Union

import pendulum
import sqlalchemy as sa
from sqlalchemy import FetchedValue
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import as_declarative, declarative_mixin, declared_attr
from sqlalchemy.sql.functions import coalesce

import prefect
import prefect.orion.schemas as schemas
from prefect.orion.utilities.database import (
    JSON,
    UUID,
    GenerateUUID,
    Pydantic,
    Timestamp,
    camel_to_snake,
    date_diff,
    interval_add,
    now,
)
from prefect.orion.utilities.encryption import decrypt_fernet, encrypt_fernet
from prefect.utilities.names import generate_slug


class ORMBase:
    """
    Base SQLAlchemy model that automatically infers the table name
    and provides ID, created, and updated columns
    """

    # required in order to access columns with server defaults
    # or SQL expression defaults, subsequent to a flush, without
    # triggering an expired load
    #
    # this allows us to load attributes with a server default after
    # an INSERT, for example
    #
    # https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html#preventing-implicit-io-when-using-asyncsession
    __mapper_args__ = {"eager_defaults": True}

    def __repr__(self):
        return f"{self.__class__.__name__}(id={self.id})"

    @declared_attr
    def __tablename__(cls):
        """
        By default, turn the model's camel-case class name
        into a snake-case table name. Override by providing
        an explicit `__tablename__` class property.
        """
        return camel_to_snake.sub("_", cls.__name__).lower()

    id = sa.Column(
        UUID(),
        primary_key=True,
        server_default=GenerateUUID(),
        default=uuid.uuid4,
    )

    created = sa.Column(
        Timestamp(),
        nullable=False,
        server_default=now(),
        default=lambda: pendulum.now("UTC"),
    )

    # onupdate is only called when statements are actually issued
    # against the database. until COMMIT is issued, this column
    # will not be updated
    updated = sa.Column(
        Timestamp(),
        nullable=False,
        index=True,
        server_default=now(),
        default=lambda: pendulum.now("UTC"),
        onupdate=now(),
        server_onupdate=FetchedValue(),
    )


@declarative_mixin
class ORMFlow:
    """SQLAlchemy mixin of a flow."""

    name = sa.Column(sa.String, nullable=False)
    tags = sa.Column(JSON, server_default="[]", default=list, nullable=False)

    @declared_attr
    def flow_runs(cls):
        return sa.orm.relationship("FlowRun", back_populates="flow", lazy="raise")

    @declared_attr
    def deployments(cls):
        return sa.orm.relationship("Deployment", back_populates="flow", lazy="raise")

    @declared_attr
    def __table_args__(cls):
        return (sa.UniqueConstraint("name"), sa.Index("ix_flow__created", "created"))


@declarative_mixin
class ORMFlowRunState:
    """SQLAlchemy mixin of a flow run state."""

    # this column isn't explicitly indexed because it is included in
    # the unique compound index on (flow_run_id, timestamp)
    @declared_attr
    def flow_run_id(cls):
        return sa.Column(
            UUID(), sa.ForeignKey("flow_run.id", ondelete="cascade"), nullable=False
        )

    type = sa.Column(
        sa.Enum(schemas.states.StateType, name="state_type"), nullable=False, index=True
    )
    timestamp = sa.Column(
        Timestamp(),
        nullable=False,
        server_default=now(),
        default=lambda: pendulum.now("UTC"),
    )
    name = sa.Column(sa.String, nullable=False, index=True)
    message = sa.Column(sa.String)
    state_details = sa.Column(
        Pydantic(schemas.states.StateDetails),
        server_default="{}",
        default=schemas.states.StateDetails,
        nullable=False,
    )
    _data = sa.Column(sa.JSON, nullable=True, name="data")

    @declared_attr
    def result_artifact_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey(
                "artifact.id",
                ondelete="SET NULL",
                use_alter=True,
            ),
            index=True,
        )

    @declared_attr
    def _result_artifact(cls):
        return sa.orm.relationship(
            "Artifact",
            lazy="joined",
            foreign_keys=[cls.result_artifact_id],
            primaryjoin="Artifact.id==%s.result_artifact_id" % cls.__name__,
        )

    @hybrid_property
    def data(self):
        if self._data:
            # ensures backwards compatibility for results stored on state objects
            return self._data
        if not self.result_artifact_id:
            # do not try to load the relationship if there's no artifact id
            return None
        return self._result_artifact.data

    @declared_attr
    def flow_run(cls):
        return sa.orm.relationship(
            "FlowRun",
            lazy="raise",
            foreign_keys=[cls.flow_run_id],
        )

    def as_state(self) -> schemas.states.State:
        return schemas.states.State.from_orm(self)

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "uq_flow_run_state__flow_run_id_timestamp_desc",
                "flow_run_id",
                sa.desc("timestamp"),
                unique=True,
            ),
        )


@declarative_mixin
class ORMTaskRunState:
    """SQLAlchemy model of a task run state."""

    # this column isn't explicitly indexed because it is included in
    # the unique compound index on (task_run_id, timestamp)
    @declared_attr
    def task_run_id(cls):
        return sa.Column(
            UUID(), sa.ForeignKey("task_run.id", ondelete="cascade"), nullable=False
        )

    type = sa.Column(
        sa.Enum(schemas.states.StateType, name="state_type"), nullable=False, index=True
    )
    timestamp = sa.Column(
        Timestamp(),
        nullable=False,
        server_default=now(),
        default=lambda: pendulum.now("UTC"),
    )
    name = sa.Column(sa.String, nullable=False, index=True)
    message = sa.Column(sa.String)
    state_details = sa.Column(
        Pydantic(schemas.states.StateDetails),
        server_default="{}",
        default=schemas.states.StateDetails,
        nullable=False,
    )
    _data = sa.Column(sa.JSON, nullable=True, name="data")

    @declared_attr
    def result_artifact_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey(
                "artifact.id",
                ondelete="SET NULL",
                use_alter=True,
            ),
            index=True,
        )

    @declared_attr
    def _result_artifact(cls):
        return sa.orm.relationship(
            "Artifact",
            lazy="joined",
            foreign_keys=[cls.result_artifact_id],
            primaryjoin="Artifact.id==%s.result_artifact_id" % cls.__name__,
        )

    @hybrid_property
    def data(self):
        if self._data:
            # ensures backwards compatibility for results stored on state objects
            return self._data
        if not self.result_artifact_id:
            # do not try to load the relationship if there's no artifact id
            return None
        return self._result_artifact.data

    @declared_attr
    def task_run(cls):
        return sa.orm.relationship(
            "TaskRun",
            lazy="raise",
            foreign_keys=[cls.task_run_id],
        )

    def as_state(self) -> schemas.states.State:
        return schemas.states.State.from_orm(self)

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "uq_task_run_state__task_run_id_timestamp_desc",
                "task_run_id",
                sa.desc("timestamp"),
                unique=True,
            ),
        )


@declarative_mixin
class ORMArtifact:
    """
    SQLAlchemy model of artifacts.
    """

    key = sa.Column(
        sa.String,
        nullable=True,
        index=True,
    )

    @declared_attr
    def task_run_id(cls):
        return sa.Column(
            UUID(),
            nullable=True,
            index=True,
        )

    @declared_attr
    def flow_run_id(cls):
        return sa.Column(
            UUID(),
            nullable=True,
            index=True,
        )

    type = sa.Column(sa.String)
    data = sa.Column(sa.JSON, nullable=True)
    # Suffixed with underscore as attribute name 'metadata' is reserved for the MetaData instance when using a declarative base class.
    metadata_ = sa.Column(sa.JSON, nullable=True)

    @declared_attr
    def __table_args__(cls):
        return (sa.UniqueConstraint("key"),)


class ORMTaskRunStateCache:
    """
    SQLAlchemy model of a task run state cache.
    """

    cache_key = sa.Column(sa.String, nullable=False)
    cache_expiration = sa.Column(
        Timestamp(),
        nullable=True,
    )
    task_run_state_id = sa.Column(UUID(), nullable=False)

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "ix_task_run_state_cache__cache_key_created_desc",
                "cache_key",
                sa.desc("created"),
            ),
        )


@declarative_mixin
class ORMRun:
    """
    Common columns and logic for FlowRun and TaskRun models
    """

    name = sa.Column(
        sa.String,
        default=lambda: generate_slug(2),
        nullable=False,
        index=True,
    )
    state_type = sa.Column(sa.Enum(schemas.states.StateType, name="state_type"))
    state_name = sa.Column(sa.String, nullable=True)
    state_timestamp = sa.Column(Timestamp(), nullable=True)
    run_count = sa.Column(sa.Integer, server_default="0", default=0, nullable=False)
    expected_start_time = sa.Column(Timestamp())
    next_scheduled_start_time = sa.Column(Timestamp())
    start_time = sa.Column(Timestamp())
    end_time = sa.Column(Timestamp())
    total_run_time = sa.Column(
        sa.Interval(),
        server_default="0",
        default=datetime.timedelta(0),
        nullable=False,
    )

    @hybrid_property
    def estimated_run_time(self):
        """Total run time is incremented in the database whenever a RUNNING
        state is exited. To give up-to-date estimates, we estimate incremental
        run time for any runs currently in a RUNNING state."""
        if self.state_type and self.state_type == schemas.states.StateType.RUNNING:
            return self.total_run_time + (pendulum.now() - self.state_timestamp)
        else:
            return self.total_run_time

    @estimated_run_time.expression
    def estimated_run_time(cls):
        return (
            sa.select(
                sa.case(
                    (
                        cls.state_type == schemas.states.StateType.RUNNING,
                        interval_add(
                            cls.total_run_time,
                            date_diff(now(), cls.state_timestamp),
                        ),
                    ),
                    else_=cls.total_run_time,
                )
            )
            # add a correlate statement so this can reuse the `FROM` clause
            # of any parent query
            .correlate(cls).label("estimated_run_time")
        )

    @hybrid_property
    def estimated_start_time_delta(self) -> datetime.timedelta:
        """The delta to the expected start time (or "lateness") is computed as
        the difference between the actual start time and expected start time. To
        give up-to-date estimates, we estimate lateness for any runs that don't
        have a start time and are not in a final state and were expected to
        start already."""
        if self.start_time and self.start_time > self.expected_start_time:
            return (self.start_time - self.expected_start_time).as_interval()
        elif (
            self.start_time is None
            and self.expected_start_time
            and self.expected_start_time < pendulum.now("UTC")
            and self.state_type not in schemas.states.TERMINAL_STATES
        ):
            return (pendulum.now("UTC") - self.expected_start_time).as_interval()
        else:
            return datetime.timedelta(0)

    @estimated_start_time_delta.expression
    def estimated_start_time_delta(cls):
        return sa.case(
            (
                cls.start_time > cls.expected_start_time,
                date_diff(cls.start_time, cls.expected_start_time),
            ),
            (
                sa.and_(
                    cls.start_time.is_(None),
                    cls.state_type.not_in(schemas.states.TERMINAL_STATES),
                    cls.expected_start_time < now(),
                ),
                date_diff(now(), cls.expected_start_time),
            ),
            else_=datetime.timedelta(0),
        )


@declarative_mixin
class ORMFlowRun(ORMRun):
    """SQLAlchemy model of a flow run."""

    @declared_attr
    def flow_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("flow.id", ondelete="cascade"),
            nullable=False,
            index=True,
        )

    @declared_attr
    def deployment_id(cls):
        return sa.Column(
            UUID(), sa.ForeignKey("deployment.id", ondelete="set null"), index=True
        )

    work_queue_name = sa.Column(sa.String, index=True)
    flow_version = sa.Column(sa.String, index=True)
    parameters = sa.Column(JSON, server_default="{}", default=dict, nullable=False)
    idempotency_key = sa.Column(sa.String)
    context = sa.Column(JSON, server_default="{}", default=dict, nullable=False)
    empirical_policy = sa.Column(
        Pydantic(schemas.core.FlowRunPolicy),
        server_default="{}",
        default=schemas.core.FlowRunPolicy,
        nullable=False,
    )
    tags = sa.Column(JSON, server_default="[]", default=list, nullable=False)
    created_by = sa.Column(
        Pydantic(schemas.core.CreatedBy),
        server_default=None,
        default=None,
        nullable=True,
    )

    infrastructure_pid = sa.Column(sa.String)

    @declared_attr
    def infrastructure_document_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("block_document.id", ondelete="CASCADE"),
            nullable=True,
            index=True,
        )

    @declared_attr
    def parent_task_run_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey(
                "task_run.id",
                ondelete="SET NULL",
                use_alter=True,
            ),
            index=True,
        )

    auto_scheduled = sa.Column(
        sa.Boolean, server_default="0", default=False, nullable=False
    )

    # TODO remove this foreign key for significant delete performance gains
    @declared_attr
    def state_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey(
                "flow_run_state.id",
                ondelete="SET NULL",
                use_alter=True,
            ),
            index=True,
        )

    @declared_attr
    def work_queue_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("work_queue.id", ondelete="SET NULL"),
            nullable=True,
            index=True,
        )

    # -------------------------- relationships

    # current states are eagerly loaded unless otherwise specified
    @declared_attr
    def _state(cls):
        return sa.orm.relationship(
            "FlowRunState",
            lazy="joined",
            foreign_keys=[cls.state_id],
            primaryjoin="FlowRunState.id==%s.state_id" % cls.__name__,
        )

    @hybrid_property
    def state(self):
        return self._state

    @state.setter
    def state(self, value):
        # because this is a slightly non-standard SQLAlchemy relationship, we
        # prefer an explicit setter method to a setter property, because
        # user expectations about SQLAlchemy attribute assignment might not be
        # met, namely that an unrelated (from SQLAlchemy's perspective) field of
        # the provided state is also modified. However, property assignment
        # still works because the ORM model's __init__ depends on it.
        return self.set_state(value)

    def set_state(self, state):
        """
        If a state is assigned to this run, populate its run id.

        This would normally be handled by the back-populated SQLAlchemy
        relationship, but because this is a one-to-one pointer to a
        one-to-many relationship, SQLAlchemy can't figure it out.
        """
        if state is not None:
            state.flow_run_id = self.id
        self._state = state

    @declared_attr
    def flow(cls):
        return sa.orm.relationship("Flow", back_populates="flow_runs", lazy="raise")

    @declared_attr
    def task_runs(cls):
        return sa.orm.relationship(
            "TaskRun",
            back_populates="flow_run",
            lazy="raise",
            # foreign_keys=lambda: [cls.flow_run_id],
            primaryjoin="TaskRun.flow_run_id==%s.id" % cls.__name__,
        )

    @declared_attr
    def parent_task_run(cls):
        return sa.orm.relationship(
            "TaskRun",
            back_populates="subflow_run",
            lazy="raise",
            foreign_keys=lambda: [cls.parent_task_run_id],
        )

    @declared_attr
    def work_queue(cls):
        return sa.orm.relationship(
            "WorkQueue",
            lazy="joined",
            foreign_keys=[cls.work_queue_id],
        )

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "uq_flow_run__flow_id_idempotency_key",
                "flow_id",
                "idempotency_key",
                unique=True,
            ),
            sa.Index(
                "ix_flow_run__coalesce_start_time_expected_start_time_desc",
                sa.desc(coalesce("start_time", "expected_start_time")),
            ),
            sa.Index(
                "ix_flow_run__coalesce_start_time_expected_start_time_asc",
                sa.asc(coalesce("start_time", "expected_start_time")),
            ),
            sa.Index(
                "ix_flow_run__expected_start_time_desc",
                sa.desc("expected_start_time"),
            ),
            sa.Index(
                "ix_flow_run__next_scheduled_start_time_asc",
                sa.asc("next_scheduled_start_time"),
            ),
            sa.Index(
                "ix_flow_run__end_time_desc",
                sa.desc("end_time"),
            ),
            sa.Index(
                "ix_flow_run__start_time",
                "start_time",
            ),
            sa.Index(
                "ix_flow_run__state_type",
                "state_type",
            ),
            sa.Index(
                "ix_flow_run__state_name",
                "state_name",
            ),
            sa.Index(
                "ix_flow_run__state_timestamp",
                "state_timestamp",
            ),
        )


@declarative_mixin
class ORMTaskRun(ORMRun):
    """SQLAlchemy model of a task run."""

    @declared_attr
    def flow_run_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("flow_run.id", ondelete="cascade"),
            nullable=False,
            index=True,
        )

    task_key = sa.Column(sa.String, nullable=False)
    dynamic_key = sa.Column(sa.String, nullable=False)
    cache_key = sa.Column(sa.String)
    cache_expiration = sa.Column(Timestamp())
    task_version = sa.Column(sa.String)
    flow_run_run_count = sa.Column(
        sa.Integer, server_default="0", default=0, nullable=False
    )
    empirical_policy = sa.Column(
        Pydantic(schemas.core.TaskRunPolicy),
        server_default="{}",
        default=schemas.core.TaskRunPolicy,
        nullable=False,
    )
    task_inputs = sa.Column(
        Pydantic(
            Dict[
                str,
                List[
                    Union[
                        schemas.core.TaskRunResult,
                        schemas.core.Parameter,
                        schemas.core.Constant,
                    ]
                ],
            ]
        ),
        server_default="{}",
        default=dict,
        nullable=False,
    )
    tags = sa.Column(JSON, server_default="[]", default=list, nullable=False)

    # TODO remove this foreign key for significant delete performance gains
    @declared_attr
    def state_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey(
                "task_run_state.id",
                ondelete="SET NULL",
                use_alter=True,
            ),
            index=True,
        )

    # -------------------------- relationships

    # current states are eagerly loaded unless otherwise specified
    @declared_attr
    def _state(cls):
        return sa.orm.relationship(
            "TaskRunState",
            lazy="joined",
            foreign_keys=[cls.state_id],
            primaryjoin="TaskRunState.id==%s.state_id" % cls.__name__,
        )

    @hybrid_property
    def state(self):
        return self._state

    @state.setter
    def state(self, value):
        # because this is a slightly non-standard SQLAlchemy relationship, we
        # prefer an explicit setter method to a setter property, because
        # user expectations about SQLAlchemy attribute assignment might not be
        # met, namely that an unrelated (from SQLAlchemy's perspective) field of
        # the provided state is also modified. However, property assignment
        # still works because the ORM model's __init__ depends on it.
        return self.set_state(value)

    def set_state(self, state):
        """
        If a state is assigned to this run, populate its run id.

        This would normally be handled by the back-populated SQLAlchemy
        relationship, but because this is a one-to-one pointer to a
        one-to-many relationship, SQLAlchemy can't figure it out.
        """
        if state is not None:
            state.task_run_id = self.id
        self._state = state

    @declared_attr
    def flow_run(cls):
        return sa.orm.relationship(
            "FlowRun",
            back_populates="task_runs",
            lazy="raise",
            foreign_keys=[cls.flow_run_id],
        )

    @declared_attr
    def subflow_run(cls):
        return sa.orm.relationship(
            "FlowRun",
            back_populates="parent_task_run",
            lazy="raise",
            # foreign_keys=["FlowRun.parent_task_run_id"],
            primaryjoin="FlowRun.parent_task_run_id==%s.id" % cls.__name__,
            uselist=False,
        )

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "uq_task_run__flow_run_id_task_key_dynamic_key",
                "flow_run_id",
                "task_key",
                "dynamic_key",
                unique=True,
            ),
            sa.Index(
                "ix_task_run__expected_start_time_desc",
                sa.desc("expected_start_time"),
            ),
            sa.Index(
                "ix_task_run__next_scheduled_start_time_asc",
                sa.asc("next_scheduled_start_time"),
            ),
            sa.Index(
                "ix_task_run__end_time_desc",
                sa.desc("end_time"),
            ),
            sa.Index(
                "ix_task_run__start_time",
                "start_time",
            ),
            sa.Index(
                "ix_task_run__state_type",
                "state_type",
            ),
            sa.Index(
                "ix_task_run__state_name",
                "state_name",
            ),
            sa.Index(
                "ix_task_run__state_timestamp",
                "state_timestamp",
            ),
        )


@declarative_mixin
class ORMDeployment:
    """SQLAlchemy model of a deployment."""

    name = sa.Column(sa.String, nullable=False)
    version = sa.Column(sa.String, nullable=True)
    description = sa.Column(sa.Text(), nullable=True)
    manifest_path = sa.Column(sa.String, nullable=True)
    work_queue_name = sa.Column(sa.String, nullable=True, index=True)
    infra_overrides = sa.Column(JSON, server_default="{}", default=dict, nullable=False)
    path = sa.Column(sa.String, nullable=True)
    entrypoint = sa.Column(sa.String, nullable=True)

    @declared_attr
    def flow_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("flow.id", ondelete="CASCADE"),
            nullable=False,
            index=True,
        )

    @declared_attr
    def work_queue_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("work_queue.id", ondelete="SET NULL"),
            nullable=True,
            index=True,
        )

    schedule = sa.Column(Pydantic(schemas.schedules.SCHEDULE_TYPES))
    is_schedule_active = sa.Column(
        sa.Boolean, nullable=False, server_default="1", default=True
    )
    tags = sa.Column(JSON, server_default="[]", default=list, nullable=False)
    parameters = sa.Column(JSON, server_default="{}", default=dict, nullable=False)
    parameter_openapi_schema = sa.Column(JSON, default=dict, nullable=True)
    created_by = sa.Column(
        Pydantic(schemas.core.CreatedBy),
        server_default=None,
        default=None,
        nullable=True,
    )
    updated_by = sa.Column(
        Pydantic(schemas.core.UpdatedBy),
        server_default=None,
        default=None,
        nullable=True,
    )

    @declared_attr
    def infrastructure_document_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("block_document.id", ondelete="CASCADE"),
            nullable=True,
            index=False,
        )

    @declared_attr
    def storage_document_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("block_document.id", ondelete="CASCADE"),
            nullable=True,
            index=False,
        )

    @declared_attr
    def flow(cls):
        return sa.orm.relationship("Flow", back_populates="deployments", lazy="raise")

    @declared_attr
    def work_queue(cls):
        return sa.orm.relationship(
            "WorkQueue", lazy="joined", foreign_keys=[cls.work_queue_id]
        )

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "uq_deployment__flow_id_name",
                "flow_id",
                "name",
                unique=True,
            ),
            sa.Index(
                "ix_deployment__created",
                "created",
            ),
        )


@declarative_mixin
class ORMLog:
    """
    SQLAlchemy model of a logging statement.
    """

    name = sa.Column(sa.String, nullable=False)
    level = sa.Column(sa.SmallInteger, nullable=False, index=True)
    flow_run_id = sa.Column(UUID(), nullable=False, index=True)
    task_run_id = sa.Column(UUID(), nullable=True, index=True)
    message = sa.Column(sa.Text, nullable=False)

    # The client-side timestamp of this logged statement.
    timestamp = sa.Column(Timestamp(), nullable=False, index=True)


@declarative_mixin
class ORMConcurrencyLimit:
    tag = sa.Column(sa.String, nullable=False)
    concurrency_limit = sa.Column(sa.Integer, nullable=False)
    active_slots = sa.Column(JSON, server_default="[]", default=list, nullable=False)

    @declared_attr
    def __table_args__(cls):
        return (sa.Index("uq_concurrency_limit__tag", "tag", unique=True),)


@declarative_mixin
class ORMBlockType:
    name = sa.Column(sa.String, nullable=False)
    slug = sa.Column(sa.String, nullable=False)
    logo_url = sa.Column(sa.String, nullable=True)
    documentation_url = sa.Column(sa.String, nullable=True)
    description = sa.Column(sa.String, nullable=True)
    code_example = sa.Column(sa.String, nullable=True)
    is_protected = sa.Column(
        sa.Boolean, nullable=False, server_default="0", default=False
    )

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "uq_block_type__slug",
                "slug",
                unique=True,
            ),
        )


@declarative_mixin
class ORMBlockSchema:
    checksum = sa.Column(sa.String, nullable=False)
    fields = sa.Column(JSON, server_default="{}", default=dict, nullable=False)
    capabilities = sa.Column(JSON, server_default="[]", default=list, nullable=False)
    version = sa.Column(
        sa.String,
        server_default=schemas.core.DEFAULT_BLOCK_SCHEMA_VERSION,
        nullable=False,
    )

    @declared_attr
    def block_type_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("block_type.id", ondelete="cascade"),
            nullable=False,
            index=True,
        )

    @declared_attr
    def block_type(cls):
        return sa.orm.relationship("BlockType", lazy="joined")

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "uq_block_schema__checksum_version",
                "checksum",
                "version",
                unique=True,
            ),
            sa.Index("ix_block_schema__created", "created"),
        )


@declarative_mixin
class ORMBlockSchemaReference:
    name = sa.Column(sa.String, nullable=False)

    @declared_attr
    def parent_block_schema_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("block_schema.id", ondelete="cascade"),
            nullable=False,
        )

    @declared_attr
    def reference_block_schema_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("block_schema.id", ondelete="cascade"),
            nullable=False,
        )


@declarative_mixin
class ORMBlockDocument:
    name = sa.Column(sa.String, nullable=False, index=True)
    data = sa.Column(JSON, server_default="{}", default=dict, nullable=False)
    is_anonymous = sa.Column(sa.Boolean, server_default="0", index=True, nullable=False)

    @declared_attr
    def block_type_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("block_type.id", ondelete="cascade"),
            nullable=False,
        )

    @declared_attr
    def block_type(cls):
        return sa.orm.relationship("BlockType", lazy="joined")

    @declared_attr
    def block_schema_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("block_schema.id", ondelete="cascade"),
            nullable=False,
        )

    @declared_attr
    def block_schema(cls):
        return sa.orm.relationship("BlockSchema", lazy="joined")

    @declared_attr
    def __table_args__(cls):
        return (
            sa.Index(
                "uq_block__type_id_name",
                "block_type_id",
                "name",
                unique=True,
            ),
        )

    async def encrypt_data(self, session, data):
        """
        Store encrypted data on the ORM model

        Note: will only succeed if the caller has sufficient permission.
        """
        self.data = await encrypt_fernet(session, data)

    async def decrypt_data(self, session):
        """
        Retrieve decrypted data from the ORM model.

        Note: will only succeed if the caller has sufficient permission.
        """
        return await decrypt_fernet(session, self.data)


@declarative_mixin
class ORMBlockDocumentReference:
    name = sa.Column(sa.String, nullable=False)

    @declared_attr
    def parent_block_document_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("block_document.id", ondelete="cascade"),
            nullable=False,
        )

    @declared_attr
    def reference_block_document_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("block_document.id", ondelete="cascade"),
            nullable=False,
        )


@declarative_mixin
class ORMConfiguration:
    key = sa.Column(sa.String, nullable=False, index=True)
    value = sa.Column(JSON, nullable=False)

    @declared_attr
    def __table_args__(cls):
        return (sa.UniqueConstraint("key"),)


@declarative_mixin
class ORMSavedSearch:
    """SQLAlchemy model of a saved search."""

    name = sa.Column(sa.String, nullable=False)
    filters = sa.Column(
        JSON,
        server_default="[]",
        default=list,
        nullable=False,
    )

    @declared_attr
    def __table_args__(cls):
        return (sa.UniqueConstraint("name"),)


@declarative_mixin
class ORMWorkQueue:
    """SQLAlchemy model of a work queue"""

    name = sa.Column(sa.String, nullable=False)

    filter = sa.Column(
        Pydantic(schemas.core.QueueFilter),
        server_default=None,
        default=None,
        nullable=True,
    )
    description = sa.Column(sa.String, nullable=False, default="", server_default="")
    is_paused = sa.Column(sa.Boolean, nullable=False, server_default="0", default=False)
    concurrency_limit = sa.Column(
        sa.Integer,
        nullable=True,
    )
    priority = sa.Column(sa.Integer, index=True, nullable=False)
    last_polled = sa.Column(
        Timestamp(),
        nullable=True,
    )

    @declared_attr
    def __table_args__(cls):
        return (sa.UniqueConstraint("work_pool_id", "name"),)

    @declared_attr
    def work_pool_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("work_pool.id", ondelete="cascade"),
            nullable=False,
            index=True,
        )

    @declared_attr
    def work_pool(cls):
        return sa.orm.relationship(
            "WorkPool",
            lazy="joined",
            foreign_keys=[cls.work_pool_id],
        )


@declarative_mixin
class ORMWorkPool:
    """SQLAlchemy model of an worker"""

    name = sa.Column(sa.String, nullable=False)
    description = sa.Column(sa.String)
    type = sa.Column(sa.String)
    base_job_template = sa.Column(JSON, nullable=False, server_default="{}", default={})
    is_paused = sa.Column(sa.Boolean, nullable=False, server_default="0", default=False)
    default_queue_id = sa.Column(UUID, nullable=True)
    concurrency_limit = sa.Column(
        sa.Integer,
        nullable=True,
    )

    @declared_attr
    def __table_args__(cls):
        return (sa.UniqueConstraint("name"),)


@declarative_mixin
class ORMWorker:
    """SQLAlchemy model of an worker"""

    @declared_attr
    def work_pool_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("work_pool.id", ondelete="cascade"),
            nullable=False,
            index=True,
        )

    name = sa.Column(sa.String, nullable=False)
    last_heartbeat_time = sa.Column(
        Timestamp(),
        nullable=False,
        server_default=now(),
        default=lambda: pendulum.now("UTC"),
        index=True,
    )

    @declared_attr
    def __table_args__(cls):
        return (sa.UniqueConstraint("work_pool_id", "name"),)


@declarative_mixin
class ORMAgent:
    """SQLAlchemy model of an agent"""

    name = sa.Column(sa.String, nullable=False)

    @declared_attr
    def work_queue_id(cls):
        return sa.Column(
            UUID,
            sa.ForeignKey("work_queue.id"),
            nullable=False,
            index=True,
        )

    last_activity_time = sa.Column(
        Timestamp(),
        nullable=False,
        server_default=now(),
        default=lambda: pendulum.now("UTC"),
    )

    @declared_attr
    def __table_args__(cls):
        return (sa.UniqueConstraint("name"),)


@declarative_mixin
class ORMFlowRunNotificationPolicy:
    is_active = sa.Column(sa.Boolean, server_default="1", default=True, nullable=False)
    state_names = sa.Column(JSON, server_default="[]", default=[], nullable=False)
    tags = sa.Column(JSON, server_default="[]", default=[], nullable=False)
    message_template = sa.Column(sa.String, nullable=True)

    @declared_attr
    def block_document_id(cls):
        return sa.Column(
            UUID(),
            sa.ForeignKey("block_document.id", ondelete="cascade"),
            nullable=False,
        )

    @declared_attr
    def block_document(cls):
        return sa.orm.relationship(
            "BlockDocument",
            lazy="joined",
            foreign_keys=[cls.block_document_id],
        )


@declarative_mixin
class ORMFlowRunNotificationQueue:
    # these are both foreign keys but there is no need to enforce that constraint
    # as this is just a queue for service workers; if the keys don't match at the
    # time work is pulled, the work can be discarded
    flow_run_notification_policy_id = sa.Column(UUID, nullable=False)
    flow_run_state_id = sa.Column(UUID, nullable=False)


class BaseORMConfiguration(ABC):
    """
    Abstract base class used to inject database-specific ORM configuration into Orion.

    Modifications to core Orion data structures can have unintended consequences.
    Use with caution.

    Args:
        base_metadata: sqlalchemy.schema.Metadata used to create the Base orm class
        base_model_mixins: a list of mixins to add to the Base orm model
        flow_mixin: flow orm mixin, combined with Base orm class
        flow_run_mixin: flow run orm mixin, combined with Base orm class
        flow_run_state_mixin: flow run state mixin, combined with Base orm class
        task_run_mixin: task run mixin, combined with Base orm class
        task_run_state_mixin: task run state, combined with Base orm class
        task_run_state_cache_mixin: task run state cache orm mixin, combined with Base orm class
        deployment_mixin: deployment orm mixin, combined with Base orm class
        saved_search_mixin: saved search orm mixin, combined with Base orm class
        log_mixin: log orm mixin, combined with Base orm class
        work_pool_mixin: work pool orm mixin, combined with Base orm class
        worker_mixin: worker orm mixin, combined with Base orm class
        concurrency_limit_mixin: concurrency limit orm mixin, combined with Base orm class
        block_type_mixin: block_type orm mixin, combined with Base orm class
        block_schema_mixin: block_schema orm mixin, combined with Base orm class
        block_schema_reference_mixin: block_schema_reference orm mixin, combined with Base orm class
        block_document_mixin: block_document orm mixin, combined with Base orm class
        block_document_reference_mixin: block_document_reference orm mixin, combined with Base orm class
        configuration_mixin: configuration orm mixin, combined with Base orm class

    """

    def __init__(
        self,
        base_metadata: sa.schema.MetaData = None,
        base_model_mixins: List = None,
        flow_mixin=ORMFlow,
        flow_run_mixin=ORMFlowRun,
        flow_run_state_mixin=ORMFlowRunState,
        task_run_mixin=ORMTaskRun,
        task_run_state_mixin=ORMTaskRunState,
        artifact_mixin=ORMArtifact,
        task_run_state_cache_mixin=ORMTaskRunStateCache,
        deployment_mixin=ORMDeployment,
        saved_search_mixin=ORMSavedSearch,
        log_mixin=ORMLog,
        concurrency_limit_mixin=ORMConcurrencyLimit,
        work_pool_mixin=ORMWorkPool,
        worker_mixin=ORMWorker,
        block_type_mixin=ORMBlockType,
        block_schema_mixin=ORMBlockSchema,
        block_schema_reference_mixin=ORMBlockSchemaReference,
        block_document_mixin=ORMBlockDocument,
        block_document_reference_mixin=ORMBlockDocumentReference,
        work_queue_mixin=ORMWorkQueue,
        agent_mixin=ORMAgent,
        configuration_mixin=ORMConfiguration,
    ):
        self.base_metadata = base_metadata or sa.schema.MetaData(
            # define naming conventions for our Base class to use
            # sqlalchemy will use the following templated strings
            # to generate the names of indices, constraints, and keys
            #
            # we offset the table name with two underscores (__) to
            # help differentiate, for example, between "flow_run.state_type"
            # and "flow_run_state.type".
            #
            # more information on this templating and available
            # customization can be found here
            # https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.MetaData
            #
            # this also allows us to avoid having to specify names explicitly
            # when using sa.ForeignKey.use_alter = True
            # https://docs.sqlalchemy.org/en/14/core/constraints.html
            naming_convention={
                "ix": "ix_%(table_name)s__%(column_0_N_name)s",
                "uq": "uq_%(table_name)s__%(column_0_N_name)s",
                "ck": "ck_%(table_name)s__%(constraint_name)s",
                "fk": "fk_%(table_name)s__%(column_0_N_name)s__%(referred_table_name)s",
                "pk": "pk_%(table_name)s",
            }
        )
        self.base_model_mixins = base_model_mixins or []

        self._create_base_model()
        self._create_orm_models(
            flow_mixin=flow_mixin,
            flow_run_mixin=flow_run_mixin,
            flow_run_state_mixin=flow_run_state_mixin,
            task_run_mixin=task_run_mixin,
            task_run_state_mixin=task_run_state_mixin,
            artifact_mixin=artifact_mixin,
            task_run_state_cache_mixin=task_run_state_cache_mixin,
            deployment_mixin=deployment_mixin,
            saved_search_mixin=saved_search_mixin,
            log_mixin=log_mixin,
            concurrency_limit_mixin=concurrency_limit_mixin,
            work_pool_mixin=work_pool_mixin,
            worker_mixin=worker_mixin,
            work_queue_mixin=work_queue_mixin,
            agent_mixin=agent_mixin,
            block_type_mixin=block_type_mixin,
            block_schema_mixin=block_schema_mixin,
            block_schema_reference_mixin=block_schema_reference_mixin,
            block_document_mixin=block_document_mixin,
            block_document_reference_mixin=block_document_reference_mixin,
            configuration_mixin=configuration_mixin,
        )

    def _unique_key(self) -> Tuple[Hashable, ...]:
        """
        Returns a key used to determine whether to instantiate a new DB interface.
        """
        return (self.__class__, self.base_metadata, tuple(self.base_model_mixins))

    def _create_base_model(self):
        """
        Defines the base ORM model and binds it to `self`. The base model will be
        extended by mixins specified in the database configuration. This method only
        runs on instantiation.
        """

        @as_declarative(metadata=self.base_metadata)
        class Base(*self.base_model_mixins, ORMBase):
            pass

        self.Base = Base

    def _create_orm_models(
        self,
        flow_mixin=ORMFlow,
        flow_run_mixin=ORMFlowRun,
        flow_run_state_mixin=ORMFlowRunState,
        task_run_mixin=ORMTaskRun,
        task_run_state_mixin=ORMTaskRunState,
        artifact_mixin=ORMArtifact,
        task_run_state_cache_mixin=ORMTaskRunStateCache,
        deployment_mixin=ORMDeployment,
        saved_search_mixin=ORMSavedSearch,
        log_mixin=ORMLog,
        concurrency_limit_mixin=ORMConcurrencyLimit,
        work_pool_mixin=ORMWorkPool,
        worker_mixin=ORMWorker,
        block_type_mixin=ORMBlockType,
        block_schema_mixin=ORMBlockSchema,
        block_schema_reference_mixin=ORMBlockSchemaReference,
        block_document_mixin=ORMBlockDocument,
        block_document_reference_mixin=ORMBlockDocumentReference,
        flow_run_notification_policy_mixin=ORMFlowRunNotificationPolicy,
        flow_run_notification_queue_mixin=ORMFlowRunNotificationQueue,
        work_queue_mixin=ORMWorkQueue,
        agent_mixin=ORMAgent,
        configuration_mixin=ORMConfiguration,
    ):
        """
        Defines the ORM models used in Orion and binds them to the `self`. This method
        only runs on instantiation.
        """

        class Flow(flow_mixin, self.Base):
            pass

        class FlowRunState(flow_run_state_mixin, self.Base):
            pass

        class TaskRunState(task_run_state_mixin, self.Base):
            pass

        class Artifact(artifact_mixin, self.Base):
            pass

        class TaskRunStateCache(task_run_state_cache_mixin, self.Base):
            pass

        class FlowRun(flow_run_mixin, self.Base):
            pass

        class TaskRun(task_run_mixin, self.Base):
            pass

        class Deployment(deployment_mixin, self.Base):
            pass

        class SavedSearch(saved_search_mixin, self.Base):
            pass

        class Log(log_mixin, self.Base):
            pass

        class ConcurrencyLimit(concurrency_limit_mixin, self.Base):
            pass

        class WorkPool(work_pool_mixin, self.Base):
            pass

        class Worker(worker_mixin, self.Base):
            pass

        class WorkQueue(work_queue_mixin, self.Base):
            pass

        class Agent(agent_mixin, self.Base):
            pass

        class BlockType(block_type_mixin, self.Base):
            pass

        class BlockSchema(block_schema_mixin, self.Base):
            pass

        class BlockSchemaReference(block_schema_reference_mixin, self.Base):
            pass

        class BlockDocument(block_document_mixin, self.Base):
            pass

        class BlockDocumentReference(block_document_reference_mixin, self.Base):
            pass

        class FlowRunNotificationPolicy(flow_run_notification_policy_mixin, self.Base):
            pass

        class FlowRunNotificationQueue(flow_run_notification_queue_mixin, self.Base):
            pass

        class Configuration(configuration_mixin, self.Base):
            pass

        self.Flow = Flow
        self.FlowRunState = FlowRunState
        self.TaskRunState = TaskRunState
        self.Artifact = Artifact
        self.TaskRunStateCache = TaskRunStateCache
        self.FlowRun = FlowRun
        self.TaskRun = TaskRun
        self.Deployment = Deployment
        self.SavedSearch = SavedSearch
        self.Log = Log
        self.ConcurrencyLimit = ConcurrencyLimit
        self.WorkPool = WorkPool
        self.Worker = Worker
        self.WorkQueue = WorkQueue
        self.Agent = Agent
        self.BlockType = BlockType
        self.BlockSchema = BlockSchema
        self.BlockSchemaReference = BlockSchemaReference
        self.BlockDocument = BlockDocument
        self.BlockDocumentReference = BlockDocumentReference
        self.FlowRunNotificationPolicy = FlowRunNotificationPolicy
        self.FlowRunNotificationQueue = FlowRunNotificationQueue
        self.Configuration = Configuration

    @property
    @abstractmethod
    def versions_dir(self):
        """Directory containing migrations"""
        ...

    @property
    def deployment_unique_upsert_columns(self):
        """Unique columns for upserting a Deployment"""
        return [self.Deployment.flow_id, self.Deployment.name]

    @property
    def concurrency_limit_unique_upsert_columns(self):
        """Unique columns for upserting a ConcurrencyLimit"""
        return [self.ConcurrencyLimit.tag]

    @property
    def flow_run_unique_upsert_columns(self):
        """Unique columns for upserting a FlowRun"""
        return [self.FlowRun.flow_id, self.FlowRun.idempotency_key]

    @property
    def block_type_unique_upsert_columns(self):
        """Unique columns for upserting a BlockType"""
        return [self.BlockType.slug]

    @property
    def block_schema_unique_upsert_columns(self):
        """Unique columns for upserting a BlockSchema"""
        return [self.BlockSchema.checksum, self.BlockSchema.version]

    @property
    def flow_unique_upsert_columns(self):
        """Unique columns for upserting a Flow"""
        return [self.Flow.name]

    @property
    def saved_search_unique_upsert_columns(self):
        """Unique columns for upserting a SavedSearch"""
        return [self.SavedSearch.name]

    @property
    def task_run_unique_upsert_columns(self):
        """Unique columns for upserting a TaskRun"""
        return [
            self.TaskRun.flow_run_id,
            self.TaskRun.task_key,
            self.TaskRun.dynamic_key,
        ]

    @property
    def block_document_unique_upsert_columns(self):
        """Unique columns for upserting a BlockDocument"""
        return [self.BlockDocument.block_type_id, self.BlockDocument.name]


class AsyncPostgresORMConfiguration(BaseORMConfiguration):
    """Postgres specific orm configuration"""

    @property
    def versions_dir(self) -> Path:
        """Directory containing migrations"""
        return (
            Path(prefect.orion.database.__file__).parent
            / "migrations"
            / "versions"
            / "postgresql"
        )


class AioSqliteORMConfiguration(BaseORMConfiguration):
    """SQLite specific orm configuration"""

    @property
    def versions_dir(self) -> Path:
        """Directory containing migrations"""
        return (
            Path(prefect.orion.database.__file__).parent
            / "migrations"
            / "versions"
            / "sqlite"
        )