mirror of
https://github.com/keylase/nvidia-patch.git
synced 2025-08-09 20:14:01 +00:00
nv-driver-locator: implemented new NvidiaDownloadsChannel "plugin"
This commit is contained in:
328
tools/nv-driver-locator/nv-driver-locator.py
Executable file
328
tools/nv-driver-locator/nv-driver-locator.py
Executable file
@@ -0,0 +1,328 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import hashlib
|
||||
import importlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
HASH_DELIM = b'\x00'
|
||||
HASH = hashlib.sha256
|
||||
|
||||
|
||||
class BaseDB(ABC):
|
||||
@abstractmethod
|
||||
def check_key(self, key):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_key(self, key, value):
|
||||
pass
|
||||
|
||||
|
||||
class FileDB(BaseDB):
|
||||
def __init__(self, workdir):
|
||||
self._ospath = importlib.import_module('os.path')
|
||||
self._tempfile = importlib.import_module('tempfile')
|
||||
self._wd = workdir
|
||||
self._test_writable()
|
||||
|
||||
def _test_writable(self):
|
||||
TEST_STRING = b"test"
|
||||
with self._tempfile.NamedTemporaryFile('w+b', 0, dir=self._wd) as f:
|
||||
f.write(TEST_STRING)
|
||||
f.flush()
|
||||
with open(f.name, 'rb') as tf:
|
||||
assert tf.read() == TEST_STRING, "Test write failed"
|
||||
|
||||
def _get_key_filename(self, key):
|
||||
return self._ospath.join(self._wd, key + '.json')
|
||||
|
||||
def check_key(self, key):
|
||||
filename = self._get_key_filename(key)
|
||||
return self._ospath.isfile(filename)
|
||||
|
||||
def set_key(self, key, obj):
|
||||
filename = self._get_key_filename(key)
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(obj, f, indent=4)
|
||||
f.flush()
|
||||
|
||||
|
||||
class Hasher:
|
||||
def __init__(self, key_components):
|
||||
self._key_components = key_components
|
||||
|
||||
def _eval_key_component(self, obj, component_path):
|
||||
res = obj
|
||||
for path_component in component_path:
|
||||
res = res[path_component]
|
||||
return str(res).encode('utf-8')
|
||||
|
||||
def hash_object(self, obj):
|
||||
return HASH(HASH_DELIM.join(
|
||||
self._eval_key_component(obj, c) for c in self._key_components)
|
||||
).hexdigest()
|
||||
|
||||
|
||||
class BaseNotifier(ABC):
|
||||
@abstractmethod
|
||||
def notify(self, obj):
|
||||
pass
|
||||
|
||||
|
||||
class EmailNotifier(BaseNotifier):
|
||||
def __init__(self, name, *,
|
||||
from_addr,
|
||||
to_addrs,
|
||||
host='localhost',
|
||||
port=None,
|
||||
local_hostname=None,
|
||||
use_ssl=False,
|
||||
use_starttls=False,
|
||||
login=None,
|
||||
password=None,
|
||||
timeout=10):
|
||||
self.name = name
|
||||
self._from_addr = from_addr
|
||||
self._Mailer = importlib.import_module('mailer').Mailer
|
||||
self._MIMEText = importlib.import_module('email.mime.text').MIMEText
|
||||
self._MIMEMult = importlib.import_module(
|
||||
'email.mime.multipart').MIMEMultipart
|
||||
self._MIMEBase = importlib.import_module('email.mime.base').MIMEBase
|
||||
self._encoders = importlib.import_module('email.encoders')
|
||||
self._m = self._Mailer(from_addr=from_addr,
|
||||
host=host,
|
||||
port=port,
|
||||
local_hostname=local_hostname,
|
||||
use_ssl=use_ssl,
|
||||
use_starttls=use_starttls,
|
||||
login=login,
|
||||
password=password,
|
||||
timeout=timeout)
|
||||
self._to_addrs = to_addrs
|
||||
|
||||
def notify(self, obj):
|
||||
msg = self._MIMEMult()
|
||||
msg['Subject'] = "New Nvidia driver available!"
|
||||
msg['From'] = self._from_addr
|
||||
msg['To'] = ', '.join(self._to_addrs)
|
||||
body = "See attached JSON"
|
||||
msg.attach(self._MIMEText(body, 'plain'))
|
||||
p = self._MIMEBase('application', 'octet-stream')
|
||||
p.set_payload(json.dumps(obj, indent=4).encode('utf-8'))
|
||||
self._encoders.encode_base64(p)
|
||||
p.add_header('Content-Disposition', "attachment; filename=obj.json")
|
||||
msg.attach(p)
|
||||
self._m.send(self._to_addrs, msg.as_string())
|
||||
|
||||
|
||||
class CommandNotifier(BaseNotifier):
|
||||
def __init__(self, name, *,
|
||||
cmdline,
|
||||
timeout=10):
|
||||
self.name = name
|
||||
self._subprocess = importlib.import_module('subprocess')
|
||||
self._cmdline = cmdline
|
||||
self._timeout = timeout
|
||||
|
||||
def notify(self, obj):
|
||||
proc = self._subprocess.Popen(self._cmdline,
|
||||
stdin=self._subprocess.PIPE)
|
||||
try:
|
||||
proc.communicate(json.dumps(obj, indent=4).encode('utf-8'),
|
||||
self._timeout)
|
||||
except self._subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.communicate()
|
||||
|
||||
|
||||
class BaseChannel(ABC):
|
||||
@abstractmethod
|
||||
def get_latest_driver(self):
|
||||
pass
|
||||
|
||||
|
||||
class GFEClientChannel(BaseChannel):
|
||||
def __init__(self, name, **kwargs):
|
||||
self.name = name
|
||||
self._kwargs = kwargs
|
||||
gfe_get_driver = importlib.import_module('gfe_get_driver')
|
||||
self._get_latest_driver = gfe_get_driver.get_latest_geforce_driver
|
||||
|
||||
def get_latest_driver(self):
|
||||
return self._get_latest_driver(**self._kwargs)
|
||||
|
||||
|
||||
class NvidiaDownloadsChannel(BaseChannel):
|
||||
def __init__(self, name, *,
|
||||
os="Linux_64",
|
||||
product="GeForce",
|
||||
certlevel="All",
|
||||
driver_type="Standard",
|
||||
lang="English",
|
||||
cuda_ver="Nothing"):
|
||||
self.name = name
|
||||
gnd = importlib.import_module('get_nvidia_downloads')
|
||||
self._gnd = gnd
|
||||
self._os = gnd.OS[os]
|
||||
self._product = gnd.Product[product]
|
||||
self._certlevel = gnd.CertLevel[certlevel]
|
||||
self._driver_type = gnd.DriverType[driver_type]
|
||||
self._lang = gnd.DriverLanguage[lang]
|
||||
self._cuda_ver = gnd.CUDAToolkitVersion[cuda_ver]
|
||||
|
||||
def get_latest_driver(self):
|
||||
drivers = self._gnd.get_drivers(os=self._os,
|
||||
product=self._product,
|
||||
certlevel=self._certlevel,
|
||||
driver_type=self._driver_type,
|
||||
lang=self._lang,
|
||||
cuda_ver=self._cuda_ver)
|
||||
if not drivers:
|
||||
return None
|
||||
latest = max(drivers, key=lambda d: tuple(d['version'].split('.')))
|
||||
return {
|
||||
'DriverAttributes': {
|
||||
'Version': latest['version'],
|
||||
'Name': latest['name'],
|
||||
'NameLocalized': latest['name'],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Watches for GeForce experience driver updates for "
|
||||
"configured systems",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("-c", "--config",
|
||||
default="/etc/nv-driver-locator.json",
|
||||
help="config file location")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class DriverLocator:
|
||||
_ret_code = 0
|
||||
|
||||
def __init__(self, conf):
|
||||
self._logger = logging.getLogger(self.__class__.__name__)
|
||||
self._channels = self._construct_channels(conf['channels'])
|
||||
self._db = self._construct_db(conf['db'])
|
||||
self._hasher = Hasher(conf['key_components'])
|
||||
self._notifiers = self._construct_notifiers(conf['notifiers'])
|
||||
|
||||
def _construct_channels(self, channels_config):
|
||||
channel_types = {
|
||||
'gfe_client': GFEClientChannel,
|
||||
'nvidia_downloads': NvidiaDownloadsChannel,
|
||||
}
|
||||
|
||||
channels = []
|
||||
for ch in channels_config:
|
||||
try:
|
||||
ctor = channel_types[ch['type']]
|
||||
C = ctor(ch['name'], **ch['params'])
|
||||
except Exception as e:
|
||||
self._perror("Channel construction failed with exception: %s. "
|
||||
"Skipping..." % (str(e),))
|
||||
else:
|
||||
channels.append(C)
|
||||
return channels
|
||||
|
||||
def _construct_db(self, db_config):
|
||||
db_types = {
|
||||
'file': FileDB,
|
||||
}
|
||||
ctor = db_types[db_config['type']]
|
||||
db = ctor(**db_config['params'])
|
||||
return db
|
||||
|
||||
def _construct_notifiers(self, notifiers_config):
|
||||
notifier_types = {
|
||||
'email': EmailNotifier,
|
||||
'command': CommandNotifier,
|
||||
}
|
||||
|
||||
notifiers = []
|
||||
for nc in notifiers_config:
|
||||
try:
|
||||
ctor = notifier_types[nc['type']]
|
||||
N = ctor(nc['name'], **nc['params'])
|
||||
except Exception as e:
|
||||
self._perror("Notifier construction failed with exception: %s."
|
||||
" Skipping..." % (str(e),))
|
||||
else:
|
||||
notifiers.append(N)
|
||||
return notifiers
|
||||
|
||||
def _perror(self, err):
|
||||
self._ret_code = 3
|
||||
self._logger.error(err)
|
||||
|
||||
def _notify_all(self, obj):
|
||||
fails = 0
|
||||
for n in self._notifiers:
|
||||
try:
|
||||
n.notify(obj)
|
||||
except Exception as e:
|
||||
self._perror("Notify channel %s failed with exception: %s." %
|
||||
(n.name, str(e)))
|
||||
fails += 1
|
||||
return fails < len(self._notifiers)
|
||||
|
||||
def run(self):
|
||||
for ch in self._channels:
|
||||
try:
|
||||
drv = ch.get_latest_driver()
|
||||
except Exception as e:
|
||||
self._perror("get_latest_driver() invocation failed for "
|
||||
"channel %s. Exception: %s. Continuing..." %
|
||||
(repr(ch.name), str(e)))
|
||||
continue
|
||||
if drv is None:
|
||||
self._perror("Driver not found for channel %s" %
|
||||
(repr(ch.name),))
|
||||
continue
|
||||
try:
|
||||
key = self._hasher.hash_object(drv)
|
||||
except Exception as e:
|
||||
self._perror("Key evaluation failed for channel %s. "
|
||||
"Exception: %s" % (repr(name), str(e)))
|
||||
continue
|
||||
if not self._db.check_key(key):
|
||||
if self._notify_all(drv):
|
||||
self._db.set_key(key, drv)
|
||||
return self._ret_code
|
||||
|
||||
|
||||
def setup_logger(name, verbosity):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(verbosity)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(verbosity)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s '
|
||||
'%(levelname)-8s '
|
||||
'%(name)s: %(message)s',
|
||||
'%Y-%m-%d %H:%M:%S'))
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
setup_logger(DriverLocator.__name__, logging.ERROR)
|
||||
|
||||
with open(args.config, 'r') as conf_file:
|
||||
conf = json.load(conf_file)
|
||||
|
||||
ret = DriverLocator(conf).run()
|
||||
sys.exit(ret)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user