--- github-test.py +++ github-test.py @@ -13,6 +13,7 @@ from nvchecker.api import ( HTTPError, session, RichResult, GetVersionError, ) +http_client = None logger = structlog.get_logger(logger_name=__name__) ALLOW_REQUEST = None RATE_LIMITED_ERROR = False @@ -20,8 +21,48 @@ RATE_LIMITED_ERROR = False GITHUB_GRAPHQL_URL = 'https://api.%s/graphql' +async def execute_github_query(host: str, owner: str, reponame: str, token: str) -> dict: + """ + Execute GraphQL query against GitHub API and return the response data. + Centralizes error handling and query execution. + """ + global http_client + + # Initialize the HTTP client if not already done + if http_client is None: + if asyncio.iscoroutine(session): + http_client = await session + http_client = session + + headers = { + 'Authorization': f'bearer {token}', + 'Content-Type': 'application/json', + } + + query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame) + + async with http_client.post( + GITHUB_GRAPHQL_URL % host, + headers=headers, + json={'query': query_vars} + ) as res: + j = await res.json() + if 'errors' in j: + raise GetVersionError(f"GitHub API error: {j['errors']}") + return j['data']['repository'] + +def get_github_token(conf: Entry, host: str, keymanager: KeyManager) -> Optional[str]: + """Get GitHub token from config, keymanager, or environment.""" + token = conf.get('token') + if token is None: + token = keymanager.get_key(host.lower(), 'github') + if token is None: + token = os.environ.get('GITHUB_TOKEN') + return token + async def get_version(name, conf, **kwargs): - global RATE_LIMITED_ERROR, ALLOW_REQUEST + global RATE_LIMITED_ERROR, ALLOW_REQUEST + if RATE_LIMITED_ERROR: + raise RuntimeError('rate limited') if ALLOW_REQUEST is None: ALLOW_REQUEST = asyncio.Event() @@ -91,21 +132,11 @@ query { async def get_latest_tag(key: Tuple[str, str, str, str]) -> RichResult: host, repo, query, token = key owner, reponame = repo.split('/') - headers = { - 'Authorization': f'bearer {token}', - 'Content-Type': 'application/json', - } - # Make GraphQL query - query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame) - async with session.post( - GITHUB_GRAPHQL_URL % host, - headers=headers, - json={'query': query_vars} - ) as res: - j = await res.json() - if 'errors' in j: - raise GetVersionError(f"GitHub API error: {j['errors']}") + if not token: + raise GetVersionError('token is required for latest tag query') + + repo_data = await execute_github_query(host, owner, reponame, token) - refs = j['data']['repository']['refs']['edges'] + refs = repo_data['refs']['edges'] if not refs: raise GetVersionError('no tag found') @@ -120,21 +151,11 @@ async def get_latest_tag(key: Tuple[str, str, str, str]) -> RichResult: async def get_latest_release_with_prereleases(key: Tuple[str, str, str, str]) -> RichResult: host, repo, token, use_release_name = key owner, reponame = repo.split('/') - headers = { - 'Authorization': f'bearer {token}', - 'Content-Type': 'application/json', - } - # Make GraphQL query - query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame) - async with session.post( - GITHUB_GRAPHQL_URL % host, - headers=headers, - json={'query': query_vars} - ) as res: - j = await res.json() - if 'errors' in j: - raise GetVersionError(f"GitHub API error: {j['errors']}") + if not token: + raise GetVersionError('token is required for latest release query') + + repo_data = await execute_github_query(host, owner, reponame, token) - releases = j['data']['repository']['releases']['edges'] + releases = repo_data['releases']['edges'] if not releases: raise GetVersionError('no release found') @@ -199,30 +220,17 @@ async def get_version_real( repo = conf['github'] owner, reponame = repo.split('/') host = conf.get('host', "github.com") + token = get_github_token(conf, host, keymanager) - # Load token from config - token = conf.get('token') - # Load token from keyman - if token is None: - token = keymanager.get_key(host.lower(), 'github') - # Load token from environment - if token is None: - token = os.environ.get('GITHUB_TOKEN') - use_latest_tag = conf.get('use_latest_tag', False) if use_latest_tag: if not token: raise GetVersionError('token not given but it is required') - query = conf.get('query', '') return await cache.get((host, repo, query, token), get_latest_tag) - headers = { - 'Authorization': f'bearer {token}', - 'Content-Type': 'application/json', - } + repo_data = await execute_github_query(host, owner, reponame, token) - # Make GraphQL query - query_vars = QUERY_GITHUB.replace("$owner", owner).replace("$name", reponame) - async with session.post( - GITHUB_GRAPHQL_URL % host, - headers=headers, - json={'query': query_vars} - ) as res: - j = await res.json() - if 'errors' in j: - raise GetVersionError(f"GitHub API error: {j['errors']}") - - use_max_tag = conf.ger('use_max_tag', False) + use_max_tag = conf.get('use_max_tag', False) if use_max_tag: - refs = j['data']['repository']['refs']['edges'] + refs = repo_data['refs']['edges'] tags: List[Union[str, RichResult]] = [ RichResult( version=ref['node']['name'], @@ -233,10 +241,10 @@ async def get_version_real( if not tags: raise GetVersionError('No tag found in upstream repository.') return tags - use_latest_release = conf.ger('use_latest_release', False) + use_latest_release = conf.get('use_latest_release', False) if use_latest_release: - releases = j['data']['repository']['releases']['edges'] - + releases = repo_data['releases']['edges'] + if not releases: raise GetVersionError('No release found in upstream repository.') latest_release = releases[0]['node'] - use_release_name = conf.ger('use_release_name', False) + use_release_name = conf.get('use_release_name', False) version = latest_release['name'] if use_release_name else latest_release['tagName'] return RichResult( @@ -245,7 +253,7 @@ async def get_version_real( url=latest_release['url'], ) else: - commit = j['data']['repository']['defaultBranchRef']['target']['history']['edges'][0]['node'] + commit = repo_data['defaultBranchRef']['target']['history']['edges'][0]['node'] return RichResult( version=commit['committedDate'].rstrip('Z').replace('-', '').replace(':', '').replace('T', '.'), revision=commit['oid'],