Skip to content

Commit

Permalink
Improve typing on Windows (#2803)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreasBackx authored Dec 22, 2024
1 parent 4ffa1ef commit 45f89cd
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/click/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,14 @@ def auto_wrap_for_ansi(stream: t.TextIO, color: bool | None = None) -> t.TextIO:
rv = t.cast(t.TextIO, ansi_wrapper.stream)
_write = rv.write

def _safe_write(s):
def _safe_write(s: str) -> int:
try:
return _write(s)
except BaseException:
ansi_wrapper.reset_all()
raise

rv.write = _safe_write
rv.write = _safe_write # type: ignore[method-assign]

try:
_ansi_stream_wrappers[stream] = rv
Expand Down
7 changes: 3 additions & 4 deletions src/click/_termui_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def _translate_ch_to_exc(ch: str) -> None:
return None


if WIN:
if sys.platform == "win32":
import msvcrt

@contextlib.contextmanager
Expand Down Expand Up @@ -711,12 +711,11 @@ def getchar(echo: bool) -> str:
#
# Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
# is doing the right thing in more situations than with `getch`.
func: t.Callable[[], str]

if echo:
func = msvcrt.getwche # type: ignore
func = t.cast(t.Callable[[], str], msvcrt.getwche)
else:
func = msvcrt.getwch # type: ignore
func = t.cast(t.Callable[[], str], msvcrt.getwch)

rv = func()

Expand Down
55 changes: 34 additions & 21 deletions src/click/_winconsole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys
import time
import typing as t
from ctypes import Array
from ctypes import byref
from ctypes import c_char
from ctypes import c_char_p
Expand Down Expand Up @@ -67,6 +68,14 @@
EOF = b"\x1a"
MAX_BYTES_WRITTEN = 32767

if t.TYPE_CHECKING:
try:
# Using `typing_extensions.Buffer` instead of `collections.abc`
# on Windows for some reason does not have `Sized` implemented.
from collections.abc import Buffer # type: ignore
except ImportError:
from typing_extensions import Buffer

try:
from ctypes import pythonapi
except ImportError:
Expand All @@ -93,32 +102,32 @@ class Py_buffer(Structure):
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
PyBuffer_Release = pythonapi.PyBuffer_Release

def get_buffer(obj, writable=False):
def get_buffer(obj: Buffer, writable: bool = False) -> Array[c_char]:
buf = Py_buffer()
flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
flags: int = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
PyObject_GetBuffer(py_object(obj), byref(buf), flags)

try:
buffer_type = c_char * buf.len
buffer_type: Array[c_char] = c_char * buf.len
return buffer_type.from_address(buf.buf)
finally:
PyBuffer_Release(byref(buf))


class _WindowsConsoleRawIOBase(io.RawIOBase):
def __init__(self, handle):
def __init__(self, handle: int | None) -> None:
self.handle = handle

def isatty(self):
def isatty(self) -> t.Literal[True]:
super().isatty()
return True


class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
def readable(self):
def readable(self) -> t.Literal[True]:
return True

def readinto(self, b):
def readinto(self, b: Buffer) -> int:
bytes_to_be_read = len(b)
if not bytes_to_be_read:
return 0
Expand Down Expand Up @@ -150,18 +159,18 @@ def readinto(self, b):


class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
def writable(self):
def writable(self) -> t.Literal[True]:
return True

@staticmethod
def _get_error_message(errno):
def _get_error_message(errno: int) -> str:
if errno == ERROR_SUCCESS:
return "ERROR_SUCCESS"
elif errno == ERROR_NOT_ENOUGH_MEMORY:
return "ERROR_NOT_ENOUGH_MEMORY"
return f"Windows error {errno}"

def write(self, b):
def write(self, b: Buffer) -> int:
bytes_to_be_written = len(b)
buf = get_buffer(b)
code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
Expand Down Expand Up @@ -209,7 +218,7 @@ def __getattr__(self, name: str) -> t.Any:
def isatty(self) -> bool:
return self.buffer.isatty()

def __repr__(self):
def __repr__(self) -> str:
return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"


Expand Down Expand Up @@ -267,16 +276,20 @@ def _get_windows_console_stream(
f: t.TextIO, encoding: str | None, errors: str | None
) -> t.TextIO | None:
if (
get_buffer is not None
and encoding in {"utf-16-le", None}
and errors in {"strict", None}
and _is_console(f)
get_buffer is None
or encoding not in {"utf-16-le", None}
or errors not in {"strict", None}
or not _is_console(f)
):
func = _stream_factories.get(f.fileno())
if func is not None:
b = getattr(f, "buffer", None)
return None

func = _stream_factories.get(f.fileno())
if func is None:
return None

b = getattr(f, "buffer", None)

if b is None:
return None
if b is None:
return None

return func(b)
return func(b)
12 changes: 9 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ commands = pre-commit run --all-files
[testenv:typing]
deps = -r requirements/typing.txt
commands =
mypy
pyright tests/typing
pyright --verifytypes click --ignoreexternal
mypy --platform linux
mypy --platform darwin
mypy --platform win32
pyright tests/typing --pythonplatform Linux
pyright tests/typing --pythonplatform Darwin
pyright tests/typing --pythonplatform Windows
pyright --verifytypes click --ignoreexternal --pythonplatform Linux
pyright --verifytypes click --ignoreexternal --pythonplatform Darwin
pyright --verifytypes click --ignoreexternal --pythonplatform Windows

[testenv:docs]
deps = -r requirements/docs.txt
Expand Down

0 comments on commit 45f89cd

Please sign in to comment.