#!/usr/bin/env python3 import sys import argparse import json import os.path import posixpath from string import Template from functools import partial import urllib.request from constants import OSKind, Product, WinSeries, DATAFILE_PATH, \ DRIVER_URL_TEMPLATE, DRIVER_DIR_PREFIX, BASE_PATH from utils import find_driver, linux_driver_key, windows_driver_key def parse_args(): def check_enum_arg(enum, value): try: return enum[value] except KeyError: raise argparse.ArgumentTypeError("%s is not valid option for %s" % (repr(value), repr(enum.__name__))) parser = argparse.ArgumentParser( description="Adds new Nvidia driver into drivers.json file of " "in your repo working copy", formatter_class=argparse.ArgumentDefaultsHelpFormatter) os_options = parser.add_argument_group("OS options") os_group=os_options.add_mutually_exclusive_group(required=True) os_group.add_argument("-L", "--linux", action="store_const", dest="os", const=OSKind.Linux, help="add Linux driver") os_group.add_argument("-W", "--win", action="store_const", dest="os", const=OSKind.Windows, help="add Windows driver") win_opts = parser.add_argument_group("Windows-specific options") win_opts.add_argument("--variant", default="", help="driver variant (use for special cases like " "\"Studio Driver\")") win_opts.add_argument("-P", "--product", type=partial(check_enum_arg, Product), choices=list(Product), default=Product.GeForce, help="product type") win_opts.add_argument("-w", "--winseries", type=partial(check_enum_arg, WinSeries), choices=list(WinSeries), default=WinSeries.win10, help="Windows series") win_opts.add_argument("--patch32", default="${winseries}_x64/" "${drvprefix}${version}/nvencodeapi.1337", help="template for Windows 32bit patch URL") win_opts.add_argument("--patch64", default="${winseries}_x64/" "${drvprefix}${version}/nvencodeapi64.1337", help="template for Windows 64bit patch URL") win_opts.add_argument("--skip-patch-check", action="store_true", help="skip patch files presense test") parser.add_argument("-U", "--url", help="override driver link") parser.add_argument("--skip-url-check", action="store_true", help="skip driver URL check") parser.add_argument("--no-fbc", dest="fbc", action="store_false", help="add driver w/o NvFBC patch") parser.add_argument("--no-enc", dest="enc", action="store_false", help="add driver w/o NVENC patch") parser.add_argument("version", help="driver version") args = parser.parse_args() return args def posixpath_components(path): result = [] while True: head, tail = posixpath.split(path) if head == path: break result.append(tail) path = head result.reverse() if result and not result[-1]: result.pop() return result def validate_url(url): req = urllib.request.Request(url, method="HEAD") with urllib.request.urlopen(req, timeout=10) as resp: if int(resp.headers['Content-Length']) < 50 * 2**20: raise Exception("Bad driver length: %s" % resp.headers['Content-Length']) def validate_patch(patch64, patch32): wc_base = os.path.abspath(os.path.join(BASE_PATH, "..", "..", "win")) p64_filepath = os.path.join(wc_base, patch64) p32_filepath = os.path.join(wc_base, patch32) if not os.path.exists(p64_filepath): raise Exception("File %s not found!" % p64_filepath) if not os.path.exists(p32_filepath): raise Exception("File %s not found!" % p32_filepath) if os.path.getsize(p64_filepath) == 0: raise Exception("File %s empty!" % p64_filepath) if os.path.exists(p32_filepath) == 0: raise Exception("File %s empty!" % p32_filepath) def validate_unique(drivers, new_driver, kf): if find_driver(drivers, kf(new_driver), kf) is not None: raise Exception("Duplicate driver!") def main(): args = parse_args() if not args.url: if args.os is OSKind.Linux: url_tmpl = DRIVER_URL_TEMPLATE[(args.os, None, None, None)] else: url_tmpl = DRIVER_URL_TEMPLATE[(args.os, args.product, args.winseries, args.variant)] if isinstance(url_tmpl, str): url_tmpl = [url_tmpl] urls = [Template(i).substitute(version=args.version) for i in url_tmpl if i] else: urls = [args.url] url = "" if urls and not args.skip_url_check: last_exc = None for url in urls: try: validate_url(url) break except KeyboardInterrupt: raise except Exception as exc: last_exc = exc else: print("Driver URL validation failed with error: %s" % str(last_exc), file=sys.stderr) print("Driver URL: %s" % ", ".join(urls), file=sys.stderr) print("Please use option -U to override driver link manually", file=sys.stderr) print("or use option --skip-url-check to submit incorrect URL.", file=sys.stderr) return if args.os is OSKind.Windows: driver_dir_prefix = DRIVER_DIR_PREFIX[(args.product, args.variant)] patch64_url = Template(args.patch64).substitute(winseries=args.winseries, drvprefix=driver_dir_prefix, version=args.version) patch32_url = Template(args.patch32).substitute(winseries=args.winseries, drvprefix=driver_dir_prefix, version=args.version) if not args.skip_patch_check: try: validate_patch(patch64_url, patch32_url) except KeyboardInterrupt: raise except Exception as exc: print("Driver patch validation failed with error: %s" % str(exc), file=sys.stderr) print("Use options --patch64 and --patch32 to override patch path ", file=sys.stderr) print("template or use option --skip-patch-check to submit driver with ", file=sys.stderr) print("missing patch files.", file=sys.stderr) return with open(DATAFILE_PATH) as data_file: data = json.load(data_file) drivers = data[args.os.value]['x86_64']['drivers'] if args.os is OSKind.Windows: new_driver = { "os": str(args.winseries), "product": str(args.product), "version": args.version, "variant": args.variant, "patch64_url": patch64_url, "patch32_url": patch32_url, "driver_url": url, } key_fun = windows_driver_key else: new_driver = { "version": args.version, "nvenc_patch": args.enc, "nvfbc_patch": args.fbc, } if url: new_driver["driver_url"] = url key_fun = linux_driver_key drivers = sorted(drivers, key=key_fun) try: validate_unique(drivers, new_driver, key_fun) except KeyboardInterrupt: raise except Exception as exc: print("Driver uniqueness validation failed with error: %s" % str(exc), file=sys.stderr) return data[args.os.value]['x86_64']['drivers'].append(new_driver) with open(DATAFILE_PATH, 'w') as data_file: json.dump(data, data_file, indent=4) data_file.write('\n') if __name__ == '__main__': main()