Skip to content

Commit

Permalink
Make possible to use one-shot iterators as children nodes.
Browse files Browse the repository at this point in the history
Previously there were checks for generators specifically. We also support
one-off iterables that are not based on generators such as itertools.chain()
or any non-generator based object that implements __next__().

This regression was introduced in #56.

Based on #71.
  • Loading branch information
pelme committed Dec 15, 2024
1 parent ac4b0bd commit 66ab3b7
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 21 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## NEXT
- Fixed handling of non-generator iterables such as `itertools.chain()` as
children. Thanks to Aleksei Pirogov ([@astynax](https://github.com/astynax)).
[PR #72](https://github.com/pelme/htpy/pull/72).

## 24.10.1 - 2024-10-24
- Fix handling of Python keywords such as `<del>` in html2htpy. [PR #61](https://github.com/pelme/htpy/pull/61).

Expand Down
29 changes: 20 additions & 9 deletions htpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import keyword
import typing as t
from collections.abc import Callable, Generator, Iterable, Iterator
from collections.abc import Callable, Iterable, Iterator

from markupsafe import Markup as _Markup
from markupsafe import escape as _escape
Expand Down Expand Up @@ -297,14 +297,26 @@ def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:


def _validate_children(children: t.Any) -> None:
if isinstance(children, _KnownValidChildren):
# Non-lazy iterables:
# list and tuple are iterables and part of _KnownValidChildren. Since we
# know they can be consumed multiple times, we validate them recursively now
# rather than at render time to provide better error messages.
if isinstance(children, list | tuple):
for child in children: # pyright: ignore[reportUnknownVariableType]
_validate_children(child)
return

if isinstance(children, Iterable) and not isinstance(children, _KnownInvalidChildren):
for child in children: # pyright: ignore [reportUnknownVariableType]
_validate_children(child)
# bytes, bytearray etc:
# These are Iterable (part of _KnownValidChildren) but still not
# useful as a child node.
if isinstance(children, _KnownInvalidChildren):
raise TypeError(f"{children!r} is not a valid child element")

# Element, str, int and all other regular/valid types.
if isinstance(children, _KnownValidChildren):
return

# Arbitrary objects that are not valid children.
raise TypeError(f"{children!r} is not a valid child element")


Expand Down Expand Up @@ -487,15 +499,14 @@ def __html__(self) -> str: ...


_KnownInvalidChildren: UnionType = bytes | bytearray | memoryview

_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType]
_KnownValidChildren: UnionType = (
None
| BaseElement
| ContextProvider # pyright: ignore [reportMissingTypeArgument]
| ContextConsumer # pyright: ignore [reportMissingTypeArgument]
| Callable # pyright: ignore [reportMissingTypeArgument]
| str
| int
| Generator # pyright: ignore [reportMissingTypeArgument]
| _HasHtml
| Callable
| Iterable
)
67 changes: 55 additions & 12 deletions tests/test_children.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import re
import typing as t
from collections.abc import Iterator

import pytest
from markupsafe import Markup
Expand All @@ -16,13 +17,36 @@
from .conftest import Trace

if t.TYPE_CHECKING:
from collections.abc import Callable, Generator
from collections.abc import Callable

from htpy import Node

from .conftest import RenderFixture, TraceFixture


T = t.TypeVar("T")


class SingleShotIterator(t.Generic[T]):
def __init__(self, value: T) -> None:
self.value = value
self.consumed = False

def __iter__(self) -> SingleShotIterator[T]:
return self

def __next__(self) -> T:
if self.consumed:
raise StopIteration

self.consumed = True

return self.value


assert isinstance(SingleShotIterator("foo"), Iterator)


def test_void_element(render: RenderFixture) -> None:
result = input(name="foo")
assert_type(result, VoidElement)
Expand Down Expand Up @@ -88,12 +112,12 @@ def test_flatten_very_nested_children(render: RenderFixture) -> None:


def test_flatten_nested_generators(render: RenderFixture) -> None:
def cols() -> Generator[str, None, None]:
def cols() -> Iterator[str]:
yield "a"
yield "b"
yield "c"

def rows() -> Generator[Generator[str, None, None], None, None]:
def rows() -> Iterator[Iterator[str]]:
yield cols()
yield cols()
yield cols()
Expand All @@ -104,11 +128,21 @@ def rows() -> Generator[Generator[str, None, None], None, None]:


def test_generator_children(render: RenderFixture) -> None:
gen: Generator[Element, None, None] = (li[x] for x in ["a", "b"])
gen: Iterator[Element] = (li[x] for x in ["a", "b"])
result = ul[gen]
assert render(result) == ["<ul>", "<li>", "a", "</li>", "<li>", "b", "</li>", "</ul>"]


def test_non_generator_iterator(render: RenderFixture) -> None:
result = ul[SingleShotIterator("hello")]

assert render(result) == [
"<ul>",
"hello",
"</ul>",
]


def test_html_tag_with_doctype(render: RenderFixture) -> None:
result = html(foo="bar")["hello"]
assert render(result) == ["<!doctype html>", '<html foo="bar">', "hello", "</html>"]
Expand Down Expand Up @@ -137,7 +171,7 @@ def test_ignored(render: RenderFixture, ignored_value: t.Any) -> None:


def test_lazy_iter(render: RenderFixture, trace: TraceFixture) -> None:
def generate_list() -> Generator[Element, None, None]:
def generate_list() -> Iterator[Element]:
trace("before yield")
yield li("#a")
trace("after yield")
Expand Down Expand Up @@ -202,7 +236,7 @@ def test_safe_children(render: RenderFixture) -> None:


def test_nested_callable_generator(render: RenderFixture) -> None:
def func() -> Generator[str, None, None]:
def func() -> Iterator[str]:
return (x for x in "abc")

assert render(div[func]) == ["<div>", "a", "b", "c", "</div>"]
Expand Down Expand Up @@ -260,7 +294,19 @@ def test_invalid_child_direct(not_a_child: t.Any) -> None:


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_nested_iterable(not_a_child: t.Any) -> None:
def test_invalid_child_wrapped_in_list(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
div[[not_a_child]]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_wrapped_in_tuple(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
div[(not_a_child,)]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_nested_iterator(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
div[[not_a_child]]

Expand All @@ -276,14 +322,11 @@ def test_invalid_child_lazy_callable(not_a_child: t.Any, render: RenderFixture)


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_lazy_generator(not_a_child: t.Any, render: RenderFixture) -> None:
def test_invalid_child_lazy_iterator(not_a_child: t.Any, render: RenderFixture) -> None:
"""
Ensure proper exception is raised for lazily evaluated invalid children.
"""

def gen() -> t.Any:
yield not_a_child

element = div[gen()]
element = div[SingleShotIterator(not_a_child)]
with pytest.raises(TypeError, match="is not a valid child element"):
render(element)

0 comments on commit 66ab3b7

Please sign in to comment.