From 961c1315ef6a93a4b49982013d7d2ad9a6afc924 Mon Sep 17 00:00:00 2001 From: lilydjwg Date: Thu, 13 Aug 2020 20:42:24 +0800 Subject: [PATCH] better and simpler caching mechanism --- nvchecker/core.py | 4 +- nvchecker/util.py | 98 +++++++++---------- .../source => nvchecker_source}/archpkg.py | 15 +-- nvchecker_source/cmd.py | 30 +++--- 4 files changed, 69 insertions(+), 78 deletions(-) rename {nvchecker-old/source => nvchecker_source}/archpkg.py (69%) diff --git a/nvchecker/core.py b/nvchecker/core.py index 5dbf1bf..d7c4baf 100644 --- a/nvchecker/core.py +++ b/nvchecker/core.py @@ -1,4 +1,3 @@ -# vim: se sw=2: # MIT licensed # Copyright (c) 2013-2020 lilydjwg , et al. @@ -206,8 +205,7 @@ def dispatch( ) if worker_cls is FunctionWorker: func = mod.get_version # type: ignore - cacher = getattr(mod, 'cacher', None) - worker.set_func(func, cacher) + worker.initialize(func) ret.append(worker.run()) diff --git a/nvchecker/util.py b/nvchecker/util.py index 9054f39..363dffd 100644 --- a/nvchecker/util.py +++ b/nvchecker/util.py @@ -8,7 +8,8 @@ from asyncio import Queue import contextlib from typing import ( Dict, Optional, List, AsyncGenerator, NamedTuple, Union, - Any, Tuple, Callable, TYPE_CHECKING, + Any, Tuple, Callable, TypeVar, Coroutine, Generic, + TYPE_CHECKING, ) from pathlib import Path @@ -71,6 +72,38 @@ class BaseWorker: await self.token_q.put(token) logger.debug('return token') +T = TypeVar('T') +S = TypeVar('S') + +class AsyncCache(Generic[T, S]): + cache: Dict[T, Union[S, asyncio.Task]] + lock: asyncio.Lock + + def __init__(self) -> None: + self.cache = {} + self.lock = asyncio.Lock() + + async def get( + self, + key: T, + func: Callable[[T], Coroutine[None, None, S]], + ) -> S: + async with self.lock: + cached = self.cache.get(key) + if cached is None: + coro = func(key) + fu = asyncio.create_task(coro) + self.cache[key] = fu + + if asyncio.isfuture(cached): # pending + return await cached # type: ignore + elif cached is not None: # cached + return cached # type: ignore + else: # not cached + r = await fu + self.cache[key] = r + return r + if TYPE_CHECKING: from typing_extensions import Protocol class GetVersionFunc(Protocol): @@ -78,37 +111,22 @@ if TYPE_CHECKING: self, name: str, conf: Entry, *, + cache: AsyncCache, keymanager: KeyManager, ) -> VersionResult: ... else: GetVersionFunc = Any -Cacher = Callable[[str, Entry], str] - class FunctionWorker(BaseWorker): - func = None - cacher = None + func: GetVersionFunc + cache: AsyncCache - cache: Dict[str, Union[ - VersionResult, - asyncio.Task, - ]] - lock: asyncio.Lock - - def set_func( - self, - func: GetVersionFunc, - cacher: Optional[Cacher], - ) -> None: + def initialize(self, func: GetVersionFunc) -> None: self.func = func - self.cacher = cacher - if cacher: - self.cache = {} - self.lock = asyncio.Lock() + self.cache = AsyncCache() async def run(self) -> None: - assert self.func is not None futures = [ self.run_one(name, entry) for name, entry in self.tasks @@ -123,43 +141,15 @@ class FunctionWorker(BaseWorker): try: async with self.acquire_token(): - if self.cacher: - version = await self.run_one_may_cache( - name, entry) - else: - version = await self.func( - name, entry, keymanager = self.keymanager, - ) + version = await self.func( + name, entry, + cache = self.cache, + keymanager = self.keymanager, + ) await self.result_q.put(RawResult(name, version, entry)) except Exception as e: await self.result_q.put(RawResult(name, e, entry)) - async def run_one_may_cache( - self, name: str, entry: Entry, - ) -> VersionResult: - assert self.cacher is not None - assert self.func is not None - - key = self.cacher(name, entry) - - async with self.lock: - cached = self.cache.get(key) - if cached is None: - coro = self.func( - name, entry, keymanager = self.keymanager, - ) - fu = asyncio.create_task(coro) - self.cache[key] = fu - - if asyncio.isfuture(cached): # pending - return await cached # type: ignore - elif cached is not None: # cached - return cached # type: ignore - else: # not cached - version = await fu - self.cache[key] = version - return version - class GetVersionError(Exception): def __init__(self, msg: str, **kwargs: Any) -> None: self.msg = msg diff --git a/nvchecker-old/source/archpkg.py b/nvchecker_source/archpkg.py similarity index 69% rename from nvchecker-old/source/archpkg.py rename to nvchecker_source/archpkg.py index 9540a79..d6bc379 100644 --- a/nvchecker-old/source/archpkg.py +++ b/nvchecker_source/archpkg.py @@ -1,23 +1,24 @@ # MIT licensed -# Copyright (c) 2013-2017 lilydjwg , et al. +# Copyright (c) 2013-2020 lilydjwg , et al. import structlog -from . import session, conf_cacheable_with_name +from nvchecker.httpclient import session # type: ignore logger = structlog.get_logger(logger_name=__name__) URL = 'https://www.archlinux.org/packages/search/json/' -get_cacheable_conf = conf_cacheable_with_name('archpkg') +async def request(pkg): + async with session.get(URL, params={"name": pkg}) as res: + return await res.json() -async def get_version(name, conf, **kwargs): +async def get_version(name, conf, *, cache, **kwargs): pkg = conf.get('archpkg') or name - strip_release = conf.getboolean('strip-release', False) + strip_release = conf.get('strip_release', False) provided = conf.get('provided') - async with session.get(URL, params={"name": pkg}) as res: - data = await res.json() + data = await cache.get(pkg, request) if not data['results']: logger.error('Arch package not found', name=name) diff --git a/nvchecker_source/cmd.py b/nvchecker_source/cmd.py index 40f8531..bf72c22 100644 --- a/nvchecker_source/cmd.py +++ b/nvchecker_source/cmd.py @@ -9,12 +9,8 @@ from nvchecker.util import GetVersionError logger = structlog.get_logger(logger_name=__name__) -def cacher(name, conf): - return conf['cmd'] - -async def get_version(name, conf, *, keymanager=None): - cmd = conf['cmd'] - logger.debug('running cmd', name=name, cmd=cmd) +async def run_cmd(cmd: str) -> str: + logger.debug('running cmd', cmd=cmd) p = await asyncio.create_subprocess_shell( cmd, stdout=asyncio.subprocess.PIPE, @@ -22,17 +18,23 @@ async def get_version(name, conf, *, keymanager=None): ) output, error = await p.communicate() - output = output.strip().decode('latin1') - error = error.strip().decode(errors='replace') + output_s = output.strip().decode('latin1') + error_s = error.strip().decode(errors='replace') if p.returncode != 0: raise GetVersionError( 'command exited with error', - cmd=cmd, error=error, - name=name, returncode=p.returncode) - elif not output: + cmd=cmd, error=error_s, + returncode=p.returncode) + elif not output_s: raise GetVersionError( 'command exited without output', - cmd=cmd, error=error, - name=name, returncode=p.returncode) + cmd=cmd, error=error_s, + returncode=p.returncode) else: - return output + return output_s + +async def get_version( + name, conf, *, cache, keymanager=None +): + cmd = conf['cmd'] + return await cache.get(cmd, run_cmd)