Skip to content

Commit

Permalink
Extract the Python APIs in the pt1 dir back to the root (llvm#3237)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored and archana-ramalingam committed May 8, 2024
1 parent 8f0a679 commit e865616
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 188 deletions.
1 change: 0 additions & 1 deletion projects/pt1/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
SOURCES
torchscript.py
_dynamo_fx_importer.py
compiler_utils.py
dynamo.py
_version.py
)
Expand Down
75 changes: 0 additions & 75 deletions projects/pt1/python/torch_mlir/compiler_utils.py

This file was deleted.

105 changes: 6 additions & 99 deletions projects/pt1/python/torch_mlir/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,65 +17,15 @@
from torch_mlir.dynamo import _get_decomposition_table
from torch.fx.experimental.proxy_tensor import make_fx

from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
OutputType,
lower_mlir_module
)
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library


class OutputType(Enum):
"""The kind of output that `torchscript.compile` can produce.
In MLIR terminology, this describes the mix of dialects that will be
produced by the conversion process.
In user-facing API's, this type can always be passed interchangeably with an
appropriate string specifying the output type. The allowed strings are
the set of enum vales, allowed to be case insensitive and with `-` allowed
in place of `_`. The `OutputType.get` static method can be used to convert
from a string to an `OutputType` instance.
"""

# This output type consists of `torch` dialect ops that have been converted
# maximally to value semantics, decomposed, and shapes have been inferred.
TORCH = "torch"

# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
# as taking the `TORCH` output type and lowering it so that tensor
# computations are done with `linalg`-on-tensors ops.
LINALG_ON_TENSORS = "linalg-on-tensors"

# This output type consists of `tosa` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to TOSA.
TOSA = "tosa"

# This output type consists of `stablehlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to StableHLO.
STABLEHLO = "stablehlo"

# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = "raw"

@staticmethod
def get(spec: Union[str, "OutputType"]) -> "OutputType":
"""Gets an OutputType from allowed way to specify one.
Args:
spec: An OutputType instance or the case-insensitive name of one of the
enum values.
Returns:
An OutputType instance.
"""
if isinstance(spec, OutputType):
return spec
spec = spec.upper().replace("-", "_")
if spec not in OutputType.__members__:
raise ValueError(f"For output_type= argument, expected one of: "
f"{', '.join(OutputType.__members__.keys())}")
return OutputType[spec]


class TensorPlaceholder:
"""A class that represents a formal parameter of a given shape and dtype.
Expand Down Expand Up @@ -270,49 +220,6 @@ def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra
return ""


def _lower_mlir_module(verbose, output_type, module):
if verbose:
print("\n====================")
print("Torch Backend IR")
print(module)

if output_type == OutputType.TORCH:
return module

if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
"Lowering Torch Backend IR -> TOSA Backend IR")
if verbose:
print("\n====================")
print("TOSA Backend IR")
print(module)
return module

if output_type == OutputType.LINALG_ON_TENSORS:
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
if verbose:
print("\n====================")
print("LINALG Backend IR")
print(module)
return module

elif output_type == OutputType.STABLEHLO:
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR")
if verbose:
print("\n====================")
print("StableHLO Backend IR")
print(module)
return module
raise Exception(f"Unknown OutputType: {output_type}")


def compile(model: torch.nn.Module,
example_args: _example_args,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
Expand Down Expand Up @@ -464,4 +371,4 @@ def compile(model: torch.nn.Module,
enable_ir_printing=enable_ir_printing,
)

return _lower_mlir_module(verbose, output_type, mb.module)
return lower_mlir_module(verbose, output_type, mb.module)
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from torch.export import ExportedProgram

from torch_mlir import fx
from torch_mlir.torchscript import (
_example_args,
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir.torchscript import (
BACKEND_LEGAL_OPS,
run_pipeline_with_repro_report,
_lower_mlir_module,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
Expand Down Expand Up @@ -76,7 +77,7 @@ def jit(
"Lowering TorchFX IR -> Torch Backend IR",
)

return _lower_mlir_module(verbose, output_type, mlir_module)
return lower_mlir_module(verbose, output_type, mlir_module)


class FxImporterTestConfig(TestConfig):
Expand Down
10 changes: 6 additions & 4 deletions projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
set_model_name,
)

from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
from torch_mlir.dynamo import _get_decomposition_table
from torch_mlir.torchscript import (
_example_args,
OutputType,
BACKEND_LEGAL_OPS,
run_pipeline_with_repro_report,
_lower_mlir_module,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
Expand Down Expand Up @@ -148,7 +150,7 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule,
"Lowering TorchFX IR -> Torch Backend IR",
)

return _lower_mlir_module(verbose, output_type, mlir_module)
return lower_mlir_module(verbose, output_type, mlir_module)


class TorchDynamoTestConfig(TestConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# Also available under a BSD-style license. See LICENSE.


from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir.ir import *
from torch_mlir.passmanager import *
from torch_mlir.torchscript import OutputType
from torch_mlir.torchscript import _lower_mlir_module

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend

Expand Down Expand Up @@ -58,7 +60,7 @@ def compile(self, imported_module: Module):
"Lowering TorchFX IR -> Torch Backend IR",
)

imported_module = _lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
imported_module = lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
compiled_module = self.refbackend.compile(imported_module)
return compiled_module

Expand Down
1 change: 1 addition & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
compiler_utils.py
fx.py
extras/fx_decomp_util.py
)
Expand Down
Loading

0 comments on commit e865616

Please sign in to comment.