Skip to content

Commit

Permalink
Add advanced source transformations to reduce type checking overhead
Browse files Browse the repository at this point in the history
The new 'munge' module performs transformations on the source code.
It uses the AST (abstract syntax tree) representation of Python code
to recognize some idioms such as `if STATIC_TYPING:` and transforms
them into alternatives that have zero overhead in mpy-compiled files
(e.g., `if STATIC_TYPING:` is transformed into `if 0:`, which is eliminated
at compile time due to mpy-cross constant-propagation and dead branch
elimination)

The code assumes the input file is black-formatted. In particular, it
would malfunction if an if-statement and its body are on the same line:
`if STATIC_TYPING: print("boo")` would be incorrectly munged.
  • Loading branch information
jepler committed Jun 17, 2024
1 parent 1f3c5bd commit 3002d23
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 30 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ jobs:
git clone --recurse-submodules https://github.com/adafruit/CircuitPython_Community_Bundle.git
cd CircuitPython_Community_Bundle
circuitpython-build-bundles --filename_prefix test-bundle --library_location libraries --library_depth 2
- name: Munge tests
run: pytest
- name: Build Python package
run: |
pip install --upgrade setuptools wheel twine readme_renderer testresources
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ version.py
.env/*
.DS_Store
.idea/*
testcases/*.out
47 changes: 18 additions & 29 deletions circuitpython_build_tools/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import subprocess
import tempfile

from .munge import munge

# pyproject.toml `py_modules` values that are incorrect. These should all have PRs filed!
# and should be removed when the fixed version is incorporated in its respective bundle.

Expand Down Expand Up @@ -170,16 +172,6 @@ def mpy_cross(mpy_cross_filename, circuitpython_tag, quiet=False):

shutil.copy("build_deps/circuitpython/mpy-cross/mpy-cross", mpy_cross_filename)

def _munge_to_temp(original_path, temp_file, library_version):
with open(original_path, "r", encoding="utf-8") as original_file:
for line in original_file:
line = line.strip("\n")
if line.startswith("__version__"):
line = line.replace("0.0.0-auto.0", library_version)
line = line.replace("0.0.0+auto.0", library_version)
print(line, file=temp_file)
temp_file.flush()

def get_package_info(library_path, package_folder_prefix):
lib_path = pathlib.Path(library_path)
parent_idx = len(lib_path.parts)
Expand Down Expand Up @@ -289,25 +281,22 @@ def library(library_path, output_directory, package_folder_prefix,
full_path = os.path.join(library_path, filename)
output_file = output_directory / filename.relative_to(library_path)
if filename.suffix == ".py":
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file:
temp_file_name = temp_file.name
try:
_munge_to_temp(full_path, temp_file, library_version)
temp_file.close()
if mpy_cross and os.stat(temp_file.name).st_size != 0:
output_file = output_file.with_suffix(".mpy")
mpy_success = subprocess.call([
mpy_cross,
"-o", output_file,
"-s", str(filename.relative_to(library_path)),
temp_file.name
])
if mpy_success != 0:
raise RuntimeError("mpy-cross failed on", full_path)
else:
shutil.copyfile(temp_file_name, output_file)
finally:
os.remove(temp_file_name)
content = munge(full_path, library_version)
if mpy_cross and content:
# TODO: Once 8.x bundles are no longer built, switch to
# sending mpy-cross the code on stdin instead of via
# temporary file (supports the "-" input argument)
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file:
temp_file.write(content)
temp_file.flush()
subprocess.check_output([
mpy_cross,
"-o", output_file.with_suffix(".mpy"),
"-s", str(filename.relative_to(library_path)),
temp_file.name
], input=content.encode('utf-8'))
else:
output_file.write_text(content, encoding="utf-8")
else:
shutil.copyfile(full_path, output_file)

Expand Down
117 changes: 117 additions & 0 deletions circuitpython_build_tools/munge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# The MIT License (MIT)
#
# Copyright (c) 2024 Jeff Epler for Adafruit Industries
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

# Filter program removes some code patterns introduced by type checking,
# to move towards zero overhead static typing in circuitpython libraries
#
# Recognized:
# from __future__ import ... -- eliminated
# try: import typing -- eliminated, but first except: preserved
# try: from typing import ... -- eliminated, but first except: preserved
# if STATIC_TYPING: -- transformed to 'if 0:'
# if sys.implementation_name... -- transformed to unconditional if
# __version__ = ... -- set to library version string
#
# mpy-cross does constant propagation and dead branch elimination of
# 'if 0:' and 'if 1:'
#
# Depends on the file being black-formatted!

import pathlib
import sys
import ast

VERBOSE = 0

# The canonical spelling of this test...
sys_implementation_is_circuitpython = ast.unparse(ast.parse('sys.implementation.name == "circuitpython"'))
sys_implementation_not_circuitpython = ast.unparse(ast.parse('not sys.implementation.name == "circuitpython"'))
sys_implementation_not_circuitpython2 = ast.unparse(ast.parse('sys.implementation.name != "circuitpython"'))

def munge(src: pathlib.Path|str, version_str: str) -> str:
path = pathlib.Path(src)
replacements = {}

def replace(line, new):
if VERBOSE:
replacements[line] = f"{new:<40s} ### {lines[line]}"
else:
replacements[line] = new

def blank_range(node):
for i in range(node.lineno, node.end_lineno+1):
replace(i, "")

def unblank_range(node):
for i in range(node.lineno, node.end_lineno+1):
replacements.pop(i, None)

def imports_from_typing(node):
if isinstance(node, ast.Import) and node.names[0].name == 'typing':
return True
if isinstance(node, ast.ImportFrom) and node.module == 'typing':
return True
return False

def process_statement(node):
# filter out 'from future import...'
if isinstance(node, ast.ImportFrom):
if node.module == '__future__':
blank_range(node)
# filter out 'try: import typing...'
# but preserve the first 'except:' or 'except ImportError'
elif isinstance(node, ast.Try):
b = node.body[0]
if imports_from_typing(node.body[0]):
blank_range(node)
for h in node.handlers:
if h.type is None or ast.unparse(h.type) == 'ImportError' or ast.unparse(h.type) == 'Exception':
unblank_range(h)
replace(h.lineno, 'if 1:')
break
return
elif isinstance(node, ast.If):
node_test = ast.unparse(node.test)
# return the statements in the 'if' branch of 'if sys.implementation...: ...'
if node_test == sys_implementation_is_circuitpython:
replace(node.lineno, 'if 1:')
# return the statements in the 'else' branch of 'if sys.implementation...: ...'
elif node_test == sys_implementation_not_circuitpython or node_test == sys_implementation_not_circuitpython2:
replace(node.lineno, 'if 0:')
# return the statements in the else branch of 'if TYPE_CHECKING: ...'
elif node_test == 'TYPE_CHECKING':
replace(node.lineno, 'if 0:')
elif isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name) and node.targets[0].id == '__version__':
replace(node.lineno, f"__version__ = \"{version_str}\"")

content = pathlib.Path(path).read_text(encoding="utf-8")
# Insert a blank line 0 because ast line numbers are 1-based
lines = [''] + content.rstrip().split('\n')
a = ast.parse(content, path.name)

for node in a.body: process_statement(node)

result = []
for i in range(1, len(lines)):
result.append(replacements.get(i, lines[i]))

return "\n".join(result) + "\n"
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Click
pytest
requests
semver
wheel
tomli; python_version < "3.11"
wheel
33 changes: 33 additions & 0 deletions testcases/test1.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@




if 1:
pass



if 1:
pass




if 1:
pass



if 1:
pass

__version__ = "1.2.3"

if 1:
print("is circuitpython")

if 0:
print("not circuitpython (1)")

if 0:
print("not circuitpython (2)")
33 changes: 33 additions & 0 deletions testcases/test1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotation

try:
from typing import TYPE_CHECKING
except ImportError:
pass

try:
from typing import TYPE_CHECKING as T
except ImportError:
pass


try:
import typing
except:
pass

try:
import typing as T
except:
pass

__version__ = "0.0.0-auto"

if sys.implementation.name == "circuitpython":
print("is circuitpython")

if sys.implementation.name != "circuitpython":
print("not circuitpython (1)")

if not sys.implementation.name == "circuitpython":
print("not circuitpython (2)")
22 changes: 22 additions & 0 deletions tests/test_munge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys, pathlib
import pytest

top = pathlib.Path(__file__).parent.parent
sys.path.insert(0, str(top))

from circuitpython_build_tools.munge import munge

@pytest.mark.parametrize("test_path", top.glob("testcases/*.py"))
def test_munge(test_path):
result_path = test_path.with_suffix(".out")
result_path.unlink(missing_ok = True)

result_content = munge(test_path, "1.2.3")
result_path.write_text(result_content, encoding="utf-8")

expected_path = test_path.with_suffix(".exp")
expected_content = expected_path.read_text(encoding="utf-8")

assert result_content == expected_content

result_path.unlink()

0 comments on commit 3002d23

Please sign in to comment.