use asyncio.Semaphore instead of self-made queue

This commit is contained in:
lilydjwg 2020-09-20 22:15:54 +08:00
parent f4983eaea3
commit 185a7e88a9
7 changed files with 15 additions and 38 deletions

View file

@ -44,11 +44,11 @@ def main() -> None:
if options.proxy is not None: if options.proxy is not None:
ctx_proxy.set(options.proxy) ctx_proxy.set(options.proxy)
token_q = core.token_queue(options.max_concurrency) task_sem = asyncio.Semaphore(options.max_concurrency)
result_q: asyncio.Queue[RawResult] = asyncio.Queue() result_q: asyncio.Queue[RawResult] = asyncio.Queue()
try: try:
futures = core.dispatch( futures = core.dispatch(
entries, token_q, result_q, entries, task_sem, result_q,
keymanager, args.tries, keymanager, args.tries,
) )
except ModuleNotFoundError as e: except ModuleNotFoundError as e:

View file

@ -193,17 +193,9 @@ def load_file(
return cast(Entries, config), Options( return cast(Entries, config), Options(
ver_files, max_concurrency, proxy, keymanager) ver_files, max_concurrency, proxy, keymanager)
def token_queue(maxsize: int) -> Queue[bool]:
token_q: Queue[bool] = Queue(maxsize=maxsize)
for _ in range(maxsize):
token_q.put_nowait(True)
return token_q
def dispatch( def dispatch(
entries: Entries, entries: Entries,
token_q: Queue[bool], task_sem: asyncio.Semaphore,
result_q: Queue[RawResult], result_q: Queue[RawResult],
keymanager: KeyManager, keymanager: KeyManager,
tries: int, tries: int,
@ -232,7 +224,7 @@ def dispatch(
ctx = root_ctx.copy() ctx = root_ctx.copy()
worker = ctx.run( worker = ctx.run(
worker_cls, worker_cls,
token_q, result_q, tasks, keymanager, task_sem, result_q, tasks, keymanager,
) )
if worker_cls is FunctionWorker: if worker_cls is FunctionWorker:
func = mod.get_version # type: ignore func = mod.get_version # type: ignore

View file

@ -5,9 +5,8 @@ from __future__ import annotations
import asyncio import asyncio
from asyncio import Queue from asyncio import Queue
import contextlib
from typing import ( from typing import (
Dict, Optional, List, AsyncGenerator, NamedTuple, Union, Dict, Optional, List, NamedTuple, Union,
Any, Tuple, Callable, Coroutine, Hashable, Any, Tuple, Callable, Coroutine, Hashable,
TYPE_CHECKING, TYPE_CHECKING,
) )
@ -72,10 +71,10 @@ class Result(NamedTuple):
class BaseWorker: class BaseWorker:
'''The base class for defining `Worker` classes for source plugins. '''The base class for defining `Worker` classes for source plugins.
.. py:attribute:: token_q .. py:attribute:: task_sem
:type: Queue[bool] :type: asyncio.Semaphore
This is the rate-limiting queue. Workers should obtain one token before doing one unit of work. This is the rate-limiting semaphore. Workers should acquire it while doing one unit of work.
.. py:attribute:: result_q .. py:attribute:: result_q
:type: Queue[RawResult] :type: Queue[RawResult]
@ -96,28 +95,16 @@ class BaseWorker:
''' '''
def __init__( def __init__(
self, self,
token_q: Queue[bool], task_sem: asyncio.Semaphore,
result_q: Queue[RawResult], result_q: Queue[RawResult],
tasks: List[Tuple[str, Entry]], tasks: List[Tuple[str, Entry]],
keymanager: KeyManager, keymanager: KeyManager,
) -> None: ) -> None:
self.token_q = token_q self.task_sem = task_sem
self.result_q = result_q self.result_q = result_q
self.keymanager = keymanager self.keymanager = keymanager
self.tasks = tasks self.tasks = tasks
@contextlib.asynccontextmanager
async def acquire_token(self) -> AsyncGenerator[None, None]:
'''A context manager to obtain a token from the `token_q` on entrance and
release it on exit.'''
token = await self.token_q.get()
logger.debug('got token')
try:
yield
finally:
await self.token_q.put(token)
logger.debug('return token')
@abc.abstractmethod @abc.abstractmethod
async def run(self) -> None: async def run(self) -> None:
'''Run the `tasks`. Subclasses should implement this method.''' '''Run the `tasks`. Subclasses should implement this method.'''
@ -227,7 +214,7 @@ class FunctionWorker(BaseWorker):
ctx_ua.set(ua) ctx_ua.set(ua)
try: try:
async with self.acquire_token(): async with self.task_sem:
version = await self.func( version = await self.func(
name, entry, name, entry,
cache = self.cache, cache = self.cache,

View file

@ -20,8 +20,6 @@ def _decompress_data(url: str, data: bytes) -> str:
elif url.endswith(".gz"): elif url.endswith(".gz"):
import gzip import gzip
data = gzip.decompress(data) data = gzip.decompress(data)
else:
raise NotImplementedError(url)
return data.decode('utf-8') return data.decode('utf-8')

View file

@ -67,7 +67,7 @@ class Worker(BaseWorker):
) -> None: ) -> None:
task_by_name: Dict[str, Entry] = dict(self.tasks) task_by_name: Dict[str, Entry] = dict(self.tasks)
async with self.acquire_token(): async with self.task_sem:
results = await _run_batch_impl(batch, aur_results) results = await _run_batch_impl(batch, aur_results)
for name, version in results.items(): for name, version in results.items():
r = RawResult(name, version, task_by_name[name]) r = RawResult(name, version, task_by_name[name])

View file

@ -10,7 +10,7 @@ from nvchecker.api import (
class Worker(BaseWorker): class Worker(BaseWorker):
async def run(self) -> None: async def run(self) -> None:
exc = GetVersionError('no source specified') exc = GetVersionError('no source specified')
async with self.acquire_token(): async with self.task_sem:
for name, conf in self.tasks: for name, conf in self.tasks:
await self.result_q.put( await self.result_q.put(
RawResult(name, exc, conf)) RawResult(name, exc, conf))

View file

@ -18,7 +18,7 @@ use_keyfile = False
async def run( async def run(
entries: Entries, max_concurrency: int = 20, entries: Entries, max_concurrency: int = 20,
) -> VersData: ) -> VersData:
token_q = core.token_queue(max_concurrency) task_sem = asyncio.Semaphore(max_concurrency)
result_q: asyncio.Queue[RawResult] = asyncio.Queue() result_q: asyncio.Queue[RawResult] = asyncio.Queue()
keyfile = os.environ.get('KEYFILE') keyfile = os.environ.get('KEYFILE')
if use_keyfile and keyfile: if use_keyfile and keyfile:
@ -28,7 +28,7 @@ async def run(
keymanager = core.KeyManager(None) keymanager = core.KeyManager(None)
futures = core.dispatch( futures = core.dispatch(
entries, token_q, result_q, entries, task_sem, result_q,
keymanager, 1, keymanager, 1,
) )