nvidia-patch/tools/readme-autogen/add_driver.py
2021-01-26 12:22:43 -08:00

202 lines
8.2 KiB
Python
Executable File

#!/usr/bin/env python3
import sys
import argparse
import json
import os.path
import posixpath
from string import Template
from itertools import groupby
from functools import partial
import urllib.request
from constants import OSKind, Product, WinSeries, DATAFILE_PATH, \
DRIVER_URL_TEMPLATE, DRIVER_DIR_PREFIX, BASE_PATH, REPO_BASE
from utils import find_driver, linux_driver_key, windows_driver_key
def parse_args():
def check_enum_arg(enum, value):
try:
return enum[value]
except KeyError:
raise argparse.ArgumentTypeError("%s is not valid option for %s" % (repr(value), repr(enum.__name__)))
parser = argparse.ArgumentParser(
description="Adds new Nvidia driver into drivers.json file of "
"in your repo working copy",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
os_options = parser.add_argument_group("OS options")
os_group=os_options.add_mutually_exclusive_group(required=True)
os_group.add_argument("-L", "--linux",
action="store_const",
dest="os",
const=OSKind.Linux,
help="add Linux driver")
os_group.add_argument("-W", "--win",
action="store_const",
dest="os",
const=OSKind.Windows,
help="add Windows driver")
win_opts = parser.add_argument_group("Windows-specific options")
win_opts.add_argument("--variant",
default="",
help="driver variant (use for special cases like "
"\"Studio Driver\")")
win_opts.add_argument("-P", "--product",
type=partial(check_enum_arg, Product),
choices=list(Product),
default=Product.GeForce,
help="product type")
win_opts.add_argument("-w", "--winseries",
type=partial(check_enum_arg, WinSeries),
choices=list(WinSeries),
default=WinSeries.win10,
help="Windows series")
win_opts.add_argument("--patch32",
default="${winseries}_x64/"
"${drvprefix}${version}/nvencodeapi.1337",
help="template for Windows 32bit patch URL")
win_opts.add_argument("--patch64",
default="${winseries}_x64/"
"${drvprefix}${version}/nvencodeapi64.1337",
help="template for Windows 64bit patch URL")
win_opts.add_argument("--skip-patch-check",
action="store_true",
help="skip patch files presense test")
parser.add_argument("-U", "--url",
help="override driver link")
parser.add_argument("--skip-url-check",
action="store_true",
help="skip driver URL check")
parser.add_argument("--no-fbc",
dest="fbc",
action="store_false",
help="add driver w/o NvFBC patch")
parser.add_argument("--no-enc",
dest="enc",
action="store_false",
help="add driver w/o NVENC patch")
parser.add_argument("version",
help="driver version")
args = parser.parse_args()
return args
def posixpath_components(path):
result = []
while True:
head, tail = posixpath.split(path)
if head == path:
break
result.append(tail)
path = head
result.reverse()
if result and not result[-1]:
result.pop()
return result
def validate_url(url):
req = urllib.request.Request(url, method="HEAD")
with urllib.request.urlopen(req, timeout=10) as resp:
if int(resp.headers['Content-Length']) < 50 * 2**20:
raise Exception("Bad driver length: %s" % resp.headers['Content-Length'])
def validate_patch(patch64, patch32):
wc_base = os.path.abspath(os.path.join(BASE_PATH, "..", "..", "win"))
p64_filepath = os.path.join(wc_base, patch64)
p32_filepath = os.path.join(wc_base, patch32)
if not os.path.exists(p64_filepath):
raise Exception("File %s not found!" % p64_filepath)
if not os.path.exists(p32_filepath):
raise Exception("File %s not found!" % p32_filepath)
if os.path.getsize(p64_filepath) == 0:
raise Exception("File %s empty!" % p64_filepath)
if os.path.exists(p32_filepath) == 0:
raise Exception("File %s empty!" % p32_filepath)
def validate_unique(drivers, new_driver, kf):
if find_driver(drivers, kf(new_driver), kf) is not None:
raise Exception("Duplicate driver!")
def main():
args = parse_args()
if args.url is None:
if args.os is OSKind.Linux:
url_tmpl = Template(DRIVER_URL_TEMPLATE[(args.os, None, None, None)])
else:
url_tmpl = Template(DRIVER_URL_TEMPLATE[(args.os,
args.product,
args.winseries,
args.variant)])
url = url_tmpl.substitute(version=args.version)
else:
url = args.url
if url and not args.skip_url_check:
try:
validate_url(url)
except KeyboardInterrupt:
raise
except Exception as exc:
print("Driver URL validation failed with error: %s" % str(exc), file=sys.stderr)
print("Please use option -U to override driver link manually", file=sys.stderr)
print("or use option --skip-url-check to submit incorrect URL.", file=sys.stderr)
return
if args.os is OSKind.Windows:
driver_dir_prefix = DRIVER_DIR_PREFIX[(args.product, args.variant)]
patch64_url = Template(args.patch64).substitute(winseries=args.winseries,
drvprefix=driver_dir_prefix,
version=args.version)
patch32_url = Template(args.patch32).substitute(winseries=args.winseries,
drvprefix=driver_dir_prefix,
version=args.version)
if not args.skip_patch_check:
try:
validate_patch(patch64_url, patch32_url)
except KeyboardInterrupt:
raise
except Exception as exc:
print("Driver patch validation failed with error: %s" % str(exc), file=sys.stderr)
print("Use options --patch64 and --patch32 to override patch path ", file=sys.stderr)
print("template or use option --skip-patch-check to submit driver with ", file=sys.stderr)
print("missing patch files.", file=sys.stderr)
return
with open(DATAFILE_PATH) as data_file:
data = json.load(data_file)
drivers = data[args.os.value]['x86_64']['drivers']
if args.os is OSKind.Windows:
new_driver = {
"os": str(args.winseries),
"product": str(args.product),
"version": args.version,
"variant": args.variant,
"patch64_url": patch64_url,
"patch32_url": patch32_url,
"driver_url": url,
}
key_fun = windows_driver_key
else:
new_driver = {
"version": args.version,
"nvenc_patch": args.enc,
"nvfbc_patch": args.fbc,
}
if url:
new_driver["driver_url"] = url
key_fun = linux_driver_key
drivers = sorted(drivers, key=key_fun)
try:
validate_unique(drivers, new_driver, key_fun)
except KeyboardInterrupt:
raise
except Exception as exc:
print("Driver uniqueness validation failed with error: %s" % str(exc), file=sys.stderr)
return
data[args.os.value]['x86_64']['drivers'].append(new_driver)
with open(DATAFILE_PATH, 'w') as data_file:
json.dump(data, data_file, indent=4)
data_file.write('\n')
if __name__ == '__main__':
main()