Source code for stravalib.protocol

"""Protocol
==============
Low-level classes for interacting directly with the Strava API webservers.
"""
from __future__ import annotations

import abc
import functools
import logging
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
from urllib.parse import urlencode, urljoin, urlunsplit

import requests

from stravalib import exc

if TYPE_CHECKING:
    from _typeshed import SupportsRead

Scope = Literal[
    "read",
    "read_all",
    "profile:read_all",
    "profile:write",
    "activity:read",
    "activity:read_all",
    "activity:write",
]


[docs]class AccessInfo(TypedDict): """Dictionary containing token exchange response from Strava.""" access_token: str """A short live token the access Strava API""" refresh_token: str """The refresh token for this user, to be used to get the next access token for this user. Please expect that this value can change anytime you retrieve a new access token. Once a new refresh token code has been returned, the older code will no longer work. """ expires_at: int """The number of seconds since the epoch when the provided access token will expire"""
[docs]class ApiV3(metaclass=abc.ABCMeta): """This class is responsible for performing the HTTP requests, rate limiting, and error handling.""" server = "www.strava.com" api_base = "/api/v3" def __init__( self, access_token: str | None = None, requests_session: requests.Session | None = None, rate_limiter: Callable[[dict[str, str]], None] | None = None, ): """Initialize this protocol client, optionally providing a (shared) :class:`requests.Session` object. Parameters ---------- access_token : str The token that provides access to a specific Strava account. requests_session : :class:`requests.Session` An existing :class:`requests.Session` object to use. """ self.log = logging.getLogger( "{0.__module__}.{0.__name__}".format(self.__class__) ) self.access_token = access_token if requests_session: self.rsession: requests.Session = requests_session else: self.rsession = requests.Session() self.rate_limiter = rate_limiter or (lambda _: None)
[docs] def authorization_url( self, client_id: int, redirect_uri: str, approval_prompt: Literal["auto", "force"] = "auto", scope: list[Scope] | Scope | None = None, state: str | None = None, ) -> str: """Get the URL needed to authorize your application to access a Strava user's information. See https://developers.strava.com/docs/authentication/ Parameters ---------- client_id : int The numeric developer client id. redirect_uri : str The URL that Strava will redirect to after successful (or failed) authorization. approval_prompt : str Whether to prompt for approval even if approval already granted to app. Choices are 'auto' or 'force'. (Default is 'auto') scope : list[str] The access scope required. Omit to imply "read" and "activity:read" Valid values are 'read', 'read_all', 'profile:read_all', 'profile:write', 'activity:read', 'activity:read_all', 'activity:write'. state : str An arbitrary variable that will be returned to your application in the redirect URI. Returns ------- str The URL to use for authorization link. """ assert approval_prompt in ("auto", "force") if scope is None: scope = ["read", "activity:read"] elif isinstance(scope, (str, bytes)): scope = [scope] unsupported = set(scope) - { "read", "read_all", "profile:read_all", "profile:write", "activity:read", "activity:read_all", "activity:write", } assert not unsupported, "Unsupported scope value(s): {}".format( unsupported ) params = { "client_id": client_id, "redirect_uri": redirect_uri, "approval_prompt": approval_prompt, "scope": ",".join(scope), "response_type": "code", } if state is not None: params["state"] = state return urlunsplit( ("https", self.server, "/oauth/authorize", urlencode(params), "") )
[docs] def exchange_code_for_token( self, client_id: int, client_secret: str, code: str ) -> AccessInfo: """Exchange the temporary authorization code (returned with redirect from Strava authorization URL) for a short-lived access token and a refresh token (used to obtain the next access token later on). Parameters ---------- client_id : int The numeric developer client id. client_secret : str The developer client secret code : str The temporary authorization code Returns ------- dict Dictionary containing the access_token, refresh_token and expires_at (number of seconds since Epoch when the provided access token will expire) """ response = self._request( f"https://{self.server}/oauth/token", params={ "client_id": client_id, "client_secret": client_secret, "code": code, "grant_type": "authorization_code", }, method="POST", ) access_info: AccessInfo = { "access_token": response["access_token"], "refresh_token": response["refresh_token"], "expires_at": response["expires_at"], } self.access_token = response["access_token"] return access_info
[docs] def refresh_access_token( self, client_id: int, client_secret: str, refresh_token: str ) -> AccessInfo: """Exchanges the previous refresh token for a short-lived access token and a new refresh token (used to obtain the next access token later on) Parameters ---------- client_id : int The numeric developer client id. client_secret : str The developer client secret refresh_token : str The refresh token obtain from a previous authorization request Returns ------- dict Dictionary containing the access_token, refresh_token and expires_at (number of seconds since Epoch when the provided access token will expire) """ response = self._request( f"https://{self.server}/oauth/token", params={ "client_id": client_id, "client_secret": client_secret, "refresh_token": refresh_token, "grant_type": "refresh_token", }, method="POST", ) access_info: AccessInfo = { "access_token": response["access_token"], "refresh_token": response["refresh_token"], "expires_at": response["expires_at"], } self.access_token = response["access_token"] return access_info
[docs] def resolve_url(self, url: str) -> str: """ Parameters ---------- url : str url string to be be accessed / resolved Returns ------- str A string representing the full properly formatted (https) url. """ if not url.startswith("http"): url = urljoin( f"https://{self.server}", self.api_base + "/" + url.strip("/"), ) return url
def _request( self, url: str, params: dict[str, Any] | None = None, files: dict[str, SupportsRead[str | bytes]] | None = None, method: Literal["GET", "POST", "PUT", "DELETE"] = "GET", check_for_errors: bool = True, ) -> Any: """Perform the underlying request, returning the parsed JSON results. Parameters ---------- url : str The request URL. params : Dict[str,Any] Request parameters files : Dict[str,file] Dictionary of file name to file-like objects. method : str The request method (GET/POST/etc.) check_for_errors : bool Whether to raise Returns ------- Dict[str,Any] The parsed JSON response. """ url = self.resolve_url(url) self.log.info( "{method} {url!r} with params {params!r}".format( method=method, url=url, params=params ) ) if params is None: params = {} if self.access_token: params["access_token"] = self.access_token methods = { "GET": self.rsession.get, "POST": functools.partial(self.rsession.post, files=files), "PUT": self.rsession.put, "DELETE": self.rsession.delete, } try: requester = methods[method.upper()] except KeyError: raise ValueError( "Invalid/unsupported request method specified: {}".format( method ) ) raw = requester(url, params=params) # type: ignore[operator] # Rate limits are taken from HTTP response headers # https://developers.strava.com/docs/rate-limits/ self.rate_limiter(raw.headers) if check_for_errors: self._handle_protocol_error(raw) # 204 = No content if raw.status_code in [204]: resp = {} else: resp = raw.json() return resp def _handle_protocol_error( self, response: requests.Response ) -> requests.Response: """Parses the raw response from the server, raising a :class:`stravalib.exc.Fault` if the server returned an error. Parameters ---------- response The response object. Raises ------ Fault If the response contains an error. """ error_str = None try: json_response = response.json() except ValueError: pass else: if "message" in json_response or "errors" in json_response: error_str = "{}: {}".format( json_response.get("message", "Undefined error"), json_response.get("errors"), ) # Special subclasses for some errors if response.status_code == 404: msg = "{}: {}".format(response.reason, error_str) raise exc.ObjectNotFound(msg, response=response) elif response.status_code == 401: msg = "{}: {}".format(response.reason, error_str) raise exc.AccessUnauthorized(msg, response=response) elif 400 <= response.status_code < 500: msg = "{} Client Error: {} [{}]".format( response.status_code, response.reason, error_str, ) raise exc.Fault(msg, response=response) elif 500 <= response.status_code < 600: msg = "{} Server Error: {} [{}]".format( response.status_code, response.reason, error_str, ) raise exc.Fault(msg, response=response) elif error_str: msg = error_str raise exc.Fault(msg, response=response) return response def _extract_referenced_vars(self, s: str) -> list[str]: """Utility method to find the referenced format variables in a string. (Assumes string.format() format vars.) Parameters ---------- s The string that contains format variables. (e.g. "{foo}-text") Returns ------- list The list of referenced variable names. (e.g. ['foo']) """ d: dict[str, int] = {} while True: try: s.format(**d) except KeyError as exc: # exc.args[0] contains the name of the key that was not found; # 0 is used because it appears to work with all types of # placeholders. d[exc.args[0]] = 0 else: break return list(d.keys())
[docs] def get( self, url: str, check_for_errors: bool = True, **kwargs: Any ) -> Any: """Performs a generic GET request for specified params, returning the response. Parameters ---------- url : str String representing the url to retrieve check-for_errors: bool (default = True) Flag used to raise an error (or not) Returns ------- dict Performs the request and returns a JSON object deserialized as dict """ referenced = self._extract_referenced_vars(url) url = url.format(**kwargs) params = {k: v for k, v in kwargs.items() if k not in referenced} return self._request( url, params=params, check_for_errors=check_for_errors )
[docs] def post( self, url: str, files: dict[str, SupportsRead[str | bytes]] | None = None, check_for_errors: bool = True, **kwargs: Any, ) -> Any: """Performs a generic POST request for specified params, returning the response. Parameters ---------- url : str Url string to be requested. files: dict Dictionary of file name to file-like objects. Used by _requests check_for_errors: bool Whether to raise an error (or not) Returns ------- Deserialized request output. """ referenced = self._extract_referenced_vars(url) url = url.format(**kwargs) params = {k: v for k, v in kwargs.items() if k not in referenced} return self._request( url, params=params, files=files, method="POST", check_for_errors=check_for_errors, )
[docs] def put( self, url: str, check_for_errors: bool = True, **kwargs: Any ) -> Any: """Performs a generic PUT request for specified params, returning the response. Parameters ---------- url : str String representing url to access. check_for_errors: bool Whether to raise an error (or not) Returns ------- Replaces current online content with new content. """ referenced = self._extract_referenced_vars(url) url = url.format(**kwargs) params = {k: v for k, v in kwargs.items() if k not in referenced} return self._request( url, params=params, method="PUT", check_for_errors=check_for_errors )
[docs] def delete( self, url: str, check_for_errors: bool = True, **kwargs: Any ) -> Any: """Performs a generic DELETE request for specified params, returning the response. Parameters ---------- url : str String representing url to access. check_for_errors: bool Whether to raise an error (or not) Returns ------- Deletes specified current online content. """ referenced = self._extract_referenced_vars(url) url = url.format(**kwargs) params = {k: v for k, v in kwargs.items() if k not in referenced} return self._request( url, params=params, method="DELETE", check_for_errors=check_for_errors, )