mirror of
https://github.com/lilydjwg/nvchecker.git
synced 2025-03-10 06:14:02 +00:00
use asyncio.Semaphore instead of self-made queue
This commit is contained in:
parent
f4983eaea3
commit
185a7e88a9
7 changed files with 15 additions and 38 deletions
|
@ -44,11 +44,11 @@ def main() -> None:
|
||||||
if options.proxy is not None:
|
if options.proxy is not None:
|
||||||
ctx_proxy.set(options.proxy)
|
ctx_proxy.set(options.proxy)
|
||||||
|
|
||||||
token_q = core.token_queue(options.max_concurrency)
|
task_sem = asyncio.Semaphore(options.max_concurrency)
|
||||||
result_q: asyncio.Queue[RawResult] = asyncio.Queue()
|
result_q: asyncio.Queue[RawResult] = asyncio.Queue()
|
||||||
try:
|
try:
|
||||||
futures = core.dispatch(
|
futures = core.dispatch(
|
||||||
entries, token_q, result_q,
|
entries, task_sem, result_q,
|
||||||
keymanager, args.tries,
|
keymanager, args.tries,
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
|
|
|
@ -193,17 +193,9 @@ def load_file(
|
||||||
return cast(Entries, config), Options(
|
return cast(Entries, config), Options(
|
||||||
ver_files, max_concurrency, proxy, keymanager)
|
ver_files, max_concurrency, proxy, keymanager)
|
||||||
|
|
||||||
def token_queue(maxsize: int) -> Queue[bool]:
|
|
||||||
token_q: Queue[bool] = Queue(maxsize=maxsize)
|
|
||||||
|
|
||||||
for _ in range(maxsize):
|
|
||||||
token_q.put_nowait(True)
|
|
||||||
|
|
||||||
return token_q
|
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
entries: Entries,
|
entries: Entries,
|
||||||
token_q: Queue[bool],
|
task_sem: asyncio.Semaphore,
|
||||||
result_q: Queue[RawResult],
|
result_q: Queue[RawResult],
|
||||||
keymanager: KeyManager,
|
keymanager: KeyManager,
|
||||||
tries: int,
|
tries: int,
|
||||||
|
@ -232,7 +224,7 @@ def dispatch(
|
||||||
ctx = root_ctx.copy()
|
ctx = root_ctx.copy()
|
||||||
worker = ctx.run(
|
worker = ctx.run(
|
||||||
worker_cls,
|
worker_cls,
|
||||||
token_q, result_q, tasks, keymanager,
|
task_sem, result_q, tasks, keymanager,
|
||||||
)
|
)
|
||||||
if worker_cls is FunctionWorker:
|
if worker_cls is FunctionWorker:
|
||||||
func = mod.get_version # type: ignore
|
func = mod.get_version # type: ignore
|
||||||
|
|
|
@ -5,9 +5,8 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
import contextlib
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict, Optional, List, AsyncGenerator, NamedTuple, Union,
|
Dict, Optional, List, NamedTuple, Union,
|
||||||
Any, Tuple, Callable, Coroutine, Hashable,
|
Any, Tuple, Callable, Coroutine, Hashable,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
)
|
)
|
||||||
|
@ -72,10 +71,10 @@ class Result(NamedTuple):
|
||||||
class BaseWorker:
|
class BaseWorker:
|
||||||
'''The base class for defining `Worker` classes for source plugins.
|
'''The base class for defining `Worker` classes for source plugins.
|
||||||
|
|
||||||
.. py:attribute:: token_q
|
.. py:attribute:: task_sem
|
||||||
:type: Queue[bool]
|
:type: asyncio.Semaphore
|
||||||
|
|
||||||
This is the rate-limiting queue. Workers should obtain one token before doing one unit of work.
|
This is the rate-limiting semaphore. Workers should acquire it while doing one unit of work.
|
||||||
|
|
||||||
.. py:attribute:: result_q
|
.. py:attribute:: result_q
|
||||||
:type: Queue[RawResult]
|
:type: Queue[RawResult]
|
||||||
|
@ -96,28 +95,16 @@ class BaseWorker:
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token_q: Queue[bool],
|
task_sem: asyncio.Semaphore,
|
||||||
result_q: Queue[RawResult],
|
result_q: Queue[RawResult],
|
||||||
tasks: List[Tuple[str, Entry]],
|
tasks: List[Tuple[str, Entry]],
|
||||||
keymanager: KeyManager,
|
keymanager: KeyManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.token_q = token_q
|
self.task_sem = task_sem
|
||||||
self.result_q = result_q
|
self.result_q = result_q
|
||||||
self.keymanager = keymanager
|
self.keymanager = keymanager
|
||||||
self.tasks = tasks
|
self.tasks = tasks
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
|
||||||
async def acquire_token(self) -> AsyncGenerator[None, None]:
|
|
||||||
'''A context manager to obtain a token from the `token_q` on entrance and
|
|
||||||
release it on exit.'''
|
|
||||||
token = await self.token_q.get()
|
|
||||||
logger.debug('got token')
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
await self.token_q.put(token)
|
|
||||||
logger.debug('return token')
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
'''Run the `tasks`. Subclasses should implement this method.'''
|
'''Run the `tasks`. Subclasses should implement this method.'''
|
||||||
|
@ -227,7 +214,7 @@ class FunctionWorker(BaseWorker):
|
||||||
ctx_ua.set(ua)
|
ctx_ua.set(ua)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.acquire_token():
|
async with self.task_sem:
|
||||||
version = await self.func(
|
version = await self.func(
|
||||||
name, entry,
|
name, entry,
|
||||||
cache = self.cache,
|
cache = self.cache,
|
||||||
|
|
|
@ -20,8 +20,6 @@ def _decompress_data(url: str, data: bytes) -> str:
|
||||||
elif url.endswith(".gz"):
|
elif url.endswith(".gz"):
|
||||||
import gzip
|
import gzip
|
||||||
data = gzip.decompress(data)
|
data = gzip.decompress(data)
|
||||||
else:
|
|
||||||
raise NotImplementedError(url)
|
|
||||||
|
|
||||||
return data.decode('utf-8')
|
return data.decode('utf-8')
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ class Worker(BaseWorker):
|
||||||
) -> None:
|
) -> None:
|
||||||
task_by_name: Dict[str, Entry] = dict(self.tasks)
|
task_by_name: Dict[str, Entry] = dict(self.tasks)
|
||||||
|
|
||||||
async with self.acquire_token():
|
async with self.task_sem:
|
||||||
results = await _run_batch_impl(batch, aur_results)
|
results = await _run_batch_impl(batch, aur_results)
|
||||||
for name, version in results.items():
|
for name, version in results.items():
|
||||||
r = RawResult(name, version, task_by_name[name])
|
r = RawResult(name, version, task_by_name[name])
|
||||||
|
|
|
@ -10,7 +10,7 @@ from nvchecker.api import (
|
||||||
class Worker(BaseWorker):
|
class Worker(BaseWorker):
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
exc = GetVersionError('no source specified')
|
exc = GetVersionError('no source specified')
|
||||||
async with self.acquire_token():
|
async with self.task_sem:
|
||||||
for name, conf in self.tasks:
|
for name, conf in self.tasks:
|
||||||
await self.result_q.put(
|
await self.result_q.put(
|
||||||
RawResult(name, exc, conf))
|
RawResult(name, exc, conf))
|
||||||
|
|
|
@ -18,7 +18,7 @@ use_keyfile = False
|
||||||
async def run(
|
async def run(
|
||||||
entries: Entries, max_concurrency: int = 20,
|
entries: Entries, max_concurrency: int = 20,
|
||||||
) -> VersData:
|
) -> VersData:
|
||||||
token_q = core.token_queue(max_concurrency)
|
task_sem = asyncio.Semaphore(max_concurrency)
|
||||||
result_q: asyncio.Queue[RawResult] = asyncio.Queue()
|
result_q: asyncio.Queue[RawResult] = asyncio.Queue()
|
||||||
keyfile = os.environ.get('KEYFILE')
|
keyfile = os.environ.get('KEYFILE')
|
||||||
if use_keyfile and keyfile:
|
if use_keyfile and keyfile:
|
||||||
|
@ -28,7 +28,7 @@ async def run(
|
||||||
keymanager = core.KeyManager(None)
|
keymanager = core.KeyManager(None)
|
||||||
|
|
||||||
futures = core.dispatch(
|
futures = core.dispatch(
|
||||||
entries, token_q, result_q,
|
entries, task_sem, result_q,
|
||||||
keymanager, 1,
|
keymanager, 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue