Repository URL to install this package:
|
Version:
0.6.6.dev0 ▾
|
""" Copyright (C) Sarus Technologies SAS - All Rights Reserved
Unauthorized copying of this file, via any medium is strictly prohibited
Proprietary and confidential
Write to contact@sarus.tech for more information about purchasing a licence
"""
import json
import os
import pickle
import typing as t
from urllib.parse import quote_plus, urlparse
from uuid import uuid4
import boto3
import fsspec.implementations.dirfs
from flask import abort
from flask_login import UserMixin
from fsspec import filesystem as fsspec_filesystem
from fsspec.core import OpenFile as open_file
from gcsfs import GCSFileSystem
from s3fs import S3FileSystem
from sqlalchemy import MetaData, create_engine, inspect
from sqlalchemy.dialects import registry
from sqlalchemy.orm import Session
from sqlalchemy.sql import expression
from sqlalchemy.sql.functions import now
from werkzeug.security import check_password_hash, generate_password_hash
from sarus_flask_auth.ability_enum import AbilityEnum
MS_2012_VERSION = (11,)
def create_auth_models(db, bind_key=None):
class User(db.Model, UserMixin):
"""Model for user accounts"""
__tablename__ = "users"
__bind_key__ = bind_key
id = db.Column(db.Integer, primary_key=True)
login_id = db.Column(
db.String(36), unique=True, default=lambda: str(uuid4())
)
active = db.Column(
"is_active", db.Boolean(), nullable=False, server_default="1"
)
email = db.Column(db.String(255), nullable=False, unique=True)
email_confirmed_at = db.Column(db.DateTime())
password = db.Column(db.String(255), nullable=False, server_default="")
oidc_id = db.Column(db.String(255), nullable=True)
initial_password = db.Column(
db.Boolean(), nullable=False, server_default=expression.true()
)
# User information
username = db.Column(db.String(100), nullable=False, server_default="")
# Terms & conditions readed & accepted
terms_readed = db.Column(db.LargeBinary, nullable=True)
terms_readed_at = db.Column(db.DateTime(), nullable=True)
# No Personal Data Acceptance
nopersonaldata_acceptance_at = db.Column(db.DateTime(), nullable=True)
# Define the relationship to Group via UserGroups
groups = db.relationship("Group", secondary="user_groups")
roles = db.relationship("Role", secondary="user_role", backref="users")
organization_id = db.Column(
db.Integer, db.ForeignKey("organization.id"), nullable=False
)
organization = db.relationship("Organization", backref="users")
is_super_admin = db.Column(
db.Boolean, nullable=False, server_default=expression.false()
)
def set_password(self, password):
"""Create hashed password."""
self.password = generate_password_hash(
password, method="pbkdf2:sha256:600000"
)
if self.initial_password:
self.initial_password = False
def check_password(self, password):
"""Check hashed password."""
return check_password_hash(self.password, password)
def get_id(self):
# https://flask-login.readthedocs.io/en/latest/#alternative-tokens
return self.login_id
def __repr__(self):
return "<User {}>".format(self.email)
def has_ability(self, ability):
desired_ability = Ability.query.filter_by(name=ability).first()
current_user_abilities = [
current_user_ability
for role in self.roles
for current_user_ability in role.abilities
]
return desired_ability in current_user_abilities
def is_role(self, role):
desired_role = Role.query.filter_by(name=role).first()
return desired_role in self.roles
def to_dict(self):
response = {
"id": self.id,
"username": self.username,
"email": self.email,
"groups": ",".join([r.name for r in self.groups]),
"groupId": [r.id for r in self.groups],
"roles": [role.name for role in self.roles],
"initial_password": self.initial_password,
}
if not self.password:
response["token"] = self.login_id
return response
class Group(db.Model):
__tablename__ = "groups"
__bind_key__ = bind_key
id = db.Column(db.Integer(), primary_key=True)
name = db.Column(db.String(50), unique=False)
description = db.Column(db.String(250), unique=False)
singleton = db.Column(db.Boolean, unique=False, default=False)
organization_id = db.Column(
db.Integer, db.ForeignKey("organization.id"), nullable=False
)
organization = db.relationship("Organization", backref="groups")
def to_dict(self):
return dict(
id=self.id,
name=self.name,
description=self.description,
singleton=self.singleton,
)
class UserGroups(db.Model):
__tablename__ = "user_groups"
__bind_key__ = bind_key
id = db.Column(db.Integer(), primary_key=True)
user_id = db.Column(
db.Integer(), db.ForeignKey("users.id", ondelete="CASCADE")
)
group_id = db.Column(
db.Integer(), db.ForeignKey("groups.id", ondelete="CASCADE")
)
class UserInvitation(db.Model):
__tablename__ = "user_invitation"
__bind_key__ = bind_key
id = db.Column(db.Integer(), primary_key=True)
token = db.Column(
db.String(36), unique=True, default=lambda: str(uuid4())
)
owner_id = db.Column(
db.Integer, db.ForeignKey("users.id"), nullable=True
)
owner = db.relationship("User")
amount = db.Column(db.Integer(), nullable=False)
organization_id = db.Column(
db.Integer, db.ForeignKey("organization.id"), nullable=False
)
organization = db.relationship(
"Organization", backref="user_invitations"
)
class Organization(db.Model):
__tablename__ = "organization"
__bind_key__ = bind_key
__organization_id_field__ = "id"
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(length=64), nullable=False)
class Role(db.Model):
__tablename__ = "role"
__bind_key__ = bind_key
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(128), unique=True)
abilities = db.relationship("Ability", secondary="role_ability")
def __repr__(self):
return f"{self.__class__.__name__} <{self.name}>"
class RoleAbility(db.Model):
__tablename__ = "role_ability"
__bind_key__ = bind_key
__table_args__ = (db.UniqueConstraint("role_id", "ability_id"),)
id = db.Column(db.Integer, primary_key=True)
role_id = db.Column(db.Integer, db.ForeignKey("role.id"))
ability_id = db.Column(db.Integer, db.ForeignKey("ability.id"))
class Ability(db.Model):
__tablename__ = "ability"
__bind_key__ = bind_key
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Enum(AbilityEnum, native_enum=False), unique=True)
def __repr__(self):
return f"{self.__class__.__name__} <{self.name}>"
class UserRole(db.Model):
__tablename__ = "user_role"
__bind_key__ = bind_key
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("users.id"))
role_id = db.Column(db.Integer, db.ForeignKey("role.id"))
return (
User,
Group,
UserGroups,
UserInvitation,
Organization,
Role,
RoleAbility,
UserRole,
Ability,
)
def create_dataconnection_models( # noqa: C901
db,
bind_key=None,
):
def _health_check_sql(dataconnection): # pragma: no cover
connection_string = dataconnection.get_sqlalchemy_string()
extra_args = dataconnection.get_sqlalchemy_kwargs()
if (
dataconnection.connector_type
== DataConnection.ConnectorType.BIGQUERY
):
extra_args = dict(**extra_args, pool_timeout=20)
elif dataconnection.connector_type in [
DataConnection.ConnectorType.HIVE,
DataConnection.ConnectorType.MONGODB,
DataConnection.ConnectorType.REDSHIFT,
]:
extra_args = {}
# parameters for Synapse are determined by get_sqlalchemy_kwargs()
elif (
dataconnection.connector_type
== DataConnection.ConnectorType.SYNAPSE
):
pass
elif extra_args.get("connect_args"):
extra_args.get("connect_args")["connect_timeout"] = 20
else:
extra_args = dict(
**extra_args, connect_args={"connect_timeout": 20}
)
engine = create_engine(connection_string, **extra_args)
try:
inspector = inspect(engine)
inspector.get_schema_names()
return True, None
except Exception as error: # pragma: no cover
message = error.orig if hasattr(error, "orig") else error
return False, str(message)
def _health_check_fs(dataconnection): # pragma: no cover
path = dataconnection.get_path()
bucket = dataconnection.get_bucket()
try:
source_fs = dataconnection.get_filesystem()
if path != "":
source_fs.ls(f"{bucket}/{path}")
else:
source_fs.ls(bucket)
return True, None
except Exception as error: # pragma: no cover
message = error.orig if hasattr(error, "orig") else error
return False, str(message)
class DataConnection(db.Model): # the actual object being store in the DB
__tablename__ = "dataconnection"
__bind_key__ = bind_key
class ConnectorType:
"""Namespace to define constants."""
AZUREBLOB = "azureblob"
AZURESQL = "azuresql"
BIGQUERY = "bigquery"
DATABRICKS = "databricks"
GCS = "gcs"
HIVE = "hive"
MONGODB = "mongodb"
MYSQL = "mysql"
POSTGRESQL = "postgres"
REDSHIFT = "redshift"
S3 = "s3"
SQLSERVER = "sqlserver"
SYNAPSE = "synapse"
LOCAL = "local"
SQL_CONNECTOR_TYPES = [
ConnectorType.AZURESQL,
ConnectorType.BIGQUERY,
ConnectorType.DATABRICKS,
ConnectorType.HIVE,
ConnectorType.MONGODB,
ConnectorType.MYSQL,
ConnectorType.POSTGRESQL,
ConnectorType.REDSHIFT,
ConnectorType.SQLSERVER,
ConnectorType.SYNAPSE,
]
FS_CONNECTOR_TYPES = [
ConnectorType.AZUREBLOB,
ConnectorType.GCS,
ConnectorType.S3,
ConnectorType.LOCAL,
]
id = db.Column(db.Integer(), primary_key=True)
name = db.Column(db.String(50), unique=True)
owner_id = db.Column(
db.Integer, db.ForeignKey("users.id"), nullable=True
)
owner = db.relationship("User")
created = db.Column(db.DateTime, default=now())
last_modified = db.Column(db.DateTime, default=now(), onupdate=now())
# datasets = db.relationship("Dataset")
organization_id = db.Column(
db.Integer, db.ForeignKey("organization.id"), nullable=False
)
organization = db.relationship(
"Organization", backref="dataconnections"
)
connector_type = db.Column(
db.Enum(
*SQL_CONNECTOR_TYPES,
*FS_CONNECTOR_TYPES,
name="connector_type",
native_enum=False,
),
unique=False,
)
is_public = db.Column(db.Boolean, default=True)
allow_copy = db.Column(db.Boolean, default=True)
credentials = db.Column(
db.String(20000), unique=False
) # a JSON containing the credentials, sensitive
params = db.Column(
db.String(20000), unique=False
) # a JSON containing the connections params, non sensitive
__mapper_args__ = {
"polymorphic_identity": "",
"polymorphic_on": connector_type,
}
def to_dict(self):
ret = {}
for c in self.__table__.columns:
if c.name == "credentials":
ret[c.name] = "****"
elif c.name == "params":
ret[c.name] = self.get_params()
else:
ret[c.name] = getattr(self, c.name)
if self.connector_type in DataConnection.SQL_CONNECTOR_TYPES:
ret["connector_family"] = "sql"
elif self.connector_type in DataConnection.FS_CONNECTOR_TYPES:
ret["connector_family"] = "filesystem"
else: # pragma: no cover
ret["connector_family"] = "not supported"
return ret
def health_check(self):
"""Works only for subtypes of DataConnection"""
polymorphic_identity = self.__mapper_args__.get(
"polymorphic_identity"
)
if polymorphic_identity in DataConnection.SQL_CONNECTOR_TYPES:
return _health_check_sql(self)
elif polymorphic_identity in DataConnection.FS_CONNECTOR_TYPES:
return _health_check_fs(self)
else:
message = f"{polymorphic_identity} is not supported"
return False, str(message)
def get_data_rows_bytes(self, source_name: str) -> t.Tuple[int, int]:
"""
Gives the source size as a tuple:
- the number of rows for SQL-based dataconnections and None for
FS-based dataconnections
- the number of bytes for both SQL-based and FS-based
dataconnections
Returns a tuple with 2 Integers or 1 Integer and None.
"""
if self.connector_type in DataConnection.SQL_CONNECTOR_TYPES:
connection_string = self.get_sqlalchemy_string()
extra_args = self.get_sqlalchemy_kwargs()
engine = create_engine(connection_string, **extra_args)
if "." in source_name:
sql_schema, table_name = source_name.split(".")
metadata = MetaData(schema=sql_schema)
metadata.reflect(bind=engine, only=[table_name])
else:
metadata = MetaData()
metadata.reflect(bind=engine, only=[source_name])
with Session(engine) as session:
table = metadata.tables[source_name]
size_lines = session.query(table).count()
sample = session.query(table).limit(100).all()
size_bytes = int(
(size_lines / 100) * len(pickle.dumps(sample))
)
elif self.connector_type in DataConnection.FS_CONNECTOR_TYPES:
source_fs = self.get_filesystem()
size_bytes = source_fs.size(source_name)
with open_file(fs=source_fs, path=source_name, mode='r') as f:
size_lines = f.read().count("\n")
else:
size_lines = size_bytes = None
return size_lines, size_bytes
class LocalFSConnection(DataConnection): # pragma: no cover
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.LOCAL,
}
__bind_key__ = bind_key
class DirFileSystem(fsspec.implementations.dirfs.DirFileSystem):
"""
DirFileSystem restricts the access of any FS outside a certain dir.
All paths given to it are considered relative to this root.
Something is off with fsspec:
- `asynchronous` is not set for pure subclasses of
AbstractFileSystem
- tests of DirFileSystem are missing the fact that the sub-fs is
not asyn
- `pipe_file` is using nonetheless a mirrored impl and
`_pipe_file` is not implemented in DirFileSystem
TODO: propose a patch to fsspec ?
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def _pipe_file(self, path, value, **kwargs):
path = self._join(path)
if self.fs.async_impl:
return await self.fs._pipe_file(path, value, **kwargs)
else:
return self.fs.pipe_file(path, value, **kwargs)
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
def set_credentials(self, credentials: dict):
pass
def get_credentials(self):
return {}
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_params(self):
return json.loads(self.params)
def get_filesystem(self): # pragma: no cover
params = self.get_params()
root_dir = params["FS_PATH"]
fs = fsspec_filesystem("file", auto_mkdir=True)
return LocalFSConnection.DirFileSystem(root_dir, fs)
class AzureBlobConnection(DataConnection): # pragma: no cover
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.AZUREBLOB,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
Account_name=AZUREBLOB_ACCOUNT_NAME,
Account_key=AZUREBLOB_ACCOUNT_KEY
)
dataconnection.set_params(
{
"Container": AZUREBLOB_CONTAINER,
"Path": AZUREBLOB_PATH,
}
)
"""
fields_credentials = [
{
"fieldName": "Account_name",
"fieldType": "secret",
"fieldLabel": "Storage account",
},
{
"fieldName": "Account_key",
"fieldType": "secret",
"fieldLabel": "Access key",
},
]
fields_params = [
{
"fieldName": "Container",
"fieldType": "text",
"fieldLabel": "Container",
},
{
"fieldName": "Path",
"fieldType": "text",
"fieldLabel": "Path in container",
"optional": True,
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
if "dependency" not in field:
field["dependency"] = None
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self):
return json.loads(self.credentials)
def get_account_name(self):
credentials = self.get_credentials()
account_name = credentials.get("Account_name")
return account_name
def get_account_key(self):
credentials = self.get_credentials()
account_key = credentials.get("Account_key")
return account_key
def get_params(self):
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_bucket(self):
params = self.get_params()
bucket = params.get("Container")
return bucket
def get_path(self):
params = json.loads(self.params)
if "Path" in params.keys():
path = params.get("Path")
else:
path = ""
return path
def uri_http_to_abfs(self, http_string: str) -> str:
"""
Tranform Azure Blob URI which starts with https or wasbs or az to a
URI with abfs protocol
"""
AZUREBLOB_ENDPOINT = os.environ.get("AZUREBLOB_ENDPOINT")
AZUREBLOB_PROTOCOL = os.environ.get("AZUREBLOB_PROTOCOL")
parsed_string = urlparse(http_string)
parsed_account_name = urlparse(http_string).hostname.split(
AZUREBLOB_ENDPOINT
)[0]
if parsed_account_name != self.get_account_name() + ".":
abort(400, "URI is not from DataConnection")
container_name = self.get_bucket()
if not parsed_string.path.startswith("/" + container_name):
abort(400, "URI is not from DataConnection")
abfs_string = (
AZUREBLOB_PROTOCOL
+ "://"
+ container_name
+ "@"
+ parsed_account_name
+ AZUREBLOB_ENDPOINT
+ urlparse(http_string).path.split(container_name)[1]
)
return abfs_string
def get_filesystem(self): # pragma: no cover
account_name = self.get_account_name()
account_key = self.get_account_key()
storage_options = {
"account_name": account_name,
"account_key": account_key,
}
source_fs = fsspec_filesystem("abfs", **storage_options)
return source_fs
class BigQueryConnection(DataConnection): # pragma: no cover # TODO
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.BIGQUERY,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{
"Key_file": BIGQUERY_CREDENTIALS,
}
)
dataconnection.set_params(
{
"Project": BIGQUERY_PROJECT,
"Dataset": BIGQUERY_DATASET,
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue?: string,
optional?: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "Key_file",
"fieldType": "file",
"fieldLabel": "Service account JSON file",
},
]
fields_params = [
{
"fieldName": "Project",
"fieldType": "text",
"fieldLabel": "Project ID",
},
{
"fieldName": "Dataset",
"fieldType": "text",
"fieldLabel": "Dataset ID",
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional?" not in field:
field["optional?"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self): # pragma: no cover # TODO
return json.loads(self.credentials)
def get_key_file(self):
credentials = json.loads(self.credentials)
if isinstance(credentials, str):
key_file = credentials
else:
key_file = credentials.get("Key_file")
return key_file
def get_params(self):
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_project(self):
params = json.loads(self.params)
if isinstance(params, str):
project_name = params
else:
project_name = params.get("Project")
return project_name
def get_sqlalchemy_string(self):
params = self.get_params()
datasetbase = params.get("Dataset")
project = params.get("Project")
key_file = self.get_key_file()
sqlalchemy_string = (
f"bigquery://{project}/{datasetbase}"
+ f"?credentials_path={key_file}"
)
return sqlalchemy_string
def get_sqlalchemy_kwargs(self):
return {}
class SynapseConnection(DataConnection):
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.SYNAPSE,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{"User": SYNAPSE_USERNAME, "Password": SYNAPSE_PASSWORD}
)
dataconnection.set_params(
{
"Server": SYNAPSE_SERVER,
"Port": SYNAPSE_PORT,
"Database": SYNAPSE_DATABASE,
"Workspace": SYNAPSE_WORKSPACE,
"Parameters": {
"Encrypt": "yes",
"TrustServerCertificate": "no",
"Connection Timeout": 30,
},
"ActiveDirectory": SYNAPSE_ACTIVEDIRECTORYPASSWORD,
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'underfined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "User",
"fieldType": "text",
"fieldLabel": "Username",
},
{"fieldName": "Password", "fieldType": "secret"},
]
fields_params = [
{
"fieldName": "ActiveDirectory",
"fieldType": "toggle",
"fieldLabel": "Use Active Directory authentication",
},
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "Port",
"fieldType": "text",
"defaultValue": "1433",
},
{
"fieldName": "Database",
"fieldType": "text",
"fieldLabel": "Source Database",
},
{
"fieldName": "Workspace",
"fieldType": "text",
"fieldLabel": "Workspace",
"dependency": {
"field": "ActiveDirectory",
"condition": False,
},
},
{
"fieldName": "Parameters",
"fieldType": "text",
"fieldLabel": "Advanced parameters",
"defaultValue": "Encrypt=yes;TrustServerCertificate=no;"
+ "Connection Timeout=30;",
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self): # pragma: no cover # TODO
return json.loads(self.credentials)
def get_params(self): # pragma: no cover
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def parameters_to_string(self): # pragma: no cover
parameters = self.get_params().get("Parameters")
if isinstance(parameters, str):
return parameters
elif isinstance(parameters, dict):
return "".join(
[f"{key}={value};" for key, value in parameters.items()]
)
else:
return ""
def get_sqlalchemy_string(self): # pragma: no cover
params = self.get_params()
driver = os.environ.get("AZURE_DRIVER")
database = params.get("Database")
server = params.get("Server")
workspace = params.get("Workspace")
port = params.get("Port")
credentials = json.loads(self.credentials)
user = credentials.get("User")
password = credentials.get("Password")
driver_type = os.environ.get("SYNAPSE_DRIVER_TYPE")
if params.get(
"ActiveDirectory"
): # if self.activedirectory: # pragma: no cover # TODO
active_directory = "Authentication=ActiveDirectoryPassword;"
else:
active_directory = ""
if driver_type == "odbc":
odbc_string = (
f"Driver={driver};Server=tcp:{server},{port};"
+ f"Database={database};"
+ f"Uid={user};Pwd={password};"
+ self.parameters_to_string()
+ active_directory
)
quoted = quote_plus(odbc_string)
sqlalchemy_string = "mssql+pyodbc:///?odbc_connect={}".format(
quoted
)
elif driver_type == "jdbc":
registry.register(
"synapse_jdbc",
"sarus_flask_auth.dialects.MSJDBCDialect",
"MSJDBCDialect",
)
if params.get("ActiveDirectory"):
sqlalchemy_string = (
f"synapse_jdbc://{server}:{port};"
+ f"database={database};user={user};"
+ f"password={password};"
+ "encrypt=true;trustServerCertificate=false;"
+ "hostNameInCertificate=*.sql.azuresynapse.net;"
+ "loginTimeout=30;"
+ f"{active_directory}"
)
else:
sqlalchemy_string = (
f"synapse_jdbc://{server}:{port};"
+ f"database={database};"
+ f"user={user}@{workspace};"
+ f"password={password};"
+ "encrypt=true;trustServerCertificate=false;"
+ "hostNameInCertificate=*.sql.azuresynapse.net;"
+ "loginTimeout=30;"
)
else:
raise Exception("Synapse driver type is not supported")
return sqlalchemy_string
def get_sqlalchemy_kwargs(self): # pragma: no cover
SYNAPSE_DRIVER_TYPE = os.environ.get("SYNAPSE_DRIVER_TYPE")
if SYNAPSE_DRIVER_TYPE == "odbc":
return dict(
fast_executemany=True,
connect_args={"autocommit": True, "connect_timeout": 20},
)
elif SYNAPSE_DRIVER_TYPE == "jdbc":
return {}
else:
raise Exception("Synapse driver type is not supported")
class AzureSQLConnection(DataConnection):
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.AZURESQL,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format(): # pragma: no cover
"""
dataconnection.set_credentials(
{"User": AZURESQL_USERNAME, "Password": AZURESQL_PASSWORD}
)
dataconnection.set_params(
{
"Database": AZURESQL_DATABASE,
"Server": AZURESQL_SERVER,
"Port": AZURESQL_PORT,
"ActiveDirectory": AZURESQL_ACTIVEDIRECTORYPASSWORD,
"Parameters": {
"Encrypt": "yes",
"TrustServerCertificate": "no",
"Connection Timeout": 30,
},
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'underfined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "User",
"fieldType": "text",
"fieldLabel": "Username",
},
{"fieldName": "Password", "fieldType": "secret"},
]
fields_params = [
{
"fieldName": "ActiveDirectory",
"fieldType": "toggle",
"fieldLabel": "Use Active Directory authentication",
},
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "Port",
"fieldType": "text",
"defaultValue": "1433",
},
{
"fieldName": "Database",
"fieldType": "text",
"fieldLabel": "Source Database",
},
{
"fieldName": "Parameters",
"fieldType": "text",
"fieldLabel": "Advanced parameters",
"defaultValue": "Encrypt=yes;TrustServerCertificate=no;"
+ "Connection Timeout=30;",
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict): # pragma: no cover
self.credentials = json.dumps(credentials)
def get_credentials(self): # pragma: no cover # TODO
return json.loads(self.credentials)
def get_params(self): # pragma: no cover
return json.loads(self.params)
def set_params(self, params_json): # pragma: no cover
self.params = json.dumps(params_json)
def parameters_to_string(self): # pragma: no cover
parameters = self.get_params().get("Parameters")
if isinstance(parameters, str):
return parameters
elif isinstance(parameters, dict):
return "".join(
[f"{key}={value};" for key, value in parameters.items()]
)
else:
return ""
def get_odbc_string(self): # pragma: no cover
params = self.get_params()
if params.get(
"ActiveDirectory"
): # if self.activedirectory: # pragma: no cover # TODO
active_directory = "Authentication=ActiveDirectoryPassword;"
else:
active_directory = ""
driver = os.environ.get("AZURE_DRIVER")
database = params.get("Database")
server = params.get("Server")
port = params.get("Port")
credentials = json.loads(self.credentials)
user = credentials.get("User")
password = credentials.get("Password")
odbc_string = (
f"Driver={driver};Server=tcp:{server},{port};"
+ f"Database={database};"
+ f"Uid={user};Pwd={password};"
+ self.parameters_to_string()
+ active_directory
)
return odbc_string
def get_sqlalchemy_string(self): # pragma: no cover
composed_odbc_string = self.get_odbc_string()
quoted = quote_plus(composed_odbc_string)
sqlalchemy_string = "mssql+pyodbc:///?odbc_connect={}".format(
quoted
)
return sqlalchemy_string
def get_sqlalchemy_kwargs(self): # pragma: no cover
return dict(
fast_executemany=True,
connect_args={"autocommit": True},
)
class PostgreSQLConnection(DataConnection): # pragma: no cover
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.POSTGRESQL,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{
"User": POSTGRES_CONNECTION_USERNAME,
"Password": POSTGRES_CONNECTION_PASSWORD
}
)
dataconnection.set_params(
{
"Database": POSTGRES_CONNECTION_DATABASE,
"Server": POSTGRES_CONNECTION_SERVER,
"Port": POSTGRES_CONNECTION_PORT,
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'undefined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "User",
"fieldType": "text",
"fieldLabel": "Username",
},
{"fieldName": "Password", "fieldType": "secret"},
]
fields_params = [
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "Port",
"fieldType": "text",
"defaultValue": "5432",
},
{
"fieldName": "Database",
"fieldType": "text",
"fieldLabel": "Source Database",
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self):
return json.loads(self.credentials)
def get_params(self):
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_sqlalchemy_string(self):
params = self.get_params()
database = params.get("Database")
server = params.get("Server")
port = params.get("Port")
credentials = self.get_credentials()
user = credentials.get("User")
password = credentials.get("Password")
connection_string = (
f"postgresql+psycopg2://{user}:{password}@"
+ f"{server}:{port}/{database}"
)
return connection_string
def get_sqlalchemy_kwargs(self): # pragma: no cover
return {}
class GCSConnection(DataConnection): # pragma: no cover # TODO
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.GCS,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{
"Key_file": GOOGLE_APPLICATION_CREDENTIALS,
}
)
dataconnection.set_params(
{
"Project": GCS_PROJECT,
"Bucket": GCS_BUCKET,
"Path": GCS_PATH,
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'undefined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "Key_file",
"fieldType": "file",
"fieldLabel": "Service account JSON file",
},
]
fields_params = [
{
"fieldName": "Project",
"fieldType": "text",
"fieldLabel": "Project ID",
},
{
"fieldName": "Bucket",
"fieldType": "text",
"fieldLabel": "Source bucket",
},
{
"fieldName": "Path",
"fieldType": "text",
"fieldLabel": "Path in bucket",
"optional": True,
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self): # pragma: no cover # TODO
return json.loads(self.credentials)
def get_key_file(self):
credentials = json.loads(self.credentials)
if isinstance(credentials, str):
key_file = credentials
else:
key_file = credentials.get("Key_file")
return key_file
def get_params(self):
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_project(self):
params = json.loads(self.params)
if isinstance(params, str):
project_name = params
else:
project_name = params.get("Project")
return project_name
def get_bucket(self):
params = json.loads(self.params)
bucket_name = params.get("Bucket")
return bucket_name
def get_path(self):
params = json.loads(self.params)
if "Path" in params.keys():
path = params.get("Path")
else:
path = ""
return path
def get_filesystem(self): # pragma: no cover
project = self.get_project()
token = self.get_key_file()
return GCSFileSystem(
project=project,
token=token,
)
class DatabricksConnection(DataConnection): # pragma: no cover
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.DATABRICKS,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{"Token": DATABRICKS_TOKEN}
)
dataconnection.set_params(
{
"Catalog": DATABRICKS_CATALOG,
"Database": DATABRICKS_HTTP_PATH,
"Server": DATABRICKS_SERVER,
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'undefined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "Token",
"fieldType": "secret",
"fieldLabel": "Access Token",
},
]
fields_params = [
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "HTTP_path",
"fieldType": "text",
"fieldLabel": "HTTP Path",
},
{"fieldName": "Catalog", "fieldType": "text"},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self):
return json.loads(self.credentials)
def get_params(self):
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_sqlalchemy_string(self):
registry.register(
"sarus_databricks",
"sarus_flask_auth.dialects.DATABRICKSDialect",
"DATABRICKSDialect",
)
params = self.get_params()
server = params.get("Server")
http_path = params.get("HTTP_path")
catalog = params.get("Catalog")
credentials = self.get_credentials()
access_token = credentials.get("Token")
connection_string = (
f"sarus_databricks://token:{access_token}@"
+ f"{server}?http_path={http_path}&catalog={catalog}"
)
return connection_string
def get_sqlalchemy_kwargs(self): # pragma: no cover
return dict(
connect_args={
"_tls_verify_hostname": True,
"_user_agent_entry": "Sarus Dataconnection Agent",
},
)
class HiveConnection(DataConnection):
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.HIVE,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{"User": HIVE_USERNAME, "Password": HIVE_PASSWORD}
)
dataconnection.set_params(
{
"Server": HIVE_SERVER,
"Port": HIVE_PORT,
"Database": HIVE_DATABASE,
"UseSSL": HIVE_USE_SSL,
"Parameters": {
"transportMode": "http",
"httpPath": "cliservice",
},
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'underfined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "User",
"fieldType": "text",
"fieldLabel": "Username",
},
{"fieldName": "Password", "fieldType": "secret"},
]
fields_params = [
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "Port",
"fieldType": "text",
"defaultValue": "10000",
},
{
"fieldName": "Database",
"fieldType": "text",
"defaultValue": "default",
},
{
"fieldName": "UseSSL",
"fieldType": "toggle",
"fieldLabel": "Use SSL",
},
{
"fieldName": "Parameters",
"fieldType": "text",
"fieldLabel": "Advanced parameters",
"defaultValue": "transportMode=http;httpPath=cliservice",
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self): # pragma: no cover # TODO
return json.loads(self.credentials)
def get_params(self): # pragma: no cover
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_sqlalchemy_string(self): # pragma: no cover
params = self.get_params()
server = params.get("Server")
port = params.get("Port")
database = params.get("Database")
advanced_parameters = params.get("Parameters")
if advanced_parameters[-1] != ';':
advanced_parameters = advanced_parameters + ';'
credentials = json.loads(self.credentials)
user = credentials.get("User")
password = credentials.get("Password")
registry.register(
"hive_jdbc",
"sarus_flask_auth.dialects.HIVEJDBCDialect",
"HIVEJDBCDialect",
)
if params.get("UseSSL"):
# self-signed SSL certificates
HIVE_TRUSTSTORE_FILE = os.environ.get(
"HIVE_TRUSTSTORE_FILE", "/app/truststore.jks"
)
HIVE_TRUSTSTORE_PASSWORD = os.environ.get(
"HIVE_TRUSTSTORE_PASSWORD"
)
if HIVE_TRUSTSTORE_PASSWORD:
sqlalchemy_string = (
"hive_jdbc://"
+ f"{user}:{password}@{server}:{port}/"
+ f"{database}?{advanced_parameters}ssl=true;"
+ f"sslTrustStore={HIVE_TRUSTSTORE_FILE};"
+ f"trustStorePassword={HIVE_TRUSTSTORE_PASSWORD};"
)
# network SSL certificates
else:
sqlalchemy_string = (
"hive_jdbc://"
+ f"{user}:{password}@{server}:{port}/"
+ f"{database}?{advanced_parameters}ssl=true;"
)
# no SSL
else:
sqlalchemy_string = (
"hive_jdbc://"
+ f"{user}:{password}@{server}:{port}/"
+ f"{database}?{advanced_parameters}ssl=false;"
)
return sqlalchemy_string
def get_driver_string(self):
"""
Gives the connection string for Hive JDBC driver
"""
params = self.get_params()
server = params.get("Server")
port = params.get("Port")
database = params.get("Database")
advanced_parameters = params.get("Parameters")
if advanced_parameters[-1] != ";":
advanced_parameters = advanced_parameters + ";"
if params.get("UseSSL"):
# self-signed SSL certificates
HIVE_TRUSTSTORE_FILE = os.environ.get(
"HIVE_TRUSTSTORE_FILE", "/app/truststore.jks"
)
HIVE_TRUSTSTORE_PASSWORD = os.environ.get(
"HIVE_TRUSTSTORE_PASSWORD"
)
if HIVE_TRUSTSTORE_PASSWORD:
connection_string = (
f"jdbc:hive2://{server}:{port}/"
+ f"{database};{advanced_parameters}ssl=true;"
+ f"sslTrustStore={HIVE_TRUSTSTORE_FILE};"
+ f"trustStorePassword={HIVE_TRUSTSTORE_PASSWORD};"
)
# network SSL certificates
else:
connection_string = (
f"jdbc:hive2://{server}:{port}/"
+ f"{database};{advanced_parameters}ssl=true;"
)
# no SSL
else:
connection_string = (
f"jdbc:hive2://{server}:{port}/{database};"
+ f"{advanced_parameters}"
)
return connection_string
def get_sqlalchemy_kwargs(self): # pragma: no cover
return {}
class MongoDBConnection(DataConnection):
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.MONGODB,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{"User": MONGODB_USERNAME, "Password": MONGODB_PASSWORD}
)
dataconnection.set_params(
{
"Server": MONGODB_SERVER,
"Port": MONGODB_PORT,
"Database": MONGODB_DATABASE,
"Parameters": {
"ssl": "true",
"authSource": "admin",
},
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'underfined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "User",
"fieldType": "text",
"fieldLabel": "Username",
},
{"fieldName": "Password", "fieldType": "secret"},
]
fields_params = [
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "Port",
"fieldType": "text",
"defaultValue": "27017",
},
{
"fieldName": "Database",
"fieldType": "text",
},
{
"fieldName": "Parameters",
"fieldType": "text",
"fieldLabel": "Advanced parameters",
"defaultValue": "ssl=true&authSource=admin",
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self): # pragma: no cover # TODO
return json.loads(self.credentials)
def get_params(self): # pragma: no cover
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_sqlalchemy_string(self): # pragma: no cover
params = self.get_params()
server = params.get("Server")
port = params.get("Port")
database = params.get("Database")
advanced_parameters = params.get("Parameters")
credentials = json.loads(self.credentials)
user = credentials.get("User")
password = credentials.get("Password")
registry.register(
"sarus_mongodb",
"sarus_flask_auth.dialects.MONGODBDialect",
"MONGODBDialect",
)
sqlalchemy_string = (
"sarus_mongodb://"
+ f"{user}:{password}@{server}:{port}/"
+ f"{database}?{advanced_parameters}"
)
return sqlalchemy_string
def get_driver_string(self):
"""
Gives the connection string for MongoDB JDBC driver,
"""
params = self.get_params()
server = params.get("Server")
port = params.get("Port")
database = params.get("Database")
advanced_parameters = params.get("Parameters")
if advanced_parameters[-1] != ";":
advanced_parameters = advanced_parameters + ";"
connection_string = (
f"jdbc:mongodb://{server}:{port}/{database}?"
+ f"{advanced_parameters}"
)
return connection_string
def get_sqlalchemy_kwargs(self): # pragma: no cover
return {}
class S3Connection(DataConnection): # pragma: no cover
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.S3,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{
Access_key_id=S3_ACCESS_KEY_ID,
Secret_access_key=S3_SECRET_ACCESS_KEY
}
)
dataconnection.set_params(
{
"Bucket": S3_BUCKET,
"Path": S3_PATH,
}
)
"""
fields_credentials = [
{
"fieldName": "Access_key_id",
"fieldType": "secret",
"fieldLabel": "AccessKey",
"dependency": {
"field": "UseIAMRole",
"condition": False,
},
},
{
"fieldName": "Secret_access_key",
"fieldType": "secret",
"fieldLabel": "SecretKey",
"dependency": {
"field": "UseIAMRole",
"condition": False,
},
},
]
fields_params = [
{
"fieldName": "UseIAMRole",
"fieldType": "toggle",
"fieldLabel": "Use IAM role",
},
{
"fieldName": "Bucket",
"fieldType": "text",
"fieldLabel": "Source bucket",
},
{
"fieldName": "Path",
"fieldType": "text",
"fieldLabel": "Path in bucket",
"optional": True,
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
if "dependency" not in field:
field["dependency"] = None
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self):
if self.using_iam_role():
return None
else:
return json.loads(self.credentials)
def using_iam_role(self):
params = json.loads(self.params)
return params.get("UseIAMRole")
def get_temp_credentials(self):
session = boto3.Session()
credentials = session.get_credentials()
credentials = credentials.get_frozen_credentials()
access_key = credentials.access_key
secret_key = credentials.secret_key
token = credentials.token
return {
"Access_key_id": access_key,
"Secret_access_key": secret_key,
"Token": token,
}
def get_access_key(self):
if not self.using_iam_role():
credentials = self.get_credentials()
key = credentials.get("Access_key_id")
return key
else:
return None
def get_temp_access_key(self):
temp_credentials = self.get_temp_credentials()
key = temp_credentials.get("Access_key_id")
return key
def get_secret_key(self):
if not self.using_iam_role():
credentials = self.get_credentials()
key = credentials.get("Secret_access_key")
return key
else:
return None
def get_temp_secret_key(self):
temp_credentials = self.get_temp_credentials()
key = temp_credentials.get("Secret_access_key")
return key
def get_temp_token(self):
temp_credentials = self.get_temp_credentials()
key = temp_credentials.get("Token")
return key
def get_params(self):
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_bucket(self):
params = self.get_params()
bucket = params.get("Bucket")
return bucket
def get_path(self):
params = json.loads(self.params)
if "Path" in params.keys():
path = params.get("Path")
else:
path = ""
return path
def get_filesystem(self): # pragma: no cover
params = self.get_params()
if params.get("UseIAMRole"):
source_fs = S3FileSystem(anon=False)
else:
access_key = self.get_access_key()
secret_key = self.get_secret_key()
source_fs = S3FileSystem(
anon=False,
key=access_key,
secret=secret_key,
)
return source_fs
class SQLServerConnection(AzureSQLConnection):
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.SQLSERVER,
}
__bind_key__ = bind_key
@staticmethod
def get_format(): # pragma: no cover
"""
dataconnection.set_credentials(
{
"User": SQLSERVER_USERNAME,
"Password": SQLSERVER_PASSWORD
}
)
dataconnection.set_params(
{
"Database": SQLSERVER_DATABASE,
"Server": SQLSERVER_SERVER,
"Port": SQLSERVER_PORT,
"ActiveDirectory": SQLSERVER_ACTIVEDIRECTORYPASSWORD,
"Parameters": {
"Connection Timeout": 30,
},
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'underfined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "User",
"fieldType": "text",
"fieldLabel": "Username",
},
{"fieldName": "Password", "fieldType": "secret"},
]
fields_params = [
{
"fieldName": "ActiveDirectory",
"fieldType": "toggle",
"fieldLabel": "Use Active Directory authentication",
},
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "Port",
"fieldType": "text",
"defaultValue": "1433",
},
{
"fieldName": "Database",
"fieldType": "text",
"fieldLabel": "Source Database",
},
{
"fieldName": "Parameters",
"fieldType": "text",
"fieldLabel": "Advanced parameters",
"defaultValue": "Connection Timeout=30;",
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
class RedshiftConnection(DataConnection): # pragma: no cover
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.REDSHIFT,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{"User": REDSHIFT_USERNAME, "Password": REDSHIFT_PASSWORD}
)
dataconnection.set_params(
{
"Server": REDSHIFT_SERVER,
"Port": REDSHIFT_PORT,
"Database": REDSHIFT_DATABASE,
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'undefined',
optional: boolean,
disabled?: boolean,
dependency: dict,
}
"""
fields_credentials = [
{
"fieldName": "User",
"fieldType": "text",
"fieldLabel": "Username",
"dependency": {
"field": "UseIAMRole",
"condition": False,
},
},
{
"fieldName": "Password",
"fieldType": "secret",
"dependency": {
"field": "UseIAMRole",
"condition": False,
},
},
]
fields_params = [
{
"fieldName": "UseIAMRole",
"fieldType": "toggle",
"fieldLabel": "Use IAM role",
},
{
"fieldName": "Region",
"fieldType": "text",
"fieldLabel": "Region name",
"dependency": {
"field": "UseIAMRole",
"condition": True,
},
},
{
"fieldName": "Cluster",
"fieldType": "text",
"fieldLabel": "Cluster ID",
"dependency": {
"field": "UseIAMRole",
"condition": True,
},
},
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "Port",
"fieldType": "text",
"defaultValue": "5439",
},
{
"fieldName": "Database",
"fieldType": "text",
"fieldLabel": "Source Database",
},
{
"fieldName": "DB_user",
"fieldType": "text",
"fieldLabel": "DB user",
"dependency": {
"field": "UseIAMRole",
"condition": True,
},
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
if "dependency" not in field:
field["dependency"] = None
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self):
if self.using_iam_role():
return None
else:
return json.loads(self.credentials)
def get_temp_credentials(self):
params = self.get_params()
region = params.get("Region")
db_user = params.get("DB_user")
database = params.get("Database")
cluster = params.get("Cluster")
redshift_client = boto3.client("redshift", region_name=region)
redshift_creds = redshift_client.get_cluster_credentials(
DbUser=db_user,
DbName=database,
ClusterIdentifier=cluster,
AutoCreate=False,
)
redshift_user = quote_plus(redshift_creds["DbUser"])
redshift_password = quote_plus(redshift_creds["DbPassword"])
return {"User": redshift_user, "Password": redshift_password}
def get_params(self):
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def using_iam_role(self):
params = json.loads(self.params)
return params.get("UseIAMRole")
def get_sqlalchemy_string(self):
params = self.get_params()
database = params.get("Database")
server = params.get("Server")
port = params.get("Port")
if self.using_iam_role():
credentials = self.get_temp_credentials()
else:
credentials = json.loads(self.credentials)
user = credentials.get("User")
password = credentials.get("Password")
registry.register(
"redshift_jdbc",
"sarus_flask_auth.dialects.REDSHIFTJDBCDialect",
"REDSHIFTJDBCDialect",
)
sqlalchemy_string = (
f"redshift_jdbc://{user}:{password}@"
+ f"{server}:{port}/{database}"
)
return sqlalchemy_string
def get_database(self):
params = self.get_params()
database = params.get("Database")
return database
def get_sqlalchemy_kwargs(self): # pragma: no cover
return {}
class MysqlConnection(DataConnection): # pragma: no cover
__tablename__ = "dataconnection"
__table_args__ = {"extend_existing": True}
__mapper_args__ = {
"polymorphic_identity": DataConnection.ConnectorType.MYSQL,
}
__bind_key__ = bind_key
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
@staticmethod
def get_format():
"""
dataconnection.set_credentials(
{"User": MYSQL_USERNAME, "Password": MYSQL_PASSWORD}
)
dataconnection.set_params(
{
"Database": MYSQL_DATABASE,
"Port": MYSQL_PORT,
"Server": MYSQL_SERVER,
}
)
expected in JS:
type FieldType = 'text' | 'secret' | 'check' | 'toggle' | 'file' ;
interface BaseDataForm {
fieldName: string,
fieldLabel: string,
fieldType: FieldType,
errorMessage: string,
helperText?: string,
defaultValue: 'string' | 'undefined',
optional: boolean,
disabled?: boolean,
}
"""
fields_credentials = [
{
"fieldName": "User",
"fieldType": "text",
"fieldLabel": "Username",
},
{"fieldName": "Password", "fieldType": "secret"},
]
fields_params = [
{"fieldName": "Server", "fieldType": "text"},
{
"fieldName": "Port",
"fieldType": "text",
"defaultValue": "3306",
},
{
"fieldName": "Database",
"fieldType": "text",
"fieldLabel": "Source Database",
},
]
all_fields = fields_params + fields_credentials
for field in all_fields: # easy way to populate fieldLabel
if "fieldLabel" not in field:
field["fieldLabel"] = field["fieldName"]
if "optional" not in field:
field["optional"] = False
if "disabled?" not in field:
field["disabled?"] = False
return {
"params": fields_params,
"credentials": fields_credentials,
}
def set_credentials(self, credentials: dict):
self.credentials = json.dumps(credentials)
def get_credentials(self):
return json.loads(self.credentials)
def get_params(self):
return json.loads(self.params)
def set_params(self, params_json):
self.params = json.dumps(params_json)
def get_odbc_string(self):
params = self.get_params()
driver = os.environ.get("MYSQL_DRIVER", "{MariaDB Driver}")
database = params.get("Database")
port = params.get("Port")
server = params.get("Server")
credentials = json.loads(self.credentials)
user = credentials.get("User")
password = credentials.get("Password")
odbc_string = (
f"Driver={driver};Server={server};Port={port};"
+ f"Database={database};Uid={user};Pwd={password};"
)
return odbc_string
def get_sqlalchemy_string(self):
params = self.get_params()
database = params.get("Database")
server = params.get("Server")
credentials = self.get_credentials()
user = credentials.get("User")
password = credentials.get("Password")
connection_string = (
f"mysql+mysqldb://{user}:{password}@{server}/{database}"
)
return connection_string
def get_sqlalchemy_kwargs(self): # pragma: no cover
return {}
return (
DataConnection,
AzureBlobConnection,
AzureSQLConnection,
BigQueryConnection,
DatabricksConnection,
GCSConnection,
HiveConnection,
MongoDBConnection,
LocalFSConnection,
MysqlConnection,
PostgreSQLConnection,
RedshiftConnection,
S3Connection,
SQLServerConnection,
SynapseConnection,
)