Source code for peony.oauth

# -*- coding: utf-8 -*-

import asyncio
import base64
import hmac
import random
import string
import time
import urllib.parse
from abc import ABC, abstractmethod
from hashlib import sha1

import aiohttp

from . import __version__, utils


[docs]def quote(s): return urllib.parse.quote(s, safe="")
[docs]class PeonyHeaders(ABC, dict): """ Dynamic headers for Peony This is the base class of :class:`OAuth1Headers` and :class:`OAuth2Headers`. Parameters ---------- compression : bool, optional If set to True the client will be able to receive compressed responses else it should not happen unless you provide the corresponding header when you make a request. Defaults to True. user_agent : str, optional The user agent set in the headers. Defaults to "peony v{version number}" headers : dict dict containing custom headers """ def __init__(self, compression=True, user_agent=None, headers=None): """Add a nice User-Agent""" super().__init__() if user_agent is None: self["User-Agent"] = "peony v" + __version__ else: self["User-Agent"] = user_agent if compression: self["Accept-Encoding"] = "deflate, gzip" if headers is not None: for key, value in headers.items(): self[key] = value def __setitem__(self, key, value): super().__setitem__(key.title(), value)
[docs] async def prepare_request( self, method, url, headers=None, skip_params=False, proxy=None, **kwargs ): """ prepare all the arguments for the request Parameters ---------- method : str HTTP method used by the request url : str The url to request headers : dict, optional Additionnal headers proxy : str proxy of the request skip_params : bool Don't use the parameters to sign the request Returns ------- dict Parameters of the request correctly formatted """ if method.lower() == "post": key = "data" else: key = "params" if key in kwargs and not skip_params: request_params = {key: kwargs.pop(key)} else: request_params = {} request_params.update(dict(method=method.upper(), url=url)) coro = self.sign(**request_params, skip_params=skip_params, headers=headers) request_params["headers"] = await utils.execute(coro) request_params["proxy"] = proxy kwargs.update(request_params) return kwargs
def _user_headers(self, headers=None): """Make sure the user doesn't override the Authorization header""" h = self.copy() if headers is not None: keys = set(headers.keys()) if h.get("Authorization", False): keys -= {"Authorization"} for key in keys: h[key] = headers[key] return h
[docs] @abstractmethod def sign(self, *args, headers=None, **kwargs): """ sign, that is, generate the `Authorization` headers before making a request """
[docs]class OAuth1Headers(PeonyHeaders): """ Dynamic headers implementing OAuth1 Parameters ---------- consumer_key : str Your consumer key consumer_secret : str Your consumer secret access_token : str Your access token access_token_secret : str Your access token secret **kwargs Other headers """ def __init__( self, consumer_key, consumer_secret, access_token=None, access_token_secret=None, compression=True, user_agent=None, headers=None, ): super().__init__(compression, user_agent, headers) self.consumer_key = consumer_key self.consumer_secret = consumer_secret self.access_token = access_token self.access_token_secret = access_token_secret self.alphabet = string.ascii_letters + string.digits @staticmethod def _default_content_type(skip_params): if skip_params: return "application/octet-stream" else: return "application/x-www-form-urlencoded"
[docs] def sign( self, method="GET", url=None, data=None, params=None, skip_params=False, headers=None, **kwargs ): headers = self._user_headers(headers) if data: if "Content-Type" not in headers: default = self._default_content_type(skip_params) headers["Content-Type"] = default params = data.copy() elif params: params = params.copy() oauth = { "oauth_consumer_key": self.consumer_key, "oauth_nonce": self.gen_nonce(), "oauth_signature_method": "HMAC-SHA1", "oauth_timestamp": str(int(time.time())), "oauth_version": "1.0", } if self.access_token is not None: oauth["oauth_token"] = self.access_token oauth["oauth_signature"] = self.gen_signature( method=method, url=url, params=params, skip_params=skip_params, oauth=oauth ) headers["Authorization"] = "OAuth " for key, value in sorted(oauth.items(), key=lambda i: i[0]): if len(headers["Authorization"]) > len("OAuth "): headers["Authorization"] += ", " headers["Authorization"] += quote(key) + '="' + quote(value) + '"' return headers
[docs] def gen_nonce(self): return "".join(random.choice(self.alphabet) for i in range(32))
[docs] def gen_signature(self, method, url, params, skip_params, oauth): signature = method.upper() + "&" + quote(url) + "&" if params is None or skip_params: params = oauth else: params.update(oauth) param_string = "" for key, value in sorted(params.items(), key=lambda i: i[0]): if param_string: param_string += "&" param_string += quote(key) + "=" if key == "q": encoded_value = urllib.parse.quote(value, safe="$:!?/()'*@") param_string += encoded_value else: param_string += quote(value) signature += quote(param_string) key = quote(self.consumer_secret).encode() + b"&" if self.access_token_secret is not None: key += quote(self.access_token_secret).encode() signature = hmac.new(key, signature.encode(), sha1) signature = base64.b64encode(signature.digest()).decode().rstrip("\n") return signature
[docs]class OAuth2Headers(PeonyHeaders): """ Dynamic headers implementing OAuth2 Parameters ---------- consumer_key : str Your consumer key consumer_secret : str Your consumer secret client : .client.BasePeonyClient The client to authenticate bearer_token : :obj:`str`, optional Your bearer_token **kwargs Other headers """ def __init__( self, consumer_key, consumer_secret, client, bearer_token=None, compression=True, user_agent=None, headers=None, ): super().__init__(compression, user_agent, headers) self.consumer_key = consumer_key self.consumer_secret = consumer_secret self.client = client self.basic_authorization = self.get_basic_authorization() self._refreshing = asyncio.Event() self._refreshing.clear() if bearer_token is not None: self.token = bearer_token
[docs] async def sign(self, url=None, *args, headers=None, **kwargs): if url == self._invalidate_token.url(): del self.token elif "Authorization" not in self: await self.refresh_token() return self._user_headers(headers)
[docs] def get_basic_authorization(self): creds = quote(self.consumer_key), quote(self.consumer_secret) keys = ":".join(creds).encode("utf-8") auth = "Basic " + base64.b64encode(keys).decode("utf-8") return { "Authorization": auth, "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8", }
@property def token(self): print("setting token") if "Authorization" in self: return self["Authorization"][len("Bearer ") :] @token.setter def token(self, access_token): self["Authorization"] = "Bearer " + access_token @token.deleter def token(self): del self["Authorization"] @property def _invalidate_token(self): return self.client["api", "", ""].oauth2.invalidate_token
[docs] async def invalidate_token(self): if "Authorization" not in self: raise RuntimeError("There is no token to invalidate") token = self.token try: request = self._invalidate_token.post data = RawFormData({"access_token": token}, quote_fields=False) await request(_data=data, _headers=self.basic_authorization) except Exception: self.token = token raise
[docs] async def refresh_token(self): if self._refreshing.is_set(): return await self._refreshing.wait() self._refreshing.set() request = self.client["api", "", ""].oauth2.token.post token = await request( grant_type="client_credentials", _headers=self.basic_authorization, _oauth2_pass=True, ) self.token = token["access_token"] self._refreshing.clear()
[docs] async def prepare_request(self, *args, oauth2_pass=False, **kwargs): """ prepare all the arguments for the request Parameters ---------- oauth2_pass : bool For oauth2 authentication only (don't use it) Returns ------- dict Parameters of the request correctly formatted """ if not oauth2_pass: await self.sign() return await super().prepare_request(*args, **kwargs)
[docs]class RawFormData(aiohttp.FormData): def _gen_form_urlencoded(self): def key(item): return item[0]["name"] data = "" for type_options, _, value in sorted(self._fields, key=key): if data: data += "&" data += "%s=%s" % (type_options["name"], value) charset = self._charset if self._charset is not None else "utf-8" content_type = "application/x-www-form-urlencoded;charset=" + charset return aiohttp.payload.BytesPayload(data.encode(), content_type=content_type)