diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index 2034e99db429..aa882f328ec7 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -93,8 +93,9 @@ needs to be executed when running under a user-provided optimization level. The .. code:: c++ class PassInfoNode : public Object { - ffi::String name; int opt_level; + ffi::String name; + bool traceable; ffi::Array required; }; @@ -125,8 +126,8 @@ Python APIs to create a compilation pipeline using pass context. class PassContextNode : public Object { public: int opt_level{2}; - tvm::ffi::Array required_pass; - tvm::ffi::Array disabled_pass; + ffi::Array required_pass; + ffi::Array disabled_pass; mutable ffi::Optional diag_ctx; ffi::Map config; ffi::Array instruments; @@ -277,9 +278,7 @@ order that they were appended to the pass list. const PassInfo& pass_info = pass->Info(); if (!PassEnabled(pass_info)) continue; for (const auto& it : pass_info->required) { - const auto* name = it.as(); - TVM_FFI_ICHECK(name); - mod = GetPass(name->value)(mod, pass_ctx); + mod = GetPass(it)(std::move(mod), pass_ctx); } mod = pass(mod, pass_ctx); } @@ -317,19 +316,22 @@ favorably use Python APIs to create a specific pass object. std::function pass_func, int opt_level, ffi::String name, - ffi::Array required); + ffi::Array required, + bool traceable = false); Pass CreatePrimFuncPass( std::function pass_func, int opt_level, ffi::String name, - ffi::Array required); + ffi::Array required, + bool traceable = false); Pass CreateModulePass( std::function pass_func, int opt_level, ffi::String name, - ffi::Array required); + ffi::Array required, + bool traceable = false); Pass Sequential(tvm::ffi::Array passes, PassInfo pass_info); @@ -511,21 +513,27 @@ and ``PassContext`` methods. See (`src/ir/transform.cc`_) for more details. Built-in Instrument ^^^^^^^^^^^^^^^^^^^ -There are several built-in instruments. Those marked with *TODO* are not implemented yet. +There are several built-in instruments. - PassTimingInstrument (see `src/ir/instrument.cc`_) * Profile the execution time of passes. -- PrintIRBefore(TODO) +- PrintBeforeAll (see `python/tvm/ir/instrument.py`_) + + * Print the IR module and pass info before each pass executes. + +- PrintAfterAll (see `python/tvm/ir/instrument.py`_) + + * Print the IR module and pass info after each pass executes. + +- PassPrintingInstrument (see `python/tvm/ir/instrument.py`_) - * Print the IR module before the pass transforms it. :py:func:`tvm.transform.PrintIR` - can also serve this purpose if we insert it around passes. However, - with the ``PassInstrument``, we don't need to modify the sequence of passes. + * Selectively print the IR module before or after specific named passes. -- PrintAfter(TODO) +- DumpIR (see `python/tvm/ir/instrument.py`_) - * Print the IR module after the pass transforms it. + * Dump the IR module to files after each pass executes. Python Frontend ~~~~~~~~~~~~~~~ diff --git a/docs/conf.py b/docs/conf.py index 902ff6a657e5..1502c72a2eb5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -607,11 +607,10 @@ def update_alias_docstring(name, obj, lines): tvm_class_name_rewrite_map = { "tvm.tirx": ["Var", "Call"], - "tvm.relax": ["Var", "Call"], + "tvm.relax": ["Var", "Call", "StringImm"], "tvm.relax.frontend.nn": ["Module"], } - def distinguish_class_name(name: str, lines: list[str]): """Distinguish the docstring of type annotations. @@ -660,6 +659,77 @@ def strip_ipython_magic(app, docname, source): source[i] = re.sub(r"%%.*\n\s*", "", source[i]) +def _patch_python_domain_find_obj(): + """Patch PythonDomain.find_obj to resolve ambiguous cross-references. + + Sphinx's ``warn-missing-reference`` event is only fired for unresolved + references. Ambiguous short names such as ``StringImm`` already have + multiple matches at ``PythonDomain.find_obj`` time, so the disambiguation + needs to happen here instead. + """ + from sphinx.domains.python import PythonDomain + + if getattr(PythonDomain.find_obj, "_tvm_patched", False): + return + + _original_find_obj = PythonDomain.find_obj + + def _common_prefix_len(lhs: str, rhs: str) -> int: + count = 0 + for lpart, rpart in zip(lhs.split("."), rhs.split(".")): + if lpart != rpart: + break + count += 1 + return count + + def _dedup_find_obj(self, env, modname, classname, name, objtype, searchmode=0): + matches = _original_find_obj(self, env, modname, classname, name, objtype, searchmode) + if len(matches) <= 1: + return matches + + short_name = name.rsplit(".", 1)[-1] + + # Prefer a single canonical (non-aliased) entry if Sphinx already found one. + canonical_matches = [match for match in matches if not match[1].aliased] + if len(canonical_matches) == 1: + return canonical_matches + + # Use TVM's module context for the known short names we rewrite in docstrings. + if modname: + candidate_modules = sorted( + ( + module_name + for module_name, class_names in tvm_class_name_rewrite_map.items() + if short_name in class_names and modname.startswith(module_name) + ), + key=len, + reverse=True, + ) + for module_name in candidate_modules: + target_name = f"{module_name}.{short_name}" + context_matches = [match for match in matches if match[0] == target_name] + if len(context_matches) == 1: + return context_matches + + # Fall back to the unique match that best shares the current module prefix. + match_scores = { + match[0]: _common_prefix_len(modname, match[0]) for match in matches + } + best_score = max(match_scores.values()) + if best_score > 1: + best_matches = [match for match in matches if match_scores[match[0]] == best_score] + if len(best_matches) == 1: + return best_matches + + return matches + + _dedup_find_obj._tvm_patched = True + PythonDomain.find_obj = _dedup_find_obj + + +_patch_python_domain_find_obj() + + def setup(app): app.connect("source-read", strip_ipython_magic) app.connect("autodoc-process-docstring", process_docstring) diff --git a/docs/deep_dive/relax/learning.rst b/docs/deep_dive/relax/learning.rst index 72dc21186b8b..b24ed2675207 100644 --- a/docs/deep_dive/relax/learning.rst +++ b/docs/deep_dive/relax/learning.rst @@ -110,13 +110,13 @@ The code block below shows a low-level numpy implementation of the same model. def lnumpy_mlp(data, w0, b0, w1, b1): n = data.shape[0] lv0 = np.empty((n, 128), dtype="float32") - lnumpy_matmul(data, w0, b0, lv0) + lnumpy_linear(data, w0, b0, lv0) lv1 = np.empty((n, 128), dtype="float32") - lnumpy_relu(lv0, lv1) + lnumpy_relu0(lv0, lv1) out = np.empty((n, 10), dtype="float32") - lnumpy_matmul(lv1, w1, b1, out) + lnumpy_linear(lv1, w1, b1, out) return out With the low-level NumPy example in mind, now we are ready to introduce an Relax abstraction @@ -124,6 +124,10 @@ for the end-to-end model execution. The code block below shows a TVMScript imple .. code:: python + from tvm.script import ir as I + from tvm.script import tirx as T + from tvm.script import relax as R + @I.ir_module class Module: @T.prim_func(private=True) @@ -167,8 +171,8 @@ for the end-to-end model execution. The code block below shows a TVMScript imple n = T.int64() with R.dataflow(): lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32")) - lv1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 256), dtype="float32")) - lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((b, 10), dtype="float32")) + lv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((n, 256), dtype="float32")) + lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((n, 10), dtype="float32")) R.output(lv2) return lv2 diff --git a/docs/deep_dive/tensor_ir/learning.rst b/docs/deep_dive/tensor_ir/learning.rst index 229d6d9d69ca..e93870c32d57 100644 --- a/docs/deep_dive/tensor_ir/learning.rst +++ b/docs/deep_dive/tensor_ir/learning.rst @@ -59,6 +59,8 @@ language called TVMScript, which is a domain-specific dialect embedded in python .. code:: python + from tvm.script import tirx as T + @tvm.script.ir_module class MyModule: @T.prim_func diff --git a/docs/reference/api/python/relax/op.rst b/docs/reference/api/python/relax/op.rst index 922af768f50f..2f4ebe2a4912 100644 --- a/docs/reference/api/python/relax/op.rst +++ b/docs/reference/api/python/relax/op.rst @@ -66,6 +66,18 @@ tvm.relax.op.memory :members: :imported-members: +tvm.relax.op.vision +******************* +.. automodule:: tvm.relax.op.vision + :members: + :imported-members: + +tvm.relax.op.vm +*************** +.. automodule:: tvm.relax.op.vm + :members: + :imported-members: + tvm.relax.op.op_attrs ********************* .. automodule:: tvm.relax.op.op_attrs