mirror of
https://github.com/keylase/nvidia-patch.git
synced 2025-01-10 06:27:00 +00:00
130 lines
4.1 KiB
Python
130 lines
4.1 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
import urllib.request
|
||
|
import urllib.error
|
||
|
import json
|
||
|
import posixpath
|
||
|
import codecs
|
||
|
import gzip
|
||
|
import sys
|
||
|
from contextlib import contextmanager
|
||
|
import itertools
|
||
|
import string
|
||
|
import codecs
|
||
|
import pprint
|
||
|
import re
|
||
|
import collections
|
||
|
|
||
|
USER_AGENT = 'Debian APT-HTTP/1.3 (1.6.6)'
|
||
|
DEFAULT_REPO = "https://developer.download.nvidia.com/"\
|
||
|
"compute/cuda/repos/ubuntu1804/x86_64/Packages.gz"
|
||
|
DEFAULT_REGEX = '^cuda-drivers$'
|
||
|
ENCODING = 'utf-8-sig'
|
||
|
|
||
|
DriverEntry = collections.namedtuple('DriverEntry', ('name', 'version'))
|
||
|
|
||
|
def upstream_version(version):
|
||
|
epoch, delim, tail = version.partition(':')
|
||
|
version = tail if delim else epoch
|
||
|
return version.partition('-')[0]
|
||
|
|
||
|
@contextmanager
|
||
|
def packages_reader(url, timeout):
|
||
|
http_req = urllib.request.Request(
|
||
|
url,
|
||
|
data=None,
|
||
|
headers={
|
||
|
'User-Agent': USER_AGENT
|
||
|
}
|
||
|
)
|
||
|
with urllib.request.urlopen(http_req, None, timeout) as resp:
|
||
|
if url.endswith('.gz') or resp.headers.get('Content-Type', '') == 'application/x-gzip':
|
||
|
with gzip.GzipFile(fileobj=resp) as reader:
|
||
|
yield codecs.getreader(ENCODING)(reader)
|
||
|
else:
|
||
|
yield codecs.getreader(ENCODING)(resp)
|
||
|
|
||
|
def parse_packages(reader):
|
||
|
for k, g in itertools.groupby(reader, lambda s: bool(s.strip())):
|
||
|
if not k:
|
||
|
continue
|
||
|
pkg = dict()
|
||
|
current_key = None
|
||
|
current_val = None
|
||
|
for line in g:
|
||
|
if line[0] in string.whitespace:
|
||
|
# Continuation
|
||
|
if current_key is None:
|
||
|
continue
|
||
|
current_val += line.rstrip()
|
||
|
else:
|
||
|
# New field
|
||
|
if current_key is not None:
|
||
|
pkg[current_key.lower()] = current_val
|
||
|
current_key, _, current_val = line.partition(':')
|
||
|
current_val = current_val.strip()
|
||
|
if current_key is not None:
|
||
|
pkg[current_key.lower()] = current_val
|
||
|
if 'package' in pkg and 'version' in pkg:
|
||
|
yield pkg
|
||
|
|
||
|
def _get_deb_versions(*, url=DEFAULT_REPO, name=DEFAULT_REGEX, timeout=10):
|
||
|
pattern = re.compile(name)
|
||
|
with packages_reader(url, timeout) as lines:
|
||
|
for pkg in parse_packages(lines):
|
||
|
if pattern.match(pkg['package']) is not None:
|
||
|
yield DriverEntry(pkg['package'], upstream_version(pkg['version']))
|
||
|
|
||
|
def get_deb_versions(*args, **kwargs):
|
||
|
return list(set(_get_deb_versions(*args, **kwargs)))
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
import argparse
|
||
|
|
||
|
def check_positive_float(val):
|
||
|
val = float(val)
|
||
|
if val <= 0:
|
||
|
raise ValueError("Value %s is not valid positive float" %
|
||
|
(repr(val),))
|
||
|
return val
|
||
|
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Retrieves info about latest NVIDIA drivers available in "
|
||
|
"Nvidia deb packages repositories",
|
||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||
|
parser.add_argument("-U", "--url",
|
||
|
default=DEFAULT_REPO,
|
||
|
help="URL for Packages or Packages.gz file")
|
||
|
parser.add_argument("-N", "--name",
|
||
|
default=DEFAULT_REGEX,
|
||
|
help="Package name regexp")
|
||
|
parser.add_argument("-T", "--timeout",
|
||
|
type=check_positive_float,
|
||
|
default=10.,
|
||
|
help="timeout for network operations")
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse_args()
|
||
|
drv = get_deb_versions(url=args.url,
|
||
|
name=args.name,
|
||
|
timeout=args.timeout)
|
||
|
if drv is None:
|
||
|
print("NOT FOUND")
|
||
|
sys.exit(3)
|
||
|
if False: #not args.raw:
|
||
|
print("Version: %s" % (drv['DriverAttributes']['Version'],))
|
||
|
print("Beta: %s" % (bool(int(drv['DriverAttributes']['IsBeta'])),))
|
||
|
print("WHQL: %s" % (bool(int(drv['DriverAttributes']['IsWHQL'])),))
|
||
|
print("URL: %s" % (drv['DriverAttributes']['DownloadURLAdmin'],))
|
||
|
else:
|
||
|
json.dump(drv, sys.stdout, indent=4)
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|