Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
""" Copyright (C) Sarus Technologies SAS - All Rights Reserved
Unauthorized copying of this file, via any medium is strictly prohibited
Proprietary and confidential
Write to contact@sarus.tech for more information about purchasing a licence
"""
import base64
import hashlib
import json
import logging
import os

import requests
from flask import Response, redirect, request, session
from flask.sessions import SecureCookieSessionInterface
from flask_login import LoginManager, current_user, login_user, logout_user
from oauthlib.oauth2 import WebApplicationClient
from packaging.specifiers import Specifier
from packaging.version import InvalidVersion, Version

from . import messages

logger = logging.getLogger(__name__)

"""Routes for user authentication."""


def add_auth_routes(  # noqa: C901
    app,
    db,
    User,
    Group,
    Role,
    Organization,
    UserInvitation,
    prefix="",
    oidc_client_id=None,
    oidc_client_secret=None,
    oidc_discovery_url=None,
    user_creation_init=None,
    sdk_version_specifier=None,
    header_name_client_sdk_version=None,
):
    if (
        sdk_version_specifier is not None
        and header_name_client_sdk_version is None
    ):
        raise ValueError(
            "The header name for getting the Client SDK version has to be provided "
            "with SDK version specifier"
        )

    if not isinstance(sdk_version_specifier, Specifier):
        sdk_version_specifier = Specifier(sdk_version_specifier)

    if (
        oidc_client_id is not None
        and oidc_client_secret is not None
        and oidc_discovery_url is not None
    ):
        client = WebApplicationClient(oidc_client_id)

        @app.route(f"/{prefix}/oidc_login")
        def oidc_login():
            # Find out what URL to hit for Google login
            oidc_provider_cfg = requests.get(oidc_discovery_url).json()
            authorization_endpoint = oidc_provider_cfg[
                "authorization_endpoint"
            ]

            state_token = hashlib.sha256(os.urandom(1024)).hexdigest()
            state = dict(
                token=request.args.get("token"),
                state_token=state_token,
                headless=request.args.get("headless", "").lower() == "true",
            )
            session["state_token"] = state_token
            redirect_uri = request.base_url + "/callback"
            if "X-Scheme" in request.headers:
                redirect_uri = (
                    f"{request.headers['X-Scheme']}://"
                    f"{redirect_uri.split('://', 1)[-1]}"
                )

            request_uri = client.prepare_request_uri(
                authorization_endpoint,
                redirect_uri=redirect_uri,
                state=base64.urlsafe_b64encode(
                    json.dumps(state).encode("utf-8")
                ).decode("ascii"),
                scope=["email"],
            )
            return redirect(request_uri)

        @app.route(f"/{prefix}/oidc_login/callback")
        def callback():
            # Get authorization code
            code = request.args.get("code")
            state = json.loads(
                base64.urlsafe_b64decode(request.args.get("state"))
            )
            state_token = state["state_token"]
            # token = state["token"]
            headless = state["headless"]
            if state_token != session.get("state_token"):
                return "Invalid state token", 400
            # Find out what URL to hit to get tokens that allow you to ask for
            # things on behalf of a user
            oidc_provider_cfg = requests.get(oidc_discovery_url).json()
            token_endpoint = oidc_provider_cfg["token_endpoint"]
            # Prepare and send a request to get tokens! Yay tokens!
            base_url = request.base_url
            url = request.url
            if "X-Scheme" in request.headers:
                base_url = (
                    f"{request.headers['X-Scheme']}://"
                    f"{base_url.split('://', 1)[-1]}"
                )
                url = (
                    f"{request.headers['X-Scheme']}://"
                    f"{url.split('://', 1)[-1]}"
                )
            token_url, headers, body = client.prepare_token_request(
                token_endpoint,
                authorization_response=url,
                redirect_url=base_url,
                code=code,
            )
            token_response = requests.post(
                token_url,
                headers=headers,
                data=body,
                auth=(oidc_client_id, oidc_client_secret),
            )

            # Parse the tokens!
            client.parse_request_body_response(
                json.dumps(token_response.json())
            )

            # Now that you have tokens (yay) let's find and hit the URL
            # from Google that gives you the user's profile information,
            # including their Google profile image and email
            userinfo_endpoint = oidc_provider_cfg["userinfo_endpoint"]
            uri, headers, body = client.add_token(userinfo_endpoint)
            userinfo_response = requests.get(uri, headers=headers, data=body)
            if userinfo_response.json().get("email_verified"):
                users_email = userinfo_response.json()["email"]
                oidc_id = userinfo_response.json()["sub"]
            else:
                return (
                    "User email not available or not verified by Google.",
                    400,
                )

            user = User.query.filter_by(oidc_id=oidc_id).first()
            if not user:
                # create a new user
                new_organization = Organization(name=f"{users_email}_org")
                user = User(
                    oidc_id=oidc_id,
                    email=users_email,
                    username=users_email.split("@")[0],
                    organization=new_organization,
                )
                admin_role = Role.query.filter_by(name="admin").first()
                user.roles.append(admin_role)

                user.groups.append(
                    Group(
                        name="Admin",
                        description="trusted administrator, can read\
                                       all data",
                        organization=new_organization,
                    )
                )
                user.groups.append(
                    Group(
                        name=user.email,
                        description="",
                        organization=new_organization,
                        singleton=True,
                    )
                )
                db.session.add(new_organization)
                db.session.add(user)
                db.session.flush()
                user_creation_init(app, user.id, new_organization.id)
                db.session.commit()

            if headless:
                login_user(user)
                response = Response()
                session_interface = SecureCookieSessionInterface()
                session_interface.save_session(app, session, response)
                cookie = response.headers["Set-Cookie"].split("=", 1)[1]
                session.clear()
                return f"""<!DOCTYPE html>
<html>
<body>
<div><p>To login, paste the following code in the Sarus SDK.</p></div>
<div><p id="token" style="word-break: break-word;">{base64.b64encode(cookie.encode('ascii')).decode('ascii')}</p></div>
<button onclick="copyToken()">Copy text</button>
<div><p>This code is strictly confidential and personal, do not share it with anyone.</p></div>
<script>
function copyToken() {{
    var r = document.createRange();
    r.selectNode(document.getElementById("token"));
    window.getSelection().removeAllRanges();
    window.getSelection().addRange(r);
    document.execCommand('copy');
    window.getSelection().removeAllRanges();
}}
</script>
</body>
</html>"""
            else:
                # Begin user session by logging the user in
                login_user(user)

                # Send user back to homepage
                url_root = request.url_root
                if "X-Scheme" in request.headers:
                    url_root = (
                        f"{request.headers['X-Scheme']}://"
                        f"{url_root.split('://', 1)[-1]}"
                    )
                return redirect(url_root)

    prefix = prefix.strip("/")
    prefix = "/" + prefix if prefix != "" else prefix
    login_manager = LoginManager()

    @app.route(f"{prefix}/login", methods=["GET", "POST"])
    def login():
        """
        User login page.
        GET: Serve Log-in page.
        POST: If form is valid and new user creation succeeds,
        redirect user to the logged-in homepage.
        ---
        post:
          parameters:
            - in: body
              name: login
              schema:
                type: object
                required:
                  - username
                  - email (username or email)
                  - password
                properties:
                  username:
                    type: string
                  email:
                    type: string
                  password:
                    type: string
        responses:
          200:
            description: login OK
        """
        print("Trying to login")
        if current_user.is_authenticated:
            return (
                json.dumps({"success": True}),
                200,
                {"ContentType": "application/json"},
            )
        if request.method == "GET":
            return Response("Wrong method, you need to use POST", status=400)
        if request.method == "POST":
            try:
                payload = request.json
            except Exception as e:
                return Response(str(e), status=400)
            if payload:
                username = payload.get("username")
                email = payload.get("email")
                password = payload.get("password")
                if username:
                    user = User.query.filter_by(username=username).first()
                else:
                    user = User.query.filter_by(email=email).first()
                if user and user.check_password(password=password):
                    if user.is_super_admin and user.initial_password:
                        return (
                            json.dumps(
                                {"success": False, "reset_password": True}
                            ),
                            401,
                            {"ContentType": "application/json"},
                        )

                    client_sdk_version = request.headers.get(
                        header_name_client_sdk_version
                    )

                    if sdk_version_specifier is None:
                        warning_message = None
                    elif client_sdk_version is not None:
                        try:
                            client_sdk_version = Version(client_sdk_version)
                        except InvalidVersion:
                            warning_message = messages.invalid_sdk_version(
                                str(client_sdk_version)
                            )
                        else:
                            if sdk_version_specifier.contains(
                                client_sdk_version
                            ):
                                warning_message = None
                            else:
                                warning_message = (
                                    messages.incompatible_sdk_version(
                                        str(client_sdk_version)
                                    )
                                )
                    else:
                        warning_message = messages.invalid_sdk_version(
                            client_sdk_version=None
                        )
                    login_user(user)
                    success_dict = {
                        "success": True,
                        "reset_password": False,
                    }
                    if warning_message:
                        success_dict["warning_message"] = warning_message
                    return (
                        json.dumps(success_dict),
                        200,
                        {"ContentType": "application/json"},
                    )
            return (
                json.dumps({"success": False, "reset_password": False}),
                401,
                {"ContentType": "application/json"},
            )

        return (
            json.dumps({"success": False, "reset_password": False}),
            401,
            {"ContentType": "application/json"},
        )

    @app.route(f"{prefix}/logout")
    def logout():
        """
        Logout
        ---
        responses:
          200:
            description: user is now logged out
        """
        logout_user()
        return (
            json.dumps({"success": True}),
            200,
            {"ContentType": "application/json"},
        )

    @login_manager.user_loader
    def load_user(user_id):
        """Check if user is logged-in on every page load."""
        # https://flask-login.readthedocs.io/en/latest/#alternative-tokens
        if user_id is not None:
            return User.query.filter_by(login_id=user_id).first()
        return None

    @login_manager.unauthorized_handler
    def unauthorized():
        return (
            json.dumps({"success": False}),
            401,
            {"ContentType": "application/json"},
        )

    login_manager.init_app(app)

    return login_manager