diff --git a/nvchecker_source/github-test.py b/nvchecker_source/github-test.py index 320ad49..c451aed 100644 --- a/nvchecker_source/github-test.py +++ b/nvchecker_source/github-test.py @@ -12,25 +12,29 @@ from nvchecker.api import ( HTTPError, session, RichResult, GetVersionError, ) -http_client = None logger = structlog.get_logger(logger_name=__name__) ALLOW_REQUEST = None RATE_LIMITED_ERROR = False +http_client = None GITHUB_GRAPHQL_URL = 'https://api.%s/graphql' +async def get_http_client(): + """Initialize and return the HTTP client.""" + global http_client + if http_client is None: + if asyncio.iscoroutine(session): + http_client = await session + else: + http_client = session + return http_client + async def execute_github_query(host: str, owner: str, reponame: str, token: str) -> dict: """ Execute GraphQL query against GitHub API and return the response data. Centralizes error handling and query execution. """ - global http_client - - # Initialize the HTTP client if not already done - if http_client is None: - if asyncio.iscoroutine(session): - http_client = await session - http_client = session + client = await get_http_client() headers = { 'Authorization': f'bearer {token}', @@ -39,7 +43,7 @@ async def execute_github_query(host: str, owner: str, reponame: str, token: str) query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame) - async with http_client.post( + async with client.post( GITHUB_GRAPHQL_URL % host, headers=headers, json={'query': query_vars} @@ -59,36 +63,28 @@ def get_github_token(conf: Entry, host: str, keymanager: KeyManager) -> Optional return token async def get_version(name, conf, **kwargs): - global RATE_LIMITED_ERROR, ALLOW_REQUEST + global RATE_LIMITED_ERROR, ALLOW_REQUEST - global http_client - - # Initialize the HTTP client if not already done - if http_client is None: - if asyncio.iscoroutine(session): - http_client = await session - http_client = session + if RATE_LIMITED_ERROR: + raise RuntimeError('rate limited') - if RATE_LIMITED_ERROR: - raise RuntimeError('rate limited') + if ALLOW_REQUEST is None: + ALLOW_REQUEST = asyncio.Event() + ALLOW_REQUEST.set() - if ALLOW_REQUEST is None: - ALLOW_REQUEST = asyncio.Event() - ALLOW_REQUEST.set() - - for _ in range(2): # retry once - try: - await ALLOW_REQUEST.wait() - return await get_version_real(name, conf, **kwargs) - except HTTPError as e: - if e.code in [403, 429]: - if n := check_ratelimit(e, name): - ALLOW_REQUEST.clear() - await asyncio.sleep(n+1) - ALLOW_REQUEST.set() - continue - RATE_LIMITED_ERROR = True - raise + for _ in range(2): # retry once + try: + await ALLOW_REQUEST.wait() + return await get_version_real(name, conf, **kwargs) + except HTTPError as e: + if e.code in [403, 429]: + if n := check_ratelimit(e, name): + ALLOW_REQUEST.clear() + await asyncio.sleep(n+1) + ALLOW_REQUEST.set() + continue + RATE_LIMITED_ERROR = True + raise QUERY_GITHUB = """ query {