Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions docs/arch/pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ needs to be executed when running under a user-provided optimization level. The
.. code:: c++

class PassInfoNode : public Object {
ffi::String name;
int opt_level;
ffi::String name;
bool traceable;
ffi::Array<ffi::String> required;
};

Expand Down Expand Up @@ -125,8 +126,8 @@ Python APIs to create a compilation pipeline using pass context.
class PassContextNode : public Object {
public:
int opt_level{2};
tvm::ffi::Array<tvm::Expr> required_pass;
tvm::ffi::Array<tvm::Expr> disabled_pass;
ffi::Array<ffi::String> required_pass;
ffi::Array<ffi::String> disabled_pass;
mutable ffi::Optional<DiagnosticContext> diag_ctx;
ffi::Map<ffi::String, Any> config;
ffi::Array<instrument::PassInstrument> instruments;
Expand Down Expand Up @@ -277,9 +278,7 @@ order that they were appended to the pass list.
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
TVM_FFI_ICHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
mod = GetPass(it)(std::move(mod), pass_ctx);
}
mod = pass(mod, pass_ctx);
}
Expand Down Expand Up @@ -317,19 +316,22 @@ favorably use Python APIs to create a specific pass object.
std::function<Function(Function, IRModule, PassContext)> pass_func,
int opt_level,
ffi::String name,
ffi::Array<ffi::String> required);
ffi::Array<ffi::String> required,
bool traceable = false);

Pass CreatePrimFuncPass(
std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
int opt_level,
ffi::String name,
ffi::Array<ffi::String> required);
ffi::Array<ffi::String> required,
bool traceable = false);

Pass CreateModulePass(
std::function<IRModule(IRModule, PassContext)> pass_func,
int opt_level,
ffi::String name,
ffi::Array<ffi::String> required);
ffi::Array<ffi::String> required,
bool traceable = false);

Pass Sequential(tvm::ffi::Array<Pass> passes, PassInfo pass_info);

Expand Down Expand Up @@ -511,21 +513,27 @@ and ``PassContext`` methods. See (`src/ir/transform.cc`_) for more details.
Built-in Instrument
^^^^^^^^^^^^^^^^^^^

There are several built-in instruments. Those marked with *TODO* are not implemented yet.
There are several built-in instruments.

- PassTimingInstrument (see `src/ir/instrument.cc`_)

* Profile the execution time of passes.

- PrintIRBefore(TODO)
- PrintBeforeAll (see `python/tvm/ir/instrument.py`_)

* Print the IR module and pass info before each pass executes.

- PrintAfterAll (see `python/tvm/ir/instrument.py`_)

* Print the IR module and pass info after each pass executes.

- PassPrintingInstrument (see `python/tvm/ir/instrument.py`_)

* Print the IR module before the pass transforms it. :py:func:`tvm.transform.PrintIR`
can also serve this purpose if we insert it around passes. However,
with the ``PassInstrument``, we don't need to modify the sequence of passes.
* Selectively print the IR module before or after specific named passes.

- PrintAfter(TODO)
- DumpIR (see `python/tvm/ir/instrument.py`_)

* Print the IR module after the pass transforms it.
* Dump the IR module to files after each pass executes.

Python Frontend
~~~~~~~~~~~~~~~
Expand Down
74 changes: 72 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,10 @@ def update_alias_docstring(name, obj, lines):

tvm_class_name_rewrite_map = {
"tvm.tirx": ["Var", "Call"],
"tvm.relax": ["Var", "Call"],
"tvm.relax": ["Var", "Call", "StringImm"],
"tvm.relax.frontend.nn": ["Module"],
}


def distinguish_class_name(name: str, lines: list[str]):
"""Distinguish the docstring of type annotations.

Expand Down Expand Up @@ -660,6 +659,77 @@ def strip_ipython_magic(app, docname, source):
source[i] = re.sub(r"%%.*\n\s*", "", source[i])


def _patch_python_domain_find_obj():
"""Patch PythonDomain.find_obj to resolve ambiguous cross-references.

Sphinx's ``warn-missing-reference`` event is only fired for unresolved
references. Ambiguous short names such as ``StringImm`` already have
multiple matches at ``PythonDomain.find_obj`` time, so the disambiguation
needs to happen here instead.
"""
from sphinx.domains.python import PythonDomain

if getattr(PythonDomain.find_obj, "_tvm_patched", False):
return

_original_find_obj = PythonDomain.find_obj

def _common_prefix_len(lhs: str, rhs: str) -> int:
count = 0
for lpart, rpart in zip(lhs.split("."), rhs.split(".")):
if lpart != rpart:
break
count += 1
return count

def _dedup_find_obj(self, env, modname, classname, name, objtype, searchmode=0):
matches = _original_find_obj(self, env, modname, classname, name, objtype, searchmode)
if len(matches) <= 1:
return matches

short_name = name.rsplit(".", 1)[-1]

# Prefer a single canonical (non-aliased) entry if Sphinx already found one.
canonical_matches = [match for match in matches if not match[1].aliased]
if len(canonical_matches) == 1:
return canonical_matches

# Use TVM's module context for the known short names we rewrite in docstrings.
if modname:
candidate_modules = sorted(
(
module_name
for module_name, class_names in tvm_class_name_rewrite_map.items()
if short_name in class_names and modname.startswith(module_name)
),
key=len,
reverse=True,
)
for module_name in candidate_modules:
target_name = f"{module_name}.{short_name}"
context_matches = [match for match in matches if match[0] == target_name]
if len(context_matches) == 1:
return context_matches

# Fall back to the unique match that best shares the current module prefix.
match_scores = {
match[0]: _common_prefix_len(modname, match[0]) for match in matches
}
best_score = max(match_scores.values())
if best_score > 1:
best_matches = [match for match in matches if match_scores[match[0]] == best_score]
if len(best_matches) == 1:
return best_matches

return matches

_dedup_find_obj._tvm_patched = True
PythonDomain.find_obj = _dedup_find_obj


_patch_python_domain_find_obj()


def setup(app):
app.connect("source-read", strip_ipython_magic)
app.connect("autodoc-process-docstring", process_docstring)
14 changes: 9 additions & 5 deletions docs/deep_dive/relax/learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,24 @@ The code block below shows a low-level numpy implementation of the same model.
def lnumpy_mlp(data, w0, b0, w1, b1):
n = data.shape[0]
lv0 = np.empty((n, 128), dtype="float32")
lnumpy_matmul(data, w0, b0, lv0)
lnumpy_linear(data, w0, b0, lv0)

lv1 = np.empty((n, 128), dtype="float32")
lnumpy_relu(lv0, lv1)
lnumpy_relu0(lv0, lv1)

out = np.empty((n, 10), dtype="float32")
lnumpy_matmul(lv1, w1, b1, out)
lnumpy_linear(lv1, w1, b1, out)
return out

With the low-level NumPy example in mind, now we are ready to introduce an Relax abstraction
for the end-to-end model execution. The code block below shows a TVMScript implementation of the model.

.. code:: python

from tvm.script import ir as I
from tvm.script import tirx as T
from tvm.script import relax as R

@I.ir_module
class Module:
@T.prim_func(private=True)
Expand Down Expand Up @@ -167,8 +171,8 @@ for the end-to-end model execution. The code block below shows a TVMScript imple
n = T.int64()
with R.dataflow():
lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((b, 10), dtype="float32"))
lv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((n, 10), dtype="float32"))
Comment on lines +174 to +175
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While these lines correctly fix the usage of undefined variables, the definition of n on line 172 (n = T.int64()) is problematic. It introduces a new symbolic variable that shadows the symbolic dimension n from the function signature, breaking the connection between the input and intermediate tensor shapes.

For a more robust and clearer example, consider removing line 172. The n used in the R.Tensor struct infos will then correctly refer to the symbolic dimension from the input tensor's shape.

R.output(lv2)
return lv2

Expand Down
2 changes: 2 additions & 0 deletions docs/deep_dive/tensor_ir/learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ language called TVMScript, which is a domain-specific dialect embedded in python

.. code:: python

from tvm.script import tirx as T

@tvm.script.ir_module
class MyModule:
@T.prim_func
Expand Down
12 changes: 12 additions & 0 deletions docs/reference/api/python/relax/op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ tvm.relax.op.memory
:members:
:imported-members:

tvm.relax.op.vision
*******************
.. automodule:: tvm.relax.op.vision
:members:
:imported-members:

tvm.relax.op.vm
***************
.. automodule:: tvm.relax.op.vm
:members:
:imported-members:

tvm.relax.op.op_attrs
*********************
.. automodule:: tvm.relax.op.op_attrs
Expand Down
Loading