nvchecker/nvchecker/core.py
2024-12-17 06:14:14 -08:00

461 lines
13 KiB
Python

# MIT licensed
# Copyright (c) 2013-2020, 2024 lilydjwg <lilydjwg@gmail.com>, et al.
from __future__ import annotations
import os
import sys
import asyncio
from asyncio import Queue
import logging
import argparse
from typing import (
Tuple, NamedTuple, Optional, List, Union,
cast, Dict, Awaitable, Sequence, Any,
TYPE_CHECKING,
)
import types
from pathlib import Path
from importlib import import_module
import re
import contextvars
import json
import dataclasses
import structlog
if TYPE_CHECKING:
import tomli as tomllib
else:
try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib
import platformdirs
from .lib import nicelogger
from . import slogconf
from .util import (
Entry, Entries, KeyManager, RawResult, RichResult, ResultData,
FunctionWorker, GetVersionError,
FileLoadError, EntryWaiter,
)
from . import __version__
from .sortversion import sort_version_keys
from .ctxvars import tries as ctx_tries
from .ctxvars import entry_waiter as ctx_entry_waiter
from . import httpclient
logger = structlog.get_logger(logger_name=__name__)
def get_default_config() -> str:
confdir = platformdirs.user_config_dir(appname='nvchecker')
file = os.path.join(confdir, 'nvchecker.toml')
return file
def add_common_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument('-l', '--logging',
choices=('debug', 'info', 'warning', 'error'), default='info',
help='logging level (default: info)')
parser.add_argument('--logger', default='pretty',
choices=['pretty', 'json', 'both'],
help='select which logger to use')
parser.add_argument('--json-log-fd', metavar='FD',
type=lambda fd: os.fdopen(int(fd), mode='w'),
help='specify fd to send json logs to. stdout by default')
parser.add_argument('-V', '--version', action='store_true',
help='show version and exit')
default_config = get_default_config()
parser.add_argument('-c', '--file',
metavar='FILE', type=str,
default=default_config,
help=f'software version configuration file [default: {default_config}]')
def process_common_arguments(args: argparse.Namespace) -> bool:
'''return True if should stop'''
processors = [
slogconf.exc_info,
slogconf.filter_exc,
slogconf.filter_nones,
slogconf.filter_taskname,
]
logger_factory = None
if args.logger in ['pretty', 'both']:
slogconf.fix_logging()
nicelogger.enable_pretty_logging(
getattr(logging, args.logging.upper()))
processors.append(slogconf.stdlib_renderer)
if args.logger == 'pretty':
logger_factory=structlog.PrintLoggerFactory(
file=open(os.devnull, 'w'),
)
processors.append(slogconf.null_renderer)
if args.logger in ['json', 'both']:
processors.extend([
structlog.processors.format_exc_info,
slogconf.json_renderer,
])
if logger_factory is None:
logfile = args.json_log_fd or sys.stdout
logger_factory = structlog.PrintLoggerFactory(file=logfile)
structlog.configure(
processors = processors,
logger_factory = logger_factory,
)
if args.version:
progname = os.path.basename(sys.argv[0])
print(f'{progname} v{__version__}')
return True
return False
def safe_overwrite(file: Path, data: Union[bytes, str], *,
method: str = 'write', mode: str = 'w', encoding: Optional[str] = None) -> None:
# FIXME: directory has no read perm
# FIXME: hard links
resolved_path = file.resolve()
tmpname = str(resolved_path) + '.tmp'
# if not using "with", write can fail without exception
with open(tmpname, mode, encoding=encoding) as f:
getattr(f, method)(data)
# see also: https://thunk.org/tytso/blog/2009/03/15/dont-fear-the-fsync/
f.flush()
os.fsync(f.fileno())
# if the above write failed (because disk is full etc), the old data should be kept
os.rename(tmpname, resolved_path)
def read_verfile(file: Path) -> ResultData:
try:
with open(file) as f:
data = f.read()
except FileNotFoundError:
return {}
try:
v = json.loads(data)
except json.decoder.JSONDecodeError:
# old format
v = {}
for l in data.splitlines():
name, ver = l.rstrip().split(None, 1)
v[name] = ver
if v.get('version') is None:
v = {k: RichResult(version=a) for k, a in v.items()}
elif v['version'] == 2:
v = {k: RichResult(**a) for k, a in v['data'].items()}
else:
raise Exception('unknown verfile version', v['version'])
return v
def write_verfile(file: Path, versions: ResultData) -> None:
d = {
'version': 2,
# sort and indent to make it friendly to human and git
'data': dict(sorted(versions.items())),
}
data = json.dumps(
d,
indent = 2,
ensure_ascii = False,
default = json_encode,
) + '\n'
safe_overwrite(file, data)
def json_encode(obj):
if isinstance(obj, RichResult):
d = {k: v for k, v in dataclasses.asdict(obj).items() if v is not None}
return d
raise TypeError(obj)
class Options(NamedTuple):
ver_files: Optional[Tuple[Path, Path]]
max_concurrency: int
proxy: Optional[str]
keymanager: KeyManager
source_configs: Dict[str, Dict[str, Any]]
httplib: Optional[str]
http_timeout: int
def load_file(
file: str, *,
use_keymanager: bool,
) -> Tuple[Entries, Options]:
try:
with open(file, 'rb') as f:
config = tomllib.load(f)
except (OSError, tomllib.TOMLDecodeError) as e:
raise FileLoadError('version configuration file', file, e)
ver_files: Optional[Tuple[Path, Path]] = None
keymanager = KeyManager(None)
source_configs = {}
if '__config__' in config:
c = config.pop('__config__')
d = Path(file).parent
if 'oldver' in c and 'newver' in c:
oldver_s = os.path.expandvars(
os.path.expanduser(c.get('oldver')))
oldver = d / oldver_s
newver_s = os.path.expandvars(
os.path.expanduser(c.get('newver')))
newver = d / newver_s
ver_files = oldver, newver
if use_keymanager:
keyfile = c.get('keyfile')
if keyfile:
keyfile_s = os.path.expandvars(
os.path.expanduser(c.get('keyfile')))
keyfile = d / keyfile_s
keymanager = KeyManager(keyfile)
if 'source' in c:
source_configs = c['source']
max_concurrency = c.get('max_concurrency', 20)
proxy = c.get('proxy')
httplib = c.get('httplib', None)
http_timeout = c.get('http_timeout', 20)
else:
max_concurrency = 20
proxy = None
httplib = None
http_timeout = 20
return cast(Entries, config), Options(
ver_files, max_concurrency, proxy, keymanager,
source_configs, httplib, http_timeout,
)
def setup_httpclient(
max_concurrency: int = 20,
httplib: Optional[str] = None,
http_timeout: int = 20,
) -> Dispatcher:
httplib_ = httplib or httpclient.find_best_httplib()
httpclient.setup(
httplib_, max_concurrency, http_timeout)
return Dispatcher()
class Dispatcher:
def dispatch(
self,
entries: Entries,
task_sem: asyncio.Semaphore,
result_q: Queue[RawResult],
keymanager: KeyManager,
entry_waiter: EntryWaiter,
tries: int,
source_configs: Dict[str, Dict[str, Any]],
) -> List[asyncio.Future]:
mods: Dict[str, Tuple[types.ModuleType, List]] = {}
ctx_tries.set(tries)
ctx_entry_waiter.set(entry_waiter)
root_ctx = contextvars.copy_context()
for name, entry in entries.items():
source = entry.get('source', 'none')
if source not in mods:
mod = import_module('nvchecker_source.' + source)
tasks: List[Tuple[str, Entry]] = []
mods[source] = mod, tasks
config = source_configs.get(source)
if config and getattr(mod, 'configure'):
mod.configure(config)
else:
tasks = mods[source][1]
tasks.append((name, entry))
ret = []
for mod, tasks in mods.values():
if hasattr(mod, 'Worker'):
worker_cls = mod.Worker
else:
worker_cls = FunctionWorker
ctx = root_ctx.copy()
worker = ctx.run(
worker_cls,
task_sem, result_q, tasks, keymanager,
)
if worker_cls is FunctionWorker:
func = mod.get_version
ctx.run(worker.initialize, func)
ret.append(ctx.run(worker._run_maynot_raise))
return ret
def substitute_version(
version: str, conf: Entry,
) -> str:
'''
Substitute the version string via defined rules in the configuration file.
See usage.rst#global-options for details.
'''
prefix = conf.get('prefix')
if prefix:
if version.startswith(prefix):
version = version[len(prefix):]
from_pattern = conf.get('from_pattern')
if from_pattern:
to_pattern = conf.get('to_pattern')
if to_pattern is None:
raise ValueError("from_pattern exists but to_pattern doesn't")
version = re.sub(from_pattern, to_pattern, version)
return version
def apply_list_options(
versions: List[Union[str, RichResult]],
conf: Entry,
name: str,
) -> Optional[Union[str, RichResult]]:
pattern = conf.get('include_regex')
if versions and pattern:
re_pat = re.compile(pattern)
versions2 = [x for x in versions
if re_pat.fullmatch(str(x))]
if not versions2:
logger.warning('include_regex matched no versions',
name=name, versions=versions, regex=pattern)
return None
versions = versions2
pattern = conf.get('exclude_regex')
if pattern:
re_pat = re.compile(pattern)
versions = [x for x in versions
if not re_pat.fullmatch(str(x))]
ignored = set(conf.get('ignored', '').split())
if ignored:
versions = [x for x in versions
if str(x) not in ignored]
if not versions:
return None
sort_version_key = sort_version_keys[
conf.get("sort_version_key", "parse_version")]
versions.sort(key=lambda version: sort_version_key(str(version))) # type: ignore
return versions[-1]
def _process_result(r: RawResult) -> Union[RichResult, Exception]:
version = r.version
conf = r.conf
name = r.name
url = None
revision = None
gitref = None
if isinstance(version, GetVersionError):
kw = version.kwargs
kw['name'] = name
logger.error(version.msg, **kw)
return version
elif isinstance(version, Exception):
logger.error('unexpected error happened',
name=r.name, exc_info=r.version)
return version
elif isinstance(version, list):
version_str = apply_list_options(version, conf, name)
if isinstance(version_str, RichResult):
url = version_str.url
gitref = version_str.gitref
revision = version_str.revision
version_str = version_str.version
elif isinstance(version, RichResult):
version_str = version.version
url = version.url
gitref = version.gitref
revision = version.revision
else:
version_str = version
if version_str:
version_str = version_str.replace('\n', ' ')
try:
version_str = substitute_version(version_str, conf)
return RichResult(
version = version_str,
url = url,
gitref = gitref,
revision = revision,
)
except (ValueError, re.error) as e:
logger.exception('error occurred in version substitutions', name=name)
return e
else:
return ValueError('no version returned')
def check_version_update(
oldvers: ResultData,
name: str,
r: RichResult,
verbose: bool,
) -> None:
if old_result := oldvers.get(name):
oldver = old_result.version
else:
oldver = None
if not oldver or oldver != r.version:
logger.info(
'updated',
name = name,
version = r.version,
revision = r.revision,
old_version = oldver,
url = r.url,
)
else:
# provide visible user feedback if it was the only entry
level = logging.INFO if verbose else logging.DEBUG
logger.log(level, 'up-to-date', name=name, version=r.version, url=r.url)
async def process_result(
oldvers: ResultData,
result_q: Queue[RawResult],
entry_waiter: EntryWaiter,
verbose: bool = False,
) -> Tuple[ResultData, bool]:
ret = {}
has_failures = False
try:
while True:
r = await result_q.get()
try:
r1 = _process_result(r)
except Exception as e:
logger.exception('error processing result', result=r)
r1 = e
if isinstance(r1, Exception):
entry_waiter.set_exception(r.name, r1)
# no versions are returned from "apply_list_options"?
logger.error('no-result', name=r.name, error=repr(r1))
has_failures = True
continue
check_version_update(oldvers, r.name, r1, verbose)
entry_waiter.set_result(r.name, r1.version)
ret[r.name] = r1
except asyncio.CancelledError:
return ret, has_failures
async def run_tasks(
futures: Sequence[Awaitable[None]]
) -> None:
for fu in asyncio.as_completed(futures):
await fu