Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing: Improve FixtureDefinition and FixtureDef #13036

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dependencies = [
"packaging",
"pluggy>=1.5,<2",
"tomli>=1; python_version<'3.11'",
"typing-extensions; python_version<'3.10'",
]
optional-dependencies.dev = [
"argcomplete",
Expand Down
61 changes: 39 additions & 22 deletions src/_pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
from typing_extensions import TypeAlias
else:
from typing import ParamSpec
from typing import TypeAlias

Check warning on line 83 in src/_pytest/fixtures.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/fixtures.py#L82-L83

Added lines #L82 - L83 were not covered by tests


if TYPE_CHECKING:
from _pytest.python import CallSpec2
Expand All @@ -84,14 +91,20 @@

# The value of the fixture -- return/yield of the fixture function (type variable).
FixtureValue = TypeVar("FixtureValue")
# The type of the fixture function (type variable).
FixtureFunction = TypeVar("FixtureFunction", bound=Callable[..., object])
# The type of a fixture function (type alias generic in fixture value).
_FixtureFunc = Union[
Callable[..., FixtureValue], Callable[..., Generator[FixtureValue]]

# The parameters that a fixture function receives.
FixtureParams = ParamSpec("FixtureParams")

# A dict of fixture name -> its FixtureDef.
FixtureDefDict: TypeAlias = dict[str, "FixtureDef[Any]"]

# The type of fixture function (type alias generic in fixture params and value).
_FixtureFunc: TypeAlias = Union[
Callable[FixtureParams, FixtureValue],
Callable[FixtureParams, Generator[FixtureValue, None, None]],
]
# The type of FixtureDef.cached_result (type alias generic in fixture value).
_FixtureCachedResult = Union[
_FixtureCachedResult: TypeAlias = Union[
tuple[
# The result.
FixtureValue,
Expand Down Expand Up @@ -360,7 +373,7 @@
pyfuncitem: Function,
fixturename: str | None,
arg2fixturedefs: dict[str, Sequence[FixtureDef[Any]]],
fixture_defs: dict[str, FixtureDef[Any]],
fixture_defs: FixtureDefDict,
*,
_ispytest: bool = False,
) -> None:
Expand Down Expand Up @@ -886,7 +899,9 @@


def call_fixture_func(
fixturefunc: _FixtureFunc[FixtureValue], request: FixtureRequest, kwargs
fixturefunc: _FixtureFunc[FixtureParams, FixtureValue],
request: FixtureRequest,
kwargs: FixtureParams.kwargs,
) -> FixtureValue:
if inspect.isgeneratorfunction(fixturefunc):
fixturefunc = cast(Callable[..., Generator[FixtureValue]], fixturefunc)
Expand Down Expand Up @@ -957,7 +972,7 @@
config: Config,
baseid: str | None,
argname: str,
func: _FixtureFunc[FixtureValue],
func: _FixtureFunc[Any, FixtureValue],
scope: Scope | _ScopeName | Callable[[str, Config], _ScopeName] | None,
params: Sequence[object] | None,
ids: tuple[object | None, ...] | Callable[[Any], object | None] | None = None,
Expand Down Expand Up @@ -1113,7 +1128,7 @@

def resolve_fixture_function(
fixturedef: FixtureDef[FixtureValue], request: FixtureRequest
) -> _FixtureFunc[FixtureValue]:
) -> _FixtureFunc[Any, FixtureValue]:
"""Get the actual callable that can be called to obtain the fixture
value."""
fixturefunc = fixturedef.func
Expand Down Expand Up @@ -1192,7 +1207,9 @@
def __post_init__(self, _ispytest: bool) -> None:
check_ispytest(_ispytest)

def __call__(self, function: FixtureFunction) -> FixtureFunctionDefinition:
def __call__(
self, function: Callable[FixtureParams, FixtureValue]
) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]:
if inspect.isclass(function):
raise ValueError("class fixtures not supported (maybe in the future)")

Expand All @@ -1219,12 +1236,10 @@
return fixture_definition


# TODO: paramspec/return type annotation tracking and storing
class FixtureFunctionDefinition:
class FixtureFunctionDefinition(Generic[FixtureParams, FixtureValue]):
def __init__(
self,
*,
function: Callable[..., Any],
function: Callable[FixtureParams, FixtureValue],
fixture_function_marker: FixtureFunctionMarker,
instance: object | None = None,
_ispytest: bool = False,
Expand All @@ -1237,7 +1252,7 @@
self._fixture_function_marker = fixture_function_marker
if instance is not None:
self._fixture_function = cast(
Callable[..., Any], function.__get__(instance)
Callable[FixtureParams, FixtureValue], function.__get__(instance)
)
else:
self._fixture_function = function
Expand All @@ -1246,7 +1261,9 @@
def __repr__(self) -> str:
return f"<pytest_fixture({self._fixture_function})>"

def __get__(self, instance, owner=None):
def __get__(
self, instance: object, owner: type | None = None
) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]:
"""Behave like a method if the function it was applied to was a method."""
return FixtureFunctionDefinition(
function=self._fixture_function,
Expand All @@ -1270,14 +1287,14 @@

@overload
def fixture(
fixture_function: Callable[..., object],
fixture_function: Callable[FixtureParams, FixtureValue],
*,
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
params: Iterable[object] | None = ...,
autouse: bool = ...,
ids: Sequence[object | None] | Callable[[Any], object | None] | None = ...,
name: str | None = ...,
) -> FixtureFunctionDefinition: ...
) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]: ...


@overload
Expand All @@ -1293,14 +1310,14 @@


def fixture(
fixture_function: FixtureFunction | None = None,
fixture_function: Callable[FixtureParams, FixtureValue] | None = None,
*,
scope: _ScopeName | Callable[[str, Config], _ScopeName] = "function",
params: Iterable[object] | None = None,
autouse: bool = False,
ids: Sequence[object | None] | Callable[[Any], object | None] | None = None,
name: str | None = None,
) -> FixtureFunctionMarker | FixtureFunctionDefinition:
) -> FixtureFunctionMarker | FixtureFunctionDefinition[FixtureParams, FixtureValue]:
"""Decorator to mark a fixture factory function.

This decorator can be used, with or without parameters, to define a
Expand Down Expand Up @@ -1688,7 +1705,7 @@
self,
*,
name: str,
func: _FixtureFunc[object],
func: _FixtureFunc[Any, object],
nodeid: str | None,
scope: Scope | _ScopeName | Callable[[str, Config], _ScopeName] = "function",
params: Sequence[object] | None = None,
Expand Down
5 changes: 3 additions & 2 deletions src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import FixtureDefDict
from _pytest.fixtures import FixtureRequest
from _pytest.fixtures import FuncFixtureInfo
from _pytest.fixtures import get_scope_node
Expand Down Expand Up @@ -1085,7 +1086,7 @@ def get_direct_param_fixture_func(request: FixtureRequest) -> Any:


# Used for storing pseudo fixturedefs for direct parametrization.
name2pseudofixturedef_key = StashKey[dict[str, FixtureDef[Any]]]()
name2pseudofixturedef_key = StashKey[FixtureDefDict]()


@final
Expand Down Expand Up @@ -1271,7 +1272,7 @@ def parametrize(
if node is None:
name2pseudofixturedef = None
else:
default: dict[str, FixtureDef[Any]] = {}
default: FixtureDefDict = {}
name2pseudofixturedef = node.stash.setdefault(
name2pseudofixturedef_key, default
)
Expand Down
Loading