add httptoken global, update test

This commit is contained in:
Maud LAURENT 2021-05-10 17:12:39 +02:00
parent 4aa5078169
commit 0cb999dc14
7 changed files with 40 additions and 43 deletions

View file

@ -7,4 +7,4 @@ from .util import (
AsyncCache, KeyManager, GetVersionError,
)
from .sortversion import sort_version_keys
from .ctxvars import tries, proxy, user_agent
from .ctxvars import tries, proxy, user_agent, httptoken

View file

@ -13,3 +13,4 @@ DEFAULT_USER_AGENT = f'lilydjwg/nvchecker {__version__}'
tries = ContextVar('tries', default=1)
proxy: ContextVar[Optional[str]] = ContextVar('proxy', default=None)
user_agent = ContextVar('user_agent', default=DEFAULT_USER_AGENT)
httptoken = ContextVar('httptoken', default=None)

View file

@ -5,7 +5,7 @@ import structlog
from typing import Optional, Dict, Mapping
import json as _json
from ..ctxvars import tries, proxy, user_agent
from ..ctxvars import tries, proxy, user_agent, httptoken
logger = structlog.get_logger(logger_name=__name__)
@ -65,9 +65,11 @@ class BaseSession:
t = tries.get()
p = proxy.get()
ua = user_agent.get()
httpt = httptoken.get()
headers = headers.copy()
headers.setdefault('User-Agent', ua)
headers.setdefault('Authorization', httpt)
for i in range(1, t+1):
try:

View file

@ -21,6 +21,7 @@ from .httpclient import session
from .ctxvars import tries as ctx_tries
from .ctxvars import proxy as ctx_proxy
from .ctxvars import user_agent as ctx_ua
from .ctxvars import httptoken as ctx_httpt
logger = structlog.get_logger(logger_name=__name__)
@ -224,7 +225,13 @@ class FunctionWorker(BaseWorker):
ua = entry.get('user_agent', None)
if ua is not None:
ctx_ua.set(ua)
httpt = entry.get('httptoken_'+name, None)
if httpt is not None:
ctx_httpt.set(httpt)
else:
httpt = self.keymanager.get_key('httptoken_'+name)
if httpt is not None:
ctx_httpt.set(httpt)
try:
async with self.task_sem:
version = await self.func(
@ -236,6 +243,7 @@ class FunctionWorker(BaseWorker):
except Exception as e:
await self.result_q.put(RawResult(name, e, entry))
class GetVersionError(Exception):
'''An error occurred while getting version information.

View file

@ -1,45 +1,31 @@
# MIT licensed
# Copyright (c) 2020 Ypsilik <tt2laurent.maud@gmail.com>, et al.
# Copyright (c) 2013-2020 lilydjwg <lilydjwg@gmail.com>, et al.
import re
import sre_constants
from lxml import html, etree
from nvchecker.api import (
VersionResult, Entry, KeyManager,
TemporaryError, session
)
from nvchecker.api import session, GetVersionError
async def get_version(name, conf, **kwargs):
return await get_version_real(name, conf, **kwargs)
async def get_version(name, conf, *, cache, **kwargs):
key = tuple(sorted(conf.items()))
return await cache.get(key, get_version_impl)
async def get_version_real(
name: str, conf: Entry, *, keymanager: KeyManager,
**kwargs,
) -> VersionResult:
async def get_version_impl(info):
conf = dict(info)
encoding = conf.get('encoding', 'latin1')
encoding = conf.get('encoding', 'latin1')
# Load token from config
token = conf.get('token')
# Load token from keyman
if token is None:
key_name = 'htmlparser_' + name
token = keymanager.get_key(key_name)
# Set private token if token exists.
headers = {}
if token:
headers["Authorization"] = token
data = await session.get(conf.get('url'), headers=headers)
body = html.fromstring(data.body.decode(encoding))
try:
checkxpath = body.xpath(conf.get('xpath'))
except etree.XPathEvalError as e:
raise GetVersionError('bad xpath', exc_info=e)
try:
version = body.xpath(conf.get('xpath'))
except ValueError:
if not conf.get('missing_ok', False):
raise GetVersionError('version string not found.')
return version
res = await session.get(conf['url'])
body = html.fromstring(res.body.decode(encoding))
try:
checkxpath = body.xpath(conf.get('xpath'))
except etree.XPathEvalError as e:
raise GetVersionError('bad xpath', exc_info=e)
try:
version = body.xpath(conf.get('xpath'))
except ValueError:
if not conf.get('missing_ok', False):
raise GetVersionError('version string not found.')
return version

View file

@ -17,7 +17,7 @@ async def test_get_version_withtoken(get_version, httpbin):
assert await get_version("unifiedremote", {
"source": "httpheader",
"url": httpbin.url + "/basic-auth/username/superpassword",
"token": "Basic dXNlcm5hbWU6c3VwZXJwYXNzd29yZA==",
"httptoken_unifiedremote": "Basic dXNlcm5hbWU6c3VwZXJwYXNzd29yZA==",
"header": "server",
"regex": r'([0-9.]+)*',
}) != None

View file

@ -47,7 +47,7 @@ async def test_regex_with_tokenBasic(get_version, httpbin):
assert await get_version("example", {
"source": "regex",
"url": httpbin.url + "/basic-auth/username/superpassword",
"token": "Basic dXNlcm5hbWU6c3VwZXJwYXNzd29yZA==",
"httptoken_example": "Basic dXNlcm5hbWU6c3VwZXJwYXNzd29yZA==",
"regex": r'"user":"([a-w]+)"',
}) == "username"
@ -55,6 +55,6 @@ async def test_regex_with_tokenBearer(get_version, httpbin):
assert await get_version("example", {
"source": "regex",
"url": httpbin.url + "/bearer",
"token": "Bearer username:password",
"httptoken_example": "Bearer username:password",
"regex": r'"token":"([a-w]+):.*"',
}) == "username"