nvidia-patch/tools/nv-driver-locator/nv-driver-locator.py

329 lines
10 KiB
Python
Executable File

#!/usr/bin/env python3
import sys
import json
import argparse
import hashlib
import importlib
import logging
from abc import ABC, abstractmethod
HASH_DELIM = b'\x00'
HASH = hashlib.sha256
class BaseDB(ABC):
@abstractmethod
def check_key(self, key):
pass
@abstractmethod
def set_key(self, key, value):
pass
class FileDB(BaseDB):
def __init__(self, workdir):
self._ospath = importlib.import_module('os.path')
self._tempfile = importlib.import_module('tempfile')
self._wd = workdir
self._test_writable()
def _test_writable(self):
TEST_STRING = b"test"
with self._tempfile.NamedTemporaryFile('w+b', 0, dir=self._wd) as f:
f.write(TEST_STRING)
f.flush()
with open(f.name, 'rb') as tf:
assert tf.read() == TEST_STRING, "Test write failed"
def _get_key_filename(self, key):
return self._ospath.join(self._wd, key + '.json')
def check_key(self, key):
filename = self._get_key_filename(key)
return self._ospath.isfile(filename)
def set_key(self, key, obj):
filename = self._get_key_filename(key)
with open(filename, 'w') as f:
json.dump(obj, f, indent=4)
f.flush()
class Hasher:
def __init__(self, key_components):
self._key_components = key_components
def _eval_key_component(self, obj, component_path):
res = obj
for path_component in component_path:
res = res[path_component]
return str(res).encode('utf-8')
def hash_object(self, obj):
return HASH(HASH_DELIM.join(
self._eval_key_component(obj, c) for c in self._key_components)
).hexdigest()
class BaseNotifier(ABC):
@abstractmethod
def notify(self, obj):
pass
class EmailNotifier(BaseNotifier):
def __init__(self, name, *,
from_addr,
to_addrs,
host='localhost',
port=None,
local_hostname=None,
use_ssl=False,
use_starttls=False,
login=None,
password=None,
timeout=10):
self.name = name
self._from_addr = from_addr
self._Mailer = importlib.import_module('mailer').Mailer
self._MIMEText = importlib.import_module('email.mime.text').MIMEText
self._MIMEMult = importlib.import_module(
'email.mime.multipart').MIMEMultipart
self._MIMEBase = importlib.import_module('email.mime.base').MIMEBase
self._encoders = importlib.import_module('email.encoders')
self._m = self._Mailer(from_addr=from_addr,
host=host,
port=port,
local_hostname=local_hostname,
use_ssl=use_ssl,
use_starttls=use_starttls,
login=login,
password=password,
timeout=timeout)
self._to_addrs = to_addrs
def notify(self, obj):
msg = self._MIMEMult()
msg['Subject'] = "New Nvidia driver available!"
msg['From'] = self._from_addr
msg['To'] = ', '.join(self._to_addrs)
body = "See attached JSON"
msg.attach(self._MIMEText(body, 'plain'))
p = self._MIMEBase('application', 'octet-stream')
p.set_payload(json.dumps(obj, indent=4).encode('utf-8'))
self._encoders.encode_base64(p)
p.add_header('Content-Disposition', "attachment; filename=obj.json")
msg.attach(p)
self._m.send(self._to_addrs, msg.as_string())
class CommandNotifier(BaseNotifier):
def __init__(self, name, *,
cmdline,
timeout=10):
self.name = name
self._subprocess = importlib.import_module('subprocess')
self._cmdline = cmdline
self._timeout = timeout
def notify(self, obj):
proc = self._subprocess.Popen(self._cmdline,
stdin=self._subprocess.PIPE)
try:
proc.communicate(json.dumps(obj, indent=4).encode('utf-8'),
self._timeout)
except self._subprocess.TimeoutExpired:
proc.kill()
proc.communicate()
class BaseChannel(ABC):
@abstractmethod
def get_latest_driver(self):
pass
class GFEClientChannel(BaseChannel):
def __init__(self, name, **kwargs):
self.name = name
self._kwargs = kwargs
gfe_get_driver = importlib.import_module('gfe_get_driver')
self._get_latest_driver = gfe_get_driver.get_latest_geforce_driver
def get_latest_driver(self):
return self._get_latest_driver(**self._kwargs)
class NvidiaDownloadsChannel(BaseChannel):
def __init__(self, name, *,
os="Linux_64",
product="GeForce",
certlevel="All",
driver_type="Standard",
lang="English",
cuda_ver="Nothing"):
self.name = name
gnd = importlib.import_module('get_nvidia_downloads')
self._gnd = gnd
self._os = gnd.OS[os]
self._product = gnd.Product[product]
self._certlevel = gnd.CertLevel[certlevel]
self._driver_type = gnd.DriverType[driver_type]
self._lang = gnd.DriverLanguage[lang]
self._cuda_ver = gnd.CUDAToolkitVersion[cuda_ver]
def get_latest_driver(self):
drivers = self._gnd.get_drivers(os=self._os,
product=self._product,
certlevel=self._certlevel,
driver_type=self._driver_type,
lang=self._lang,
cuda_ver=self._cuda_ver)
if not drivers:
return None
latest = max(drivers, key=lambda d: tuple(d['version'].split('.')))
return {
'DriverAttributes': {
'Version': latest['version'],
'Name': latest['name'],
'NameLocalized': latest['name'],
}
}
def parse_args():
parser = argparse.ArgumentParser(
description="Watches for GeForce experience driver updates for "
"configured systems",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-c", "--config",
default="/etc/nv-driver-locator.json",
help="config file location")
args = parser.parse_args()
return args
class DriverLocator:
_ret_code = 0
def __init__(self, conf):
self._logger = logging.getLogger(self.__class__.__name__)
self._channels = self._construct_channels(conf['channels'])
self._db = self._construct_db(conf['db'])
self._hasher = Hasher(conf['key_components'])
self._notifiers = self._construct_notifiers(conf['notifiers'])
def _construct_channels(self, channels_config):
channel_types = {
'gfe_client': GFEClientChannel,
'nvidia_downloads': NvidiaDownloadsChannel,
}
channels = []
for ch in channels_config:
try:
ctor = channel_types[ch['type']]
C = ctor(ch['name'], **ch['params'])
except Exception as e:
self._perror("Channel construction failed with exception: %s. "
"Skipping..." % (str(e),))
else:
channels.append(C)
return channels
def _construct_db(self, db_config):
db_types = {
'file': FileDB,
}
ctor = db_types[db_config['type']]
db = ctor(**db_config['params'])
return db
def _construct_notifiers(self, notifiers_config):
notifier_types = {
'email': EmailNotifier,
'command': CommandNotifier,
}
notifiers = []
for nc in notifiers_config:
try:
ctor = notifier_types[nc['type']]
N = ctor(nc['name'], **nc['params'])
except Exception as e:
self._perror("Notifier construction failed with exception: %s."
" Skipping..." % (str(e),))
else:
notifiers.append(N)
return notifiers
def _perror(self, err):
self._ret_code = 3
self._logger.error(err)
def _notify_all(self, obj):
fails = 0
for n in self._notifiers:
try:
n.notify(obj)
except Exception as e:
self._perror("Notify channel %s failed with exception: %s." %
(n.name, str(e)))
fails += 1
return fails < len(self._notifiers)
def run(self):
for ch in self._channels:
try:
drv = ch.get_latest_driver()
except Exception as e:
self._perror("get_latest_driver() invocation failed for "
"channel %s. Exception: %s. Continuing..." %
(repr(ch.name), str(e)))
continue
if drv is None:
self._perror("Driver not found for channel %s" %
(repr(ch.name),))
continue
try:
key = self._hasher.hash_object(drv)
except Exception as e:
self._perror("Key evaluation failed for channel %s. "
"Exception: %s" % (repr(name), str(e)))
continue
if not self._db.check_key(key):
if self._notify_all(drv):
self._db.set_key(key, drv)
return self._ret_code
def setup_logger(name, verbosity):
logger = logging.getLogger(name)
logger.setLevel(verbosity)
handler = logging.StreamHandler()
handler.setLevel(verbosity)
handler.setFormatter(logging.Formatter('%(asctime)s '
'%(levelname)-8s '
'%(name)s: %(message)s',
'%Y-%m-%d %H:%M:%S'))
logger.addHandler(handler)
return logger
def main():
args = parse_args()
setup_logger(DriverLocator.__name__, logging.ERROR)
with open(args.config, 'r') as conf_file:
conf = json.load(conf_file)
ret = DriverLocator(conf).run()
sys.exit(ret)
if __name__ == '__main__':
main()