From 16ee779d7abeb77e91bc98e9e3ea78f98c3a5d76 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 8 Nov 2023 11:13:12 +0200 Subject: [PATCH] Add `pytest_django.DjangoAssertNumQueries` for typing purposes This allows typing the `django_assert_num_queries` and `django_assert_max_num_queries` fixtures. --- docs/helpers.rst | 18 ++++++++++++++++++ pytest_django/__init__.py | 3 ++- pytest_django/fixtures.py | 19 ++++++++++++++++--- tests/test_fixtures.py | 14 +++++++++----- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/docs/helpers.rst b/docs/helpers.rst index e2b49fee..b0fac00a 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -447,6 +447,15 @@ Example usage:: assert 'foo' in captured.captured_queries[0]['sql'] +If you use type annotations, you can annotate the fixture like this:: + + from pytest_django import DjangoAssertNumQueries + + def test_num_queries( + django_assert_num_queries: DjangoAssertNumQueries, + ): + ... + .. fixture:: django_assert_max_num_queries @@ -470,6 +479,15 @@ Example usage:: Item.objects.create('foo') Item.objects.create('bar') +If you use type annotations, you can annotate the fixture like this:: + + from pytest_django import DjangoAssertNumQueries + + def test_max_num_queries( + django_assert_max_num_queries: DjangoAssertNumQueries, + ): + ... + .. fixture:: django_capture_on_commit_callbacks diff --git a/pytest_django/__init__.py b/pytest_django/__init__.py index 0dc1c6d4..4e551d5b 100644 --- a/pytest_django/__init__.py +++ b/pytest_django/__init__.py @@ -5,12 +5,13 @@ __version__ = "unknown" -from .fixtures import DjangoCaptureOnCommitCallbacks +from .fixtures import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks from .plugin import DjangoDbBlocker __all__ = [ "__version__", + "DjangoAssertNumQueries", "DjangoCaptureOnCommitCallbacks", "DjangoDbBlocker", ] diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 03d96f2c..c6b6a1cf 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -601,12 +601,25 @@ def _live_server_helper(request: pytest.FixtureRequest) -> Generator[None, None, live_server._live_server_modified_settings.disable() +class DjangoAssertNumQueries(Protocol): + """The type of the `django_assert_num_queries` and + `django_assert_max_num_queries` fixtures.""" + + def __call__( + self, + num: int, + connection: Any | None = ..., + info: str | None = ..., + ) -> django.test.utils.CaptureQueriesContext: + pass # pragma: no cover + + @contextmanager def _assert_num_queries( config: pytest.Config, num: int, exact: bool = True, - connection=None, + connection: Any | None = None, info: str | None = None, ) -> Generator[django.test.utils.CaptureQueriesContext, None, None]: from django.test.utils import CaptureQueriesContext @@ -641,12 +654,12 @@ def _assert_num_queries( @pytest.fixture() -def django_assert_num_queries(pytestconfig: pytest.Config): +def django_assert_num_queries(pytestconfig: pytest.Config) -> DjangoAssertNumQueries: return partial(_assert_num_queries, pytestconfig) @pytest.fixture() -def django_assert_max_num_queries(pytestconfig: pytest.Config): +def django_assert_max_num_queries(pytestconfig: pytest.Config) -> DjangoAssertNumQueries: return partial(_assert_num_queries, pytestconfig, exact=False) diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 58fa7da9..20c907fc 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -18,7 +18,7 @@ from .helpers import DjangoPytester -from pytest_django import DjangoCaptureOnCommitCallbacks, DjangoDbBlocker +from pytest_django import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks, DjangoDbBlocker from pytest_django_test.app.models import Item @@ -91,7 +91,7 @@ def test_async_rf(async_rf: AsyncRequestFactory) -> None: @pytest.mark.django_db def test_django_assert_num_queries_db( request: pytest.FixtureRequest, - django_assert_num_queries, + django_assert_num_queries: DjangoAssertNumQueries, ) -> None: with nonverbose_config(request.config): with django_assert_num_queries(3): @@ -111,7 +111,7 @@ def test_django_assert_num_queries_db( @pytest.mark.django_db def test_django_assert_max_num_queries_db( request: pytest.FixtureRequest, - django_assert_max_num_queries, + django_assert_max_num_queries: DjangoAssertNumQueries, ) -> None: with nonverbose_config(request.config): with django_assert_max_num_queries(2): @@ -134,7 +134,9 @@ def test_django_assert_max_num_queries_db( @pytest.mark.django_db(transaction=True) def test_django_assert_num_queries_transactional_db( - request: pytest.FixtureRequest, transactional_db: None, django_assert_num_queries + request: pytest.FixtureRequest, + transactional_db: None, + django_assert_num_queries: DjangoAssertNumQueries, ) -> None: with nonverbose_config(request.config): with transaction.atomic(): @@ -187,7 +189,9 @@ def test_queries(django_assert_num_queries): @pytest.mark.django_db -def test_django_assert_num_queries_db_connection(django_assert_num_queries) -> None: +def test_django_assert_num_queries_db_connection( + django_assert_num_queries: DjangoAssertNumQueries, +) -> None: from django.db import connection with django_assert_num_queries(1, connection=connection):