From cdd31a01e45dfd1fea1bb8234ef64d3366500140 Mon Sep 17 00:00:00 2001 From: Dusk Banks Date: Tue, 22 Feb 2022 18:00:51 -0800 Subject: [PATCH] github: factor out GitHub API querying --- nvchecker/util.py | 26 +++-- nvchecker_source/github.py | 188 ++++++++++++++++++++++--------------- 2 files changed, 129 insertions(+), 85 deletions(-) diff --git a/nvchecker/util.py b/nvchecker/util.py index fb50750..bf7657d 100644 --- a/nvchecker/util.py +++ b/nvchecker/util.py @@ -146,6 +146,14 @@ class BaseWorker: '''Run the `tasks`. Subclasses should implement this method.''' raise NotImplementedError +def _normalize(x: Any) -> Any: + if isinstance(x, list): + return tuple(sorted(_normalize(y) for y in x)) + elif isinstance(x, dict): + return tuple(sorted((_normalize(k), _normalize(v)) for k, v in x.items())) + else: + return x + class AsyncCache: '''A cache for use with async functions.''' cache: Dict[Hashable, Any] @@ -156,28 +164,32 @@ class AsyncCache: self.lock = asyncio.Lock() async def _get_json( - self, key: Tuple[str, str, Tuple[Tuple[str, str], ...]], + self, key: Tuple[str, str, Tuple[Tuple[str, str], ...], object], extra: Any, ) -> Any: - _, url, headers = key - res = await session.get(url, headers=dict(headers)) + _, url, headers, json = key + json = extra # denormalizing json would be a pain, so we sneak it through + res = await (session.get(url=url, headers=dict(headers)) if json is None \ + else session.post(url=url, headers=dict(headers), json=json)) return res.json() async def get_json( self, url: str, *, headers: Dict[str, str] = {}, + json: Optional[object] = None, ) -> Any: '''Get specified ``url`` and return the response content as JSON. The returned data will be cached for reuse. ''' - key = '_jsonurl', url, tuple(sorted(headers.items())) + key = '_jsonurl', url, _normalize(headers), _normalize(json) return await self.get( - key , self._get_json) # type: ignore + key, self._get_json, extra=json) # type: ignore async def get( self, key: Hashable, - func: Callable[[Hashable], Coroutine[Any, Any, Any]], + func: Callable[[Hashable, Optional[Any]], Coroutine[Any, Any, Any]], + extra: Optional[Any] = None, ) -> Any: '''Run async ``func`` and cache its return value by ``key``. @@ -189,7 +201,7 @@ class AsyncCache: async with self.lock: cached = self.cache.get(key) if cached is None: - coro = func(key) + coro = func(key, extra) fu = asyncio.create_task(coro) self.cache[key] = fu diff --git a/nvchecker_source/github.py b/nvchecker_source/github.py index 84eedde..f8e4fa1 100644 --- a/nvchecker_source/github.py +++ b/nvchecker_source/github.py @@ -4,7 +4,7 @@ import itertools import time from urllib.parse import urlencode -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import structlog @@ -30,6 +30,53 @@ async def get_version(name, conf, **kwargs): except TemporaryError as e: check_ratelimit(e, name) +async def query_graphql( + *, + cache: AsyncCache, + token: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + query: str, + variables: Optional[Dict[str, object]] = None, + json: Optional[Dict[str, object]] = None, + url: Optional[str] = None, + **kwargs, +) -> Any: + if not token: + raise GetVersionError('token not given but it is required') + if headers is None: + headers = {} + headers.setdefault('Authorization', f'bearer {token}') + headers.setdefault('Content-Type', 'application/json') + + if json is None: + json = {} + json['query'] = query + if variables is not None: + json.setdefault('variables', {}).update(variables) + + if url is None: + url = GITHUB_GRAPHQL_URL + return await cache.get_json(url = url, headers = headers, json = json) + +async def query_rest( + *, + cache: AsyncCache, + token: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + parameters: Optional[Dict[str, str]] = None, + url: str, +) -> Any: + if headers is None: + headers = {} + if token: + headers.setdefault('Authorization', f'token {token}') + headers.setdefault('Accept', 'application/vnd.github.quicksilver-preview+json') + + if parameters: + url += '?' + urlencode(parameters) + + return await cache.get_json(url = url, headers = headers) + QUERY_LATEST_TAG = ''' query latestTag( $owner: String!, $name: String!, @@ -51,43 +98,13 @@ query latestTag( } ''' -async def get_latest_tag(key: Tuple[str, Optional[str], str, bool]) -> str: - repo, query, token, use_commit_name = key - owner, reponame = repo.split('/') - headers = { - 'Authorization': f'bearer {token}', - 'Content-Type': 'application/json', - } - variables = { - 'owner': owner, - 'name': reponame, - 'includeCommitName': use_commit_name, - } - if query is not None: - variables['query'] = query - - res = await session.post( - GITHUB_GRAPHQL_URL, - headers = headers, - json = {'query': QUERY_LATEST_TAG, 'variables': variables}, - ) - j = res.json() - - refs = j['data']['repository']['refs']['edges'] - if not refs: - raise GetVersionError('no tag found') - - return next(add_commit_name( - ref['node']['name'], - ref['node']['target']['oid'] if use_commit_name else None, - ) for ref in refs) - async def get_version_real( name: str, conf: Entry, *, cache: AsyncCache, keymanager: KeyManager, **kwargs, ) -> VersionResult: repo = conf['github'] + use_commit_name = conf.get('use_commit_name', False) # Load token from config token = conf.get('token') @@ -95,62 +112,77 @@ async def get_version_real( if token is None: token = keymanager.get_key('github') - use_latest_tag = conf.get('use_latest_tag', False) - use_commit_name = conf.get('use_commit_name', False) - if use_latest_tag: - if not token: - raise GetVersionError('token not given but it is required') - - query = conf.get('query') - return await cache.get((repo, query, token, use_commit_name), get_latest_tag) # type: ignore - - br = conf.get('branch') - path = conf.get('path') - use_latest_release = conf.get('use_latest_release', False) - use_max_tag = conf.get('use_max_tag', False) - if use_latest_release: - url = GITHUB_LATEST_RELEASE % repo - elif use_max_tag: - url = GITHUB_MAX_TAG % repo - else: - url = GITHUB_URL % repo - parameters = {} - if br: - parameters['sha'] = br - if path: - parameters['path'] = path - url += '?' + urlencode(parameters) - headers = { - 'Accept': 'application/vnd.github.quicksilver-preview+json', - } - if token: - headers['Authorization'] = f'token {token}' - - data = await cache.get_json(url, headers = headers) - - if use_max_tag: - tags = [add_commit_name( - ref['ref'].split('/', 2)[-1], - ref['object']['sha'] if use_commit_name else None, - ) for ref in data] + if conf.get('use_latest_tag', False): + owner, reponame = repo.split('/') + j = await query_graphql( + cache = cache, + token = token, + query = QUERY_LATEST_TAG, + variables = { + 'owner': owner, + 'name': reponame, + 'query': conf.get('query'), + 'includeCommitName': use_commit_name, + }, + ) + refs = j['data']['repository']['refs']['edges'] + if not refs: + raise GetVersionError('no tag found') + ref = next( + add_commit_name( + ref['node']['name'], + ref['node']['target']['oid'] if use_commit_name else None, + ) + for ref in refs + ) + return ref + elif conf.get('use_latest_release', False): + data = await query_rest( + cache = cache, + token = token, + url = GITHUB_LATEST_RELEASE % repo, + ) + if 'tag_name' not in data: + raise GetVersionError('No release found in upstream repository.') + tag = data['tag_name'] + return tag + elif conf.get('use_max_tag', False): + data = await query_rest( + cache = cache, + token = token, + url = GITHUB_MAX_TAG % repo, + ) + tags = [ + add_commit_name( + ref['ref'].split('/', 2)[-1], + ref['object']['sha'] if use_commit_name else None, + ) + for ref in data + ] if not tags: raise GetVersionError('No tag found in upstream repository.') return tags - - if use_latest_release: - if 'tag_name' not in data: - raise GetVersionError('No release found in upstream repository.') - version = data['tag_name'] - else: + br = conf.get('branch') + path = conf.get('path') + parameters = {} + if br is not None: + parameters['sha'] = br + if path is not None: + parameters['path'] = path + data = await query_rest( + cache = cache, + token = token, + url = GITHUB_URL % repo, + parameters = parameters, + ) # YYYYMMDD.HHMMSS version = add_commit_name( data[0]['commit']['committer']['date'] \ .rstrip('Z').replace('-', '').replace(':', '').replace('T', '.'), data[0]['sha'] if use_commit_name else None, ) - - return version + return version def check_ratelimit(exc, name): res = exc.response