This commit is contained in:
envolution 2024-11-19 23:28:56 -05:00
parent fa4cfefc83
commit 15da543c55

View file

@ -12,25 +12,29 @@ from nvchecker.api import (
HTTPError, session, RichResult, GetVersionError, HTTPError, session, RichResult, GetVersionError,
) )
http_client = None
logger = structlog.get_logger(logger_name=__name__) logger = structlog.get_logger(logger_name=__name__)
ALLOW_REQUEST = None ALLOW_REQUEST = None
RATE_LIMITED_ERROR = False RATE_LIMITED_ERROR = False
http_client = None
GITHUB_GRAPHQL_URL = 'https://api.%s/graphql' GITHUB_GRAPHQL_URL = 'https://api.%s/graphql'
async def get_http_client():
"""Initialize and return the HTTP client."""
global http_client
if http_client is None:
if asyncio.iscoroutine(session):
http_client = await session
else:
http_client = session
return http_client
async def execute_github_query(host: str, owner: str, reponame: str, token: str) -> dict: async def execute_github_query(host: str, owner: str, reponame: str, token: str) -> dict:
""" """
Execute GraphQL query against GitHub API and return the response data. Execute GraphQL query against GitHub API and return the response data.
Centralizes error handling and query execution. Centralizes error handling and query execution.
""" """
global http_client client = await get_http_client()
# Initialize the HTTP client if not already done
if http_client is None:
if asyncio.iscoroutine(session):
http_client = await session
http_client = session
headers = { headers = {
'Authorization': f'bearer {token}', 'Authorization': f'bearer {token}',
@ -39,7 +43,7 @@ async def execute_github_query(host: str, owner: str, reponame: str, token: str)
query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame) query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame)
async with http_client.post( async with client.post(
GITHUB_GRAPHQL_URL % host, GITHUB_GRAPHQL_URL % host,
headers=headers, headers=headers,
json={'query': query_vars} json={'query': query_vars}
@ -59,36 +63,28 @@ def get_github_token(conf: Entry, host: str, keymanager: KeyManager) -> Optional
return token return token
async def get_version(name, conf, **kwargs): async def get_version(name, conf, **kwargs):
global RATE_LIMITED_ERROR, ALLOW_REQUEST global RATE_LIMITED_ERROR, ALLOW_REQUEST
global http_client if RATE_LIMITED_ERROR:
raise RuntimeError('rate limited')
# Initialize the HTTP client if not already done
if http_client is None:
if asyncio.iscoroutine(session):
http_client = await session
http_client = session
if RATE_LIMITED_ERROR: if ALLOW_REQUEST is None:
raise RuntimeError('rate limited') ALLOW_REQUEST = asyncio.Event()
ALLOW_REQUEST.set()
if ALLOW_REQUEST is None: for _ in range(2): # retry once
ALLOW_REQUEST = asyncio.Event() try:
ALLOW_REQUEST.set() await ALLOW_REQUEST.wait()
return await get_version_real(name, conf, **kwargs)
for _ in range(2): # retry once except HTTPError as e:
try: if e.code in [403, 429]:
await ALLOW_REQUEST.wait() if n := check_ratelimit(e, name):
return await get_version_real(name, conf, **kwargs) ALLOW_REQUEST.clear()
except HTTPError as e: await asyncio.sleep(n+1)
if e.code in [403, 429]: ALLOW_REQUEST.set()
if n := check_ratelimit(e, name): continue
ALLOW_REQUEST.clear() RATE_LIMITED_ERROR = True
await asyncio.sleep(n+1) raise
ALLOW_REQUEST.set()
continue
RATE_LIMITED_ERROR = True
raise
QUERY_GITHUB = """ QUERY_GITHUB = """
query { query {