#!/usr/bin/env python3

import sys
import json
import argparse
import hashlib
import importlib
import logging
from abc import ABC, abstractmethod


HASH_DELIM = b'\x00'
HASH = hashlib.sha256


class BaseDB(ABC):
    @abstractmethod
    def check_key(self, key):
        pass

    @abstractmethod
    def set_key(self, key, value):
        pass


class FileDB(BaseDB):
    def __init__(self, workdir):
        self._ospath = importlib.import_module('os.path')
        self._tempfile = importlib.import_module('tempfile')
        self._wd = workdir
        self._test_writable()

    def _test_writable(self):
        TEST_STRING = b"test"
        with self._tempfile.NamedTemporaryFile('w+b', 0, dir=self._wd) as f:
            f.write(TEST_STRING)
            f.flush()
            with open(f.name, 'rb') as tf:
                assert tf.read() == TEST_STRING, "Test write failed"

    def _get_key_filename(self, key):
        return self._ospath.join(self._wd, key + '.json')

    def check_key(self, key):
        filename = self._get_key_filename(key)
        return self._ospath.isfile(filename)

    def set_key(self, key, obj):
        filename = self._get_key_filename(key)
        with open(filename, 'w') as f:
            json.dump(obj, f, indent=4)
            f.flush()


class Hasher:
    def __init__(self, key_components):
        self._key_components = key_components

    def _eval_key_component(self, obj, component_path):
        res = obj
        for path_component in component_path:
            res = res[path_component]
        return str(res).encode('utf-8')

    def hash_object(self, obj):
        return HASH(HASH_DELIM.join(
            self._eval_key_component(obj, c) for c in self._key_components)
        ).hexdigest()


class BaseNotifier(ABC):
    @abstractmethod
    def notify(self, obj):
        pass


class EmailNotifier(BaseNotifier):
    def __init__(self, name, *,
                 from_addr,
                 to_addrs,
                 host='localhost',
                 port=None,
                 local_hostname=None,
                 use_ssl=False,
                 use_starttls=False,
                 login=None,
                 password=None,
                 timeout=10):
        self.name = name
        self._from_addr = from_addr
        self._Mailer = importlib.import_module('mailer').Mailer
        self._MIMEText = importlib.import_module('email.mime.text').MIMEText
        self._MIMEMult = importlib.import_module(
            'email.mime.multipart').MIMEMultipart
        self._MIMEBase = importlib.import_module('email.mime.base').MIMEBase
        self._encoders = importlib.import_module('email.encoders')
        self._m = self._Mailer(from_addr=from_addr,
                               host=host,
                               port=port,
                               local_hostname=local_hostname,
                               use_ssl=use_ssl,
                               use_starttls=use_starttls,
                               login=login,
                               password=password,
                               timeout=timeout)
        self._to_addrs = to_addrs

    def notify(self, obj):
        msg = self._MIMEMult()
        msg['Subject'] = "New Nvidia driver available!"
        msg['From'] = self._from_addr
        msg['To'] = ', '.join(self._to_addrs)
        body = "See attached JSON"
        msg.attach(self._MIMEText(body, 'plain'))
        p = self._MIMEBase('application', 'octet-stream')
        p.set_payload(json.dumps(obj, indent=4).encode('utf-8'))
        self._encoders.encode_base64(p)
        p.add_header('Content-Disposition', "attachment; filename=obj.json")
        msg.attach(p)
        self._m.send(self._to_addrs, msg.as_string())


class CommandNotifier(BaseNotifier):
    def __init__(self, name, *,
                 cmdline,
                 timeout=10):
        self.name = name
        self._subprocess = importlib.import_module('subprocess')
        self._cmdline = cmdline
        self._timeout = timeout

    def notify(self, obj):
        proc = self._subprocess.Popen(self._cmdline,
                                      stdin=self._subprocess.PIPE)
        try:
            proc.communicate(json.dumps(obj, indent=4).encode('utf-8'),
                             self._timeout)
        except self._subprocess.TimeoutExpired:
            proc.kill()
            proc.communicate()


class BaseChannel(ABC):
    @abstractmethod
    def get_latest_driver(self):
        pass


class GFEClientChannel(BaseChannel):
    def __init__(self, name, **kwargs):
        self.name = name
        self._kwargs = kwargs
        gfe_get_driver = importlib.import_module('gfe_get_driver')
        self._get_latest_driver = gfe_get_driver.get_latest_geforce_driver

    def get_latest_driver(self):
        return self._get_latest_driver(**self._kwargs)


class NvidiaDownloadsChannel(BaseChannel):
    def __init__(self, name, *,
                 os="Linux_64",
                 product="GeForce",
                 certlevel="All",
                 driver_type="Standard",
                 lang="English",
                 cuda_ver="Nothing",
                 timeout=10):
        self.name = name
        gnd = importlib.import_module('get_nvidia_downloads')
        self._gnd = gnd
        self._os = gnd.OS[os]
        self._product = gnd.Product[product]
        self._certlevel = gnd.CertLevel[certlevel]
        self._driver_type = gnd.DriverType[driver_type]
        self._lang = gnd.DriverLanguage[lang]
        self._cuda_ver = gnd.CUDAToolkitVersion[cuda_ver]
        self._timeout = timeout

    def get_latest_driver(self):
        drivers = self._gnd.get_drivers(os=self._os,
                                        product=self._product,
                                        certlevel=self._certlevel,
                                        driver_type=self._driver_type,
                                        lang=self._lang,
                                        cuda_ver=self._cuda_ver,
                                        timeout=self._timeout)
        if not drivers:
            return None
        latest = max(drivers, key=lambda d: tuple(d['version'].split('.')))
        return {
            'DriverAttributes': {
                'Version': latest['version'],
                'Name': latest['name'],
                'NameLocalized': latest['name'],
            }
        }


class CudaToolkitDownloadsChannel(BaseChannel):
    def __init__(self, name, *,
                 timeout=10):
        self.name = name
        gcd = importlib.import_module('get_cuda_downloads')
        self._gcd = gcd
        self._timeout = timeout

    def get_latest_driver(self):
        latest = self._gcd.get_latest_cuda_tk(timeout=self._timeout)
        if not latest:
            return None
        return {
            'DriverAttributes': {
                'Version': '???',
                'Name': latest,
                'NameLocalized': latest,
            }
        }


def parse_args():
    parser = argparse.ArgumentParser(
        description="Watches for GeForce experience driver updates for "
        "configured systems",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-c", "--config",
                        default="/etc/nv-driver-locator.json",
                        help="config file location")
    args = parser.parse_args()
    return args


class DriverLocator:
    _ret_code = 0

    def __init__(self, conf):
        self._logger = logging.getLogger(self.__class__.__name__)
        self._channels = self._construct_channels(conf['channels'])
        self._db = self._construct_db(conf['db'])
        self._hasher = Hasher(conf['key_components'])
        self._notifiers = self._construct_notifiers(conf['notifiers'])

    def _construct_channels(self, channels_config):
        channel_types = {
            'gfe_client': GFEClientChannel,
            'nvidia_downloads': NvidiaDownloadsChannel,
            'cuda_downloads': CudaToolkitDownloadsChannel,
        }

        channels = []
        for ch in channels_config:
            try:
                ctor = channel_types[ch['type']]
                C = ctor(ch['name'], **ch['params'])
            except Exception as e:
                self._perror("Channel construction failed with exception: %s. "
                             "Skipping..." % (str(e),))
            else:
                channels.append(C)
        return channels

    def _construct_db(self, db_config):
        db_types = {
            'file': FileDB,
        }
        ctor = db_types[db_config['type']]
        db = ctor(**db_config['params'])
        return db

    def _construct_notifiers(self, notifiers_config):
        notifier_types = {
            'email': EmailNotifier,
            'command': CommandNotifier,
        }

        notifiers = []
        for nc in notifiers_config:
            try:
                ctor = notifier_types[nc['type']]
                N = ctor(nc['name'], **nc['params'])
            except Exception as e:
                self._perror("Notifier construction failed with exception: %s."
                             " Skipping..." % (str(e),))
            else:
                notifiers.append(N)
        return notifiers

    def _perror(self, err):
        self._ret_code = 3
        self._logger.error(err)

    def _notify_all(self, obj):
        fails = 0
        for n in self._notifiers:
            try:
                n.notify(obj)
            except Exception as e:
                self._perror("Notify channel %s failed with exception: %s." %
                             (n.name, str(e)))
                fails += 1
        return fails < len(self._notifiers)

    def run(self):
        for ch in self._channels:
            try:
                drv = ch.get_latest_driver()
            except Exception as e:
                self._perror("get_latest_driver() invocation failed for "
                             "channel %s. Exception: %s. Continuing..." %
                             (repr(ch.name), str(e)))
                continue
            if drv is None:
                self._perror("Driver not found for channel %s" %
                             (repr(ch.name),))
                continue
            try:
                key = self._hasher.hash_object(drv)
            except Exception as e:
                self._perror("Key evaluation failed for channel %s. "
                             "Exception: %s" % (repr(name), str(e)))
                continue
            if not self._db.check_key(key):
                if self._notify_all(drv):
                    self._db.set_key(key, drv)
        return self._ret_code


def setup_logger(name, verbosity):
    logger = logging.getLogger(name)
    logger.setLevel(verbosity)
    handler = logging.StreamHandler()
    handler.setLevel(verbosity)
    handler.setFormatter(logging.Formatter('%(asctime)s '
                                           '%(levelname)-8s '
                                           '%(name)s: %(message)s',
                                           '%Y-%m-%d %H:%M:%S'))
    logger.addHandler(handler)
    return logger


def main():
    args = parse_args()
    setup_logger(DriverLocator.__name__, logging.ERROR)

    with open(args.config, 'r') as conf_file:
        conf = json.load(conf_file)

    ret = DriverLocator(conf).run()
    sys.exit(ret)


if __name__ == '__main__':
    main()