This commit is contained in:
envolution 2024-11-20 00:02:26 -05:00
parent d3a60cc29a
commit af18ac688b

View file

@ -3,14 +3,14 @@ import time
from urllib.parse import urlencode from urllib.parse import urlencode
from typing import List, Tuple, Union, Optional from typing import List, Tuple, Union, Optional
import asyncio import asyncio
import json # Added for JSON handling import aiohttp
import structlog import structlog
from nvchecker.api import ( from nvchecker.api import (
VersionResult, Entry, AsyncCache, KeyManager, VersionResult, Entry, AsyncCache, KeyManager,
HTTPError, session, RichResult, GetVersionError, HTTPError, session, RichResult, GetVersionError,
) )
DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=60)
logger = structlog.get_logger(logger_name=__name__) logger = structlog.get_logger(logger_name=__name__)
ALLOW_REQUEST = None ALLOW_REQUEST = None
RATE_LIMITED_ERROR = False RATE_LIMITED_ERROR = False
@ -18,18 +18,18 @@ _http_client = None
GITHUB_GRAPHQL_URL = 'https://api.%s/graphql' GITHUB_GRAPHQL_URL = 'https://api.%s/graphql'
async def create_http_client():
"""Create a new aiohttp client session with proper configuration."""
return aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT)
async def get_http_client(): async def get_http_client():
"""Initialize and return the HTTP client.""" """Initialize and return the HTTP client."""
global _http_client global _http_client
if _http_client is not None: if _http_client is not None:
return _http_client return _http_client
# Get the client instance, awaiting if necessary # Create a new client session if none exists
client = await session if asyncio.iscoroutine(session) else session client = await create_http_client()
if not hasattr(client, '__aenter__'):
raise RuntimeError("HTTP client must support async context management")
_http_client = client _http_client = client
return _http_client return _http_client
@ -46,21 +46,29 @@ async def execute_github_query(host: str, owner: str, reponame: str, token: str)
} }
query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame) query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame)
client = await get_http_client()
try: try:
# Ensure we have a properly initialized client async with client.post(
http_client = await get_http_client() GITHUB_GRAPHQL_URL % host,
async with http_client.post( headers=headers,
GITHUB_GRAPHQL_URL % host, json={'query': query_vars}
headers=headers,
json={'query': query_vars}
) as response: ) as response:
# Check response status # Check response status
response.raise_for_status() response.raise_for_status()
# Parse JSON response # Parse JSON response
data = await response.json() data = await response.json()
# Handle rate limiting headers
remaining = response.headers.get('X-RateLimit-Remaining')
if remaining and int(remaining) == 0:
reset_time = int(response.headers.get('X-RateLimit-Reset', 0))
logger.warning(
"GitHub API rate limit reached",
reset_time=time.ctime(reset_time)
)
# Check for GraphQL errors # Check for GraphQL errors
if 'errors' in data: if 'errors' in data:
raise GetVersionError(f"GitHub API error: {data['errors']}") raise GetVersionError(f"GitHub API error: {data['errors']}")