Source code for peony.oauth
# -*- coding: utf-8 -*-
import asyncio
import base64
import hmac
import random
import string
import time
import urllib.parse
from abc import ABC, abstractmethod
from hashlib import sha1
import aiohttp
from . import __version__, utils
[docs]class PeonyHeaders(ABC, dict):
"""
Dynamic headers for Peony
This is the base class of :class:`OAuth1Headers` and
:class:`OAuth2Headers`.
Parameters
----------
compression : bool, optional
If set to True the client will be able to receive compressed
responses else it should not happen unless you provide the
corresponding header when you make a request. Defaults to True.
user_agent : str, optional
The user agent set in the headers. Defaults to
"peony v{version number}"
headers : dict
dict containing custom headers
"""
def __init__(self, compression=True, user_agent=None, headers=None):
""" Add a nice User-Agent """
super().__init__()
if user_agent is None:
self['User-Agent'] = "peony v" + __version__
else:
self['User-Agent'] = user_agent
if compression:
self['Accept-Encoding'] = "deflate, gzip"
if headers is not None:
for key, value in headers.items():
self[key] = value
def __setitem__(self, key, value):
super().__setitem__(key.title(), value)
[docs] async def prepare_request(self, method, url,
headers=None,
skip_params=False,
proxy=None,
**kwargs):
"""
prepare all the arguments for the request
Parameters
----------
method : str
HTTP method used by the request
url : str
The url to request
headers : dict, optional
Additionnal headers
proxy : str
proxy of the request
skip_params : bool
Don't use the parameters to sign the request
Returns
-------
dict
Parameters of the request correctly formatted
"""
if method.lower() == "post":
key = 'data'
else:
key = 'params'
if key in kwargs and not skip_params:
request_params = {key: kwargs.pop(key)}
else:
request_params = {}
request_params.update(dict(method=method.upper(), url=url))
coro = self.sign(**request_params, skip_params=skip_params,
headers=headers)
request_params['headers'] = await utils.execute(coro)
request_params['proxy'] = proxy
kwargs.update(request_params)
return kwargs
def _user_headers(self, headers=None):
""" Make sure the user doesn't override the Authorization header """
h = self.copy()
if headers is not None:
keys = set(headers.keys())
if h.get('Authorization', False):
keys -= {'Authorization'}
for key in keys:
h[key] = headers[key]
return h
[docs] @abstractmethod
def sign(self, *args, headers=None, **kwargs):
"""
sign, that is, generate the `Authorization` headers before making
a request
"""
[docs]class OAuth1Headers(PeonyHeaders):
"""
Dynamic headers implementing OAuth1
Parameters
----------
consumer_key : str
Your consumer key
consumer_secret : str
Your consumer secret
access_token : str
Your access token
access_token_secret : str
Your access token secret
**kwargs
Other headers
"""
def __init__(self, consumer_key, consumer_secret,
access_token=None, access_token_secret=None,
compression=True, user_agent=None, headers=None):
super().__init__(compression, user_agent, headers)
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.access_token = access_token
self.access_token_secret = access_token_secret
self.alphabet = string.ascii_letters + string.digits
@staticmethod
def _default_content_type(skip_params):
if skip_params:
return "application/octet-stream"
else:
return "application/x-www-form-urlencoded"
[docs] def sign(self, method='GET', url=None,
data=None,
params=None,
skip_params=False,
headers=None,
**kwargs):
headers = self._user_headers(headers)
if data:
if 'Content-Type' not in headers:
default = self._default_content_type(skip_params)
headers['Content-Type'] = default
params = data.copy()
elif params:
params = params.copy()
oauth = {
'oauth_consumer_key': self.consumer_key,
'oauth_nonce': self.gen_nonce(),
'oauth_signature_method': 'HMAC-SHA1',
'oauth_timestamp': str(int(time.time())),
'oauth_version': '1.0'
}
if self.access_token is not None:
oauth['oauth_token'] = self.access_token
oauth['oauth_signature'] = self.gen_signature(method=method, url=url,
params=params,
skip_params=skip_params,
oauth=oauth)
headers['Authorization'] = "OAuth "
for key, value in sorted(oauth.items(), key=lambda i: i[0]):
if len(headers['Authorization']) > len("OAuth "):
headers['Authorization'] += ", "
headers['Authorization'] += quote(key) + '="' + quote(value) + '"'
return headers
[docs] def gen_signature(self, method, url, params, skip_params, oauth):
signature = method.upper() + '&' + quote(url) + '&'
if params is None or skip_params:
params = oauth
else:
params.update(oauth)
param_string = ""
for key, value in sorted(params.items(), key=lambda i: i[0]):
if param_string:
param_string += '&'
param_string += quote(key) + '='
if key == "q":
encoded_value = urllib.parse.quote(value, safe="$:!?/()'*@")
param_string += encoded_value
else:
param_string += quote(value)
signature += quote(param_string)
key = quote(self.consumer_secret).encode() + b"&"
if self.access_token_secret is not None:
key += quote(self.access_token_secret).encode()
signature = hmac.new(key, signature.encode(), sha1)
signature = base64.b64encode(signature.digest()).decode().rstrip("\n")
return signature
[docs]class OAuth2Headers(PeonyHeaders):
"""
Dynamic headers implementing OAuth2
Parameters
----------
consumer_key : str
Your consumer key
consumer_secret : str
Your consumer secret
client : .client.BasePeonyClient
The client to authenticate
bearer_token : :obj:`str`, optional
Your bearer_token
**kwargs
Other headers
"""
def __init__(self, consumer_key, consumer_secret, client,
bearer_token=None, compression=True, user_agent=None,
headers=None):
super().__init__(compression, user_agent, headers)
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.client = client
self.basic_authorization = self.get_basic_authorization()
self._refreshing = asyncio.Event()
self._refreshing.clear()
if bearer_token is not None:
self.token = bearer_token
[docs] async def sign(self, url=None, *args, headers=None, **kwargs):
if url == self._invalidate_token.url():
del self.token
elif 'Authorization' not in self:
await self.refresh_token()
return self._user_headers(headers)
[docs] def get_basic_authorization(self):
creds = quote(self.consumer_key), quote(self.consumer_secret)
keys = ':'.join(creds).encode('utf-8')
auth = "Basic " + base64.b64encode(keys).decode('utf-8')
return {'Authorization': auth,
'Content-Type': "application/x-www-form-urlencoded;"
"charset=UTF-8"}
@property
def token(self):
print("setting token")
if 'Authorization' in self:
return self['Authorization'][len("Bearer "):]
@token.setter
def token(self, access_token):
self['Authorization'] = "Bearer " + access_token
@token.deleter
def token(self):
del self['Authorization']
@property
def _invalidate_token(self):
return self.client['api', '', ''].oauth2.invalidate_token
[docs] async def invalidate_token(self):
if 'Authorization' not in self:
raise RuntimeError('There is no token to invalidate')
token = self.token
try:
request = self._invalidate_token.post
data = RawFormData({'access_token': token}, quote_fields=False)
await request(_data=data, _headers=self.basic_authorization)
except Exception:
self.token = token
raise
[docs] async def refresh_token(self):
if self._refreshing.is_set():
return await self._refreshing.wait()
self._refreshing.set()
request = self.client['api', "", ""].oauth2.token.post
token = await request(grant_type="client_credentials",
_headers=self.basic_authorization,
_oauth2_pass=True)
self.token = token['access_token']
self._refreshing.clear()
[docs] async def prepare_request(self, *args, oauth2_pass=False, **kwargs):
"""
prepare all the arguments for the request
Parameters
----------
oauth2_pass : bool
For oauth2 authentication only (don't use it)
Returns
-------
dict
Parameters of the request correctly formatted
"""
if not oauth2_pass:
await self.sign()
return await super().prepare_request(*args, **kwargs)
[docs]class RawFormData(aiohttp.FormData):
def _gen_form_urlencoded(self):
def key(item):
return item[0]['name']
data = ""
for type_options, _, value in sorted(self._fields, key=key):
if data:
data += "&"
data += "%s=%s" % (type_options['name'], value)
charset = self._charset if self._charset is not None else 'utf-8'
content_type = "application/x-www-form-urlencoded;charset=" + charset
return aiohttp.payload.BytesPayload(data.encode(),
content_type=content_type)