#!/usr/bin/env python3 import sys import argparse import json import os.path import posixpath from string import Template from itertools import groupby from functools import partial import urllib.request import urllib.parse from constants import OSKind, Product, WinSeries, DATAFILE_PATH, \ DRIVER_URL_TEMPLATE, DRIVER_DIR_PREFIX, BASE_PATH, REPO_BASE from utils import find_driver, linux_driver_key, windows_driver_key def parse_args(): def check_enum_arg(enum, value): try: return enum[value] except KeyError: raise argparse.ArgumentTypeError("%s is not valid option for %s" % (repr(value), repr(enum.__name__))) parser = argparse.ArgumentParser( description="Adds new Nvidia driver into drivers.json file of " "in your repo working copy", formatter_class=argparse.ArgumentDefaultsHelpFormatter) os_options = parser.add_argument_group("OS options") os_group=os_options.add_mutually_exclusive_group(required=True) os_group.add_argument("-L", "--linux", action="store_const", dest="os", const=OSKind.Linux, help="add Linux driver") os_group.add_argument("-W", "--win", action="store_const", dest="os", const=OSKind.Windows, help="add Windows driver") win_opts = parser.add_argument_group("Windows-specific options") win_opts.add_argument("--variant", default="", help="driver variant (use for special cases like " "\"Studio Driver\")") win_opts.add_argument("-P", "--product", type=partial(check_enum_arg, Product), choices=list(Product), default=Product.GeForce, help="product type") win_opts.add_argument("-w", "--winseries", type=partial(check_enum_arg, WinSeries), choices=list(WinSeries), default=WinSeries.win10, help="Windows series") win_opts.add_argument("--patch32", default="https://raw.githubusercontent.com/keylase/" "nvidia-patch/master/win/${winseries}_x64/" "${drvprefix}${version}/nvcuvid32.1337", help="template for Windows 32bit patch URL") win_opts.add_argument("--patch64", default="https://raw.githubusercontent.com/keylase/" "nvidia-patch/master/win/${winseries}_x64/" "${drvprefix}${version}/nvcuvid64.1337", help="template for Windows 64bit patch URL") win_opts.add_argument("--skip-patch-check", action="store_true", help="skip patch files presense test") parser.add_argument("-U", "--url", help="override driver link") parser.add_argument("--skip-url-check", action="store_true", help="skip driver URL check") parser.add_argument("version", help="driver version") args = parser.parse_args() return args def posixpath_components(path): result = [] while True: head, tail = posixpath.split(path) if head == path: break result.append(tail) path = head result.reverse() if result and not result[-1]: result.pop() return result def validate_url(url): req = urllib.request.Request(url, method="HEAD") with urllib.request.urlopen(req, timeout=10) as resp: if int(resp.headers['Content-Length']) < 50 * 2**20: raise Exception("Bad driver length: %s" % resp.headers['Content-Length']) def validate_patch(patch64, patch32): wc_base = os.path.abspath(os.path.join(BASE_PATH, "..", "..")) base_parse = urllib.parse.urlsplit(REPO_BASE, scheme='http') p64_parse = urllib.parse.urlsplit(patch64, scheme='http') p32_parse = urllib.parse.urlsplit(patch32, scheme='http') if not (p64_parse[0] == p32_parse[0] == base_parse[0]): raise Exception("URL scheme doesn't match repo base URL scheme") if not (p64_parse[1] == p32_parse[1] == base_parse[1]): raise Exception("URL network location doesn't match repo base URL network location") if posixpath.commonpath((base_parse[2], p64_parse[2], p32_parse[2])) != \ posixpath.commonpath((base_parse[2],)): raise Exception("URL is not subpath of repo base path") p64_posix_relpath = posixpath.relpath(p64_parse[2], base_parse[2]) p32_posix_relpath = posixpath.relpath(p32_parse[2], base_parse[2]) p64_comp = posixpath_components(p64_posix_relpath) p32_comp = posixpath_components(p32_posix_relpath) p64_filepath = os.path.join(wc_base, *p64_comp) p32_filepath = os.path.join(wc_base, *p32_comp) if not os.path.exists(p64_filepath): raise Exception("File %s not found!" % p64_filepath) if not os.path.exists(p32_filepath): raise Exception("File %s not found!" % p32_filepath) if os.path.getsize(p64_filepath) == 0: raise Exception("File %s empty!" % p64_filepath) if os.path.exists(p32_filepath) == 0: raise Exception("File %s empty!" % p32_filepath) def validate_unique(drivers, new_driver, kf): if find_driver(drivers, kf(new_driver), kf) is not None: raise Exception("Duplicate driver!") def main(): args = parse_args() if args.url is None: if args.os is OSKind.Linux: url_tmpl = Template(DRIVER_URL_TEMPLATE[(args.os, None, None, None)]) else: url_tmpl = Template(DRIVER_URL_TEMPLATE[(args.os, args.product, args.winseries, args.variant)]) url = url_tmpl.substitute(version=args.version) else: url = args.url if not args.skip_url_check: try: validate_url(url) except KeyboardInterrupt: raise except Exception as exc: print("Driver URL validation failed with error: %s" % str(exc), file=sys.stderr) print("Please use option -U to override driver link manually", file=sys.stderr) print("or use option --skip-url-check to submit incorrect URL.", file=sys.stderr) return if args.os is OSKind.Windows: driver_dir_prefix = DRIVER_DIR_PREFIX[(args.product, args.variant)] patch64_url = Template(args.patch64).substitute(winseries=args.winseries, drvprefix=driver_dir_prefix, version=args.version) patch32_url = Template(args.patch32).substitute(winseries=args.winseries, drvprefix=driver_dir_prefix, version=args.version) if not args.skip_patch_check: try: validate_patch(patch64_url, patch32_url) except KeyboardInterrupt: raise except Exception as exc: print("Driver patch validation failed with error: %s" % str(exc), file=sys.stderr) print("Use options --patch64 and --patch32 to override patch path ", file=sys.stderr) print("template or use option --skip-patch-check to submit driver with ", file=sys.stderr) print("missing patch files.", file=sys.stderr) return with open(DATAFILE_PATH) as data_file: data = json.load(data_file) drivers = data[args.os.value]['x86_64']['drivers'] if args.os is OSKind.Windows: new_driver = { "os": str(args.winseries), "product": str(args.product), "version": args.version, "variant": args.variant, "patch64_url": patch64_url, "patch32_url": patch32_url, "driver_url": url, } key_fun = windows_driver_key else: new_driver = { "version": args.version, "driver_url": url, } key_fun = linux_driver_key drivers = sorted(drivers, key=key_fun) try: validate_unique(drivers, new_driver, key_fun) except KeyboardInterrupt: raise except Exception as exc: print("Driver uniqueness validation failed with error: %s" % str(exc), file=sys.stderr) return data[args.os.value]['x86_64']['drivers'].append(new_driver) with open(DATAFILE_PATH, 'w') as data_file: json.dump(data, data_file, indent=4) if __name__ == '__main__': main()