Skip to content

Commit

Permalink
Include torch-mlir-opt in Python wheels (#3964)
Browse files Browse the repository at this point in the history
This adds the `torch-mlir-opt` tool to the Python wheels, which allows
to use the commandline tool via a pip installed package instead of
having to compile the torch-mlir project yourself. The executable is
still installed to the deault location and copied over via the
`setup.py` to be included in the Python wheel. This could be refactored
and handled within CMake in a follow-up.
  • Loading branch information
marbre authored Jan 17, 2025
1 parent 33337fc commit db82bb9
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Tools
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
tools/import_onnx/__main__.py
tools/opt/__main__.py
)

declare_mlir_python_sources(TorchMLIRSiteInitialize
Expand Down Expand Up @@ -123,3 +124,5 @@ add_mlir_python_modules(TorchMLIRPythonModules
COMMON_CAPI_LINK_LIBS
TorchMLIRAggregateCAPI
)

add_dependencies(TorchMLIRPythonModules torch-mlir-opt)
40 changes: 40 additions & 0 deletions python/torch_mlir/tools/opt/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

"""Torch-MLIR modular optimizer driver
Typically, when installed from a wheel, this can be invoked as:
torch-mlir-opt [options] <input file>
To see available passes, dialects, and options, run:
torch-mlir-opt --help
"""
import os
import platform
import subprocess
import sys

from typing import Optional


def _get_builtin_tool(exe_name: str) -> Optional[str]:
if platform.system() == "Windows":
exe_name = exe_name + ".exe"
this_path = os.path.dirname(__file__)
tool_path = os.path.join(this_path, "..", "..", "_mlir_libs", exe_name)
return tool_path


def main(args=None):
if args is None:
args = sys.argv[1:]
exe = _get_builtin_tool("torch-mlir-opt")
return subprocess.call(args=[exe] + args)


if __name__ == "__main__":
sys.exit(main())
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def run(self):

shutil.copytree(python_package_dir, target_dir, symlinks=False)

torch_mlir_opt_src = os.path.join(cmake_build_dir, "bin", "torch-mlir-opt")
torch_mlir_opt_dst = os.path.join(
target_dir, "torch_mlir", "_mlir_libs", "torch-mlir-opt"
)
shutil.copy2(torch_mlir_opt_src, torch_mlir_opt_dst, follow_symlinks=False)


class CMakeExtension(Extension):
def __init__(self, name, sourcedir=""):
Expand Down Expand Up @@ -267,6 +273,7 @@ def build_extension(self, ext):
entry_points={
"console_scripts": [
"torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main",
"torch-mlir-opt = torch_mlir.tools.opt.__main__:main",
],
},
zip_safe=False,
Expand Down

0 comments on commit db82bb9

Please sign in to comment.