Source code for peony.oauth
# -*- coding: utf-8 -*-
import asyncio
import base64
import hmac
import random
import string
import time
import urllib.parse
from abc import ABC, abstractmethod
from hashlib import sha1
import aiohttp
from . import __version__, utils
[docs]class PeonyHeaders(ABC, dict):
"""
Dynamic headers for Peony
This is the base class of :class:`OAuth1Headers` and
:class:`OAuth2Headers`.
Parameters
----------
compression : bool, optional
If set to True the client will be able to receive compressed
responses else it should not happen unless you provide the
corresponding header when you make a request. Defaults to True.
user_agent : str, optional
The user agent set in the headers. Defaults to
"peony v{version number}"
headers : dict
dict containing custom headers
"""
def __init__(self, compression=True, user_agent=None, headers=None):
"""Add a nice User-Agent"""
super().__init__()
if user_agent is None:
self["User-Agent"] = "peony v" + __version__
else:
self["User-Agent"] = user_agent
if compression:
self["Accept-Encoding"] = "deflate, gzip"
if headers is not None:
for key, value in headers.items():
self[key] = value
def __setitem__(self, key, value):
super().__setitem__(key.title(), value)
[docs] async def prepare_request(
self, method, url, headers=None, skip_params=False, proxy=None, **kwargs
):
"""
prepare all the arguments for the request
Parameters
----------
method : str
HTTP method used by the request
url : str
The url to request
headers : dict, optional
Additionnal headers
proxy : str
proxy of the request
skip_params : bool
Don't use the parameters to sign the request
Returns
-------
dict
Parameters of the request correctly formatted
"""
if method.lower() == "post":
key = "data"
else:
key = "params"
if key in kwargs and not skip_params:
request_params = {key: kwargs.pop(key)}
else:
request_params = {}
request_params.update(dict(method=method.upper(), url=url))
coro = self.sign(**request_params, skip_params=skip_params, headers=headers)
request_params["headers"] = await utils.execute(coro)
request_params["proxy"] = proxy
kwargs.update(request_params)
return kwargs
def _user_headers(self, headers=None):
"""Make sure the user doesn't override the Authorization header"""
h = self.copy()
if headers is not None:
keys = set(headers.keys())
if h.get("Authorization", False):
keys -= {"Authorization"}
for key in keys:
h[key] = headers[key]
return h
[docs] @abstractmethod
def sign(self, *args, headers=None, **kwargs):
"""
sign, that is, generate the `Authorization` headers before making
a request
"""
[docs]class OAuth1Headers(PeonyHeaders):
"""
Dynamic headers implementing OAuth1
Parameters
----------
consumer_key : str
Your consumer key
consumer_secret : str
Your consumer secret
access_token : str
Your access token
access_token_secret : str
Your access token secret
**kwargs
Other headers
"""
def __init__(
self,
consumer_key,
consumer_secret,
access_token=None,
access_token_secret=None,
compression=True,
user_agent=None,
headers=None,
):
super().__init__(compression, user_agent, headers)
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.access_token = access_token
self.access_token_secret = access_token_secret
self.alphabet = string.ascii_letters + string.digits
@staticmethod
def _default_content_type(skip_params):
if skip_params:
return "application/octet-stream"
else:
return "application/x-www-form-urlencoded"
[docs] def sign(
self,
method="GET",
url=None,
data=None,
params=None,
skip_params=False,
headers=None,
**kwargs
):
headers = self._user_headers(headers)
if data:
if "Content-Type" not in headers:
default = self._default_content_type(skip_params)
headers["Content-Type"] = default
params = data.copy()
elif params:
params = params.copy()
oauth = {
"oauth_consumer_key": self.consumer_key,
"oauth_nonce": self.gen_nonce(),
"oauth_signature_method": "HMAC-SHA1",
"oauth_timestamp": str(int(time.time())),
"oauth_version": "1.0",
}
if self.access_token is not None:
oauth["oauth_token"] = self.access_token
oauth["oauth_signature"] = self.gen_signature(
method=method, url=url, params=params, skip_params=skip_params, oauth=oauth
)
headers["Authorization"] = "OAuth "
for key, value in sorted(oauth.items(), key=lambda i: i[0]):
if len(headers["Authorization"]) > len("OAuth "):
headers["Authorization"] += ", "
headers["Authorization"] += quote(key) + '="' + quote(value) + '"'
return headers
[docs] def gen_signature(self, method, url, params, skip_params, oauth):
signature = method.upper() + "&" + quote(url) + "&"
if params is None or skip_params:
params = oauth
else:
params.update(oauth)
param_string = ""
for key, value in sorted(params.items(), key=lambda i: i[0]):
if param_string:
param_string += "&"
param_string += quote(key) + "="
if key == "q":
encoded_value = urllib.parse.quote(value, safe="$:!?/()'*@")
param_string += encoded_value
else:
param_string += quote(value)
signature += quote(param_string)
key = quote(self.consumer_secret).encode() + b"&"
if self.access_token_secret is not None:
key += quote(self.access_token_secret).encode()
signature = hmac.new(key, signature.encode(), sha1)
signature = base64.b64encode(signature.digest()).decode().rstrip("\n")
return signature
[docs]class OAuth2Headers(PeonyHeaders):
"""
Dynamic headers implementing OAuth2
Parameters
----------
consumer_key : str
Your consumer key
consumer_secret : str
Your consumer secret
client : .client.BasePeonyClient
The client to authenticate
bearer_token : :obj:`str`, optional
Your bearer_token
**kwargs
Other headers
"""
def __init__(
self,
consumer_key,
consumer_secret,
client,
bearer_token=None,
compression=True,
user_agent=None,
headers=None,
):
super().__init__(compression, user_agent, headers)
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.client = client
self.basic_authorization = self.get_basic_authorization()
self._refreshing = asyncio.Event()
self._refreshing.clear()
if bearer_token is not None:
self.token = bearer_token
[docs] async def sign(self, url=None, *args, headers=None, **kwargs):
if url == self._invalidate_token.url():
del self.token
elif "Authorization" not in self:
await self.refresh_token()
return self._user_headers(headers)
[docs] def get_basic_authorization(self):
creds = quote(self.consumer_key), quote(self.consumer_secret)
keys = ":".join(creds).encode("utf-8")
auth = "Basic " + base64.b64encode(keys).decode("utf-8")
return {
"Authorization": auth,
"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8",
}
@property
def token(self):
print("setting token")
if "Authorization" in self:
return self["Authorization"][len("Bearer ") :]
@token.setter
def token(self, access_token):
self["Authorization"] = "Bearer " + access_token
@token.deleter
def token(self):
del self["Authorization"]
@property
def _invalidate_token(self):
return self.client["api", "", ""].oauth2.invalidate_token
[docs] async def invalidate_token(self):
if "Authorization" not in self:
raise RuntimeError("There is no token to invalidate")
token = self.token
try:
request = self._invalidate_token.post
data = RawFormData({"access_token": token}, quote_fields=False)
await request(_data=data, _headers=self.basic_authorization)
except Exception:
self.token = token
raise
[docs] async def refresh_token(self):
if self._refreshing.is_set():
return await self._refreshing.wait()
self._refreshing.set()
request = self.client["api", "", ""].oauth2.token.post
token = await request(
grant_type="client_credentials",
_headers=self.basic_authorization,
_oauth2_pass=True,
)
self.token = token["access_token"]
self._refreshing.clear()
[docs] async def prepare_request(self, *args, oauth2_pass=False, **kwargs):
"""
prepare all the arguments for the request
Parameters
----------
oauth2_pass : bool
For oauth2 authentication only (don't use it)
Returns
-------
dict
Parameters of the request correctly formatted
"""
if not oauth2_pass:
await self.sign()
return await super().prepare_request(*args, **kwargs)
[docs]class RawFormData(aiohttp.FormData):
def _gen_form_urlencoded(self):
def key(item):
return item[0]["name"]
data = ""
for type_options, _, value in sorted(self._fields, key=key):
if data:
data += "&"
data += "%s=%s" % (type_options["name"], value)
charset = self._charset if self._charset is not None else "utf-8"
content_type = "application/x-www-form-urlencoded;charset=" + charset
return aiohttp.payload.BytesPayload(data.encode(), content_type=content_type)