diff --git a/docs/source/conf.py b/docs/source/conf.py
index 08590bdf..476ee9c1 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -39,7 +39,7 @@
project = "Allo"
author = "Allo Authors"
-copyright = "2024, Allo Authors"
+copyright = "2025, Allo Authors"
# The full version, including alpha/beta/rc tags
release = "0.5"
diff --git a/docs/source/dive/ip.rst b/docs/source/dive/ip.rst
new file mode 100644
index 00000000..6d4fab85
--- /dev/null
+++ b/docs/source/dive/ip.rst
@@ -0,0 +1,88 @@
+.. Copyright Allo authors. All Rights Reserved.
+ SPDX-License-Identifier: Apache-2.0
+
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+##############
+IP Integration
+##############
+
+Apart from directly writing Allo kernels in Python, we also support integrating existing C++ HLS kernels into Allo. This feature is useful when you have a existing optimized C++ HLS code that wants to be integrated into Allo. The following example shows how to integrate a simple vector addition kernel written in C++ into Allo.
+
+Suppose the C++ kernel header is defined in the ``vadd.h`` file:
+
+.. code-block:: cpp
+
+ #ifndef VADD_H
+ #define VADD_H
+
+ void vadd(int A[32], int B[32], int C[32]);
+
+ #endif // VADD_H
+
+And the corresponding implementation is defined in the ``vadd.cpp`` file:
+
+.. code-block:: cpp
+
+ #include "vadd.h"
+ using namespace std;
+
+ void vadd(int A[32], int B[32], int C[32]) {
+ for (int i = 0; i < 32; ++i) {
+ C[i] = A[i] + B[i];
+ }
+ }
+
+In Allo, we can create an *IP module* to wrap the C++ kernel. Basically, we need to provide the top-level function name, the header files, and the implementation files. Also, currently an Allo signature is required to specify the input and output types of the kernel. Allo will automatically compile the C++ kernel and generate the corresponding Python wrapper based on the provided files and signature. The last argument ``link_hls`` determines whether the C++ compiler should link the Vitis HLS libraries (e.g., ``ap_int``), which is only available when your machine has installed Vitis HLS.
+
+.. code-block:: python
+
+ vadd = allo.IPModule(
+ top="vadd",
+ headers=["vadd.h"],
+ impls=["vadd.cpp"],
+ signature=["int32[32]", "int32[32]", "int32[32]"],
+ link_hls=False,
+ )
+
+After creating the IP module, we can use it in Allo as a normal Python function. For example, we can directly call the ``vadd`` function to perform vector addition. The inputs and outputs will be automatically wrapped and unwrapped as NumPy arrays, which greatly simplies the burden of complex C-Python interface management. This is also very useful when you want to debug the HLS kernels with the Python data.
+
+.. code-block:: python
+
+ np_A = np.random.randint(0, 100, (32,)).astype(np.int32)
+ np_B = np.random.randint(0, 100, (32,)).astype(np.int32)
+ np_C = np.zeros((32,), dtype=np.int32)
+ vadd(np_A, np_B, np_C)
+ np.testing.assert_allclose(np_A + np_B, np_C, atol=1e-6)
+
+Moreover, the IP module can also be called in a normal Allo kernel. In the following example, we wrap the ``vadd`` function into an Allo ``kernel`` and use it to perform vector addition. The Allo kernel can then be further customized and compiled with the external C++ HLS kernel.
+
+.. code-block:: python
+
+ def kernel(A: int32[32], B: int32[32]) -> int32[32]:
+ C: int32[32] = 0
+ vadd(A, B, C)
+ return C
+
+ s = allo.customize(kernel)
+ print(s.module)
+ mod = s.build()
+ np_A = np.random.randint(0, 100, (32,)).astype(np.int32)
+ np_B = np.random.randint(0, 100, (32,)).astype(np.int32)
+ allo_C = mod(np_A, np_B)
+ np.testing.assert_allclose(np_A + np_B, allo_C, atol=1e-6)
diff --git a/docs/source/dive/pytorch.rst b/docs/source/dive/pytorch.rst
new file mode 100644
index 00000000..9076dd06
--- /dev/null
+++ b/docs/source/dive/pytorch.rst
@@ -0,0 +1,70 @@
+.. Copyright Allo authors. All Rights Reserved.
+ SPDX-License-Identifier: Apache-2.0
+
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+###################
+PyTorch Integration
+###################
+
+In this document, we will show how to directly compile PyTorch models to Allo.
+First, users can define a PyTorch module as usual:
+
+.. code-block:: python
+
+ import torch
+ import torch.nn.functional as F
+ import torch.nn as nn
+
+ class Model(nn.Module):
+ def __init__(self):
+ super(Model, self).__init__()
+
+ def forward(self, x, y):
+ x = x + y
+ x = F.relu(x)
+ return x
+
+ model = Model()
+ model.eval()
+
+Then, users can compile the PyTorch model to Allo by using the ``allo.frontend.from_pytorch`` API:
+
+.. code-block:: python
+
+ import allo
+ example_inputs = [torch.rand(1, 3, 10, 10), torch.rand(1, 3, 10, 10)]
+ llvm_mod = allo.frontend.from_pytorch(model, example_inputs=example_inputs)
+
+Then, we can use the generated Allo LLVM module as usual by passing in the NumPy inputs:
+
+.. code-block:: python
+
+ golden = model(*example_inputs)
+ np_inputs = [x.detach().numpy() for x in example_inputs]
+ res = llvm_mod(*np_inputs)
+ torch.testing.assert_close(res, golden.detach().numpy())
+ print("Passed!")
+
+The process should be very similar to the original Allo workflow.
+The default target is LLVM. We can also change the backend to other compilers such as Vitis HLS by specifying the ``target``:
+
+.. code-block:: python
+
+ mod = allo.frontend.from_pytorch(model, example_inputs=example_inputs, target="vhls")
+ print(mod.hls_code)
diff --git a/docs/source/index.rst b/docs/source/index.rst
index abbfa8d1..289577db 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -40,6 +40,17 @@ Allo is an Accelerator Design Language (ADL) and compiler that facilitates the c
gallery/tutorial_02_vhls.rst
+.. toctree::
+ :maxdepth: 1
+ :caption: Deep Dive
+
+ gallery/dive_01_data_types.rst
+ gallery/dive_02_template.rst
+ gallery/dive_03_composition.rst
+ dive/ip.rst
+ dive/pytorch.rst
+ gallery/dive_04_features.rst
+
.. toctree::
:maxdepth: 1
:caption: Developer Guide
diff --git a/tutorials/dive_01_data_types.py b/tutorials/dive_01_data_types.py
new file mode 100644
index 00000000..180bf685
--- /dev/null
+++ b/tutorials/dive_01_data_types.py
@@ -0,0 +1,114 @@
+# Copyright Allo authors. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Data Types and Type Casting
+===========================
+
+**Author**: Hongzheng Chen (hzchen@cs.cornell.edu)
+
+This document will discuss the Allo-supported data types in detail.
+All the data types are defined in the ``allo.ir.types`` module.
+"""
+
+import allo
+from allo.ir.types import int16, int32, float32, Int, UInt, Float, Fixed
+
+##############################################################################
+# Currently, Allo supports three base data types for mathematical operations:
+#
+# - Integers: ``Int(bitwdith)``, ``UInt(bitwidth)``
+# - Floating points: ``Float(bitwidth)`` (only support 16, 32, and 64 bits)
+# - Fixed points: ``Fixed(bitwidth, frac)``, ``UFixed(bitwidth, frac)``
+#
+# For example, one can declare a 15-bit integer as ``Int(15)`` and an unsigned 8-bit fixed-point number with 3 fractional bits as ``UFixed(8, 3)``.
+# For all the C/C++ supported data types, we provide shorthands like ``float32`` and ``int16`` to easily declare them.
+
+# %%
+# Notice different from native Python, Allo requires the program to be **strongly and statically typed**.
+# The variable types are either declared explicitly or inferred from the context.
+# For a variable that first appears in the program, we should declare it with an expected data type using Python's type hint notation:
+
+a: int32
+
+# %%
+# Once the data types are defined, an important consideration is how to handle
+# operations between variables of different types. Allo supports two types of casting:
+# (1) implicit casting that is automatically done by the Allo compiler;
+# and (2) explicit casting that is manually done by the user.
+
+##############################################################################
+# Implicit Casting
+# ----------------
+# Allo has a strong type system that follows the `MLIR convention `_ to enforce the operand types are the same for the arithmetic operations.
+# However, it is burdensome for users to cast the variables every time, and it is also error-prone to avoid overflow when performing computations.
+# Therefore, Allo is equipped with builtin casting rules to automatically cast the variables to the same type before the operation, which is called *implicit casting*.
+# An example is shown below:
+
+
+def add(a: int32, b: int32) -> int32:
+ return a + b
+
+
+s = allo.customize(add)
+print(s.module)
+
+# %%
+# We can see that ``a`` and ``b`` are firstly casted to ``int33``, added
+# together, and converted back to ``int32``.
+# This is to avoid overflow and is automatically inferred by the Allo compiler.
+
+
+##############################################################################
+# Explicit Casting
+# ----------------
+# One can also explicitly cast the variable to a specific type by creating an intermediate variable,
+# or use Python-builtin functions like ``float()`` and ``int()`` to explicitly cast a variable to ``float32`` or ``int32``.
+# Another example is shown below:
+
+
+def cast(a: int32) -> int16:
+ b: float32 = a # explicit
+ c: float32 = b * 2
+ d: float32 = float(a) * 2
+ e: int16 = c + d
+ return e
+
+
+s = allo.customize(cast)
+print(s.module)
+
+# %%
+# By explicitly creating an intermediate variable ``b``, we can cast the ``int32`` variable ``a`` to the desired floating-point type.
+# Similarly, calling ``float(a)`` can also cast ``a`` to a floating-point type.
+#
+# .. note::
+#
+# The above stated explicit casting between integers and floating points preserves the value but the precision may be changed.
+# If you want to use a union type to represent both integers and floating points, please use the `.bitcast()` API instead. For example, ``a.bitcast()`` can convert ``int32`` to ``float32`` representation with the bit pattern preserved.
+
+##############################################################################
+# Bit Operations
+# --------------
+# As hardware accelerators have ability to manipulate each bit of the data, Allo supports bit operations on
+# those integer types. For example, we can access a specific bit in an integer ``a`` using the indexing operator:
+#
+# .. code-block:: python
+#
+# a[15]
+
+# %%
+# We can also extract a chunk of bits from an integer using the slicing operator:
+#
+# .. code-block:: python
+#
+# a[0:16]
+#
+# .. note::
+#
+# Allo follows the Python convention that the upper bound is not included, so ``[0:16]`` means
+# extracting the first 16 bits, which is different from the Xilinx HLS convention that uses ``[0:15]``
+# to indicate the first 16 bits.
+
+# %%
+# Not only constant values are supported, but also variables can be used as the index or the slice range.
diff --git a/tutorials/dive_02_template.py b/tutorials/dive_02_template.py
new file mode 100644
index 00000000..d0868c4c
--- /dev/null
+++ b/tutorials/dive_02_template.py
@@ -0,0 +1,82 @@
+# Copyright Allo authors. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Template Kernels
+================
+
+**Author**: Hongzheng Chen (hzchen@cs.cornell.edu)
+
+This document explains how to write a template kernel in Allo.
+Template kernels are useful when we need to reuse a kernel with different data types or when certain computation patterns depend on specific constants.
+By leveraging template kernels, we can achieve greater flexibility and reusability in the code.
+"""
+
+import allo
+from allo.ir.types import int32, float32
+
+# %%
+# We follow Python's convention to use *type variable* to define a template kernel.
+# Specifically, the type variable is specified after the function name using square brackets: ``def kernel[T](...)``, and the type variable can be used in the function signature and body.
+# Importantly, as the native Python interpreter does not support Allo's type declaration (i.e., base type + shape), we need to use string annotations like ``"T[10]"`` to specify the type of the variables.
+# Otherwise, it will raise a type error.
+#
+# In the following, we define a simple addition function that adds 1 to each element of the input array.
+# To invoke the kernel with a specific data type, we can use the ``instantiate`` argument in the ``allo.customize`` function.
+
+
+def kernel[T](A: "T[10]") -> "T[10]":
+ B: T[10]
+ for i in range(10):
+ B[i] = A[i] + 1
+ return B
+
+
+s = allo.customize(kernel, instantiate=[int32])
+print(s.module)
+
+# %%
+# We can see that the kernel is specialized with the given ``int32`` data type.
+# Similarly, we can directly declare a new kernel by specifying ``float32`` as the data type.
+
+s = allo.customize(kernel, instantiate=[float32])
+print(s.module)
+
+# %%
+# If we not only want to specialize the data type but also the shape of the array, we can provide another type variable, and pass it to the ``instantiate`` argument.
+# Note that here we also use the ``: base_type`` notation to constrain the type of the type variable. Here we constrain the type variable ``M`` to be an integer.
+
+
+def kernel2[T, M: int32](A: "T[M]") -> "T[M]":
+ B: T[M]
+ for i in range(M):
+ B[i] = A[i] + 1
+ return B
+
+
+s = allo.customize(kernel2, instantiate=[int32, 20])
+print(s.module)
+
+# %%
+# Furthermore, Allo's template also enables metaprogramming that can evaluate type variables at compile time.
+# Specifically, we can use the ``allo.meta_if``, ``allo.meta_elif``, and ``allo.meta_else`` to conditionally generate code based on the type variables.
+# Just to make sure the conditions can be determined at compile time.
+
+
+def kernel3[T, M: int32](A: "T[M]") -> "T[M]":
+ B: T[M]
+ for i in range(M):
+ with allo.meta_if(T == int32):
+ B[i] = A[i] + 1
+ with allo.meta_else():
+ B[i] = A[i] - 1
+ return B
+
+
+# %%
+# In final generated code, we can see that only a single branch is generated based on the given data type.
+
+s = allo.customize(kernel3, instantiate=[int32, 20])
+print(s.module)
+s = allo.customize(kernel3, instantiate=[float32, 20])
+print(s.module)
diff --git a/tutorials/dive_03_composition.py b/tutorials/dive_03_composition.py
new file mode 100644
index 00000000..ef3fc175
--- /dev/null
+++ b/tutorials/dive_03_composition.py
@@ -0,0 +1,120 @@
+# Copyright Allo authors. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Kernel Composition
+==================
+
+**Author**: Hongzheng Chen (hzchen@cs.cornell.edu)
+
+This document will discuss kernel composition.
+In the previous tutorials, we have seen how to write a simple kernel.
+However, in real applications, we often need to compose multiple kernels together.
+
+In the following example, we define a ``matrix_add`` and a ``gemm`` kernel, and wrap them into a ``top``-level function.
+"""
+
+import allo
+from allo.ir.types import int32, float32
+
+M, K, N = 32, 32, 32
+
+
+def matrix_add(A: int32[M, N]) -> int32[M, N]:
+ B: int32[M, N] = 0
+ for i, j in allo.grid(M, N):
+ B[i, j] = A[i, j] + 1
+ return B
+
+
+def gemm(A: int32[M, K], B: int32[K, N]) -> int32[M, N]:
+ C: int32[M, N] = 0
+ for i, j in allo.grid(M, N):
+ for k in allo.reduction(K):
+ C[i, j] += A[i, k] * B[k, j]
+ return C
+
+
+def top(A: int32[M, K], B: int32[K, N]) -> int32[M, N]:
+ C = gemm(A, B)
+ D = matrix_add(C)
+ return D
+
+
+# %%
+# Different teams or people can then work on different parts of the code and optimize each kernel.
+# We first create a schedule for the ``matrix_add`` kernel, and add several optimizations.
+
+s1 = allo.customize(matrix_add)
+s1.pipeline("j")
+print(s1.module)
+
+# %%
+# Then we create a schedule for the ``gemm`` kernel and optimize it.
+
+s2 = allo.customize(gemm)
+s2.reorder("k", "j")
+s2.buffer_at(s2.C, axis="i")
+s2.pipeline("j")
+print(s2.module)
+
+# %%
+# Notice that now we only optimize the separate kernels but do not incorporate them into the top-level function, as shown in the following printed module.
+
+s = allo.customize(top)
+print(s.module)
+
+# %%
+# Therefore, after each part has been optimized, we need to explicitly *compose* them together.
+# In Allo, we can use the ``.compose()`` primitive to compose the schedules together into the parent function.
+
+s.compose([s1, s2])
+print(s.module)
+
+# %%
+# We can see that the schedules for the ``matrix_add`` and ``gemm`` kernels are both correctly optimized in the top-level function.
+
+##############################################################################
+# Template Composition
+# --------------------
+# Sometimes we may define template kernels and invoke the kernel with different template arguments. Allo provides an *id* option to specify the exact kernel to be composed.
+
+
+def kernel[T_in, T_out, S](A: "T_in[S]") -> "T_out[S]":
+ B: T_out[S] = 0
+ for i in range(S):
+ with allo.meta_if(T_out == int32):
+ B[i] = A[i] + 1
+ with allo.meta_else():
+ B[i] = A[i] * 2
+ return B
+
+
+def top2(A: int32[M]) -> float32[M]:
+ C = kernel[int32, int32, M, "K1"](A)
+ D = kernel[int32, float32, M, "K2"](C)
+ return D
+
+
+# %%
+# Specifically, the last argument of the template kernel is the *id* of the kernel. Later on we can use this ID for distinguishing different kernels during composition.
+# We also customize the two template kernels with different optimizations first.
+
+s1 = allo.customize(kernel, instantiate=[int32, int32, M])
+s1.unroll("i", factor=4)
+print(s1.module)
+
+s2 = allo.customize(kernel, instantiate=[int32, float32, M])
+s2.pipeline("i")
+print(s2.module)
+
+# %%
+# Finally, we compose the two template kernels into the top-level function with the ID specified.
+
+s = allo.customize(top2)
+s.compose(s1, id="K1")
+s.compose(s2, id="K2")
+print(s.module)
+
+# %%
+# We can see from the printed module that the loop in the first kernel is unrolled by a factor of 4, and the loop in the second kernel is pipelined.
diff --git a/tutorials/dive_04_features.py b/tutorials/dive_04_features.py
new file mode 100644
index 00000000..e087285f
--- /dev/null
+++ b/tutorials/dive_04_features.py
@@ -0,0 +1,76 @@
+# Copyright Allo authors. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Other Features
+==============
+
+**Author**: Hongzheng Chen (hzchen@cs.cornell.edu)
+
+This document will discuss other features that are not covered in the previous tutorials.
+"""
+
+##############################################################################
+# Dynamic Shapes
+# --------------
+# In some cases, the shape of the tensor is not known at compile time, so we can use ``[...]`` to represent the dynamic shape.
+# From the generated MLIR module, we can see it has a ``"?"`` in the shape of the tensor, which means the shape is not predefined,
+# but we can still run the LLVM module with arbitrary shapes of NumPy arrays.
+
+import allo
+from allo.ir.types import int32, float32
+import numpy as np
+
+
+def kernel(A: float32[...], B: float32[...], size: int32):
+ for i in range(size):
+ B[i] = A[i]
+
+
+s = allo.customize(kernel)
+print(s.module)
+np_A = np.random.random((256,)).astype(np.float32)
+allo_A = np.zeros((256,)).astype(np.float32)
+mod = s.build()
+mod(np_A, allo_A, 256)
+np.testing.assert_allclose(np_A, allo_A)
+
+# %%
+# We can also check the generated HLS code that the arguments are declared as pointers.
+
+code = s.build(target="vhls")
+print(code)
+
+##############################################################################
+# Tuple Return
+# ------------
+# Another feature is the tuple support. As in Python, we can return multiple values from a function, Allo
+# also supports this by explicitly specifying the return type as a tuple.
+
+
+def callee(a: float32, b: float32) -> (float32, float32):
+ c: float32 = a + b
+ d: float32 = a - b
+ return c, d
+
+
+def kernel(A: float32[10], B: float32[10]) -> (float32[10], float32[10]):
+ C: float32[10] = 0
+ D: float32[10] = 0
+ for i in range(10):
+ C[i], D[i] = callee(A[i], B[i])
+ return C, D
+
+
+s = allo.customize(kernel)
+print(s.module)
+mod = s.build()
+np_A = np.random.random((10,)).astype(np.float32)
+np_B = np.random.random((10,)).astype(np.float32)
+np_C, np_D = mod(np_A, np_B)
+np_C_ref = np.zeros((10,), dtype=np.float32)
+np_D_ref = np.zeros((10,), dtype=np.float32)
+for i in range(10):
+ np_C_ref[i], np_D_ref[i] = callee(np_A[i], np_B[i])
+np.testing.assert_allclose(np_C, np_C_ref)
+np.testing.assert_allclose(np_D, np_D_ref)
diff --git a/tutorials/tutorial_01_get_started.py b/tutorials/tutorial_01_get_started.py
index 815d271d..253d0411 100644
--- a/tutorials/tutorial_01_get_started.py
+++ b/tutorials/tutorial_01_get_started.py
@@ -34,8 +34,10 @@
# %%
# We then define a function that takes two 32x32 matrices as inputs and
# returns a 32x32 matrix as output. The variable declaration is defined
-# as ``: []``. We require **strict type annotation** in
-# Allo's kernels, which is different from directly programming in Python.
+# as ``: []``, and the function type is defined as
+# ``(, , ...) -> ``.
+# We require **strict type annotation** in Allo's kernels, which is different
+# from directly programming in Python.
#
# Inside the kernel, we provide a shorthand for the loop iterator. For example,
# ``for i, j, k in allo.grid(32, 32, 32)`` is equivalent to the following