diff --git a/docs/helpers.rst b/docs/helpers.rst index 7dfd8e2c..e2b49fee 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -514,6 +514,14 @@ Example usage:: assert mailoutbox[0].subject == 'Contact Form' assert mailoutbox[0].body == 'I like your site' +If you use type annotations, you can annotate the fixture like this:: + + from pytest_django import DjangoCaptureOnCommitCallbacks + + def test_on_commit( + django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks, + ): + ... .. fixture:: mailoutbox diff --git a/pytest_django/__init__.py b/pytest_django/__init__.py index f99bffa5..0dc1c6d4 100644 --- a/pytest_django/__init__.py +++ b/pytest_django/__init__.py @@ -5,10 +5,12 @@ __version__ = "unknown" +from .fixtures import DjangoCaptureOnCommitCallbacks from .plugin import DjangoDbBlocker __all__ = [ "__version__", + "DjangoCaptureOnCommitCallbacks", "DjangoDbBlocker", ] diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 8a87bb20..03d96f2c 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -7,12 +7,15 @@ from typing import ( TYPE_CHECKING, Any, + Callable, ClassVar, + ContextManager, Generator, Iterable, List, Literal, Optional, + Protocol, Tuple, Union, ) @@ -647,8 +650,20 @@ def django_assert_max_num_queries(pytestconfig: pytest.Config): return partial(_assert_num_queries, pytestconfig, exact=False) +class DjangoCaptureOnCommitCallbacks(Protocol): + """The type of the `django_capture_on_commit_callbacks` fixture.""" + + def __call__( + self, + *, + using: str = ..., + execute: bool = ..., + ) -> ContextManager[list[Callable[[], Any]]]: + pass # pragma: no cover + + @pytest.fixture() -def django_capture_on_commit_callbacks(): +def django_capture_on_commit_callbacks() -> DjangoCaptureOnCommitCallbacks: from django.test import TestCase - return TestCase.captureOnCommitCallbacks + return TestCase.captureOnCommitCallbacks # type: ignore[no-any-return] diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 7f4084f0..58fa7da9 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -18,7 +18,7 @@ from .helpers import DjangoPytester -from pytest_django import DjangoDbBlocker +from pytest_django import DjangoCaptureOnCommitCallbacks, DjangoDbBlocker from pytest_django_test.app.models import Item @@ -232,7 +232,9 @@ def test_queries(django_assert_num_queries): @pytest.mark.django_db -def test_django_capture_on_commit_callbacks(django_capture_on_commit_callbacks) -> None: +def test_django_capture_on_commit_callbacks( + django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks, +) -> None: if not connection.features.supports_transactions: pytest.skip("transactions required for this test") @@ -255,7 +257,9 @@ def test_django_capture_on_commit_callbacks(django_capture_on_commit_callbacks) @pytest.mark.django_db(databases=["default", "second"]) -def test_django_capture_on_commit_callbacks_multidb(django_capture_on_commit_callbacks) -> None: +def test_django_capture_on_commit_callbacks_multidb( + django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks, +) -> None: if not connection.features.supports_transactions: pytest.skip("transactions required for this test") @@ -282,7 +286,7 @@ def test_django_capture_on_commit_callbacks_multidb(django_capture_on_commit_cal @pytest.mark.django_db(transaction=True) def test_django_capture_on_commit_callbacks_transactional( - django_capture_on_commit_callbacks, + django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks, ) -> None: if not connection.features.supports_transactions: pytest.skip("transactions required for this test")