diff --git a/pyproject.toml b/pyproject.toml index 0df45a5..dde3387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,5 +51,8 @@ version = { attr = "serpapi.__version__.__version__" } [tool.setuptools.packages.find] exclude = ["tests", "tests.*"] +[tool.setuptools.package-data] +"serpapi" = ["py.typed"] + [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/serpapi/core.py b/serpapi/core.py index 7d4454a..94e00da 100644 --- a/serpapi/core.py +++ b/serpapi/core.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Optional, Union + from .http import HTTPClient from .exceptions import SearchIDNotProvided from .models import SerpResults @@ -25,13 +27,13 @@ class Client(HTTPClient): DASHBOARD_URL = "https://serpapi.com/dashboard" - def __init__(self, *, api_key=None, timeout=None): + def __init__(self, *, api_key: Optional[str] = None, timeout: Optional[float] = None) -> None: super().__init__(api_key=api_key, timeout=timeout) def __repr__(self): return "" - def search(self, params: dict = None, **kwargs): + def search(self, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Union[SerpResults, str]: """Fetch a page of results from SerpApi. Returns a :class:`SerpResults ` object, or unicode text (*e.g.* if ``'output': 'html'`` was passed). The following three calls are equivalent: @@ -72,11 +74,11 @@ def search(self, params: dict = None, **kwargs): if kwargs: params.update(kwargs) - r = self.request("GET", "/search", params=params, **request_kwargs) + r = self.request("GET", "/search", params=params, assert_200=True, **request_kwargs) return SerpResults.from_http_response(r, client=self) - def search_archive(self, params: dict = None, **kwargs): + def search_archive(self, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Union[SerpResults, str]: """Get a result from the SerpApi Search Archive API. :param search_id: the Search ID of the search to retrieve from the archive. @@ -105,10 +107,10 @@ def search_archive(self, params: dict = None, **kwargs): f"Please provide 'search_id', found here: { self.DASHBOARD_URL }" ) - r = self.request("GET", f"/searches/{ search_id }", params=params, **request_kwargs) + r = self.request("GET", f"/searches/{ search_id }", params=params, assert_200=True, **request_kwargs) return SerpResults.from_http_response(r, client=self) - def locations(self, params: dict = None, **kwargs): + def locations(self, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any: """Get a list of supported Google locations. @@ -139,7 +141,7 @@ def locations(self, params: dict = None, **kwargs): ) return r.json() - def account(self, params: dict = None, **kwargs): + def account(self, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any: """Get SerpApi account information. :param api_key: the API Key to use for SerpApi.com. diff --git a/serpapi/exceptions.py b/serpapi/exceptions.py index c268c4d..14d1b5c 100644 --- a/serpapi/exceptions.py +++ b/serpapi/exceptions.py @@ -22,8 +22,8 @@ class SearchIDNotProvided(ValueError, SerpApiError): class HTTPError(requests.exceptions.HTTPError, SerpApiError): """HTTP Error.""" - def __init__(self, original_exception): - if (isinstance(original_exception, requests.exceptions.HTTPError)): + def __init__(self, original_exception: Exception) -> None: + if isinstance(original_exception, requests.exceptions.HTTPError): http_error_exception: requests.exceptions.HTTPError = original_exception self.status_code = http_error_exception.response.status_code @@ -35,7 +35,7 @@ def __init__(self, original_exception): self.status_code = -1 self.error = None - super().__init__(*original_exception.args, response=getattr(original_exception, 'response', None), request=getattr(original_exception, 'request', None)) + super().__init__(*original_exception.args, response=getattr(original_exception, "response", None), request=getattr(original_exception, "request", None)) diff --git a/serpapi/http.py b/serpapi/http.py index 16da0da..1e08e3a 100644 --- a/serpapi/http.py +++ b/serpapi/http.py @@ -1,4 +1,5 @@ import requests +from typing import Any, Dict, Optional from .exceptions import ( HTTPError, @@ -14,14 +15,14 @@ class HTTPClient: BASE_DOMAIN = "https://serpapi.com" USER_AGENT = f"serpapi-python, v{__version__}" - def __init__(self, *, api_key=None, timeout=None): + def __init__(self, *, api_key: Optional[str] = None, timeout: Optional[float] = None) -> None: # Used to authenticate requests. # TODO: do we want to support the environment variable? Seems like a security risk. self.api_key = api_key self.timeout = timeout self.session = requests.Session() - def request(self, method, path, params, *, assert_200=True, **kwargs): + def request(self, method: str, path: str, params: Dict[str, Any], *, assert_200: bool = True, **kwargs: Any) -> requests.Response: # Inject the API Key into the params. if "api_key" not in params: params["api_key"] = self.api_key @@ -59,7 +60,7 @@ def request(self, method, path, params, *, assert_200=True, **kwargs): return r -def raise_for_status(r): +def raise_for_status(r: requests.Response) -> None: """Raise an exception if the status code is not 200.""" # TODO: put custom behavior in here for various status codes. diff --git a/serpapi/models.py b/serpapi/models.py index 0aa7720..91d4a1e 100644 --- a/serpapi/models.py +++ b/serpapi/models.py @@ -1,13 +1,17 @@ import json +from typing import Any, Dict, Iterator, Optional, Union, TYPE_CHECKING -from pprint import pformat from collections import UserDict +import requests + from .textui import prettify_json -from .exceptions import HTTPError + +if TYPE_CHECKING: + from .core import Client -class SerpResults(UserDict): +class SerpResults(UserDict[str, Any]): """A dictionary-like object that represents the results of a SerpApi request. .. code-block:: python @@ -21,68 +25,71 @@ class SerpResults(UserDict): It can be used like a dictionary, but also has some additional methods. """ - def __init__(self, data, *, client): + def __init__(self, data: Dict[str, Any], *, client: Optional["Client"]) -> None: super().__init__(data) self.client = client - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return self.data - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self.data = state - def __repr__(self): + def __repr__(self) -> str: """The visual representation of the data, which is pretty printed, for ease of use. """ return prettify_json(json.dumps(self.data, indent=4)) - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: """Returns the data as a standard Python dictionary. This can be useful when using ``json.dumps(search), for example.""" return self.data.copy() @property - def next_page_url(self): + def next_page_url(self) -> Optional[str]: """The URL of the next page of results, if any.""" - serpapi_pagination = self.data.get("serpapi_pagination") + serpapi_pagination: Optional[Dict[str, Any]] = self.data.get("serpapi_pagination") if serpapi_pagination: - return serpapi_pagination.get("next") + next_url = serpapi_pagination.get("next") + return next_url if isinstance(next_url, str) else None + return None - def next_page(self): + def next_page(self) -> Optional[Union["SerpResults", str]]: """Return the next page of results, if any.""" - if self.next_page_url: + if self.next_page_url and self.client is not None: # Include support for the API key, as it is not included in the next page URL. params = {"api_key": self.client.api_key} r = self.client.request("GET", path=self.next_page_url, params=params) return SerpResults.from_http_response(r, client=self.client) - def yield_pages(self, max_pages=1_000): + return None + + def yield_pages(self, max_pages: int = 1_000) -> Iterator[Union["SerpResults", str]]: """A generator that ``yield`` s the next ``n`` pages of search results, if any. :param max_pages: limit the number of pages yielded to ``n``. """ current_page_count = 0 - - current_page = self + current_page: Union["SerpResults", str, None] = self while current_page and current_page_count < max_pages: yield current_page current_page_count += 1 - if current_page.next_page_url: + if isinstance(current_page, SerpResults) and current_page.next_page_url: current_page = current_page.next_page() else: break @classmethod - def from_http_response(cls, r, *, client=None): + def from_http_response(cls, r: requests.Response, *, client: Optional["Client"] = None) -> Union["SerpResults", str]: """Construct a SerpResults object from an HTTP response. :param assert_200: if ``True`` (default), raise an exception if the status code is not 200. @@ -93,9 +100,9 @@ def from_http_response(cls, r, *, client=None): """ try: - cls = cls(r.json(), client=client) + inst = cls(r.json(), client=client) - return cls + return inst except ValueError: # If the response is not JSON, return the raw text. return r.text diff --git a/serpapi/py.typed b/serpapi/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/serpapi/textui.py b/serpapi/textui.py index d82b788..57fa945 100644 --- a/serpapi/textui.py +++ b/serpapi/textui.py @@ -1,16 +1,9 @@ -try: - import pygments - from pygments import highlight, lexers, formatters -except ImportError: - pygments = None - - -def prettify_json(s): - if pygments: - return highlight( - s, - lexers.JsonLexer(), - formatters.TerminalFormatter(), - ) - else: +def prettify_json(s: str) -> str: + try: + from pygments import highlight + from pygments.lexers import get_lexer_by_name #type: ignore + from pygments.formatters import TerminalFormatter + except ImportError: return s + + return highlight(s, get_lexer_by_name("JSON"), TerminalFormatter()) \ No newline at end of file