#!/usr/bin/env python3
# coding: utf-8
# vi: tabstop=8 expandtab shiftwidth=4 softtabstop=4

import sys
import os
import argparse
import logging
import time
import socket
import re
import json
import uuid
import base64
import math
import urllib.parse
import functools as ft
from hashlib import sha256
from datetime import datetime as dt

modules_to_install = []
try:
    import requests
except ModuleNotFoundError:
    modules_to_install.append("requests")

try:
    import schema as sch
except ModuleNotFoundError:
    modules_to_install.append("schema")

if len(modules_to_install) > 0:
    print("Please install the following Python modules:", ' '.join(modules_to_install))
    sys.exit(1)


# Configuration
log_facility = 1

# End of configuration


class RFC5424Formatter(logging.Formatter):
    PRI = [0, 7, 6, 4, 3, 2]

    def __init__(self, *args, **kwargs):
        self._tz_fix = re.compile(r'([+-]\d{2})(\d{2})$')
        try:
            self.hn = socket.gethostname()
        except Exception:
            self.hn = '-'
        super(RFC5424Formatter, self).__init__(*args, **kwargs)

    def format(self, record):
        record.__dict__['hostname'] = self.hn
        isotime = dt.fromtimestamp(record.created).isoformat()
        tz = self._tz_fix.match(time.strftime('%z'))
        if time.timezone and tz:
            (offset_hrs, offset_min) = tz.groups()
            isotime = '{0}{1}:{2}'.format(isotime, offset_hrs, offset_min)
        else:
            isotime = isotime + 'Z'
        record.__dict__['isotime'] = isotime
        record.__dict__['pri'] = f"<{log_facility * 8 + self.PRI[int(record.levelno / 10)]}>1"
        return super(RFC5424Formatter, self).format(record)


def handle_error(resp, step=None, exit=True):
    if not resp.ok:
        tx = resp.headers.get('X-42c-Transactionid', '')
        res = [f"error on transaction {tx}"]
        if step is not None:
            res.append(f"at step '{step}'")
        res.append(f"with status code {resp.status_code}")

        try:
            err = resp.json()
        except Exception:
            res.append(f"message: {resp.content}")
        else:
            code = err.get('code')
            msg = err.get('message')
            details = err.get('details', [])
            if msg:
                res.append(f"message: {msg}")
            if logger.getEffectiveLevel() < logging.INFO:
                if code:
                    res.append(f"code: {code}")
                if len(details) > 0:
                    res.append(f"details: {details}")

        logger.critical(', '.join(res))
        if exit:
            sys.exit(1)


def gen_name():
    return str(uuid.uuid4())


def hash(obj: dict):
    return sha256(json.dumps(obj, sort_keys=True, ensure_ascii=True, default=repr).encode()).hexdigest()


def get_instance(instance_name):
    r = requests.get(url + '/api/v2/instances', headers=api_key)
    handle_error(r, 'get instances')
    instances = r.json()
    instance = None
    for i in instances['list']:
        if i['name'] == instance_name:
            instance = i

    if instance is None:
        logger.critical(f"instance {instance_name} not found, please create it with allianz-bootstrap.py")
        sys.exit(1)

    logger.debug(f"Found instance: {instance}")
    return instance


def get_gateway_config(settings):
    gw = {}
    if 'maxClients' in settings:
        gw['maxClients'] = settings['maxClients']
    if settings['inbound'].get('logLevel') is not None:
        gw['mainLogLevel'] = settings['inbound']['logLevel']

    gw_h = hash(gw)
    r = requests.get(url + '/api/v2/gatewayConfigs', headers=api_key)
    handle_error(r, 'get gatewayConfigs')
    gw_confs = r.json()
    gw_conf = None
    for i in gw_confs['list']:
        if i['name'] == gw_h:
            gw_conf = i

    if gw_conf is None:
        logger.info("gatewayConfig not found, creating it")
        r = requests.post(url + '/api/v2/gatewayConfigs', headers=api_key,
                          json={'name': gw_h, 'description': gw_h,
                                'settings': gw})
        handle_error(r, 'create gatewayConfig')
        gw_conf = r.json()

    logger.debug(f"Found gatewayConfig: {gw_conf}")
    return gw_conf


def update_gateway_config(instance, gw_conf):
    if instance['gatewayConfigId'] != gw_conf['id']:
        logger.debug("updating gatewayConfig of instance")
        r = requests.patch(url + f"/api/v2/instances/{instance['id']}/gatewayConfigs/set",
                           headers=api_key, json={'gatewayConfigId': gw_conf['id']})
        handle_error(r, 'set gatewayConfig to instance')


def get_profile(scope, settings, epoint):
    t = scope[:3] + 'ProfileId'
    epoint_t = {'S': 'entrypoint', 'C': 'endpoint'}[scope[3]]

    if settings is None:
        if epoint[t] != '':
            logger.debug(f"resetting {t} of {epoint_t}")
            r = requests.patch(url + f"/api/v2/{epoint_t}s/{epoint['id']}/{scope}/reset",
                               headers=api_key)
            handle_error(r, f'reset {scope} to {epoint_t}')
        return None

    h = hash(settings)

    r = requests.get(url + f'/api/v2/{scope}', headers=api_key)
    handle_error(r, f'get {scope}')
    p_confs = r.json()
    p_conf = None
    for i in p_confs['list']:
        if i['name'] == h:
            p_conf = i

    if p_conf is None:
        logger.info(f"{scope} not found, creating it")
        r = requests.post(url + f'/api/v2/{scope}', headers=api_key,
                          json={'name': h, 'settings': settings})
        handle_error(r, f'create {scope}')
        p_conf = r.json()

    logger.debug(f"Found {scope}: {p_conf}")

    if epoint[t] != p_conf['id']:
        logger.debug(f"updating {t} of {epoint_t}")
        r = requests.patch(url + f"/api/v2/{epoint_t}s/{epoint['id']}/{scope}/set",
                           json={'profileId': p_conf['id']}, headers=api_key)
        handle_error(r, f'set {scope} to {epoint_t}')
    return p_conf


set_ssl_server_profile = ft.partial(get_profile, 'sslServerProfiles')
set_pki_server_profile = ft.partial(get_profile, 'pkiServerProfiles')
set_ssl_client_profile = ft.partial(get_profile, 'sslClientProfiles')
set_pki_client_profile = ft.partial(get_profile, 'pkiClientProfiles')


def get_endpoint(settings):
    ep_settings = dict(settings)
    for k in ['tlsProfile', 'pkiProfile', 'apis']:
        ep_settings.pop(k, None)
    h = hash(ep_settings)
    ep_url = ep_settings.pop('url')

    r = requests.get(url + '/api/v2/endpoints', headers=api_key)
    handle_error(r, 'get endpoints')
    endpoints = r.json()
    endpoint = None
    for i in endpoints['list']:
        if i['url'] == ep_url and i['name'] == h:
            endpoint = i

    if endpoint is None:
        logger.info("endpoint not found, creating it")
        r = requests.post(url + '/api/v2/endpoints', headers=api_key,
                          json={'name': h, 'description': h,
                                'url': ep_url, 'settings': ep_settings})
        handle_error(r, 'create endpoint')
        endpoint = r.json()

    logger.debug(f"Found endpoint: {endpoint}")
    return endpoint


def get_entrypoint(instance, settings):
    ep_settings = dict(settings)
    for k in ['tlsProfile', 'pkiProfile']:
        ep_settings.pop(k, None)
    if 'ip' not in ep_settings:
        ep_settings['ip'] = '*'
    h = hash(ep_settings)
    ep_settings.pop("name")

    r = requests.get(url + '/api/v2/entrypoints', headers=api_key)
    handle_error(r, 'get entrypoints')
    entrypoints = r.json()
    entrypoint = None
    for i in entrypoints['list']:
        if i['instanceId'] == instance['id']:
            entrypoint = i

    vhost_name = settings['name'].split(':', 1)[0]
    scheme = 'https' if 'tlsProfile' in settings else 'http'
    ep_url = f"{scheme}://{vhost_name}:{settings['port']}/"

    if entrypoint is None:
        logger.info("entrypoint not found, creating it")
        r = requests.post(url + '/api/v2/entrypoints', headers=api_key,
                          json={'name': h, 'description': h,
                                'url': ep_url, 'settings': ep_settings})
        handle_error(r, 'create entrypoint')
        entrypoint = r.json()
        r = requests.patch(url + f"/api/v2/entrypoints/{entrypoint['id']}/instances/set",
                           json={'instanceId': instance['id']}, headers=api_key)
        handle_error(r, 'set instance to entrypoint')
        logger.debug(f"entrypoint {entrypoint['id']} added to instance {instance['name']}")
    elif entrypoint['name'] != h or entrypoint['url'] != ep_url:
        logger.debug("entrypoint has changed, updating it")
        r = requests.put(url + '/api/v2/entrypoints/' + entrypoint['id'], headers=api_key,
                         json={'name': h, 'description': h,
                               'url': ep_url, 'settings': ep_settings})
        handle_error(r, 'updating entrypoint')
        logger.debug('entrypoint updated')
        r = requests.get(url + '/api/v2/entrypoints/' + entrypoint['id'], headers=api_key)
        handle_error(r, 'retrieving updated entrypoint')
        entrypoint = r.json()

    logger.debug(f"Found entrypoint: {entrypoint}")
    return entrypoint


def get_api_base_path(api_id):
    r = requests.post(url + f'/api/v1/apis/{api_id}/specs', headers=api_key)
    handle_error(r, 'get api base_path')
    specs = r.json()
    bp = None
    if int(specs['oasVersion'].split('.', 1)[0]) > 2:
        if len(specs['endpoints']) > 0:
            i = specs['endpoints'][0]['url']
            bp = urllib.parse.urlparse(i).path
    else:
        bp = specs['basePath']

    if bp is None or len(bp) == 0:
        return '/'
    else:
        return bp


def set_entrypoint_endpoints(entrypoint, endpoints):
    r = requests.patch(url + f"/api/v2/entrypoints/{entrypoint['id']}/endpoints/set",
                       headers=api_key,
                       json={'endpointIds': endpoints})
    handle_error(r, 'setting entrypoint endpoints')


def update_endpoint_secured_rev(endpoint, apis):
    body = []
    for api in apis:
        if isinstance(api, str):
            api_id = api
        else:
            api_id = api['api']

        r = requests.post(url + f"/api/v2/apis/{api_id}/securedRevisionOas", headers=api_key)
        #handle_error(r, 'get securedRevision')
        # PPS-159
        if r.ok:
            v = r.json()
            base_path = get_api_base_path(api_id)
            a = {'securedRevisionId': v['id'],
                 'basePath': base_path}

            if 'basePathPrefixes' in api:
                bps = api['basePathPrefixes']
                a['basePathPrefixes'] = []
                for bp in bps:
                    if isinstance(bp, str):
                        a['basePathPrefixes'].append({'in': bp, 'out': bp})
                    else:
                        a['basePathPrefixes'].append(bp)
            body.append(a)

    logger.debug("updating endpoint securedRevisions")
    r = requests.patch(url + f"/api/v2/endpoints/{endpoint['id']}/securedRevisions/set",
                       json={"list": body}, headers=api_key)
    handle_error(r, 'setting endpoint securedRevisions')


def create_deployment(instance, settings, wait_time: int = 30):
    logger.debug('creating a new deployment')
    dn = gen_name()

    blocking_level = settings.get('blockingLevel', 5)
    r = requests.post(url + f"/api/v2/instances/{instance['id']}/deployments", headers=api_key,
                      json={'deploymentName': dn, 'blockingLevel': blocking_level})
    handle_error(r, 'create deployment')
    v = r.json()

    if not v['success']:
        logger.error(f"Error at deployment creation: {v}")
        sys.exit(1)

    logger.info(f"created deployment {dn}: {v} ({r.headers['X-42c-Transactionid']})")
    logger.info("wait for deployment task...")
    i = math.ceil(wait_time / 2)
    while i >= 0:
        r = requests.get(url + f"/api/v2/deployments/{v['id']}",
                         headers=api_key)
        i -= 1
        if r.ok or i < 0:
            break
        else:
            time.sleep(2)

    if not r.ok and logger.level == logging.DEBUG:
        rp = requests.get(url + f'/api/v1/instances/{instance["id"]}/report', headers=api_key)
        handle_error(rp, 'retrieve deployment report', exit=False)
        if rp.ok:
            logger.debug(base64.b64decode(rp.json()['data']))

    handle_error(r, 'wait for deployment task')
    logger.info('deployment is ready')
    return v['id']


def get_ptf_version():
    r = requests.get(url + '/api/v1/version', headers=api_key)
    handle_error(r, 'get platform version')
    return r.headers.get('X-42c-Platform-Version')


def reformat_key(key):
    special_cases = {'ca-type': 'CAType',
                     'exposed-ca-type': 'exposedCAType',
                     'crl-type': 'CRLType',
                     'add-xff-headers': 'addXFFHeaders'}
    if key in special_cases:
        return special_cases[key]

    words = key.split('-')
    new_key = ''.join([words[0]] + list(map(str.capitalize, words[1:])))
    return new_key


def vhost_conf_to_settings(conf):
    vc = json.load(conf)

    def K(s):
        return sch.And(s, sch.Use(reformat_key))

    def O(s):  # noqa
        return sch.Optional(K(s))

    def replace_profile(scope, name):
        if not isinstance(name, str):
            return name
        if scope not in vc:
            raise sch.SchemaError(f"Key {scope} is missing")
        if name not in vc[scope]:
            raise sch.SchemaError(f"Profile {name} not found in {scope}")
        return vc[scope][name]

    def PROFILE(scope, s):
        return sch.And(sch.Use(ft.partial(replace_profile, scope)), s)

    def to_list(a):
        return [a]

    s_cert_type = sch.Schema(sch.Or("none", "file", "directory"))

    s_ssl_server_profile = sch.Schema({O("protocol"): str,
                                       O("cipher-suite"): str,
                                       O("session-ticket"): bool})

    s_pki_server_profile = sch.Schema({O("ocsp-stapling-enabled"): bool,
                                       O("mtls"): {O("ca-type"): s_cert_type,
                                                   O("exposed-ca-type"): s_cert_type,
                                                   O("crl-type"): s_cert_type,
                                                   O("verify"): int,
                                                   O("verify-depth"): int,
                                                   O("ocsp-enabled"): bool,
                                                   O("ocsp-default-responder"): str}})

    s_ssl_client_profile = sch.Schema({O("protocol"): str,
                                       O("cipher-suite"): str})

    s_pki_client_profile = sch.Schema({O("ca-type"): s_cert_type,
                                       O("crl-type"): s_cert_type,
                                       O("check-peer-name"): bool,
                                       O("check-peer-expire"): bool,
                                       O("mtls"): {O("certificate-type"): s_cert_type,
                                                   O("certificate-chain-type"): s_cert_type}})

    s_api = sch.Schema({"api": sch.And(str, lambda u: uuid.UUID(u)),
                        O("base-path-prefixes"):
                            sch.Or(sch.And(str, sch.Use(to_list)),
                                   sch.And({"in": str, "out": str}, sch.Use(to_list)),
                                   [str],
                                   [{"in": str, "out": str}])
                        })
    s_apis = sch.Schema(sch.Or(sch.And(str, sch.Use(to_list)),
                               sch.And(s_api, sch.Use(to_list)),
                               sch.And([sch.Or(sch.And(str, lambda u: uuid.UUID(u)),
                                               s_api)], lambda a: len(a) > 0)))

    s_backend = sch.Schema({K("url"): str,
                            K("apis"): s_apis,
                            O("tls-profile"): PROFILE("backend-tls-profiles", s_ssl_client_profile),
                            O("pki-profile"): PROFILE("backend-pki-profiles", s_pki_client_profile),
                            O("timeout"): int,
                            O("connection-timeout"): int,
                            O("keep-alive"): bool,
                            O("keep-alive-timeout"): int,
                            O("preserve-host"): bool,
                            O("add-xff-headers"): bool})

    s = {O("blocking-level"): sch.And(int, lambda n: 5 >= n >= 0, error="Bad blocking-level, should be between 0 and 5"),
         K("inbound"): {O("tls-profile"): PROFILE("inbound-tls-profiles", s_ssl_server_profile),
                        O("pki-profile"): PROFILE("inbound-pki-profiles", s_pki_server_profile),
                        K("name"): str,
                        K("port"): sch.And(int, sch.Use(str)),
                        O("log-level"): str,
                        O("keep-alive"): bool,
                        O("keep-alive-timeout"): int,
                        O("timeout"): int,
                        O("ip"): str,
                        O('additional-server-names'): [str]},
         K("backends"): sch.Or(sch.And(s_backend, sch.Use(to_list)), [s_backend]),
         O("inbound-tls-profiles"): {str: s_ssl_server_profile},
         O("inbound-pki-profiles"): {str: s_pki_server_profile},
         O("backend-tls-profiles"): {str: s_ssl_client_profile},
         O("backend-pki-profiles"): {str: s_pki_client_profile}}

    try:
        settings = sch.Schema(s).validate(vc)
    except sch.SchemaError as e:
        logger.critical(e.code)
        sys.exit(1)

    return settings


def create_deployment_protection_token(dep_id):
    r = requests.post(url + "/api/v2/protectionTokens", headers=api_key,
                      json={"deploymentId": dep_id, "name": dep_id})
    handle_error(r, "create deployment protection token")
    return r.json()


def main(args):
    #ptf_ver = get_ptf_version()

    settings = vhost_conf_to_settings(args.vhost_config)
    if logger.level == logging.DEBUG:
        print(json.dumps(settings, indent=4))

    instance = get_instance(args.instance_name)
    gw_conf = get_gateway_config(settings)

    update_gateway_config(instance, gw_conf)

    entrypoint = get_entrypoint(instance, settings['inbound'])

    set_ssl_server_profile(settings['inbound'].get('tlsProfile'), entrypoint)
    set_pki_server_profile(settings['inbound'].get('pkiProfile'), entrypoint)

    endpoints = []
    for backend in settings['backends']:
        endpoint = get_endpoint(backend)
        set_ssl_client_profile(backend.get('tlsProfile'), endpoint)
        set_pki_client_profile(backend.get('pkiProfile'), endpoint)
        update_endpoint_secured_rev(endpoint, backend['apis'])
        endpoints.append(endpoint['id'])

    set_entrypoint_endpoints(entrypoint, endpoints)

    dep_id = create_deployment(instance, settings, wait_time=args.wait)

    if args.ptoken:
        pt = create_deployment_protection_token(dep_id)
        print(f"Protection token for this deployment: {pt['value']}")

    if args.output_file:
        r = requests.get(url + f"/api/v2/deployments/{dep_id}/file",
                         headers=api_key)
        handle_error(r, 'download deployment file')
        args.output_file.write(base64.b64decode(r.json()['data']))
        logger.info(f"deployment file written in {args.output_file.name}")

    if args.reconfigure:
        logger.info(f"reconfiguring firewall instance {instance['name']}")
        r = requests.post(url + f"/api/v2/instances/{instance['id']}/redeploy",
                          headers=api_key)
        handle_error(r, 'firewall reconfigure')
        logger.info("firewall instance reconfigured")

    return 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', required=False, action='count', default=0)
    parser.add_argument('-p', '--platform', required=False, default='https://platform.42crunch.com', type=str)
    parser.add_argument('-i', '--instance-name', required=True, type=str)
    parser.add_argument('-c', '--vhost-config', required=True, type=argparse.FileType('r'))
    parser.add_argument('-t', '--api-token', required=False, type=str)
    parser.add_argument('-r', '--reconfigure', required=False, action='store_true', default=False,
                        help="Automatically reconfigure the instance")
    parser.add_argument('-n', '--create-protection-token', required=False, action='store_true', default=False,
                        dest="ptoken", help="Create a protection token for this deployment")
    parser.add_argument('-o', '--output-file', required=False, type=argparse.FileType('wb'), help=argparse.SUPPRESS)
    parser.add_argument('-w', '--wait', required=False, type=int, help="max time to wait for deployment", default=30)
    args = parser.parse_args()
    logger = logging.getLogger('42crunch-cd')

    log_err = logging.StreamHandler(sys.stderr)
    formatter = RFC5424Formatter('%(pri)s %(isotime)s %(hostname)s %(name)s - - %(levelname)s %(message)s')
    log_err.setFormatter(formatter)
    logger.addHandler(log_err)
    logger.setLevel([logging.WARNING, logging.INFO, logging.DEBUG][min(args.verbose, 2)])

    if args.api_token is not None:
        api_key = {'X-API-KEY': args.api_token}
    else:
        try:
            api_key = {'X-API-KEY': os.environ['API_KEY']}
        except KeyError:
            logger.critical('API TOKEN not found')
            sys.exit(1)
    url = args.platform
    sys.exit(main(args))
