diff --git a/docs/changelog.md b/docs/changelog.md
index 767566d..cfc362e 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -1,5 +1,9 @@
# Changelog
+## next
+- Added support for passing data between components via Context. See the [Usage
+docs](usage.md#passing-data-with-context) for more information. [PR #48](https://github.com/pelme/htpy/pull/48).
+
## 24.8.1 - 2024-08-16
- Added the `comment()` function to render HTML comments.
[Documentation](usage.md#html-comments) / [Issue
diff --git a/docs/usage.md b/docs/usage.md
index d75ca7e..f38d6fe 100644
--- a/docs/usage.md
+++ b/docs/usage.md
@@ -393,3 +393,78 @@ got a chunk: '
'
got a chunk: 'b'
got a chunk: ''
```
+
+## Passing Data with Context
+
+Usually, you pass data via regular function calls and arguments via your
+components. Contexts can be used to avoid having to pass the data manually
+between components. Contexts in htpy is conceptually similar to contexts in
+React.
+
+Using contexts in htpy involves:
+
+- Creating a context object with `my_context = Context(name[, *, default])` to
+define the type and optional default value of a context variable.
+- Using `my_context.provider(value, lambda: children)` to set the value of a context variable for a subtree.
+- Adding the `@my_context.consumer` decorator to a component that requires the
+context value. The decorator will add the context value as the first argument to the decorated function.
+
+A context value can be passed arbitrarily deep between components. It is
+possible to nest multiple context provider and different values can be used in
+different subtrees.
+
+The `Context` class is a generic and fully supports static type checking.
+
+The values are passed as part of the tree used to render components without
+using global state. It is safe to use contexts for lazy constructs such as
+callables and generators.
+
+### Example
+
+This example shows how context can be used to pass data between components:
+
+- `theme_context: Context[Theme] = Context("theme", default="light")` creates a
+context object that can later be used to define/retrieve the value. In this
+case, `"light"` acts as the default value if no other value is provided.
+- `theme_context.provider(value, lambda: subtree)` defines the value of the
+`theme_context` for the subtree. In this case the value is set to `"dark"` which
+overrides the default value.
+- The `sidebar` component uses the `@theme_context.consumer` decorator. This
+will make htpy pass the current context value as the first argument to the
+component function.
+- In this example, a `Theme` type is used to ensure that the correct types are
+used when providing the value as well as when it is consumed.
+
+```py
+from typing import Literal
+
+from htpy import Context, Node, div, h1
+
+Theme = Literal["light", "dark"]
+
+theme_context: Context[Theme] = Context("theme", default="light")
+
+
+def my_page() -> Node:
+ return theme_context.provider(
+ "dark",
+ lambda: div[
+ h1["Hello!"],
+ sidebar("The Sidebar!"),
+ ],
+ )
+
+
+@theme_context.consumer
+def sidebar(theme: Theme, title: str) -> Node:
+ return div(class_=f"theme-{theme}")[title]
+
+
+print(my_page())
+```
+
+Output:
+
+```html
+
+```
diff --git a/examples/context.py b/examples/context.py
new file mode 100644
index 0000000..a593daa
--- /dev/null
+++ b/examples/context.py
@@ -0,0 +1,25 @@
+from typing import Literal
+
+from htpy import Context, Node, div, h1
+
+Theme = Literal["light", "dark"]
+
+theme_context: Context[Theme] = Context("theme", default="light")
+
+
+def my_page() -> Node:
+ return theme_context.provider(
+ "dark",
+ lambda: div[
+ h1["Hello!"],
+ sidebar("The Sidebar!"),
+ ],
+ )
+
+
+@theme_context.consumer
+def sidebar(theme: Theme, title: str) -> Node:
+ return div(class_=f"theme-{theme}")[title]
+
+
+print(my_page())
diff --git a/htpy/__init__.py b/htpy/__init__.py
index 4001d92..3a2b49c 100644
--- a/htpy/__init__.py
+++ b/htpy/__init__.py
@@ -3,6 +3,7 @@
__version__ = "24.8.1"
__all__: list[str] = []
+import dataclasses
import functools
import typing as t
from collections.abc import Callable, Iterable, Iterator
@@ -109,7 +110,58 @@ def _attrs_string(attrs: dict[str, Attribute]) -> str:
return " " + result
+T = t.TypeVar("T")
+P = t.ParamSpec("P")
+
+
+@dataclasses.dataclass(frozen=True)
+class ContextProvider(t.Generic[T]):
+ context: Context[T]
+ value: T
+ func: Callable[[], Node]
+
+ def __iter__(self) -> Iterator[str]:
+ return iter_node(self)
+
+ def __str__(self) -> str:
+ return render_node(self)
+
+
+@dataclasses.dataclass(frozen=True)
+class ContextConsumer(t.Generic[T]):
+ context: Context[T]
+ debug_name: str
+ func: Callable[[T], Node]
+
+
+class _NO_DEFAULT:
+ pass
+
+
+class Context(t.Generic[T]):
+ def __init__(self, name: str, *, default: T | type[_NO_DEFAULT] = _NO_DEFAULT) -> None:
+ self.name = name
+ self.default = default
+
+ def provider(self, value: T, children_func: Callable[[], Node]) -> ContextProvider[T]:
+ return ContextProvider(self, value, children_func)
+
+ def consumer(
+ self,
+ func: Callable[t.Concatenate[T, P], Node],
+ ) -> Callable[P, ContextConsumer[T]]:
+ @functools.wraps(func)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextConsumer[T]:
+ return ContextConsumer(self, func.__name__, lambda value: func(value, *args, **kwargs))
+
+ return wrapper
+
+
def iter_node(x: Node) -> Iterator[str]:
+ return _iter_node_context(x, {})
+
+
+def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> Iterator[str]:
while not isinstance(x, BaseElement) and callable(x):
x = x()
@@ -123,12 +175,22 @@ def iter_node(x: Node) -> Iterator[str]:
return
if isinstance(x, BaseElement):
- yield from x
+ yield from x._iter_context(context_dict) # pyright: ignore [reportPrivateUsage]
+ elif isinstance(x, ContextProvider):
+ yield from _iter_node_context(x.func(), {**context_dict, x.context: x.value}) # pyright: ignore [reportUnknownMemberType]
+ elif isinstance(x, ContextConsumer):
+ context_value = context_dict.get(x.context, x.context.default)
+ if context_value is _NO_DEFAULT:
+ raise LookupError(
+ f'Context value for "{x.context.name}" does not exist, '
+ f"requested by {x.debug_name}()."
+ )
+ yield from _iter_node_context(x.func(context_value), context_dict)
elif isinstance(x, str | _HasHtml):
yield str(_escape(x))
elif isinstance(x, Iterable): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x:
- yield from iter_node(child)
+ yield from _iter_node_context(child, context_dict)
else:
raise ValueError(f"{x!r} is not a valid child element")
@@ -201,8 +263,11 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen
)
def __iter__(self) -> Iterator[str]:
+ return self._iter_context({})
+
+ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield f"<{self._name}{_attrs_string(self._attrs)}>"
- yield from iter_node(self._children)
+ yield from _iter_node_context(self._children, ctx)
yield f"{self._name}>"
def __repr__(self) -> str:
@@ -221,13 +286,13 @@ def __getitem__(self: ElementSelf, children: Node) -> ElementSelf:
class HTMLElement(Element):
- def __iter__(self) -> Iterator[str]:
+ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield ""
- yield from super().__iter__()
+ yield from super()._iter_context(ctx)
class VoidElement(BaseElement):
- def __iter__(self) -> Iterator[str]:
+ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield f"<{self._name}{_attrs_string(self._attrs)}>"
@@ -248,7 +313,15 @@ def __html__(self) -> str: ...
_ClassNamesDict: t.TypeAlias = dict[str, bool]
_ClassNames: t.TypeAlias = Iterable[str | None | bool | _ClassNamesDict] | _ClassNamesDict
Node: t.TypeAlias = (
- None | bool | str | BaseElement | _HasHtml | Iterable["Node"] | Callable[[], "Node"]
+ None
+ | bool
+ | str
+ | BaseElement
+ | _HasHtml
+ | Iterable["Node"]
+ | Callable[[], "Node"]
+ | ContextProvider[t.Any]
+ | ContextConsumer[t.Any]
)
Attribute: t.TypeAlias = None | bool | str | _HasHtml | _ClassNames
diff --git a/tests/test_context.py b/tests/test_context.py
new file mode 100644
index 0000000..77a88a2
--- /dev/null
+++ b/tests/test_context.py
@@ -0,0 +1,90 @@
+import typing as t
+
+import pytest
+
+from htpy import Context, Node, div
+
+letter_ctx: Context[t.Literal["a", "b", "c"]] = Context("letter", default="a")
+no_default_ctx = Context[str]("no_default")
+
+
+@letter_ctx.consumer
+def display_letter(letter: t.Literal["a", "b", "c"], greeting: str) -> str:
+ return f"{greeting}: {letter}!"
+
+
+@no_default_ctx.consumer
+def display_no_default(value: str) -> str:
+ return f"{value=}"
+
+
+def test_context_default() -> None:
+ result = div[display_letter("Yo")]
+ assert str(result) == "Yo: a!
"
+
+
+def test_context_provider() -> None:
+ result = letter_ctx.provider("c", lambda: div[display_letter("Hello")])
+ assert str(result) == "Hello: c!
"
+
+
+def test_no_default() -> None:
+ with pytest.raises(
+ LookupError,
+ match='Context value for "no_default" does not exist, requested by display_no_default()',
+ ):
+ str(div[display_no_default()])
+
+
+def test_nested_override() -> None:
+ result = div[
+ letter_ctx.provider(
+ "b",
+ lambda: letter_ctx.provider(
+ "c",
+ lambda: display_letter("Nested"),
+ ),
+ )
+ ]
+ assert str(result) == "Nested: c!
"
+
+
+def test_multiple_consumers() -> None:
+ a_ctx: Context[t.Literal["a"]] = Context("a_ctx", default="a")
+ b_ctx: Context[t.Literal["b"]] = Context("b_ctx", default="b")
+
+ @b_ctx.consumer
+ @a_ctx.consumer
+ def ab_display(a: t.Literal["a"], b: t.Literal["b"], greeting: str) -> str:
+ return f"{greeting} a={a}, b={b}"
+
+ result = div[ab_display("Hello")]
+ assert str(result) == "Hello a=a, b=b
"
+
+
+def test_nested_consumer() -> None:
+ ctx: Context[str] = Context("ctx")
+
+ @ctx.consumer
+ def outer(value: str) -> Node:
+ return inner(value)
+
+ @ctx.consumer
+ def inner(value: str, from_outer: str) -> Node:
+ return f"outer: {from_outer}, inner: {value}"
+
+ result = div[ctx.provider("foo", outer)]
+
+ assert str(result) == "outer: foo, inner: foo
"
+
+
+def test_context_passed_via_iterable() -> None:
+ ctx: Context[str] = Context("ctx")
+
+ @ctx.consumer
+ def echo(value: str) -> str:
+ return value
+
+ result = div[ctx.provider("foo", lambda: [echo()])]
+
+ assert str(result) == "foo
"