mirror of
https://github.com/keylase/nvidia-patch.git
synced 2024-11-22 13:37:27 +00:00
406 lines
13 KiB
Python
Executable File
406 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import sys
|
|
import json
|
|
import argparse
|
|
import hashlib
|
|
import importlib
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
HASH_DELIM = b'\x00'
|
|
HASH = hashlib.sha256
|
|
|
|
|
|
class BaseDB(ABC):
|
|
@abstractmethod
|
|
def check_key(self, key):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def set_key(self, key, value):
|
|
pass
|
|
|
|
|
|
class FileDB(BaseDB):
|
|
def __init__(self, workdir):
|
|
self._ospath = importlib.import_module('os.path')
|
|
self._tempfile = importlib.import_module('tempfile')
|
|
self._wd = workdir
|
|
self._test_writable()
|
|
|
|
def _test_writable(self):
|
|
TEST_STRING = b"test"
|
|
with self._tempfile.NamedTemporaryFile('w+b', 0, dir=self._wd) as f:
|
|
f.write(TEST_STRING)
|
|
f.flush()
|
|
with open(f.name, 'rb') as tf:
|
|
assert tf.read() == TEST_STRING, "Test write failed"
|
|
|
|
def _get_key_filename(self, key):
|
|
return self._ospath.join(self._wd, key + '.json')
|
|
|
|
def check_key(self, key):
|
|
filename = self._get_key_filename(key)
|
|
return self._ospath.isfile(filename)
|
|
|
|
def set_key(self, key, obj):
|
|
filename = self._get_key_filename(key)
|
|
with open(filename, 'w') as f:
|
|
json.dump(obj, f, indent=4)
|
|
f.flush()
|
|
|
|
|
|
class Hasher:
|
|
def __init__(self, key_components):
|
|
self._key_components = key_components
|
|
|
|
def _eval_key_component(self, obj, component_path):
|
|
res = obj
|
|
try:
|
|
for path_component in component_path:
|
|
res = res[path_component]
|
|
except (KeyError, IndexError):
|
|
return b''
|
|
return str(res).encode('utf-8')
|
|
|
|
def hash_object(self, obj):
|
|
return HASH(HASH_DELIM.join(
|
|
self._eval_key_component(obj, c) for c in self._key_components)
|
|
).hexdigest()
|
|
|
|
|
|
class BaseNotifier(ABC):
|
|
@abstractmethod
|
|
def notify(self, obj):
|
|
pass
|
|
|
|
|
|
class EmailNotifier(BaseNotifier):
|
|
def __init__(self, name, *,
|
|
from_addr,
|
|
to_addrs,
|
|
host='localhost',
|
|
port=None,
|
|
local_hostname=None,
|
|
use_ssl=False,
|
|
use_starttls=False,
|
|
login=None,
|
|
password=None,
|
|
timeout=10):
|
|
self.name = name
|
|
self._from_addr = from_addr
|
|
self._Mailer = importlib.import_module('mailer').Mailer
|
|
self._MIMEText = importlib.import_module('email.mime.text').MIMEText
|
|
self._MIMEMult = importlib.import_module(
|
|
'email.mime.multipart').MIMEMultipart
|
|
self._MIMEBase = importlib.import_module('email.mime.base').MIMEBase
|
|
self._encoders = importlib.import_module('email.encoders')
|
|
self._m = self._Mailer(from_addr=from_addr,
|
|
host=host,
|
|
port=port,
|
|
local_hostname=local_hostname,
|
|
use_ssl=use_ssl,
|
|
use_starttls=use_starttls,
|
|
login=login,
|
|
password=password,
|
|
timeout=timeout)
|
|
self._to_addrs = to_addrs
|
|
|
|
def notify(self, obj):
|
|
msg = self._MIMEMult()
|
|
msg['Subject'] = "New Nvidia driver available!"
|
|
msg['From'] = self._from_addr
|
|
msg['To'] = ', '.join(self._to_addrs)
|
|
body = "See attached JSON"
|
|
msg.attach(self._MIMEText(body, 'plain'))
|
|
p = self._MIMEBase('application', 'octet-stream')
|
|
p.set_payload(json.dumps(obj, indent=4).encode('utf-8'))
|
|
self._encoders.encode_base64(p)
|
|
p.add_header('Content-Disposition', "attachment; filename=obj.json")
|
|
msg.attach(p)
|
|
self._m.send(self._to_addrs, msg.as_string())
|
|
|
|
|
|
class CommandNotifier(BaseNotifier):
|
|
def __init__(self, name, *,
|
|
cmdline,
|
|
timeout=10):
|
|
self.name = name
|
|
self._subprocess = importlib.import_module('subprocess')
|
|
self._cmdline = cmdline
|
|
self._timeout = timeout
|
|
|
|
def notify(self, obj):
|
|
proc = self._subprocess.Popen(self._cmdline,
|
|
stdin=self._subprocess.PIPE)
|
|
try:
|
|
proc.communicate(json.dumps(obj, indent=4).encode('utf-8'),
|
|
self._timeout)
|
|
except self._subprocess.TimeoutExpired:
|
|
proc.kill()
|
|
proc.communicate()
|
|
|
|
|
|
class BaseChannel(ABC):
|
|
@abstractmethod
|
|
def get_latest_driver(self):
|
|
pass
|
|
|
|
|
|
class GFEClientChannel(BaseChannel):
|
|
def __init__(self, name, notebook=False,
|
|
x86_64=True,
|
|
os_version="10.0",
|
|
os_build="17763",
|
|
language=1033,
|
|
beta=False,
|
|
dch=False,
|
|
crd=False,
|
|
timeout=10):
|
|
self.name = name
|
|
self._notebook = notebook
|
|
self._x86_64 = x86_64
|
|
self._os_version = os_version
|
|
self._os_build = os_build
|
|
self._language = language
|
|
self._beta = beta
|
|
self._dch = dch
|
|
self._crd = crd
|
|
self._timeout = timeout
|
|
gfe_get_driver = importlib.import_module('gfe_get_driver')
|
|
self._get_latest_driver = gfe_get_driver.get_latest_geforce_driver
|
|
|
|
def get_latest_driver(self):
|
|
res = self._get_latest_driver(notebook=self._notebook,
|
|
x86_64=self._x86_64,
|
|
os_version=self._os_version,
|
|
os_build=self._os_build,
|
|
language=self._language,
|
|
beta=self._beta,
|
|
dch=self._dch,
|
|
crd=self._crd,
|
|
timeout=self._timeout)
|
|
res.update({
|
|
'ChannelAttributes': {
|
|
'Name': self.name,
|
|
'Type': self.__class__.__name__,
|
|
'OS': 'Windows%d_%d' % (float(self._os_version),
|
|
64 if self._x86_64 else 32),
|
|
'OSBuild': self._os_build,
|
|
'Language': self._language,
|
|
'Beta': self._beta,
|
|
'DCH': self._dch,
|
|
'CRD': self._crd,
|
|
'Mobile': self._notebook,
|
|
}
|
|
})
|
|
return res
|
|
|
|
|
|
class NvidiaDownloadsChannel(BaseChannel):
|
|
def __init__(self, name, *,
|
|
os="Linux_64",
|
|
product="GeForce",
|
|
certlevel="All",
|
|
driver_type="Standard",
|
|
lang="English",
|
|
cuda_ver="Nothing",
|
|
timeout=10):
|
|
self.name = name
|
|
gnd = importlib.import_module('get_nvidia_downloads')
|
|
self._gnd = gnd
|
|
self._os = gnd.OS[os]
|
|
self._product = gnd.Product[product]
|
|
self._certlevel = gnd.CertLevel[certlevel]
|
|
self._driver_type = gnd.DriverType[driver_type]
|
|
self._lang = gnd.DriverLanguage[lang]
|
|
self._cuda_ver = gnd.CUDAToolkitVersion[cuda_ver]
|
|
self._timeout = timeout
|
|
|
|
def get_latest_driver(self):
|
|
drivers = self._gnd.get_drivers(os=self._os,
|
|
product=self._product,
|
|
certlevel=self._certlevel,
|
|
driver_type=self._driver_type,
|
|
lang=self._lang,
|
|
cuda_ver=self._cuda_ver,
|
|
timeout=self._timeout)
|
|
if not drivers:
|
|
return None
|
|
latest = max(drivers, key=lambda d: tuple(d['version'].split('.')))
|
|
return {
|
|
'DriverAttributes': {
|
|
'Version': latest['version'],
|
|
'Name': latest['name'],
|
|
'NameLocalized': latest['name'],
|
|
},
|
|
'ChannelAttributes': {
|
|
'Name': self.name,
|
|
'Type': self.__class__.__name__,
|
|
'OS': self._os.name,
|
|
'Product': self._product.name,
|
|
'CertLevel': self._certlevel.name,
|
|
'DriverType': self._driver_type.name,
|
|
'Lang': self._lang.name,
|
|
'CudaVer': self._cuda_ver.name,
|
|
}
|
|
}
|
|
|
|
|
|
class CudaToolkitDownloadsChannel(BaseChannel):
|
|
def __init__(self, name, *,
|
|
timeout=10):
|
|
self.name = name
|
|
gcd = importlib.import_module('get_cuda_downloads')
|
|
self._gcd = gcd
|
|
self._timeout = timeout
|
|
|
|
def get_latest_driver(self):
|
|
latest = self._gcd.get_latest_cuda_tk(timeout=self._timeout)
|
|
if not latest:
|
|
return None
|
|
return {
|
|
'DriverAttributes': {
|
|
'Version': '???',
|
|
'Name': latest,
|
|
'NameLocalized': latest,
|
|
}
|
|
}
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Watches for GeForce experience driver updates for "
|
|
"configured systems",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument("-c", "--config",
|
|
default="/etc/nv-driver-locator.json",
|
|
help="config file location")
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
class DriverLocator:
|
|
_ret_code = 0
|
|
|
|
def __init__(self, conf):
|
|
self._logger = logging.getLogger(self.__class__.__name__)
|
|
self._channels = self._construct_channels(conf['channels'])
|
|
self._db = self._construct_db(conf['db'])
|
|
self._hasher = Hasher(conf['key_components'])
|
|
self._notifiers = self._construct_notifiers(conf['notifiers'])
|
|
|
|
def _construct_channels(self, channels_config):
|
|
channel_types = {
|
|
'gfe_client': GFEClientChannel,
|
|
'nvidia_downloads': NvidiaDownloadsChannel,
|
|
'cuda_downloads': CudaToolkitDownloadsChannel,
|
|
}
|
|
|
|
channels = []
|
|
for ch in channels_config:
|
|
try:
|
|
ctor = channel_types[ch['type']]
|
|
C = ctor(ch['name'], **ch['params'])
|
|
except Exception as e:
|
|
self._perror("Channel construction failed with exception: %s. "
|
|
"Skipping..." % (str(e),))
|
|
else:
|
|
channels.append(C)
|
|
return channels
|
|
|
|
def _construct_db(self, db_config):
|
|
db_types = {
|
|
'file': FileDB,
|
|
}
|
|
ctor = db_types[db_config['type']]
|
|
db = ctor(**db_config['params'])
|
|
return db
|
|
|
|
def _construct_notifiers(self, notifiers_config):
|
|
notifier_types = {
|
|
'email': EmailNotifier,
|
|
'command': CommandNotifier,
|
|
}
|
|
|
|
notifiers = []
|
|
for nc in notifiers_config:
|
|
try:
|
|
ctor = notifier_types[nc['type']]
|
|
N = ctor(nc['name'], **nc['params'])
|
|
except Exception as e:
|
|
self._perror("Notifier construction failed with exception: %s."
|
|
" Skipping..." % (str(e),))
|
|
else:
|
|
notifiers.append(N)
|
|
return notifiers
|
|
|
|
def _perror(self, err):
|
|
self._ret_code = 3
|
|
self._logger.error(err)
|
|
|
|
def _notify_all(self, obj):
|
|
fails = 0
|
|
for n in self._notifiers:
|
|
try:
|
|
n.notify(obj)
|
|
except Exception as e:
|
|
self._perror("Notify channel %s failed with exception: %s." %
|
|
(n.name, str(e)))
|
|
fails += 1
|
|
return fails < len(self._notifiers)
|
|
|
|
def run(self):
|
|
for ch in self._channels:
|
|
try:
|
|
drv = ch.get_latest_driver()
|
|
except Exception as e:
|
|
self._perror("get_latest_driver() invocation failed for "
|
|
"channel %s. Exception: %s. Continuing..." %
|
|
(repr(ch.name), str(e)))
|
|
continue
|
|
if drv is None:
|
|
self._perror("Driver not found for channel %s" %
|
|
(repr(ch.name),))
|
|
continue
|
|
try:
|
|
key = self._hasher.hash_object(drv)
|
|
except Exception as e:
|
|
self._perror("Key evaluation failed for channel %s. "
|
|
"Exception: %s" % (repr(name), str(e)))
|
|
continue
|
|
if not self._db.check_key(key):
|
|
if self._notify_all(drv):
|
|
self._db.set_key(key, drv)
|
|
return self._ret_code
|
|
|
|
|
|
def setup_logger(name, verbosity):
|
|
logger = logging.getLogger(name)
|
|
logger.setLevel(verbosity)
|
|
handler = logging.StreamHandler()
|
|
handler.setLevel(verbosity)
|
|
handler.setFormatter(logging.Formatter('%(asctime)s '
|
|
'%(levelname)-8s '
|
|
'%(name)s: %(message)s',
|
|
'%Y-%m-%d %H:%M:%S'))
|
|
logger.addHandler(handler)
|
|
return logger
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
setup_logger(DriverLocator.__name__, logging.ERROR)
|
|
|
|
with open(args.config, 'r') as conf_file:
|
|
conf = json.load(conf_file)
|
|
|
|
ret = DriverLocator(conf).run()
|
|
sys.exit(ret)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|