better and simpler caching mechanism

This commit is contained in:
lilydjwg 2020-08-13 20:42:24 +08:00
parent b76bfb5606
commit 961c1315ef
4 changed files with 69 additions and 78 deletions

View file

@ -1,4 +1,3 @@
# vim: se sw=2:
# MIT licensed
# Copyright (c) 2013-2020 lilydjwg <lilydjwg@gmail.com>, et al.
@ -206,8 +205,7 @@ def dispatch(
)
if worker_cls is FunctionWorker:
func = mod.get_version # type: ignore
cacher = getattr(mod, 'cacher', None)
worker.set_func(func, cacher)
worker.initialize(func)
ret.append(worker.run())

View file

@ -8,7 +8,8 @@ from asyncio import Queue
import contextlib
from typing import (
Dict, Optional, List, AsyncGenerator, NamedTuple, Union,
Any, Tuple, Callable, TYPE_CHECKING,
Any, Tuple, Callable, TypeVar, Coroutine, Generic,
TYPE_CHECKING,
)
from pathlib import Path
@ -71,6 +72,38 @@ class BaseWorker:
await self.token_q.put(token)
logger.debug('return token')
T = TypeVar('T')
S = TypeVar('S')
class AsyncCache(Generic[T, S]):
cache: Dict[T, Union[S, asyncio.Task]]
lock: asyncio.Lock
def __init__(self) -> None:
self.cache = {}
self.lock = asyncio.Lock()
async def get(
self,
key: T,
func: Callable[[T], Coroutine[None, None, S]],
) -> S:
async with self.lock:
cached = self.cache.get(key)
if cached is None:
coro = func(key)
fu = asyncio.create_task(coro)
self.cache[key] = fu
if asyncio.isfuture(cached): # pending
return await cached # type: ignore
elif cached is not None: # cached
return cached # type: ignore
else: # not cached
r = await fu
self.cache[key] = r
return r
if TYPE_CHECKING:
from typing_extensions import Protocol
class GetVersionFunc(Protocol):
@ -78,37 +111,22 @@ if TYPE_CHECKING:
self,
name: str, conf: Entry,
*,
cache: AsyncCache,
keymanager: KeyManager,
) -> VersionResult:
...
else:
GetVersionFunc = Any
Cacher = Callable[[str, Entry], str]
class FunctionWorker(BaseWorker):
func = None
cacher = None
func: GetVersionFunc
cache: AsyncCache
cache: Dict[str, Union[
VersionResult,
asyncio.Task,
]]
lock: asyncio.Lock
def set_func(
self,
func: GetVersionFunc,
cacher: Optional[Cacher],
) -> None:
def initialize(self, func: GetVersionFunc) -> None:
self.func = func
self.cacher = cacher
if cacher:
self.cache = {}
self.lock = asyncio.Lock()
self.cache = AsyncCache()
async def run(self) -> None:
assert self.func is not None
futures = [
self.run_one(name, entry)
for name, entry in self.tasks
@ -123,43 +141,15 @@ class FunctionWorker(BaseWorker):
try:
async with self.acquire_token():
if self.cacher:
version = await self.run_one_may_cache(
name, entry)
else:
version = await self.func(
name, entry, keymanager = self.keymanager,
)
version = await self.func(
name, entry,
cache = self.cache,
keymanager = self.keymanager,
)
await self.result_q.put(RawResult(name, version, entry))
except Exception as e:
await self.result_q.put(RawResult(name, e, entry))
async def run_one_may_cache(
self, name: str, entry: Entry,
) -> VersionResult:
assert self.cacher is not None
assert self.func is not None
key = self.cacher(name, entry)
async with self.lock:
cached = self.cache.get(key)
if cached is None:
coro = self.func(
name, entry, keymanager = self.keymanager,
)
fu = asyncio.create_task(coro)
self.cache[key] = fu
if asyncio.isfuture(cached): # pending
return await cached # type: ignore
elif cached is not None: # cached
return cached # type: ignore
else: # not cached
version = await fu
self.cache[key] = version
return version
class GetVersionError(Exception):
def __init__(self, msg: str, **kwargs: Any) -> None:
self.msg = msg

View file

@ -1,23 +1,24 @@
# MIT licensed
# Copyright (c) 2013-2017 lilydjwg <lilydjwg@gmail.com>, et al.
# Copyright (c) 2013-2020 lilydjwg <lilydjwg@gmail.com>, et al.
import structlog
from . import session, conf_cacheable_with_name
from nvchecker.httpclient import session # type: ignore
logger = structlog.get_logger(logger_name=__name__)
URL = 'https://www.archlinux.org/packages/search/json/'
get_cacheable_conf = conf_cacheable_with_name('archpkg')
async def request(pkg):
async with session.get(URL, params={"name": pkg}) as res:
return await res.json()
async def get_version(name, conf, **kwargs):
async def get_version(name, conf, *, cache, **kwargs):
pkg = conf.get('archpkg') or name
strip_release = conf.getboolean('strip-release', False)
strip_release = conf.get('strip_release', False)
provided = conf.get('provided')
async with session.get(URL, params={"name": pkg}) as res:
data = await res.json()
data = await cache.get(pkg, request)
if not data['results']:
logger.error('Arch package not found', name=name)

View file

@ -9,12 +9,8 @@ from nvchecker.util import GetVersionError
logger = structlog.get_logger(logger_name=__name__)
def cacher(name, conf):
return conf['cmd']
async def get_version(name, conf, *, keymanager=None):
cmd = conf['cmd']
logger.debug('running cmd', name=name, cmd=cmd)
async def run_cmd(cmd: str) -> str:
logger.debug('running cmd', cmd=cmd)
p = await asyncio.create_subprocess_shell(
cmd,
stdout=asyncio.subprocess.PIPE,
@ -22,17 +18,23 @@ async def get_version(name, conf, *, keymanager=None):
)
output, error = await p.communicate()
output = output.strip().decode('latin1')
error = error.strip().decode(errors='replace')
output_s = output.strip().decode('latin1')
error_s = error.strip().decode(errors='replace')
if p.returncode != 0:
raise GetVersionError(
'command exited with error',
cmd=cmd, error=error,
name=name, returncode=p.returncode)
elif not output:
cmd=cmd, error=error_s,
returncode=p.returncode)
elif not output_s:
raise GetVersionError(
'command exited without output',
cmd=cmd, error=error,
name=name, returncode=p.returncode)
cmd=cmd, error=error_s,
returncode=p.returncode)
else:
return output
return output_s
async def get_version(
name, conf, *, cache, keymanager=None
):
cmd = conf['cmd']
return await cache.get(cmd, run_cmd)