From 05cd61c6b9ed829a7d1cacc01af8bbea2e7422a1 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 3 Feb 2026 12:33:40 -0800 Subject: [PATCH 1/3] Remove legacy bindings --- CMakeLists.txt | 144 +- csrc/fusion_segmenter.cpp | 2 +- csrc/multidevice/utils.h | 3 +- csrc/options.cpp | 2 - csrc/options.h | 2 - csrc/runtime/fusion_kernel_runtime.cpp | 12 - csrc/serde/Serde.md | 4 +- csrc/serde/fusion_record.cpp | 952 ---- csrc/serde/fusion_record.h | 124 - csrc/type.h | 4 +- .../autotune_inner_reduction.py | 2 + doc/dev/python_scheduling/autotune_matmul.py | 2 + .../python_scheduling/autotune_pointwise.py | 2 + doc/dev/python_scheduling/profile_matmul.py | 1 + python/nvfuser/README.md | 210 - python/nvfuser/__init__.py | 649 --- python/nvfuser/__init__.pyi | 4 - python/nvfuser/benchmark_utils.py | 160 - python/nvfuser/contrib/__init__.py | 9 - python/nvfuser/contrib/nn/__init__.py | 13 - python/nvfuser/contrib/nn/normalization.py | 725 --- python/nvfuser/nvfuser_version.py | 69 - python/nvfuser/pytorch_utils.py | 190 - python/python_frontend/fusion_cache.cpp | 953 ---- python/python_frontend/fusion_cache.h | 320 -- python/python_frontend/fusion_definition.cpp | 769 --- python/python_frontend/fusion_definition.h | 389 -- python/python_frontend/fusion_record.h | 3675 --------------- python/python_frontend/fusion_state.cpp | 297 -- python/python_frontend/fusion_state.h | 143 - .../python_frontend/multidevice_bindings.cpp | 103 - python/python_frontend/python_bindings.cpp | 4196 ----------------- python/python_frontend/python_bindings.h | 27 - .../python_bindings_extension.cpp | 18 - python/python_frontend/schedule_bindings.cpp | 517 -- python/python_frontend/segmentation.cpp | 369 -- python/python_frontend/segmentation.h | 246 - python/python_frontend/translation.cpp | 1484 ------ python/python_frontend/translation.h | 20 - python/python_frontend/translation_utils.cpp | 80 - python/python_frontend/translation_utils.h | 300 -- python/utils.py | 9 +- tests/python/direct/test_import.py | 17 - tests/python/utils/__init__.py | 6 - tests/python/utils/utils.py | 358 -- tools/env-config/env_options.yaml | 12 - 46 files changed, 20 insertions(+), 17573 deletions(-) delete mode 100644 csrc/serde/fusion_record.cpp delete mode 100644 csrc/serde/fusion_record.h delete mode 100644 python/nvfuser/README.md delete mode 100644 python/nvfuser/__init__.py delete mode 100644 python/nvfuser/__init__.pyi delete mode 100644 python/nvfuser/benchmark_utils.py delete mode 100644 python/nvfuser/contrib/__init__.py delete mode 100644 python/nvfuser/contrib/nn/__init__.py delete mode 100644 python/nvfuser/contrib/nn/normalization.py delete mode 100644 python/nvfuser/nvfuser_version.py delete mode 100644 python/nvfuser/pytorch_utils.py delete mode 100644 python/python_frontend/fusion_cache.cpp delete mode 100644 python/python_frontend/fusion_cache.h delete mode 100644 python/python_frontend/fusion_definition.cpp delete mode 100644 python/python_frontend/fusion_definition.h delete mode 100644 python/python_frontend/fusion_record.h delete mode 100644 python/python_frontend/fusion_state.cpp delete mode 100644 python/python_frontend/fusion_state.h delete mode 100644 python/python_frontend/multidevice_bindings.cpp delete mode 100644 python/python_frontend/python_bindings.cpp delete mode 100644 python/python_frontend/python_bindings.h delete mode 100644 python/python_frontend/python_bindings_extension.cpp delete mode 100644 python/python_frontend/schedule_bindings.cpp delete mode 100644 python/python_frontend/segmentation.cpp delete mode 100644 python/python_frontend/segmentation.h delete mode 100644 python/python_frontend/translation.cpp delete mode 100644 python/python_frontend/translation.h delete mode 100644 python/python_frontend/translation_utils.cpp delete mode 100644 python/python_frontend/translation_utils.h delete mode 100644 tests/python/utils/__init__.py delete mode 100644 tests/python/utils/utils.py diff --git a/CMakeLists.txt b/CMakeLists.txt index d21425d2e9e..24221045199 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,7 +18,6 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(NVFUSER_ROOT ${PROJECT_SOURCE_DIR}) set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc") set(NVFUSER_PYTHON_DIR "${NVFUSER_ROOT}/python") -set(NVFUSER_PYTHON_BINDINGS "${NVFUSER_ROOT}/python/python_frontend") set(NVFUSER_PYTHON_COMMON "${NVFUSER_ROOT}/python/python_common") set(NVFUSER_PYTHON_DIRECT_BINDINGS "${NVFUSER_ROOT}/python/python_direct") set(NVFUSER_CUTLASS "${NVFUSER_ROOT}/cutlass") @@ -381,21 +380,6 @@ if(NOT MSVC) ) endif() -if(BUILD_PYTHON) - list(APPEND NVFUSER_SRCS - ${NVFUSER_PYTHON_BINDINGS}/fusion_cache.cpp - ${NVFUSER_PYTHON_BINDINGS}/fusion_definition.cpp - ${NVFUSER_PYTHON_BINDINGS}/fusion_state.cpp - ${NVFUSER_PYTHON_BINDINGS}/segmentation.cpp - ${NVFUSER_PYTHON_BINDINGS}/translation.cpp - ${NVFUSER_PYTHON_BINDINGS}/translation_utils.cpp - ${NVFUSER_SRCS_DIR}/serde/fusion_record.cpp - ${NVFUSER_PYTHON_COMMON}/distributed_tensor.cpp - ${NVFUSER_PYTHON_COMMON}/python_utils.cpp - ${NVFUSER_PYTHON_COMMON}/translation_names.cpp - ) -endif() - # We create both static and shared libraries. # # Shared libraries are what ships, but a large advantage of static libraries is @@ -606,128 +590,6 @@ install(DIRECTORY "${NVFUSER_ROOT}/lib/dynamic_type/src/dynamic_type" DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/nvfuser") if(BUILD_PYTHON) - # ----------------------------- - # build nvfuser python library - # ----------------------------- - # nvfuser python API sources - set(NVFUSER_PYTHON_SRCS) - list(APPEND NVFUSER_PYTHON_SRCS - ${NVFUSER_PYTHON_BINDINGS}/multidevice_bindings.cpp - ${NVFUSER_PYTHON_BINDINGS}/python_bindings.cpp - ${NVFUSER_PYTHON_BINDINGS}/python_bindings_extension.cpp - ${NVFUSER_PYTHON_BINDINGS}/schedule_bindings.cpp - ) - - add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS}) - target_include_directories(nvf_py_internal PUBLIC ${NVFUSER_PYTHON_DIR}) - target_include_directories(nvf_py_internal PUBLIC ${NVFUSER_PYTHON_COMMON}) - target_include_directories(nvf_py_internal PUBLIC ${NVFUSER_CUTLASS}) - target_include_directories(nvf_py_internal SYSTEM INTERFACE - ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include - ) - - # setup python API version - add_custom_command( - OUTPUT ${NVFUSER_PYTHON_DIR}/nvfuser/version.py - COMMAND - "${Python_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py') .touch() \" - COMMAND - "${Python_EXECUTABLE}" ${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py nvfuser - DEPENDS ${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py - DEPENDS ${NVFUSER_PYTHON_DIR}/version.txt - WORKING_DIRECTORY ${NVFUSER_PYTHON_DIR}/tools/ - ) - add_custom_target( - gen_nvfuser_version ALL - DEPENDS ${NVFUSER_PYTHON_DIR}/nvfuser/version.py - ) - add_dependencies(nvf_py_internal gen_nvfuser_version) - - target_compile_definitions(nvf_py_internal PRIVATE - "-DTORCH_CUDA_BUILD_MAIN_LIB" - "-DC10_BUILD_MAIN_LIB=1" - EXTENSION_NAME=_C - ) - - add_library(nvfuser MODULE $) - target_compile_definitions(nvfuser PRIVATE - "-DTORCH_CUDA_BUILD_MAIN_LIB" - "-DC10_BUILD_MAIN_LIB=1" - EXTENSION_NAME=_C - ) - - if(NOT MSVC) - target_compile_options(nvf_py_internal PRIVATE -Wall -Wno-unused-function) - target_compile_options(nvf_py_internal PRIVATE -Werror) - - # Add function/data sections for dead code elimination - target_compile_options(nvf_py_internal PRIVATE - "-ffunction-sections" - "-fdata-sections" - ) - - set(NVF_LIB_SUFFIX ".so") - else() - set(NVF_LIB_SUFFIX ".pyd") - endif() - - set_target_properties(nvfuser PROPERTIES - C_STANDARD ${NVFUSER_C_STANDARD} - CUDA_STANDARD ${NVFUSER_CUDA_STANDARD} - CXX_STANDARD ${NVFUSER_CPP_STANDARD} - CXX_STANDARD_REQUIRED ON - CXX_VISIBILITY_PRESET hidden - INSTALL_RPATH - "$ORIGIN/lib:$ORIGIN/../nvfuser_common/lib:$ORIGIN/../nvidia/cuda_runtime/lib:$ORIGIN/../nvidia/cuda_nvrtc/lib:$ORIGIN/../../nvidia/cuda_cupti/lib:$ORIGIN/../torch/lib" - POSITION_INDEPENDENT_CODE Yes - SUFFIX ${NVF_LIB_SUFFIX} - VISIBILITY_INLINES_HIDDEN Yes - ) - set_target_properties(nvf_py_internal PROPERTIES - C_STANDARD ${NVFUSER_C_STANDARD} - CUDA_STANDARD ${NVFUSER_CUDA_STANDARD} - CXX_STANDARD ${NVFUSER_CPP_STANDARD} - CXX_STANDARD_REQUIRED ON - CXX_VISIBILITY_PRESET hidden - INSTALL_RPATH - "$ORIGIN/lib:$ORIGIN/../nvidia/cuda_runtime/lib:$ORIGIN/../nvidia/cuda_nvrtc/lib:$ORIGIN/../../nvidia/cuda_cupti/lib:$ORIGIN/../torch/lib" - POSITION_INDEPENDENT_CODE Yes - VISIBILITY_INLINES_HIDDEN Yes - ) - - if (NVFUSER_USE_CUTLASS) - target_link_libraries(nvf_py_internal PRIVATE nvf_cutlass) - endif() - - if (NOT MSVC) - target_link_libraries(nvf_py_internal PRIVATE CUDA::cupti) - endif() - - target_link_libraries(nvf_py_internal PRIVATE - nvfuser_codegen - "${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so" - pybind11::pybind11 pybind11::headers - ) - - target_link_libraries(nvfuser PRIVATE - nvf_py_internal - Python::Module - ) - - # Add dead code elimination flags to reduce file size - if(NOT MSVC) - target_link_options(nvfuser PRIVATE - "-Wl,--gc-sections" - "-Wl,--as-needed" - $<$:-s> - ) - endif() - - set_target_properties(nvfuser PROPERTIES - INSTALL_RPATH "$ORIGIN:$ORIGIN/lib:$ORIGIN/../build:$ORIGIN/../nvfuser_common/lib" - ) - install(TARGETS nvfuser DESTINATION lib) - # ------------------------------------------------ # build nvfuser direct python library # ------------------------------------------------ @@ -750,6 +612,9 @@ if(BUILD_PYTHON) ${NVFUSER_PYTHON_DIRECT_BINDINGS}/profile.cpp ${NVFUSER_PYTHON_DIRECT_BINDINGS}/direct_utils.cpp ${NVFUSER_PYTHON_DIRECT_BINDINGS}/python_translate.cpp + ${NVFUSER_PYTHON_COMMON}/distributed_tensor.cpp + ${NVFUSER_PYTHON_COMMON}/python_utils.cpp + ${NVFUSER_PYTHON_COMMON}/translation_names.cpp ) add_library(nvf_py_direct_internal OBJECT ${NVFUSER_PYTHON_DIRECT_SRCS}) @@ -1435,9 +1300,6 @@ target_include_directories(codegen_internal PRIVATE "${CMAKE_BINARY_DIR}/include install(EXPORT NvfuserTargets FILE NvfuserConfig.cmake DESTINATION share/cmake/nvfuser) file(CREATE_LINK "${CMAKE_BINARY_DIR}" "${NVFUSER_ROOT}/bin" SYMBOLIC) -# These symbolic links help IDEs like Cursor resolve symbols in nvfuser and -# nvfuser_direct. -file(CREATE_LINK "${NVFUSER_ROOT}/python/nvfuser" "${NVFUSER_ROOT}/nvfuser" SYMBOLIC) file(CREATE_LINK "${NVFUSER_ROOT}/python/nvfuser_direct" "${NVFUSER_ROOT}/nvfuser_direct" SYMBOLIC) message(STATUS "******** Nvfuser configuration summary ********") diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index c397bd6af40..8320ced5fa7 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -4535,7 +4535,7 @@ bool SegmentCandidateFinder::privatizeUpCastOrSqueezeOp() { // More details of the issue regarding merging horizontal groups can be // found in issue 3829 -- https://github.com/NVIDIA/Fuser/issues/3829. // Even with a squeeze op with 2 uses, this test case: - // https://github.com/NVIDIA/Fuser/blob/70ab277c7d91bcc24cd50dd75cedd79863a24f96/tests/python/test_python_frontend.py#L3666C1-L3666C30 + // https://github.com/NVIDIA/Fuser/blob/69da2d1972eb19bf7a04cef0c4debe9f55d8e11c/tests/python/direct/test_repro.py#L801 // demonstrates that privatizing the squeeze op leads to horizontal groups // that can't be merged back. if (maybe_upcast_squeeze_out_tv->definition()->isA() && diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index e924e7fcc75..755b7ae53e4 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -55,7 +55,8 @@ std::unordered_map mapDeviceAndStreamParallelTypeToId // `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in // `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's // extent if that IterDomain is sharded. -int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); +NVF_API int64_t +getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); // Returns the IterDomain that's parallelized on `parallel_type` in the domain // of type `domain_type`. diff --git a/csrc/options.cpp b/csrc/options.cpp index 6d587e35afd..14dddd89eec 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -139,8 +139,6 @@ std::unordered_map> Options< {"ptx", DebugDumpOption::Ptx}, {"ptxas_verbose", DebugDumpOption::PrintPtxasLog}, {"python_definition", DebugDumpOption::PythonDefinition}, - {"python_definition_segments", DebugDumpOption::PythonDefinitionSegments}, - {"python_frontend_debug", DebugDumpOption::PythonFrontendDebug}, {"sass", DebugDumpOption::Sass}, {"sass_to_file", DebugDumpOption::SassToFile}, {"segmented_fusion", DebugDumpOption::FusionSegments}, diff --git a/csrc/options.h b/csrc/options.h index 4c72c757460..6dad3909997 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -72,8 +72,6 @@ enum class DebugDumpOption { PreSegmenterLogging, HostIrLowering, //! Dump the Host IR after each lowering pass PythonDefinition, //! Python Frontend Fusion Definition. - PythonDefinitionSegments, //! Python Frontend Fusion Definition of segments. - PythonFrontendDebug, //! Python Frontend debug information. TransformPropagator, //! When running TransformPropagator, print propagation //! path and replay result Cubin, //! Dump compiled CUBIN diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index 6bb73ba9aad..b5caa8c3ac8 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -17,8 +17,6 @@ #include "instrumentation.h" #include "ir/base_nodes.h" #include "preseg_passes/pre_segmenter.h" -#include "python_frontend/fusion_definition.h" -#include "python_frontend/translation.h" #include "runtime/executor.h" #include "runtime/executor_dispatch.h" #include "runtime/fusion_cache_utils.h" @@ -430,16 +428,6 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { FusionProfiler::startCompile(); } - if (isDebugDumpEnabled(DebugDumpOption::PythonDefinitionSegments)) { - for (SegmentedGroup* group_to_run : runtime_workspace_.group_run_order) { - debug() << "Python definition for segmented group " - << group_to_run->groupId() << ":" << std::endl; - python_frontend::FusionDefinition fd(/*id=*/std::nullopt); - python_frontend::translate(group_to_run->getFusion(), &fd); - fd.print(debug()); - } - } - const std::vector all_runtime_inputs = prepareInputs(args); diff --git a/csrc/serde/Serde.md b/csrc/serde/Serde.md index 92ed1cc031a..7d92e4e7e1b 100644 --- a/csrc/serde/Serde.md +++ b/csrc/serde/Serde.md @@ -6,6 +6,8 @@ # NvFuser Serialization +## Serialization is disabled because legacy bindings are removed. + Serde is an acronym of serialization and deserialization. # Overview @@ -66,7 +68,7 @@ References: # Serde Testing -In test_python_frontend.py, the `exec_nvfuser` function is decorated with the `serde_check` functions. Every unit test should automatically test serialization. +The `exec_nvfuser` function is decorated with the `serde_check` functions. Every unit test should automatically test serialization. ```python def serde_check(test_fn: Callable): diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp deleted file mode 100644 index 2be8e01e86d..00000000000 --- a/csrc/serde/fusion_record.cpp +++ /dev/null @@ -1,952 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include -#include -#include -#include -#include - -namespace nvfuser::serde { - -std::vector parseStateArgs( - const flatbuffers::Vector* args) { - std::vector result; - for (auto s : *args) { - result.emplace_back(s->index(), s->type()); - } - return result; -} - -std::optional mapContiguityEnumToOptional(Contiguity v) { - switch (v) { - case Contiguity::Strided: - return std::optional(false); - case Contiguity::Contiguous: - return std::optional(true); - case Contiguity::None: - return std::nullopt; - } - NVF_THROW("Invalid contiguity type."); - return std::nullopt; -} - -template -python_frontend::RecordFunctor* deserializeOpRecord( - const std::unordered_map& str_to_func_map, - RecordType record_type, - const RecordFunctor* buffer) { - NVF_ERROR( - str_to_func_map.find(buffer->name()->str()) != str_to_func_map.end(), - "Missing mapping from operation string to nvfuser function in serde " - "deserialization: ", - buffer->name()->str()); - return new python_frontend::OpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->name()->str(), - record_type, - str_to_func_map.at(buffer->name()->str())); -} - -python_frontend::RecordFunctor* deserializeReductionRecord( - std::function&, - bool, - nvfuser::DataType)> fusion_op, - RecordType record_type, - const RecordFunctor* buffer) { - auto data = buffer->data_as_Reduction(); - return new python_frontend::ReductionOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->name()->str(), - record_type, - fusion_op, - parseVector(data->axes()), - data->keep_dim(), - mapToNvfuserDtype(data->dtype())); -} - -python_frontend::RecordFunctor* deserializeScanOpRecord( - std::function fusion_op, - RecordType record_type, - const RecordFunctor* buffer) { - auto data = buffer->data_as_ScanOp(); - BinaryOpType op_type; - if (record_type == RecordType::ScanOpCumsum) { - op_type = BinaryOpType::Add; - } else { - NVF_THROW("Only cumsum scan operation is supported."); - } - - return new python_frontend::ScanOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->name()->str(), - record_type, - fusion_op, - data->dim(), - op_type); -} - -void RecordFunctorFactory::registerAllParsers() { - auto deserializeStartRecord = [](const RecordFunctor* buffer) { - return new python_frontend::StartRecord(); - }; - registerParser(RecordType::Start, deserializeStartRecord); - - auto deserializeEndRecord = [](const RecordFunctor* buffer) { - return new python_frontend::EndRecord(); - }; - registerParser(RecordType::End, deserializeEndRecord); - - // Unary Ops - auto unary_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord( - unary_tv, RecordType::Unary_TV, buffer); - }; - registerParser(RecordType::Unary_TV, unary_tv_parser); - - auto unary_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord( - unary_val, RecordType::Unary_VAL, buffer); - }; - registerParser(RecordType::Unary_VAL, unary_val_parser); - - // Binary Ops - auto binary_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - binary_tv_fn, - TensorView*, - TensorView*, - TensorView*>(binary_tv, RecordType::Binary_TV, buffer); - }; - registerParser(RecordType::Binary_TV, binary_tv_parser); - - auto binary_tv_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - binary_tv_val_fn, - TensorView*, - TensorView*, - Val*>(binary_tv_val, RecordType::Binary_TV_VAL, buffer); - }; - registerParser(RecordType::Binary_TV_VAL, binary_tv_val_parser); - - auto binary_val_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - binary_val_tv_fn, - TensorView*, - Val*, - TensorView*>(binary_val_tv, RecordType::Binary_VAL_TV, buffer); - }; - registerParser(RecordType::Binary_VAL_TV, binary_val_tv_parser); - - auto binary_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord( - binary_val, RecordType::Binary_VAL, buffer); - }; - registerParser(RecordType::Binary_VAL, binary_val_parser); - - // Ternary Ops - auto ternary_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_tv_fn, - TensorView*, - TensorView*, - TensorView*, - TensorView*>(ternary_tv, RecordType::Ternary_TV, buffer); - }; - registerParser(RecordType::Ternary_TV, ternary_tv_parser); - - auto ternary_tv_tv_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_tv_tv_val_fn, - TensorView*, - TensorView*, - TensorView*, - Val*>(ternary_tv_tv_val, RecordType::Ternary_TV_TV_VAL, buffer); - }; - registerParser(RecordType::Ternary_TV_TV_VAL, ternary_tv_tv_val_parser); - - auto ternary_tv_val_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_tv_val_tv_fn, - TensorView*, - TensorView*, - Val*, - TensorView*>(ternary_tv_val_tv, RecordType::Ternary_TV_VAL_TV, buffer); - }; - registerParser(RecordType::Ternary_TV_VAL_TV, ternary_tv_val_tv_parser); - - auto ternary_val_tv_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_val_tv_tv_fn, - TensorView*, - Val*, - TensorView*, - TensorView*>(ternary_val_tv_tv, RecordType::Ternary_VAL_TV_TV, buffer); - }; - registerParser(RecordType::Ternary_VAL_TV_TV, ternary_val_tv_tv_parser); - - auto ternary_val_val_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_val_val_tv_fn, - TensorView*, - Val*, - Val*, - TensorView*>( - ternary_val_val_tv, RecordType::Ternary_VAL_VAL_TV, buffer); - }; - registerParser(RecordType::Ternary_VAL_VAL_TV, ternary_val_val_tv_parser); - - auto ternary_tv_val_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_tv_val_val_fn, - TensorView*, - TensorView*, - Val*, - Val*>(ternary_tv_val_val, RecordType::Ternary_TV_VAL_VAL, buffer); - }; - registerParser(RecordType::Ternary_TV_VAL_VAL, ternary_tv_val_val_parser); - - auto ternary_val_tv_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_val_tv_val_fn, - TensorView*, - Val*, - TensorView*, - Val*>(ternary_val_tv_val, RecordType::Ternary_VAL_TV_VAL, buffer); - }; - registerParser(RecordType::Ternary_VAL_TV_VAL, ternary_val_tv_val_parser); - - auto ternary_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord( - ternary_val, RecordType::Ternary_VAL, buffer); - }; - registerParser(RecordType::Ternary_VAL, ternary_val_parser); - - // Ternary-Alpha Ops - auto ternary_alpha_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_tv_fn, - TensorView*, - TensorView*, - TensorView*, - TensorView*, - Val*>(ternary_alpha_tv, RecordType::Ternary_Alpha_TV, buffer); - }; - registerParser(RecordType::Ternary_Alpha_TV, ternary_alpha_tv_parser); - - auto ternary_alpha_tv_tv_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_tv_tv_val_fn, - TensorView*, - TensorView*, - TensorView*, - Val*, - Val*>( - ternary_alpha_tv_tv_val, RecordType::Ternary_Alpha_TV_TV_VAL, buffer); - }; - registerParser( - RecordType::Ternary_Alpha_TV_TV_VAL, ternary_alpha_tv_tv_val_parser); - - auto ternary_alpha_tv_val_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_tv_val_tv_fn, - TensorView*, - TensorView*, - Val*, - TensorView*, - Val*>( - ternary_alpha_tv_val_tv, RecordType::Ternary_Alpha_TV_VAL_TV, buffer); - }; - registerParser( - RecordType::Ternary_Alpha_TV_VAL_TV, ternary_alpha_tv_val_tv_parser); - - auto ternary_alpha_val_tv_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_val_tv_tv_fn, - TensorView*, - Val*, - TensorView*, - TensorView*, - Val*>( - ternary_alpha_val_tv_tv, RecordType::Ternary_Alpha_VAL_TV_TV, buffer); - }; - registerParser( - RecordType::Ternary_Alpha_VAL_TV_TV, ternary_alpha_val_tv_tv_parser); - - auto ternary_alpha_val_val_tv_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_val_val_tv_fn, - TensorView*, - Val*, - Val*, - TensorView*, - Val*>( - ternary_alpha_val_val_tv, RecordType::Ternary_Alpha_VAL_VAL_TV, buffer); - }; - registerParser( - RecordType::Ternary_Alpha_VAL_VAL_TV, ternary_alpha_val_val_tv_parser); - - auto ternary_alpha_tv_val_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_tv_val_val_fn, - TensorView*, - TensorView*, - Val*, - Val*, - Val*>( - ternary_alpha_tv_val_val, RecordType::Ternary_Alpha_TV_VAL_VAL, buffer); - }; - registerParser( - RecordType::Ternary_Alpha_TV_VAL_VAL, ternary_alpha_tv_val_val_parser); - - auto ternary_alpha_val_tv_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_val_tv_val_fn, - TensorView*, - Val*, - TensorView*, - Val*, - Val*>( - ternary_alpha_val_tv_val, RecordType::Ternary_Alpha_VAL_TV_VAL, buffer); - }; - registerParser( - RecordType::Ternary_Alpha_VAL_TV_VAL, ternary_alpha_val_tv_val_parser); - - auto ternary_alpha_val_parser = [&](const RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_val_fn, - Val*, - Val*, - Val*, - Val*, - Val*>(ternary_alpha_val, RecordType::Ternary_Alpha_VAL, buffer); - }; - registerParser(RecordType::Ternary_Alpha_VAL, ternary_alpha_val_parser); - - auto deserializeSdpaFwdRecord = [&](const RecordFunctor* buffer) { - return new python_frontend::SdpaFwdOpRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser(RecordType::SdpaFwdOp, deserializeSdpaFwdRecord); - - auto deserializeSdpaBwdRecord = [&](const RecordFunctor* buffer) { - return new python_frontend::SdpaBwdOpRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser(RecordType::SdpaBwdOp, deserializeSdpaBwdRecord); - - auto deserializeEmbeddingFwdRecord = [&](const RecordFunctor* buffer) { - return new python_frontend::EmbeddingFwdOpRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser(RecordType::EmbeddingFwdOp, deserializeEmbeddingFwdRecord); - - // END OpRecord Parsers - - // START Reduction Parsers - auto reduction_max_parser = [](const RecordFunctor* buffer) { - return deserializeReductionRecord(max, RecordType::ReductionMax, buffer); - }; - registerParser(RecordType::ReductionMax, reduction_max_parser); - - auto reduction_min_parser = [](const RecordFunctor* buffer) { - return deserializeReductionRecord(min, RecordType::ReductionMin, buffer); - }; - registerParser(RecordType::ReductionMin, reduction_min_parser); - - auto reduction_prod_parser = [](const RecordFunctor* buffer) { - return deserializeReductionRecord(prod, RecordType::ReductionProd, buffer); - }; - registerParser(RecordType::ReductionProd, reduction_prod_parser); - - auto reduction_sum_parser = [](const RecordFunctor* buffer) { - return deserializeReductionRecord(sum, RecordType::ReductionSum, buffer); - }; - registerParser(RecordType::ReductionSum, reduction_sum_parser); - // END Reduction Parsers - - // START ScanOp Parsers - auto scanop_cumsum_parser = [](const RecordFunctor* buffer) { - return deserializeScanOpRecord(cumsum, RecordType::ScanOpCumsum, buffer); - }; - registerParser(RecordType::ScanOpCumsum, scanop_cumsum_parser); - // END ScanOp Parsers - - auto deserializeBatchNormRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_BatchNorm(); - return new python_frontend::BatchNormOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - data->training(), - data->channels_last()); - }; - registerParser(RecordType::BatchNormOp, deserializeBatchNormRecord); - - auto deserializeBroadcastRecord = [](const RecordFunctor* buffer) { - return new python_frontend::BroadcastOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->name()->str(), - parseBoolVector(buffer->data_as_Broadcast()->broadcast_dims())); - }; - registerParser(RecordType::BroadcastOp, deserializeBroadcastRecord); - - auto deserializeCatRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Cat(); - return new python_frontend::CatOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - data->dim(), - data->manual_padding()); - }; - registerParser(RecordType::CatOp, deserializeCatRecord); - - auto deserializeBroadcastInDimRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_BroadcastInDim(); - return new python_frontend::BroadcastInDimOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - data->output_size(), - parseVector(data->broadcast_dims())); - }; - registerParser(RecordType::BroadcastInDim, deserializeBroadcastInDimRecord); - - auto deserializeExpandRecord = [](const RecordFunctor* buffer) { - return new python_frontend::ExpandOpRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser(RecordType::ExpandOp, deserializeExpandRecord); - - auto deserializeCastTvRecord = [](const RecordFunctor* buffer) { - std::function fusion_op = - static_cast(castOp); - return new python_frontend::CastOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->name()->str(), - RecordType::CastTv, - fusion_op, - mapToNvfuserDtype(buffer->data_as_Dtype()->dtype())); - }; - registerParser(RecordType::CastTv, deserializeCastTvRecord); - - auto deserializeCastValRecord = [](const RecordFunctor* buffer) { - std::function fusion_op = - static_cast(castOp); - return new python_frontend::CastOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->name()->str(), - RecordType::CastVal, - fusion_op, - mapToNvfuserDtype(buffer->data_as_Dtype()->dtype())); - }; - registerParser(RecordType::CastVal, deserializeCastValRecord); - - auto deserializeScalarRecord = [](const RecordFunctor* buffer) { - return new python_frontend::ScalarRecord( - parseStateArgs(buffer->outputs()), - deserializePolymorphicValue(buffer->data_as_Scalar()), - mapToNvfuserDtype(buffer->data_as_Scalar()->dtype())); - }; - registerParser(RecordType::Scalar, deserializeScalarRecord); - - auto deserializeFullRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_TensorCreationSymbolic(); - return new python_frontend::FullOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - mapToNvfuserDtype(data->dtype())); - }; - registerParser(RecordType::FullOp, deserializeFullRecord); - - auto deserializeIotaRecord = [](const RecordFunctor* buffer) { - return new python_frontend::IotaOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - mapToNvfuserDtype(buffer->data_as_Dtype()->dtype())); - }; - registerParser(RecordType::IotaOp, deserializeIotaRecord); - - auto deserializeGatherRecord = [](const RecordFunctor* buffer) { - return new python_frontend::GatherOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->data_as_Dimension()->dim()); - }; - registerParser(RecordType::GatherOp, deserializeGatherRecord); - - auto deserializeTakeAlongAxisRecord = [](const RecordFunctor* buffer) { - return new python_frontend::TakeAlongAxisOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->data_as_Dimension()->dim()); - }; - registerParser(RecordType::TakeAlongAxisOp, deserializeTakeAlongAxisRecord); - - auto deserializeIndexSelectRecord = [](const RecordFunctor* buffer) { - return new python_frontend::IndexSelectOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->data_as_Dimension()->dim()); - }; - registerParser(RecordType::IndexSelectOp, deserializeIndexSelectRecord); - - auto deserializeScatterRecord = [](const RecordFunctor* buffer) { - return new python_frontend::ScatterOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->data_as_Dimension()->dim()); - }; - registerParser(RecordType::ScatterOp, deserializeScatterRecord); - - auto deserializeIndexPutAccumulateRecord = [](const RecordFunctor* buffer) { - return new python_frontend::IndexPutAccumulateOpRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser( - RecordType::IndexPutAccumulateOp, deserializeIndexPutAccumulateRecord); - - auto deserializeSelectRecord = [](const RecordFunctor* buffer) { - return new python_frontend::SelectOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->data_as_Dimension()->dim()); - }; - registerParser(RecordType::SelectOp, deserializeSelectRecord); - - auto deserializeOutputTvRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Output(); - return new python_frontend::OutputRecord( - parseStateArgs(buffer->args()), - RecordType::OutputTv, - parseVector(data->stride_order())); - }; - registerParser(RecordType::OutputTv, deserializeOutputTvRecord); - - auto deserializeOutputValRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Output(); - return new python_frontend::OutputRecord( - parseStateArgs(buffer->args()), - RecordType::OutputVal, - parseVector(data->stride_order())); - }; - registerParser(RecordType::OutputVal, deserializeOutputValRecord); - - auto deserializePadRecord = [](const RecordFunctor* buffer) { - return new python_frontend::PadOpRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser(RecordType::PadOp, deserializePadRecord); - - auto deserializePermuteRecord = [](const RecordFunctor* buffer) { - return new python_frontend::DimsOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - parseVector(buffer->data_as_Dims()->dims()), - buffer->name()->str()); - }; - registerParser(RecordType::PermuteOp, deserializePermuteRecord); - - auto deserializeStrideOrderRecord = [](const RecordFunctor* buffer) { - return new python_frontend::DimsOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - parseVector(buffer->data_as_Dims()->dims()), - buffer->name()->str()); - }; - registerParser(RecordType::StrideOrderOp, deserializeStrideOrderRecord); - - auto deserializeNormalDistRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_TensorCreationSymbolic(); - return new python_frontend::RandomDistOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - mapToNvfuserDtype(data->dtype())); - }; - registerParser(RecordType::NormalDistOp, deserializeNormalDistRecord); - - auto deserializeUniformDistRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_TensorCreationSymbolic(); - return new python_frontend::RandomDistOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - mapToNvfuserDtype(data->dtype())); - }; - registerParser(RecordType::UniformDistOp, deserializeUniformDistRecord); - - auto deserializeReshapeRecord = [](const RecordFunctor* buffer) { - return new python_frontend::ReshapeOpRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser(RecordType::ReshapeOp, deserializeReshapeRecord); - - auto deserializeSliceRecord = [](const RecordFunctor* buffer) { - return new python_frontend::SliceOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - buffer->data_as_Slice()->manual_normalization()); - }; - registerParser(RecordType::SliceOp, deserializeSliceRecord); - - auto deserializeSqueezeRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Squeeze(); - return new python_frontend::SqueezeOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - parseVector(data->squeeze_dims()), - data->squeeze_expanded()); - }; - registerParser(RecordType::SqueezeOp, deserializeSqueezeRecord); - - auto deserializeTensorRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Tensor(); - - std::vector> contiguity; - std::transform( - data->contiguity()->cbegin(), - data->contiguity()->cend(), - std::back_inserter(contiguity), - mapContiguityEnumToOptional); - - return new python_frontend::TensorRecord( - parseStateArgs(buffer->outputs()), - parseVector(data->sizes()), - contiguity, - mapToNvfuserDtype(data->dtype()), - data->is_cpu(), - parseVector(data->stride_order())); - }; - registerParser(RecordType::Tensor, deserializeTensorRecord); - - auto deserializeTensorSizesRecord = [](const RecordFunctor* buffer) { - return new python_frontend::TensorSizesRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser(RecordType::TensorSizes, deserializeTensorSizesRecord); - - auto deserializeShapeOpRecord = [](const RecordFunctor* buffer) { - return new python_frontend::ShapeOpRecord( - parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); - }; - registerParser(RecordType::ShapeOp, deserializeShapeOpRecord); - - auto deserializeSizeOpRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Size(); - return new python_frontend::SizeOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - data->dim()); - }; - registerParser(RecordType::SizeOp, deserializeSizeOpRecord); - - auto deserializeAtOpRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_At(); - return new python_frontend::AtOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - data->index()); - }; - registerParser(RecordType::AtOp, deserializeAtOpRecord); - - auto deserializeVarianceRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Norm(); - return new python_frontend::VarianceOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - parseVector(data->axes()), - data->correction(), - data->keep_dim()); - }; - registerParser(RecordType::VarianceOp, deserializeVarianceRecord); - - auto deserializeVarianceMeanRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Norm(); - return new python_frontend::VarianceMeanOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - parseVector(data->axes()), - data->correction(), - data->keep_dim()); - }; - registerParser(RecordType::VarianceMeanOp, deserializeVarianceMeanRecord); - - auto deserializeVectorRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Vector(); - return new python_frontend::VectorRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - mapToNvfuserDtype(data->dtype())); - }; - registerParser(RecordType::Vector, deserializeVectorRecord); - - auto deserializeWelfordRecord = [](const RecordFunctor* buffer) { - return new python_frontend::WelfordOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - parseVector(buffer->data_as_Welford()->axes())); - }; - registerParser(RecordType::WelfordOp, deserializeWelfordRecord); - - auto deserializeArgsortRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Sort(); - return new python_frontend::ArgsortOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - data->dim(), - data->descending(), - data->stable()); - }; - registerParser(RecordType::ArgsortOp, deserializeArgsortRecord); - - auto deserializeTopKRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_TopK(); - return new python_frontend::TopKOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - data->dim(), - data->largest(), - data->sorted()); - }; - registerParser(RecordType::TopKOp, deserializeTopKRecord); -} - -void RecordFunctorFactory::setupFunctionMaps() { -#define NVFUSER_UNARY_TV_OP(op_str, op_name) \ - unary_tv.emplace( \ - ("ops." op_str), static_cast(op_name)); \ - unary_val.emplace(("ops." op_str), static_cast(op_name)); - -#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \ - binary_tv_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); - -#define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \ - binary_tv.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); - -#define NVFUSER_TERNARY_TV_ONLY_OP(op_str, op_name) \ - ternary_tv.emplace( \ - ("ops." op_str), \ - static_cast( \ - op_name)); - -#define NVFUSER_BINARY_TV_OP(op_str, op_name) \ - binary_tv.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - binary_val.emplace( \ - ("ops." op_str), static_cast(op_name)); \ - binary_tv_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - binary_val_tv.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); - -#define NVFUSER_BINARY_TV_ALPHA_OP(op_str, op_name) \ - ternary_val.emplace( \ - ("ops." op_str), static_cast(op_name)); \ - ternary_tv_tv_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_tv_val_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_val_tv_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); - -#define NVFUSER_TERNARY_TV_OP(op_str, op_name) \ - ternary_tv.emplace( \ - ("ops." op_str), \ - static_cast( \ - op_name)); \ - ternary_val.emplace( \ - ("ops." op_str), static_cast(op_name)); \ - ternary_tv_tv_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_tv_val_tv.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_val_tv_tv.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_val_val_tv.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_tv_val_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_val_tv_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); - -#define NVFUSER_THRESHOLD_TV_OP(op_str, op_name) \ - ternary_val.emplace( \ - ("ops." op_str), static_cast(op_name)); \ - ternary_tv_val_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); - -#define NVFUSER_TERNARY_TV_ALPHA_OP(op_str, op_name) \ - ternary_alpha_tv.emplace( \ - ("ops." op_str), \ - static_cast< \ - TensorView* (*)(TensorView*, TensorView*, TensorView*, Val*)>( \ - op_name)); \ - ternary_alpha_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_alpha_tv_tv_val.emplace( \ - ("ops." op_str), \ - static_cast( \ - op_name)); \ - ternary_alpha_tv_val_tv.emplace( \ - ("ops." op_str), \ - static_cast( \ - op_name)); \ - ternary_alpha_val_tv_tv.emplace( \ - ("ops." op_str), \ - static_cast( \ - op_name)); \ - ternary_alpha_val_val_tv.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_alpha_tv_val_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); \ - ternary_alpha_val_tv_val.emplace( \ - ("ops." op_str), \ - static_cast(op_name)); - - NVFUSER_UNARY_TV_OP("abs", abs) - NVFUSER_UNARY_TV_OP("acos", acos) - NVFUSER_UNARY_TV_OP("acosh", acosh) - NVFUSER_UNARY_TV_OP("asin", asin) - NVFUSER_UNARY_TV_OP("asinh", asinh) - NVFUSER_UNARY_TV_OP("atan", atan) - NVFUSER_UNARY_TV_OP("atanh", atanh) - NVFUSER_UNARY_TV_OP("ceil", ceil) - NVFUSER_UNARY_TV_OP("cos", cos) - NVFUSER_UNARY_TV_OP("cosh", cosh) - NVFUSER_UNARY_TV_OP("exp", exp) - NVFUSER_UNARY_TV_OP("exp2", exp2) - NVFUSER_UNARY_TV_OP("expm1", expm1) - NVFUSER_UNARY_TV_OP("erf", erf) - NVFUSER_UNARY_TV_OP("erfc", erfc) - NVFUSER_UNARY_TV_OP("erfinv", erfinv) - NVFUSER_UNARY_TV_OP("erfcinv", erfcinv) - NVFUSER_UNARY_TV_OP("floor", floor) - NVFUSER_UNARY_TV_OP("frac", frac) - NVFUSER_UNARY_TV_OP("lgamma", lgamma) - NVFUSER_UNARY_TV_OP("logical_not", logical_not) - NVFUSER_UNARY_TV_OP("log", log) - NVFUSER_UNARY_TV_OP("log10", log10) - NVFUSER_UNARY_TV_OP("log1p", log1p) - NVFUSER_UNARY_TV_OP("log2", log2) - NVFUSER_UNARY_TV_OP("neg", neg) - NVFUSER_UNARY_TV_OP("bitwise_not", bitwise_not) - NVFUSER_UNARY_TV_OP("relu", relu) - NVFUSER_UNARY_TV_OP("rand_like", rand_like) - NVFUSER_UNARY_TV_OP("randn_like", randn_like) - NVFUSER_UNARY_TV_OP("reciprocal", reciprocal) - NVFUSER_UNARY_TV_OP("round", round) - NVFUSER_UNARY_TV_OP("rsqrt", rsqrt) - NVFUSER_UNARY_TV_OP("segment_set", segment_set) - NVFUSER_UNARY_TV_OP("set", set) - NVFUSER_UNARY_TV_OP("sign", sign) - NVFUSER_UNARY_TV_OP("sigmoid", sigmoid) - NVFUSER_UNARY_TV_OP("signbit", signbit) - NVFUSER_UNARY_TV_OP("silu", silu) - NVFUSER_UNARY_TV_OP("sin", sin) - NVFUSER_UNARY_TV_OP("sinh", sinh) - NVFUSER_UNARY_TV_OP("sqrt", sqrt) - NVFUSER_UNARY_TV_OP("tan", tan) - NVFUSER_UNARY_TV_OP("tanh", tanh) - NVFUSER_UNARY_TV_OP("trunc", trunc) - NVFUSER_UNARY_TV_OP("isfinite", isfinite) - NVFUSER_UNARY_TV_OP("isinf", isinf) - NVFUSER_UNARY_TV_OP("isnan", isnan) - NVFUSER_UNARY_TV_OP("isneginf", isneginf) - NVFUSER_UNARY_TV_OP("isposinf", isposinf) - NVFUSER_UNARY_TV_OP("isreal", isreal) - NVFUSER_UNARY_TV_OP("real", real) - NVFUSER_UNARY_TV_OP("imag", imag) - - NVFUSER_UNARY_TV_ALPHA_OP("triu", triu) - - NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul) - NVFUSER_TERNARY_TV_ONLY_OP( - "grouped_mm", - [](TensorView* mat1, TensorView* mat2, TensorView* offsets) { - ScaledTensorView scaled_out = grouped_mm(mat1, mat2, offsets); - return scaled_out.tv; - }) - NVFUSER_BINARY_TV_ONLY_OP("linear", linear) - NVFUSER_TERNARY_TV_ONLY_OP("linear", linear) - - NVFUSER_BINARY_TV_OP("add", add) - NVFUSER_BINARY_TV_OP("atan2", atan2) - NVFUSER_BINARY_TV_OP("div", div) - NVFUSER_BINARY_TV_OP("truediv", truediv) - NVFUSER_BINARY_TV_OP("fmod", fmod) - NVFUSER_BINARY_TV_OP("mul", mul) - NVFUSER_BINARY_TV_OP("nextafter", nextafter) - NVFUSER_BINARY_TV_OP("pow", pow) - NVFUSER_BINARY_TV_OP("remainder", remainder) - NVFUSER_BINARY_TV_OP("sub", sub) - NVFUSER_BINARY_TV_OP("minimum", minimum) - NVFUSER_BINARY_TV_OP("maximum", maximum) - NVFUSER_BINARY_TV_OP("mod", mod) - NVFUSER_BINARY_TV_OP("eq", eq) - NVFUSER_BINARY_TV_OP("ge", ge) - NVFUSER_BINARY_TV_OP("gt", gt) - NVFUSER_BINARY_TV_OP("le", le) - NVFUSER_BINARY_TV_OP("lt", lt) - NVFUSER_BINARY_TV_OP("ne", ne) - NVFUSER_BINARY_TV_OP("bitwise_and", bitwise_and) - NVFUSER_BINARY_TV_OP("bitwise_or", bitwise_or) - NVFUSER_BINARY_TV_OP("bitwise_xor", bitwise_xor) - NVFUSER_BINARY_TV_OP("logical_and", logical_and) - NVFUSER_BINARY_TV_OP("logical_or", logical_or) - NVFUSER_BINARY_TV_OP("bitwise_left_shift", bitwise_left_shift) - NVFUSER_BINARY_TV_OP("bitwise_right_shift", bitwise_right_shift) - NVFUSER_BINARY_TV_OP("logical_right_shift", logical_right_shift) - NVFUSER_BINARY_TV_OP("gcd", gcd) - NVFUSER_BINARY_TV_OP("ceilDiv", ceilDiv) - - NVFUSER_BINARY_TV_ALPHA_OP("add_alpha", add_alpha) - NVFUSER_BINARY_TV_ALPHA_OP("sub_alpha", sub_alpha) - - NVFUSER_TERNARY_TV_OP("lerp", lerp) - NVFUSER_TERNARY_TV_OP("where", where) - - // The following ops behave like TernaryOps but are only TV_VAL_VAL - ternary_tv_val_val.emplace( - "ops.rand_like", - static_cast(rand_like)); - ternary_tv_val_val.emplace( - "ops.randn_like", - static_cast(randn_like)); - - NVFUSER_THRESHOLD_TV_OP("clamp", clamp) - NVFUSER_THRESHOLD_TV_OP("threshold", threshold) - - NVFUSER_TERNARY_TV_ALPHA_OP("addcmul", addcmul) -} - -} // namespace nvfuser::serde diff --git a/csrc/serde/fusion_record.h b/csrc/serde/fusion_record.h deleted file mode 100644 index bb6c7463a9e..00000000000 --- a/csrc/serde/fusion_record.h +++ /dev/null @@ -1,124 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once -#include -#include -#include - -namespace nvfuser::serde { - -// Forward definition for RecordFunctor -struct RecordFunctor; - -// OpRecord Function Signatures -// ======================================================================== -// Unary Functions -typedef std::function unary_tv_fn; -typedef std::function unary_val_fn; - -// ======================================================================== -// Binary Functions -typedef std::function binary_tv_fn; -typedef std::function binary_val_fn; -typedef std::function binary_tv_val_fn; -typedef std::function binary_val_tv_fn; - -// ======================================================================== -// Ternary Functions -// Binary with Alpha Functions -typedef std::function - ternary_tv_fn; -typedef std::function ternary_val_fn; -typedef std::function - ternary_tv_tv_val_fn; -typedef std::function - ternary_tv_val_tv_fn; -typedef std::function - ternary_val_tv_tv_fn; -typedef std::function - ternary_val_val_tv_fn; -typedef std::function - ternary_tv_val_val_fn; -typedef std::function - ternary_val_tv_val_fn; - -// ======================================================================== -// Ternary with Alpha Functions -typedef std::function - ternary_alpha_tv_fn; -typedef std::function ternary_alpha_val_fn; -typedef std::function - ternary_alpha_tv_tv_val_fn; -typedef std::function - ternary_alpha_tv_val_tv_fn; -typedef std::function - ternary_alpha_val_tv_tv_fn; -typedef std::function - ternary_alpha_val_val_tv_fn; -typedef std::function - ternary_alpha_tv_val_val_fn; -typedef std::function - ternary_alpha_val_tv_val_fn; -// ======================================================================== - -//! The RecordFunctorFactory class is used to deserialize the flatbuffer -//! RecordFunctor table. We create an enum type for each RecordFunctor class. -//! Each template specialization has a unique RecordType and parser function. -class RecordFunctorFactory - : public Factory { - public: - RecordFunctorFactory() - : Factory((nvfuser::toUnderlying(RecordType::MAX) + 1)) { - setupFunctionMaps(); - registerAllParsers(); - } - - private: - void registerAllParsers(); - void setupFunctionMaps(); - - // String to Operation maps - // Unary Functions - std::unordered_map unary_tv; - std::unordered_map unary_val; - - // Binary Functions - std::unordered_map binary_tv; - std::unordered_map binary_val; - std::unordered_map binary_tv_val; - std::unordered_map binary_val_tv; - - // Ternary Functions - // Binary with Alpha Functions - std::unordered_map ternary_tv; - std::unordered_map ternary_val; - std::unordered_map ternary_tv_tv_val; - std::unordered_map ternary_tv_val_tv; - std::unordered_map ternary_val_tv_tv; - std::unordered_map ternary_val_val_tv; - std::unordered_map ternary_tv_val_val; - std::unordered_map ternary_val_tv_val; - - // Ternary with Alpha Functions - std::unordered_map ternary_alpha_tv; - std::unordered_map ternary_alpha_val; - std::unordered_map - ternary_alpha_tv_tv_val; - std::unordered_map - ternary_alpha_tv_val_tv; - std::unordered_map - ternary_alpha_val_tv_tv; - std::unordered_map - ternary_alpha_val_val_tv; - std::unordered_map - ternary_alpha_tv_val_val; - std::unordered_map - ternary_alpha_val_tv_val; -}; - -} // namespace nvfuser::serde diff --git a/csrc/type.h b/csrc/type.h index a2c02b18c5b..ea40db96d09 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -1006,9 +1006,9 @@ AdjustLastDim getLastDimAdjustment(const DataType& dtype); NVF_API std::ostream& operator<<(std::ostream&, const ValType); std::ostream& operator<<(std::ostream&, const PredicateType); NVF_API std::ostream& operator<<(std::ostream&, const DataType); -std::ostream& operator<<(std::ostream&, const UnaryOpType); +NVF_API std::ostream& operator<<(std::ostream&, const UnaryOpType); NVF_API std::ostream& operator<<(std::ostream&, const BinaryOpType); -std::ostream& operator<<(std::ostream&, const TernaryOpType); +NVF_API std::ostream& operator<<(std::ostream&, const TernaryOpType); std::ostream& operator<<(std::ostream&, const RNGOpType); NVF_API std::ostream& operator<<(std::ostream&, const ParallelType); NVF_API std::ostream& operator<<(std::ostream&, const MemoryType); diff --git a/doc/dev/python_scheduling/autotune_inner_reduction.py b/doc/dev/python_scheduling/autotune_inner_reduction.py index 3b0c40fc1de..e419e99c827 100644 --- a/doc/dev/python_scheduling/autotune_inner_reduction.py +++ b/doc/dev/python_scheduling/autotune_inner_reduction.py @@ -5,6 +5,8 @@ import torch import itertools + +# TODO Update script to use nvfuser_direct module from nvfuser import FusionDefinition, SchedulerType, DataType, ParallelType from enum import Enum from dataclasses import dataclass diff --git a/doc/dev/python_scheduling/autotune_matmul.py b/doc/dev/python_scheduling/autotune_matmul.py index da44c72faa0..4accb2174ec 100644 --- a/doc/dev/python_scheduling/autotune_matmul.py +++ b/doc/dev/python_scheduling/autotune_matmul.py @@ -5,6 +5,8 @@ import torch import itertools + +# TODO Update script to use nvfuser_direct module from nvfuser import FusionDefinition, SchedulerType # Description of the problem diff --git a/doc/dev/python_scheduling/autotune_pointwise.py b/doc/dev/python_scheduling/autotune_pointwise.py index cbda494b3cb..6ced812d0ae 100644 --- a/doc/dev/python_scheduling/autotune_pointwise.py +++ b/doc/dev/python_scheduling/autotune_pointwise.py @@ -6,6 +6,8 @@ import torch import itertools import math + +# TODO Update script to use nvfuser_direct module from nvfuser import FusionDefinition, SchedulerType, DataType from dataclasses import dataclass from enum import Enum diff --git a/doc/dev/python_scheduling/profile_matmul.py b/doc/dev/python_scheduling/profile_matmul.py index 073b9ad46a4..da433cc56d7 100644 --- a/doc/dev/python_scheduling/profile_matmul.py +++ b/doc/dev/python_scheduling/profile_matmul.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +# TODO Update script to use nvfuser_direct module from nvfuser import ( FusionDefinition, SchedulerType, diff --git a/python/nvfuser/README.md b/python/nvfuser/README.md deleted file mode 100644 index 534b382995b..00000000000 --- a/python/nvfuser/README.md +++ /dev/null @@ -1,210 +0,0 @@ - - -# nvFuser Python Frontend - -This frontend allows for a user to describe the set of operations for nvFuser to fuse via 1 or more kernels. This frontend is intended to be an integration point with PyTorch or standalone applications. - -# Usage - -## Example 1 - Define and Execute a Fusion - -```python -import torch -from nvfuser import FusionDefinition, DataType - -with FusionDefinition() as fd : - t0 = fd.define_tensor(shape=[-1, 1, -1], - contiguity=[True, None, True], - dtype=DataType.Float) - t1 = fd.define_tensor([-1, -1, -1]) - c0 = fd.define_scalar(3.0) - - t2 = fd.ops.add(t0, t1) - t3 = fd.ops.mul(t2, c0) - t4 = fd.ops.sum(t3, [-1], False, DataType.Float) - - fd.add_output(t4) - -input1 = torch.ones(2, 1, 8, device='cuda') -input2 = torch.ones(2, 4, 8, device='cuda') - -nvf_out = fd.execute([input1, input2])[0] -``` - -## Example 2 - Lookup and Execute a `FusionDefinition` Based on Id - -```python -fid = 0 -fd = FusionDefinition(fid) - -input1 = torch.ones(2, 1, 8, device='cuda') -input2 = torch.ones(2, 4, 8, device='cuda') - -nvf_out = fd.execute([input1, input2])[0] -``` - -## Components - -### `FusionDefinition` Context Manager - Interface for Defining Fusions -* `execute([inputs])`: Allows you to execute the currently defined fusion with a list of given inputs and returns a list of tensors. -* `id()`: Returns the fusion id for a given definition. -* `fusion_ir()`: Returns the Fusion IR (Intermediate Representation) as a string. -* `last_cuda_code(intrinsic_code=False)`: Returns the generated CUDA code for the last executed inputs. -* `debug_output()`: Returns debug output if capture_debug_output=True was used during execution. - -#### Defining Input Tensors -_All intermediate tensors are created by operations. Constant tensors do not exist._ - -There are 3 ways to define tensors that will be enumerated below. - -##### 1.) Defining tensors with symbolic dimensions -This interface tells nvFuser that the tensor has symbolic dimensions that are not necessarily contiguous in memory. Use `-1` for each symbolic dimension. The user also has the ability to specify a data type. The default type is `Float`. -```python -t0 = fd.define_tensor([-1, -1, -1]) # 3D tensor -t1 = fd.define_tensor([-1, -1], dtype=DataType.Half) # 2D tensor -``` - -##### 2.) Defining tensors by a list of concrete sizes and a list of strides -The `sizes` parameter defines the number of dimensions and the size of each dimension. The `strides` parameter has to have the same number of dimensions as the `sizes` parameter. -nvFuser translates the concrete sizes and strides into symbolic sizes and contiguity information that can be directly defined via the next way to define tensors. This allows the user to directly take a Pytorch defined tensor and query its sizes and strides in order to apply them in the definition. -```python -t0 = fd.define_tensor(sizes=[2, 4, 6], strides=[24, 6, 1], dtype=DataType.Half) -``` - -##### 3.) Defining tensors by a list of symbolic sizes and a list of contiguity information -The list of symbolic sizes defines the number of dimensions and `-1` is given for each dimension unless it is a broadcast dimension that is defined with a `1`. The contiguity information is viewed from right to left. A `True` definition indicates the current dimension is contiguous with the dimension to its right. - -```python -t0 = fd.define_tensor(shape=[-1, 1, -1], contiguity=[True, None, True], dtype=DataType.Float) -``` - -#### Defining Input Scalars -_All intermediate scalars, except for constants, are created by operations._ - -The only thing the user has to define for a scalar is its type. - -```python -s0 = fd.define_scalar(dtype=DataType.Half) -``` - -#### Defining Constant Scalars - -Constants can be of types: `Bool`, `ComplexDouble`, `Double`, or `Int`. The definition only takes a constant and the type is inferred by the constant's type. - -```python -c0 = fd.define_scalar(3.0) -``` - -**Note**: you cannot use Python literals directly: -```python -# Correct - define scalar constant first -scalar_const = fd.define_scalar(2.0) -result = fd.ops.mul(tensor, scalar_const) - -# Incorrect - this will cause a TypeError -result = fd.ops.mul(tensor, 2.0) # ERROR! -``` - -#### Defining Operations - -Operators are added with the following notation: -```python -output = fd.ops.foo(arg1, ... ) -``` - - -You can see a supported list of operations with the following query: -```python -python -c "from nvfuser import FusionDefinition; help(FusionDefinition.Operators)" -``` -#### Notating Outputs - -The `FusionDefinition` `add_output` method is used to indicate an intermediate is an output to the fusion. - -```python -add_output(output: Tensor) -# or -add_output(output: Scalar) -``` - -# Complete Working Example - -Here's a complete, tested example that demonstrates correct API usage: - -```python -import torch -from nvfuser import FusionDefinition, DataType - -def main(): - # Check CUDA availability - if not torch.cuda.is_available(): - print("CUDA is not available. nvfuser requires CUDA.") - return - - # Define a fusion that computes (x + y) * 2 - with FusionDefinition() as fd: - # Define input tensors with explicit shapes - x = fd.define_tensor([-1, -1], dtype=DataType.Float) # 2D tensor - y = fd.define_tensor([-1, -1], dtype=DataType.Float) # 2D tensor - - # Define operations - sum_result = fd.ops.add(x, y) # x + y - two = fd.define_scalar(2.0) # scalar constant - final_result = fd.ops.mul(sum_result, two) # (x + y) * 2 - - # Mark output - fd.add_output(final_result) - - # Create input tensors on GPU - input_x = torch.ones(3, 4, device='cuda', dtype=torch.float32) - input_y = torch.ones(3, 4, device='cuda', dtype=torch.float32) * 2 - - # Execute the fusion - nvf_result = fd.execute([input_x, input_y])[0] - - # Compare with PyTorch eager execution - eager_result = (input_x + input_y) * 2.0 - - print(f"Results match: {torch.allclose(nvf_result, eager_result)}") - - # Get debug information (only available after execution) - print(f"Fusion ID: {fd.id()}") - print(f"Fusion IR:\n{fd.fusion_ir()}") - -if __name__ == "__main__": - main() -``` - -# Debug Information -**Query a list of supported operations:** -```python -python -c "from nvfuser import FusionDefinition; help(FusionDefinition.Operators)" -``` - -**Get debug information after execution:** -```python -# These methods require the fusion to be executed first -print(f"Fusion ID: {fd.id()}") -print(f"Fusion IR:\n{fd.fusion_ir()}") -print(f"Generated CUDA code:\n{fd.last_cuda_code()}") -``` - -**View the fusion definitions that are executed by setting an environment variable:** -```python -export NVFUSER_DUMP=python_definition -``` -Example Output: -```python -def nvfuser_fusion_id0(fd : FusionDefinition) -> None : - T0 = fd.define_tensor(shape=[-1, 1, -1], contiguity=[True, None, True], dtype=DataType.Float) - T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[False, False, False], dtype=DataType.Float) - S2 = fd.define_scalar(3.00000) - T3 = fd.ops.add(T0, T1) - T4 = fd.ops.mul(T3, S2) - T5 = fd.ops.sum(T4, axes=[-1], keepdim=False, dtype=DataType.Float) - fd.add_output(T5) -``` diff --git a/python/nvfuser/__init__.py b/python/nvfuser/__init__.py deleted file mode 100644 index dfe03aab396..00000000000 --- a/python/nvfuser/__init__.py +++ /dev/null @@ -1,649 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import sys -import warnings - -if "nvfuser_direct" in sys.modules: - warnings.warn( - "Be careful! You've imported nvfuser when the nvfuser_direct module is already imported.", - UserWarning, - ) - -import logging -import os -import re -from typing import Callable -import warnings - -import torch - -# This is needed when libnvfuser.so is patched and doesn't have the pytorch library location available. -pytorch_lib_dir = os.path.join(os.path.dirname(torch.__file__), "lib") -if pytorch_lib_dir not in sys.path: - sys.path.append(pytorch_lib_dir) - -# we need to import _C here to avoid confusing error message generated from failure in this python script ended up with -# complaining on `_C` not defined for `_C._FusionDefinition` -from . import _C -from ._C import * # noqa: F401,F403 - -from . import contrib # noqa: F401 - - -logger = logging.getLogger("nvfuser") - - -# Register automatic serialization of Nvfuser cache hierarchy and cuda kernels. -def enable_automatic_serialization(): - import atexit - - atexit.register(_C.serialize) - - # A separate process is created for each device in a distributed setting. - # Each FusionCache becomes associated with a single device. - # Automatic serialization saves a separate cache for each device. - # Set the FusionCache id to the ddp local rank. - env_var_ddp_local_rank = os.environ.get("LOCAL_RANK", None) - if env_var_ddp_local_rank is not None: - env_var_ddp_local_rank = int(env_var_ddp_local_rank) - _C.FusionCache.get(max_fusions := 8192, env_var_ddp_local_rank) - - -# Unregister automatic serialization of Nvfuser cache hierarchy and cuda kernels. -def disable_automatic_serialization(): - import atexit - - atexit.unregister(_C.serialize) - - -class FusionDefinition(_C._FusionDefinition): - def __init__( - self, - id=None, - max_length=9999, - use_multidevice_executor=False, - backend_type=CommunicatorBackend.nccl, - ): - super(FusionDefinition, self).__init__( - id, max_length, use_multidevice_executor, backend_type - ) - self.profiled = False - - def segment(self, inputs): - """ - Decompose this FusionDefinition into a sequence of segment - FusionDefinitions. - - This function runs the nvfuser segmentation algorithm and translates the - segments into their corresponding FusionDefinitions. - - Args: - inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion. - - Returns: - List[FusionDefinition]: The FusionDefinitions corresponding to the - sub-fusion segments of this FusionDefinition. - """ - num_segments = self._setup_segmentation(inputs) - if num_segments == 1: - self._finalize_segmentation() - return [] - - # Track all segments for this FusionDefinition - self.segments = [] - - # Track map_segment_fid_to_original_fid for each segment - self.segment_index_space_maps = {} - - # Track the last segment a value is used as an input - self.map_value_to_last_used_segment = {} - - for idx in range(num_segments): - new_fd = FusionDefinition() - map_segment_fid_to_original_fid = self._build_segment(new_fd, idx) - - for segment_input in new_fd.inputs(): - original_input = map_segment_fid_to_original_fid[segment_input] - self.map_value_to_last_used_segment[original_input] = idx - - self.segment_index_space_maps[new_fd] = map_segment_fid_to_original_fid - self.segments.append(new_fd) - self._finalize_segmentation() - return self.segments - - def __enter__(self): - return self._setup_definition() - - def __exit__(self, type, value, traceback): - try: - self._finalize_definition() - except Exception as err: - logger.exception(self._repro_error_str("defining")) - raise - - def definition(self): - raise NotImplementedError("definition() should be implemented by child class!") - - def _execute_segments(self, input_arguments, *, device=None, profile=False): - """ - Run the sequence of FusionDefinition segments to generate the results - of this FusionDefinition. - - This FusionDefinition acts an argument manager. It gathers input - arguments for the segments and stores their output results. After - running a segment, any redundant intermediate values, which are - unnecessary for any other segments, are deleted to save memory. - - Args: - inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion. - - Kwargs: - device (Optional[Union[int, str, torch.device]]): This is a hint to run - the Fusion on the given CUDA device. This is not typically - necessary, as the device is usually inferred from the locations - of input tensors. However, for some fusion definitions, no - tensors will be input (for example when all tensors are - generated with `full` or `uniform` ops). In these cases, we - must either tell NVFuser where to run the resulting kernel, or - let it default to 0. Note that passing this option providing - and input tensors that lie on another device is an error. - profile (bool): Captures a CUPTI based profile of a fusion. - - - Returns: - List[Tensor]: The output results for this FusionDefinition. - """ - assert len(self.segments) > 0 - assert len(self.segments) == len(self.segment_index_space_maps) - - input_arguments_with_extents = [*input_arguments] - for a in input_arguments: - if type(a) is torch.Tensor: - input_arguments_with_extents.extend(a.size()) - - # Map inputs arguments to original fid - map_original_fid_to_value = { - fd_state: argument - for fd_state, argument in zip( - self.inputs() + self.extents(), input_arguments_with_extents - ) - } - - # Run all segments in correct order - for idx, segment in enumerate(self.segments): - segment_to_original_map = self.segment_index_space_maps[segment] - - # Gather segment input arguments - segment_arguments = [ - map_original_fid_to_value[segment_to_original_map[fd_state]] - for fd_state in segment.inputs() - ] - - # Run segment - segment_outputs = segment.execute( - segment_arguments, device=device, profile=profile - ) - - # Update original fusion definition indices to outputs - for fd_state, output in zip(segment.outputs(), segment_outputs): - map_original_fid_to_value[segment_to_original_map[fd_state]] = output - - # Destroy any arguments that are not used by future segments - for segment_input in segment.inputs(): - original_input = segment_to_original_map[segment_input] - if ( - original_input not in self.outputs() - and self.map_value_to_last_used_segment[original_input] == idx - ): - del map_original_fid_to_value[original_input] - - # Map output fid to actual results - return [map_original_fid_to_value[fd_state] for fd_state in self.outputs()] - - def execute( - self, - inputs, - *, - device=None, - override_user_schedule=False, - capture_debug_output=False, - print_repro=False, - profile=False, - save_repro_inputs=False, - _enable_options: list[str] = [], - _disable_options: list[str] = [], - ) -> list[torch.Tensor] | tuple[list[torch.Tensor], list[Sharding]]: - """ - Executes an nvFuser set of kernels for a given Fusion - - The FusionDefinition will be executed on a single CUDA device. - Typically, which device to run on is determined by the devices where - the input tensors reside. However, if the Fusion is defined such that - none of the inputs are tensors, we are not able to infer a device from - the inputs. For example, the following FusionDefinition will be unable - to unambiguously infer the device of its output: - - with FusionDefinition() as fd: - tv1 = fd.ops.full([5]) - fd.add_output(tv1) - - In that case, we default to selecting the first CUDA - device, i.e. `torch.device("cuda:0")`. This method enables selecting an - alternative preferred device. - - Args: - inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion. - - Kwargs: - device (Optional[Union[int, str, torch.device]]): This is a hint to run - the Fusion on the given CUDA device. This is not typically - necessary, as the device is usually inferred from the locations - of input tensors. However, for some fusion definitions, no - tensors will be input (for example when all tensors are - generated with `full` or `uniform` ops). In these cases, we - must either tell NVFuser where to run the resulting kernel, or - let it default to 0. Note that passing this option providing - and input tensors that lie on another device is an error. - override_user_schedule (bool): For a user defined schedule, - override with auto-generated schedule (default: False) - capture_debug_output (bool): Whether to capture any printed - debugging information as a string. If True, the string can be - retrieved after execution using :meth:`get_debug_output`. If False, - then that method will return None when called. - print_repro (bool): Prints a reproduction script to stdout. - profile (bool): Captures a CUPTI based profile of a fusion. - save_repro_inputs (bool): Saves the inputs for last_repro_script() to - provide a provide a reproduction script. - _enable_options/_disable_options (list): NVFUSER_ENABLE/DISABLE options to use. - This is an alternative to environment variables. - Note: Currently, we do not cache/store these options in the FusionCache which makes it - plausible to reuse kernels when executing the same fusion definition with different sets of options. - Reset the FusionCache manually to avoid inadvertent kernel reuse when between different sets of options. - - Returns: - A list of output tensors and, if multidevice_schedule is defined, a - list of output shardings. The latter is important to pack the outputs - into DTensors for framework integration. - """ - self.profiled = profile - - if not isinstance(device, int) and device is not None: - if not isinstance(device, torch.device): - device = torch.device(device) - assert ( - device.type == "cuda" - ), "If device argument is passed it must be a CUDA device" - device = device.index - - # if definition is not defined by a context manager, try a child class - defined_multidevice_schedule = hasattr(self, "multidevice_schedule") - if self.id() is None: - self._setup_definition() - self.definition() - self._finalize_definition() - - defined_schedule = hasattr(self, "schedule") and isinstance( - self.schedule, Callable - ) - assert not ( - defined_multidevice_schedule and defined_schedule - ), "I haven't tested what if both are defined. We don't plan to support this use case although it may just work." - - if defined_multidevice_schedule: - # Unlike `schedule`, `multidevice_schedule` is designed for inter-device - # scheduling, The scheduling is done before concretization and therefore - # before pre-segmentation. `schedule` however assumes the FusionDefinition - # has been concretized and pre-segmented, and therefore requires - # `_setup_schedule` and `_finalize_schedule` to be called before and after. - # - # Note: there's a plan to embed multidevice schedules into FusionDefinition - # as annotating nodes. This may eventually replace `multidevice_schedule`. - self._setup_multidevice_schedule() - self.multidevice_schedule() - self._finalize_multidevice_schedule() - - # If schedule is defined by child class and schedule is not defined for - # inputs, make a schedule. - if defined_schedule: - # Schedule fusion if it does not exist yet or profiling fusion - if profile or not self._exist_schedule(inputs): - self._setup_schedule(inputs, overwrite_existing_schedule=profile) - self.schedule() - self._finalize_schedule(inputs) - - if save_repro_inputs: - from torch._subclasses.fake_tensor import FakeTensorMode - - fake_mode = FakeTensorMode() - self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs] - - if hasattr(self, "segments") and len(self.segments) > 0: - return self._execute_segments(inputs, device=device, profile=profile) - - try: - if print_repro: - print(self.repro_script_for(inputs)) - if len(_enable_options) or len(_disable_options): - warnings.warn( - "Reset the FusionCache manually to avoid reusing kernels when re-executing the fusion definition with different options." - ) - - out_tensors, out_shardings = self._execute( - inputs, - device=device, - override_user_schedule=override_user_schedule, - capture_debug_output=capture_debug_output, - profile=profile, - _enable_options=_enable_options, - _disable_options=_disable_options, - ) - - if defined_multidevice_schedule: - return out_tensors, out_shardings - - assert len(out_shardings) == 0 - return out_tensors - - except Exception as err: - logger.exception(self._repro_error_str("executing", inputs)) - raise - - def debug_output(self): - """ - Retrieve string of captured debug information from the previous execution. - - Note that `capture_debug_output=True` must be passed to `execute()` in - order to enable capturing this output. Otherwise, this method will - return `None`. - - Returns: - Optional[String] : the captured debug output for the previous call - to execute(). If the `capture_debug_output` argument to that call - was False, returns None. Otherwise, returns the output as a string. - """ - return self._debug_output() - - def from_pytorch(self, tensor, static_sizes=False): - """ - Defines an nvfuser input tensor from a pytorch tensor and defaults - to definining a symbolic tensor for dynamic shape usage. - - Args: - tensor (torch.Tensor): Input tensor to nvFuser - static_sizes (bool) : Interprets sizes as static rather than - as symbolic for dynamic shape usage - - Returns: - nvfuser.Tensor - """ - try: - from .pytorch_utils import torch_dtype_to_nvfuser_dtype - except ImportError: - raise ImportError("Unable to import pytorch_utils!") - - supported_tensor = tensor.is_cuda or (tensor.is_cpu and len(tensor.size()) == 0) - if not supported_tensor: - raise ValueError( - f"Found unsupported device {tensor.device}, only scalar CPU or CUDA tensors are supported" - ) - - return self.define_tensor( - sizes=tensor.size(), - strides=tensor.stride(), - dtype=torch_dtype_to_nvfuser_dtype(tensor.dtype), - static_sizes=static_sizes, - is_cpu=tensor.is_cpu, - ) - - def fusion_ir(self): - """ - Returns the uscheduled Fusion IR for the given definition that corresponds to all scheduled inputs. - - Returns: - String - """ - return self._fusion_ir() - - def last_cuda_code(self, intrinsic_code=False, **kwargs): - """ - Returns the Cuda Code for the last executed set of inputs - - Args: - intrinsic_code (Bool): Include all the additional code required to run kernel(s). (default: False) - - Kwargs: - override_user_schedule (Bool): For a user defined schedule, override with auto-generated schedule (default: False) - - Returns: - String - """ - override_user_schedule = kwargs.pop("override_user_schedule", False) - return self._last_cuda_code(intrinsic_code, override_user_schedule) - - def cuda_code_for(self, inputs, intrinsic_code=False, **kwargs): - """ - Returns the Cuda Code for the given inputs - - Args: - inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion. - intrinsic_code (Bool): Include all the additional code required to run kernel(s). (default: False) - - Kwargs: - override_user_schedule (Bool): For a user defined schedule, override with auto-generated schedule (default: False) - - Returns: - String - """ - override_user_schedule = kwargs.pop("override_user_schedule", False) - return self._cuda_code_for(inputs, intrinsic_code, override_user_schedule) - - def last_scheduled_fusion_ir(self, tensor_transforms=False, **kwargs): - """ - Returns the Scheduled Fusion IR for the last executed set of inputs - - Args: - tensor_transforms (Bool): Include tensor transforms that were applied through scheduling. (default: False) - - Kwargs: - override_user_schedule (Bool): For a user defined schedule, override with auto-generated schedule (default: False) - - Returns: - String - """ - override_user_schedule = kwargs.pop("override_user_schedule", False) - return self._last_scheduled_fusion_ir(tensor_transforms, override_user_schedule) - - def scheduled_fusion_ir_for(self, inputs, tensor_transforms=False, **kwargs): - """ - Returns the Scheduled Fusion IR for the last executed set of inputs - - Args: - inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion. - tensor_transforms (Bool): Include tensor transforms that were applied through scheduling. (default: False) - - Kwargs: - override_user_schedule (Bool): For a user defined schedule, override with auto-generated schedule (default: False) - - Returns: - String - """ - override_user_schedule = kwargs.pop("override_user_schedule", False) - return self._scheduled_fusion_ir_for( - inputs, tensor_transforms, override_user_schedule - ) - - def profile(self): - """ - Returns the FusionProfile object from the CUPTI based FusionProfiler - - Returns: - FusionProfile - """ - if not self.profiled: - raise ValueError( - "The execute() method was not previously called with profiling enabled!" - ) - - fp = self._profile() - - if fp.fusion_id < 0: - raise ValueError( - "Something went wrong with Fusion Profiling as an illegal fusion_id was returned! " - + str(fp.fusion_id) - ) - if fp.segments < 1: - raise ValueError( - "Something went wrong with Fusion Profiling as no kernel segments were profiled!" - + str(fp.segments) - ) - - return fp - - def last_repro_script(self) -> str: - assert ( - self.fake_inputs is not None - ), "fd.last_repro_script() cannot provide a repro because fd.execute(inputs, save_repro_state=True) was not executed!" - script = self.repro_script_for(self.fake_inputs) - return script - - def repro_script_for(self, inputs: list | None = None) -> str: - msg = "# CUDA devices:\n" - for i in range(torch.cuda.device_count()): - msg += f"# {i}: {torch.cuda.get_device_name(i)}\n" - fusion_func_name = ( - "nvfuser_incomplete_fusion" - if self.id() is None - else f"nvfuser_fusion_id{self.id()}" - ) - msg += ( - f"# torch version: {torch.__version__}\n" - f"# cuda version: {torch.version.cuda}\n" - f"# nvfuser version: {version()}\n" - "import torch\n" - "from nvfuser import FusionDefinition, DataType\n" - f"{self}" - "with FusionDefinition() as fd:\n" - f" {fusion_func_name}(fd)\n" - ) - if inputs is not None: - msg += "\ninputs = [\n" - for i in inputs: - if isinstance(i, torch.Tensor): - if i.is_contiguous(): - msg += f" torch.testing.make_tensor({tuple(i.size())}, dtype={i.dtype}, device='{i.device}'),\n" - else: - # max linear index determines number of elements to generate - sz = 1 - for szi, stri in zip(i.size(), i.stride()): - if szi == 0: - sz = 0 - break - sz += (szi - 1) * stri - if i.dtype.is_floating_point: - msg += ( - f" torch.randn({sz}, dtype={i.dtype}, device='{i.device}')" - f".as_strided({tuple(i.size())}, {tuple(i.stride())}),\n" - ) - else: - upper_bound = 2 if i.dtype == torch.bool else 10 - msg += ( - f" torch.randint(0, {upper_bound}, ({sz},), dtype={i.dtype}, device='{i.device}')" - f".as_strided({tuple(i.size())}, {tuple(i.stride())}),\n" - ) - else: - input_as_string = str(i) - # `nan` and `inf` are stringified as is, which are not - # defined in Python. So we replace them with `float("nan")` - # and `float("inf")`. `-inf` is replaced with - # `-float("inf")`, which equals `float("-inf")`. - input_as_string = re.sub( - r"\binf\b", 'float("inf")', input_as_string - ) - input_as_string = re.sub( - r"\bnan\b", 'float("nan")', input_as_string - ) - msg += f" {input_as_string},\n" - msg += "]" - msg += "\nfd.execute(inputs)\n" - - return msg - - def _repro_error_str(self, section: str, inputs: list | None = None): - msg = ( - f"An error occurred while {section} nvFuser FusionDefinition {self.id()}.\n" - "If you believe this is a bug or need assistance, please file an issue at " - "https://github.com/NVIDIA/Fuser/issues/new\n" - f"Here's a script to reproduce the error:\n" - "```python\n" - ) - msg += self.repro_script_for(inputs) - msg += "```\n" - return msg - - def validate( - self, - inputs: list[torch.Tensor], - reference_outputs: list[torch.Tensor] = None, - **kwargs, - ): - """ - Validates the fusion outputs against the provided reference outputs, using variable tolerances determined based on datatype and reduction size. - - Inputs: - inputs: A list of inputs expected by the fusion definition - reference_outputs: A list of reference outputs to validate against - """ - fusion_outputs = self.execute(inputs, **kwargs) - - if reference_outputs is None: - return self.validate_with_auto_inferred_outputs(fusion_outputs, inputs) - - assert len(fusion_outputs) == len( - reference_outputs - ), f"Expected {len(fusion_outputs)} reference outputs for validation." - - tolerance_values = self.getValTolerances(inputs) - assert len(tolerance_values) == len( - fusion_outputs - ), f"Missing tolerance values, expected {len(fusion_outputs)}, got {len(tolerance_values)}" - - for inx, fusion_output in enumerate(fusion_outputs): - atol, rtol = tolerance_values[inx] - reference_output = reference_outputs[inx] - - assert ( - reference_output.shape == fusion_output.shape - ), "Mismatch in reference and fusion output dimensions" - if torch.is_floating_point(fusion_output) or torch.is_complex( - fusion_output - ): - assert torch.allclose( - fusion_output, reference_output, atol=atol, rtol=rtol - ), f"Max error: {torch.abs(torch.max(fusion_output - reference_output))}, \ - Absolute tolerance: {atol}, Relative tolerance: {rtol}" - - else: - assert torch.equal( - fusion_output, reference_output - ), "Mismatch in reference and fusion output values, datatype is not float/complex." - - -from .nvfuser_version import __version__ - - -def version(): - r"""returns nvfuser version in format of a string 'm.n.p+git[7d-sha]'. - - We strip the git[7d-sha] and convert the string to - `nvfuser_version.Version` for comparison. e.g. you can use it as: - import nvfuser - print(nvfuser.version()) # 0.0.1+git21df524 - nvfuser.version() == '0.0.1` # True - nvfuser.version() > '0.0.0` # True - - from nvfuser_version import Version - nvfuser.version() < Version('1.0.0') # True - """ - return __version__ diff --git a/python/nvfuser/__init__.pyi b/python/nvfuser/__init__.pyi deleted file mode 100644 index 2fc3ab3b3e0..00000000000 --- a/python/nvfuser/__init__.pyi +++ /dev/null @@ -1,4 +0,0 @@ -from typing import List - - -def compute_contiguity(sizes, strides) -> List[bool]: ... diff --git a/python/nvfuser/benchmark_utils.py b/python/nvfuser/benchmark_utils.py deleted file mode 100644 index 6a0305339c2..00000000000 --- a/python/nvfuser/benchmark_utils.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import torch -from cupti import cupti -import cxxfilt -import pytest - - -# Base class for all timers used by pytest-benchmark. -class Timer: - def __init__(self): - self.current_time = 0.0 - - def _increment_global_time(self, elapsed_time: float) -> None: - self.current_time += elapsed_time - - def __call__(self): - raise NotImplementedError("Subclass must implement this method") - - def cleanup(self): - pass - - -def demangle_kernel_name(mangled_name): - try: - return cxxfilt.demangle(mangled_name) - except Exception: - return mangled_name # Return original if demangling fails - - -def cupti_call_safe(func, *args): - """Wrapper for CUPTI calls. Failing CUPTI calls will exit the program.""" - try: - return func(*args) - except Exception as e: - print(f"CUPTI call {func.__name__} failed: {e}") - pytest.exit(1) - - -class CuptiProfiler: - # List of activities to be recorded by CUPTI. - activity_kinds: list[cupti.ActivityKind] = [ - cupti.ActivityKind.CONCURRENT_KERNEL, - ] - - # Private class variable to store the subscriber handle. - __subscriber_handle = None - - def _error_if_not_valid(self) -> None: - if not self.is_valid: - raise RuntimeError( - "CuptiProfiler is not valid. " "This instance has been torn down." - ) - - def _func_buffer_requested(self) -> tuple[int, int]: - # 8MB buffer size as recommended by CUPTI samples. - # max_num_records=0 indicates the buffer is filled with as many records as possible. - buffer_size = 8 * 1024 * 1024 - max_num_records = 0 - return buffer_size, max_num_records - - def _func_buffer_completed(self, activities: list[cupti.ActivityAPI]) -> None: - for activity in activities: - # Activity.end and Activity.start are in nanoseconds. - duration = (activity.end - activity.start) / 1e9 - self.profiler_output.append((demangle_kernel_name(activity.name), duration)) - - def __init__(self): - if CuptiProfiler.__subscriber_handle is not None: - raise RuntimeError( - "Only one instance of CuptiProfiler can be created. " - "CUPTI only supports one subscriber at a time." - ) - - self.profiler_output: list[tuple[str, float]] = [] - - # Subscribe to CUPTI and register activity callbacks. - CuptiProfiler.__subscriber_handle = cupti_call_safe(cupti.subscribe, None, None) - cupti_call_safe( - cupti.activity_register_callbacks, - self._func_buffer_requested, - self._func_buffer_completed, - ) - self.is_valid = True - - def start(self) -> None: - self._error_if_not_valid() - cupti_call_safe(cupti.activity_flush_all, 1) - self.profiler_output = [] - for activity_kind in CuptiProfiler.activity_kinds: - cupti_call_safe(cupti.activity_enable, activity_kind) - - def stop(self) -> list[tuple[str, float]]: - self._error_if_not_valid() - for activity_kind in CuptiProfiler.activity_kinds: - cupti_call_safe(cupti.activity_disable, activity_kind) - cupti_call_safe(cupti.activity_flush_all, 0) - return self.profiler_output - - def teardown_cupti(self) -> None: - self._error_if_not_valid() - if CuptiProfiler.__subscriber_handle is None: - return - cupti_call_safe(cupti.unsubscribe, CuptiProfiler.__subscriber_handle) - cupti_call_safe(cupti.finalize) - CuptiProfiler.__subscriber_handle = None - # Invalidate the profiler so it cannot be used again. - self.is_valid = False - - -class CuptiTimer(Timer): - def __init__(self): - super().__init__() - self.cupti_profiler = CuptiProfiler() - self.is_running = False - - def __call__(self): - torch.cuda.synchronize() - - if not self.is_running: - self.cupti_profiler.start() - self.is_running = True - return self.current_time - - profiler_output = self.cupti_profiler.stop() - self.is_running = False - - # Check if any activities were recorded - if len(profiler_output) == 0: - self.cleanup() - raise RuntimeError("No activities were recorded.") - - self._increment_global_time(sum(duration for _, duration in profiler_output)) - return self.current_time - - def cleanup(self): - self.is_running = False - self.cupti_profiler.teardown_cupti() - - -class FusionProfileTimer(Timer): - def __init__(self): - super().__init__() - self.fd = None - # Specifies if the timer in host measurement is called at the start/finish of execution. - # Timings are measured at the end of execution. - self.execution_start = True - - def set_fd(self, fd): - self.fd = fd - - def __call__(self): - if not self.execution_start: - profile = self.fd.profile() - elapsed_host_time = profile.host_time_ms / 1e3 - self._increment_global_time(elapsed_host_time) - self.execution_start = not self.execution_start - return self.current_time diff --git a/python/nvfuser/contrib/__init__.py b/python/nvfuser/contrib/__init__.py deleted file mode 100644 index 6ea00f3ec7d..00000000000 --- a/python/nvfuser/contrib/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -from . import nn - - -__all__ = [ - "nn", -] diff --git a/python/nvfuser/contrib/nn/__init__.py b/python/nvfuser/contrib/nn/__init__.py deleted file mode 100644 index 50bf839a515..00000000000 --- a/python/nvfuser/contrib/nn/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -from .normalization import InstanceNorm1dNVFuser -from .normalization import InstanceNorm2dNVFuser -from .normalization import InstanceNorm3dNVFuser - - -__all__ = [ - "InstanceNorm1dNVFuser", - "InstanceNorm2dNVFuser", - "InstanceNorm3dNVFuser", -] diff --git a/python/nvfuser/contrib/nn/normalization.py b/python/nvfuser/contrib/nn/normalization.py deleted file mode 100644 index c01faf86cdb..00000000000 --- a/python/nvfuser/contrib/nn/normalization.py +++ /dev/null @@ -1,725 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -import enum -from typing import Any, Dict, List, Optional, Tuple - -import torch - -import nvfuser - -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype - - -__all__ = [ - "InstanceNorm1dNVFuser", - "InstanceNorm2dNVFuser", - "InstanceNorm3dNVFuser", -] - - -NamedAxis = enum.Enum("NamedAxis", ["BATCH", "CHANNEL"]) - - -def partially_contig_tensor( - fd: "nvfuser.FusionDefinition", - x: torch.Tensor, -) -> "nvfuser.Tensor": - return fd.define_tensor( - shape=[1 if dim_size == 1 else -1 for dim_size in x.size()], - contiguity=nvfuser.compute_contiguity(x.size(), x.stride()), - dtype=torch_dtype_to_nvfuser_dtype(x.dtype), - ) - - -def norm_fusion_forward( - fd: "nvfuser.FusionDefinition", - inputs: List[torch.Tensor], - x: "nvfuser.Tensor", - weight: Optional["nvfuser.Tensor"], - bias: Optional["nvfuser.Tensor"], - running_mean: Optional["nvfuser.Tensor"], - running_var: Optional["nvfuser.Tensor"], - eps: "nvfuser.Scalar", - use_input_stats: bool, - momentum: "nvfuser.Scalar", - channels_last: bool, - x_datatype: "nvfuser.DataType", - unbiased: bool = False, - *, - stat_axes: List[NamedAxis], -) -> Tuple["nvfuser.Tensor", "nvfuser.Tensor", "nvfuser.Tensor"]: - """Modify FusionDefinition to add a generic normalization layer (forward). - - This can be used to construct a BatchNorm, GroupNorm, InstanceNorm, or - LayerNorm network by indicating different sets of axes to preserve. - - BatchNorm: `stat_axes = [NamedAxis.CHANNEL]` - LayerNorm: `stat_axes = [NamedAxis.BATCH]` - InstanceNorm: `stat_axes = [NamedAxis.BATCH, NamedAxis.CHANNEL]` - - Args: - fd: An initialized FusionDefinition. - inputs: A list of :class:'torch.Tensor' inputs to the - `FusionDefinition` `fd`. - x: An input NVFuser tensor. - weight: If given, multiply normed output by this `Tensor`. It should be - one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and - zero-dimensional otherwise. It will be broadcast along all other - dimensions. - bias: If given, add this `Tensor` to normed output. It should be - one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and - zero-dimensional otherwise. It will be broadcast along all other - dimensions. - running_mean: If given, a running mean estimate that will be modified - in place. - running_var: If given, a running variance estimate that will be - modified in place. - eps: Amount to regularize the square root needed to convert variance to - standard deviation. - use_input_stats: Whether to compute the stats of this batch or to - _only_ use the provided running_mean and running_var. - momentum: Momentum for exponentially weighted moving average of running - stats. - channels_last: Whether channels are in position -1 (`True`) or 1 - (`False`). - x_datatype: :class:'DataType' of input :class:'Tensor' `x` - unbiased: Whether to use unbiased variance for computing current batch - statistics. Note that unbiased estimates are always used for - running variance updates, regardless of this argument's value. - stat_axes: A list of `NamedAxis` objects indicating a combination of - axes with which to index the computed statistics. This can be used - to implement multiple types of normalization layers, since most of - those differ only in which axes are reduced over. - Returns: - The normalized output, as well as mean and 1/std. Note that - `fd.add_output` is _not_ called by this function. - """ - assert (running_var is None) == ( - running_mean is None - ), "Iff running mean or var is given, the other should be" - - # dyn_shape holds Scalars describing the size of the input x - dyn_shape = fd.ops.tensor_sizes(x) - - num_dims = len(dyn_shape) - - batch_dim = 0 - batch_size = dyn_shape[batch_dim] - - channel_dim = num_dims - 1 if channels_last else 1 - num_channels = dyn_shape[channel_dim] - - # Running stats will be kept possibly for channel but never by instance, so - # we will reduce along batch_dim before updating running stats. - # These are used to broadcast in spatial dims - is_spatial_dim = [True] * num_dims - is_spatial_or_batch_dim = [True] * num_dims - - if NamedAxis.BATCH in stat_axes: - is_spatial_dim[batch_dim] = False - if NamedAxis.CHANNEL in stat_axes: - is_spatial_dim[channel_dim] = False - is_spatial_or_batch_dim[channel_dim] = False - x_reduction_axes = [ax for ax, flag in enumerate(is_spatial_dim) if flag] - num_features = fd.define_scalar(1) - for ax in x_reduction_axes: - num_features = fd.ops.mul(num_features, dyn_shape[ax]) - - if use_input_stats or running_mean is None: - # In NVFuser Python we pass correction=1 to request unbiased variance calculation - x_var, x_mean = fd.ops.var_mean(x, x_reduction_axes, int(unbiased)) - if running_mean is not None: - one = fd.define_scalar(1.0) - rev_momentum = fd.ops.sub(one, momentum) - - # do running mean with momentum - current_mean_hat = fd.ops.mul(x_mean, momentum) - mean_hat = fd.ops.mul(running_mean, rev_momentum) - new_mean_hat = fd.ops.add(mean_hat, current_mean_hat) - - # If computing stats for each instance, we don't want to keep those - # for our running mean calculation, so we sum them here - new_mean_sum = ( - fd.ops.sum(new_mean_hat, [0]) - if NamedAxis.BATCH in stat_axes - else new_mean_hat - ) - - rev_batch_size = fd.ops.reciprocal(batch_size) - new_mean_channels_only = fd.ops.mul(new_mean_sum, rev_batch_size) - if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]: - new_mean_channels_only = fd.ops.cast(new_mean_channels_only, x_datatype) - fd.add_output(new_mean_channels_only, alias_input=running_mean) - - # running var calculation - x_var_unbiased = x_var - if not unbiased: - # multiply by correction to go from biased to unbiased estimate - b2ub = fd.ops.div( - num_features, fd.ops.sub(num_features, fd.define_scalar(1)) - ) - x_var_unbiased = fd.ops.mul(x_var, b2ub) - - current_var_hat = fd.ops.mul(x_var_unbiased, momentum) - var_hat = fd.ops.mul(running_var, rev_momentum) - new_var_hat = fd.ops.add(var_hat, current_var_hat) - - # See above about reducing over batch dim for running stats - new_var_sum = ( - fd.ops.sum(new_var_hat, [0]) - if NamedAxis.BATCH in stat_axes - else new_var_hat - ) - - new_var_channels_only = fd.ops.mul(new_var_sum, rev_batch_size) - if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]: - new_var_channels_only = fd.ops.cast(new_var_channels_only, x_datatype) - fd.add_output(new_var_channels_only, alias_input=running_var) - - mean = x_mean - mean_bcast = fd.ops.broadcast(mean, is_spatial_dim) - x_sub_mean = fd.ops.sub(x, mean_bcast) - - var_eps = fd.ops.add(x_var, eps) - invstd = fd.ops.rsqrt(var_eps) - invstd_bcast = fd.ops.broadcast(invstd, is_spatial_dim) - - x_normed = fd.ops.mul(x_sub_mean, invstd_bcast) - - else: # This is inference mode with running stats - assert running_mean is not None - r_mean_bcast = fd.ops.broadcast(running_mean, is_spatial_or_batch_dim) - x_sub_mean = fd.ops.sub(x, r_mean_bcast) - - var_eps = fd.ops.add(running_var, eps) - invstd = fd.ops.rsqrt(var_eps) - invstd_bcast = fd.ops.broadcast(invstd, is_spatial_or_batch_dim) - - mean = running_mean - x_normed = fd.ops.mul(x_sub_mean, invstd_bcast) - - if weight is not None: - weight_bcast = fd.ops.broadcast(weight, is_spatial_or_batch_dim) - x_normed = fd.ops.mul(x_normed, weight_bcast) - if bias is not None: - bias_bcast = fd.ops.broadcast(bias, is_spatial_or_batch_dim) - x_normed = fd.ops.add(x_normed, bias_bcast) - - return x_normed, mean, invstd - - -def norm_fusion_backward( - fd: "nvfuser.FusionDefinition", - inputs: List[torch.Tensor], - x: "nvfuser.Tensor", - grad_output: "nvfuser.Tensor", - mean: Optional[torch.Tensor], - invstd: torch.Tensor, - weight: Optional["nvfuser.Tensor"], - bias: Optional["nvfuser.Tensor"], - running_mean: Optional["nvfuser.Tensor"], - running_var: Optional["nvfuser.Tensor"], - use_input_stats: bool, - channels_last: bool, - x_datatype: "nvfuser.DataType", - *, - stat_axes: List[NamedAxis], -) -> Tuple["nvfuser.Tensor", "nvfuser.Tensor", "nvfuser.Tensor"]: - """ - Modify FusionDefinition to add a generic normalization layer (backward). - - Args: - fd: An initialized FusionDefinition. - inputs: A list of :class:'torch.Tensor' inputs to the - `FusionDefinition` `fd`. - x: The input NVFuser tensor. - grad_output: NVFuser tensor representing gradient of loss with respect - to downstream activation (typical input to backward()). - mean: The mean used in the forward normalization. - invstd: The reciprocal of standard deviation used in the forward normalization. - weight: If given, multiply normed output by this `Tensor`. It should be - one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and - zero-dimensional otherwise. It will be broadcast along all other - dimensions. - bias: If given, add this `Tensor` to normed output. It should be - one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and - zero-dimensional otherwise. It will be broadcast along all other - dimensions. - running_mean: If given, a running mean estimate that will be modified - in place. - running_var: If given, a running variance estimate that will be - modified in place. - use_input_stats: Whether to compute the stats of this batch or to - _only_ use the provided running_mean and running_var. - channels_last: Whether channels are in position -1 (`True`) or 1 - (`False`). - x_datatype: :class:'DataType' of input :class:'Tensor' `x` - stat_axes: A list of `NamedAxis` objects indicating a combination of - axes with which to index the computed statistics. This can be used - to implement multiple types of normalization layers, since most of - those differ only in which axes are reduced over. - Returns: - The normalized output, as well as mean and 1/std. Note that - `fd.add_output` is _not_ called by this function. - """ - assert not ( - (running_var is None) ^ (running_mean is None) - ), "Iff running mean or var is given, the other should be" - - # dyn_shape holds Scalars describing the size of the input x - dyn_shape = fd.ops.tensor_sizes(x) - - num_dims = len(dyn_shape) - - batch_dim = 0 - batch_size = dyn_shape[batch_dim] - - channel_dim = num_dims - 1 if channels_last else 1 - num_channels = dyn_shape[channel_dim] - - # Running stats will be kept possibly for channel but never by instance, so - # we will reduce along batch_dim before updating running stats. - # These are used to broadcast in spatial dims - is_spatial_dim = [True] * num_dims - is_spatial_or_batch_dim = [True] * num_dims - - if NamedAxis.BATCH in stat_axes: - is_spatial_dim[batch_dim] = False - if NamedAxis.CHANNEL in stat_axes: - is_spatial_dim[channel_dim] = False - is_spatial_or_batch_dim[channel_dim] = False - x_reduction_axes = [ax for ax, flag in enumerate(is_spatial_dim) if flag] - num_features = fd.define_scalar(1) - for ax in x_reduction_axes: - num_features = fd.ops.mul(num_features, dyn_shape[ax]) - - mean = fd.ops.broadcast(mean, is_spatial_dim) - - norm = fd.ops.reciprocal(num_features) - grad_output_sum = fd.ops.sum(grad_output, x_reduction_axes) - dot_p = fd.ops.sum( - fd.ops.mul( - grad_output, - fd.ops.sub(x, mean), - ), - x_reduction_axes, - ) - grad_mean = fd.ops.broadcast(fd.ops.mul(grad_output_sum, norm), is_spatial_dim) - proj_scale = fd.ops.broadcast( - fd.ops.mul( - fd.ops.mul(dot_p, norm), - fd.ops.mul(invstd, invstd), - ), - is_spatial_dim, - ) - - invstd_bcast = fd.ops.broadcast(invstd, is_spatial_dim) - grad_scale = ( - invstd_bcast - if weight is None - else fd.ops.mul( - invstd_bcast, - fd.ops.broadcast(weight, is_spatial_or_batch_dim), - ) - ) - if use_input_stats: - proj = fd.ops.mul(fd.ops.sub(x, mean), proj_scale) - grad_input = fd.ops.mul( - fd.ops.sub( - fd.ops.sub(grad_output, proj), - grad_mean, - ), - grad_scale, - ) - else: - grad_input = fd.ops.mul(grad_output, grad_scale) - - if weight is not None: - grad_weight = fd.ops.mul(dot_p, invstd) - grad_weight_reduced = fd.ops.sum(grad_weight, [0]) - else: - grad_weight_reduced = None - if bias is not None: - grad_bias = grad_output_sum - grad_bias_reduced = fd.ops.sum(grad_bias, [0]) - else: - grad_bias_reduced = None - - return grad_input, grad_weight_reduced, grad_bias_reduced - - -class NormNVFuserFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, # contexts are actually objects of the type we are currently defining - x: torch.Tensor, - weight: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - running_mean: Optional[torch.Tensor], - running_var: Optional[torch.Tensor], - use_input_stats: bool, - momentum: float, - eps: float, - unbiased: bool, - stat_axes: List[NamedAxis], - ) -> torch.Tensor: - # When x.shape[1] == 1, is_contiguous will tell us the tensor is - # channels_last, even when it is ordinary contiguous. This causes some - # issues so we only detect channels_last when channels > 1 - channels_last = x.shape[1] > 1 and ( - x.is_contiguous(memory_format=torch.channels_last) - or x.is_contiguous(memory_format=torch.channels_last_3d) - ) - xorig = x - if channels_last: - order = [0] + list(range(2, len(x.shape))) + [1] - x = x.permute(order) - - x_datatype = torch_dtype_to_nvfuser_dtype(x.dtype) - - with nvfuser.FusionDefinition() as fd: - tv_x = partially_contig_tensor(fd, x) - inputs = [x] - if weight is not None: - tv_weight = partially_contig_tensor(fd, weight) - inputs.append(weight) - else: - tv_weight = None - - if bias is not None: - tv_bias = partially_contig_tensor(fd, bias) - inputs.append(bias) - else: - tv_bias = None - - if running_mean is None: - tv_running_mean = None - tv_running_var = None - else: - assert running_var is not None - tv_running_mean = partially_contig_tensor(fd, running_mean) - tv_running_var = partially_contig_tensor(fd, running_var) - inputs.extend([running_mean, running_var]) - - s_momentum = fd.define_scalar(nvfuser.DataType.Double) - s_eps = fd.define_scalar(nvfuser.DataType.Double) - inputs.extend([momentum, eps]) - - # cast inputs if necessary - if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]: - tv_x = fd.ops.cast(tv_x, nvfuser.DataType.Float) - if weight is not None and weight.dtype in [torch.half, torch.bfloat16]: - tv_weight = fd.ops.cast(tv_weight, nvfuser.DataType.Float) - if bias is not None and bias.dtype in [torch.half, torch.bfloat16]: - tv_bias = fd.ops.cast(tv_bias, nvfuser.DataType.Float) - - out, mean, invstd = norm_fusion_forward( - fd, - inputs, - tv_x, - tv_weight, - tv_bias, - tv_running_mean, - tv_running_var, - s_eps, - use_input_stats, - s_momentum, - channels_last, - x_datatype=x_datatype, - unbiased=unbiased, - stat_axes=stat_axes, - ) - - if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]: - out = fd.ops.cast(out, x_datatype) - - fd.add_output(out) - fd.add_output(mean) - fd.add_output(invstd) - - out, mean, invstd = fd.execute(inputs) - - ctx.stat_axes = stat_axes - ctx.use_input_stats = use_input_stats - ctx.channels_last = channels_last - # saving for backward in "explicit channels-last format" - ctx.save_for_backward(x, weight, bias, running_mean, running_var, mean, invstd) - if channels_last: - order = [0, len(x.shape) - 1] + list(range(1, len(x.shape) - 1)) - out = out.permute(order) - if len(out.shape) == 4: - assert out.is_contiguous(memory_format=torch.channels_last) - assert xorig.is_contiguous(memory_format=torch.channels_last) - elif len(out.shape) == 5: - assert out.is_contiguous(memory_format=torch.channels_last_3d) - assert xorig.is_contiguous(memory_format=torch.channels_last_3d) - else: - raise RuntimeError( - "unhandled channels_last format variation in forward" - ) - return out - - @staticmethod - def backward( - ctx: Any, grad_output: torch.Tensor - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - None, - None, - None, - None, - None, - None, - None, - ]: - """Instance norm backward using NVFuser""" - if ctx.channels_last: - order = [0] + list(range(2, len(grad_output.shape))) + [1] - grad_output = grad_output.permute(order) - # input was saved in "explicit channels-last format" - # assert ctx.saved_tensors[0].is_contiguous() - # grad_output = grad_output.contiguous() - x, weight, bias, running_mean, running_var, mean, invstd = ctx.saved_tensors - - with nvfuser.FusionDefinition() as fd: - tv_x = partially_contig_tensor(fd, x) - if x.dtype in [torch.half, torch.bfloat16]: - tv_x = fd.ops.cast(tv_x, nvfuser.DataType.Float) - inputs = [x] - if weight is not None: - tv_weight = partially_contig_tensor(fd, weight) - if weight.dtype in [torch.half, torch.bfloat16]: - tv_weight = fd.ops.cast(tv_weight, nvfuser.DataType.Float) - inputs.append(weight) - else: - tv_weight = None - if bias is not None: - tv_bias = partially_contig_tensor(fd, bias) - if bias.dtype in [torch.half, torch.bfloat16]: - tv_bias = fd.ops.cast(tv_bias, nvfuser.DataType.Float) - inputs.append(bias) - else: - tv_bias = None - if running_mean is not None: - tv_running_mean = partially_contig_tensor(fd, running_mean) - if running_mean.dtype in [torch.half, torch.bfloat16]: - tv_running_mean = fd.ops.cast( - tv_running_mean, nvfuser.DataType.Float - ) - inputs.append(running_mean) - else: - tv_running_mean = None - if running_var is not None: - tv_running_var = partially_contig_tensor(fd, running_var) - if running_var.dtype in [torch.half, torch.bfloat16]: - tv_running_var = fd.ops.cast(tv_running_var, nvfuser.DataType.Float) - inputs.append(running_var) - else: - tv_running_var = None - - tv_mean = partially_contig_tensor(fd, mean) - if mean.dtype in [torch.half, torch.bfloat16]: - tv_mean = fd.ops.cast(tv_mean, nvfuser.DataType.Float) - inputs.append(mean) - tv_invstd = partially_contig_tensor(fd, invstd) - if invstd.dtype in [torch.half, torch.bfloat16]: - tv_invstd = fd.ops.cast(tv_invstd, nvfuser.DataType.Float) - inputs.append(invstd) - - tv_grad_output = partially_contig_tensor(fd, grad_output) - if grad_output.dtype in [torch.half, torch.bfloat16]: - tv_grad_output = fd.ops.cast(tv_grad_output, nvfuser.DataType.Float) - inputs.append(grad_output) - - x_datatype = torch_dtype_to_nvfuser_dtype(x.dtype) - - grad_input, grad_weight, grad_bias = norm_fusion_backward( - fd, - inputs, - tv_x, - tv_grad_output, - tv_mean, - tv_invstd, - tv_weight, - tv_bias, - tv_running_mean, - tv_running_var, - ctx.use_input_stats, - ctx.channels_last, - x_datatype=x_datatype, - stat_axes=ctx.stat_axes, - ) - - if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]: - grad_input = fd.ops.cast(grad_input, x_datatype) - fd.add_output(grad_input) - - if weight is not None: - if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]: - grad_weight = fd.ops.cast(grad_weight, x_datatype) - fd.add_output(grad_weight) - - if bias is not None: - if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]: - grad_bias = fd.ops.cast(grad_bias, x_datatype) - fd.add_output(grad_bias) - - res = fd.execute(inputs) - grad_input = res[0] - c = 1 - if weight is not None: - grad_weight = res[c] - c += 1 - else: - grad_weight = None - if bias is not None: - grad_bias = res[c] - c += 1 - else: - grad_bias = None - - if ctx.channels_last: - order = [0, len(grad_input.shape) - 1] + list( - range(1, len(grad_input.shape) - 1) - ) - grad_input = grad_input.permute(order) - if len(grad_input.shape) == 4: - assert grad_input.is_contiguous(memory_format=torch.channels_last) - elif len(grad_input.shape) == 5: - assert grad_input.is_contiguous(memory_format=torch.channels_last_3d) - else: - raise RuntimeError( - "unhandled channels_last format variation in backward" - ) - return ( - grad_input, - grad_weight, - grad_bias, - None, - None, - None, - None, - None, - None, - None, - ) - - -class _NormNVFuserBase(torch.nn.modules.batchnorm._NormBase): - stat_axes: Optional[List[NamedAxis]] = None - - def __init__( - self, - num_features: int, - eps: float = 1e-5, - momentum: float = 0.1, - affine: bool = False, - track_running_stats: bool = False, - device: torch.device = None, - dtype: torch.dtype = None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs - ) - - def _check_input_dim(self, input: torch.Tensor) -> None: - raise NotImplementedError - - def _load_from_state_dict( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Any, - strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], - ) -> None: - version = local_metadata.get("version", None) - # at version 1: removed running_mean and running_var when - # track_running_stats=False (default) - if version is None and not self.track_running_stats: - running_stats_keys = [] - for name in ("running_mean", "running_var"): - key = prefix + name - if key in state_dict: - running_stats_keys.append(key) - if len(running_stats_keys) > 0: - error_msgs.append( - "Unexpected running stats buffer(s) {names} for {klass} " - "with track_running_stats=False. If state_dict is a " - "checkpoint saved before 0.4.0, this may be expected " - "because {klass} does not track running stats by default " - "since 0.4.0. Please remove these keys from state_dict. If " - "the running stats are actually needed, instead set " - "track_running_stats=True in {klass} to enable them. See " - "the documentation of {klass} for details.".format( - names=" and ".join( - '"{}"'.format(k) for k in running_stats_keys - ), - klass=self.__class__.__name__, - ) - ) - for key in running_stats_keys: - state_dict.pop(key) - - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - - def forward(self, input: nvfuser.Tensor) -> nvfuser.Tensor: - assert input.is_cuda, "NVFuser InstanceNorm is CUDA only" - self._check_input_dim(input) - out = NormNVFuserFunction.apply( - input, - self.weight, - self.bias, - self.running_mean, - self.running_var, - self.training or not self.track_running_stats, - self.momentum, - self.eps, - False, # unbiased=False to match PyTorch functionality - self.stat_axes, - ) - return out - - -class _InstanceNormNVFuser(_NormNVFuserBase): - stat_axes = [NamedAxis.BATCH, NamedAxis.CHANNEL] - - -class _BatchNormNVFuser(_NormNVFuserBase): - stat_axes = [NamedAxis.CHANNEL] - - -class _LayerNormNVFuser(_NormNVFuserBase): - stat_axes = [NamedAxis.BATCH] - - -class InstanceNorm1dNVFuser(_InstanceNormNVFuser): - def _check_input_dim(self, input: torch.Tensor) -> None: - if input.dim() != 3: - raise ValueError("expected 3D input (got {}D input)".format(input.dim())) - - -class InstanceNorm2dNVFuser(_InstanceNormNVFuser): - def _check_input_dim(self, input: torch.Tensor) -> None: - if input.dim() != 4: - raise ValueError("expected 4D input (got {}D input)".format(input.dim())) - - -class InstanceNorm3dNVFuser(_InstanceNormNVFuser): - def _check_input_dim(self, input: torch.Tensor) -> None: - if input.dim() != 5: - raise ValueError("expected 5D input (got {}D input)".format(input.dim())) diff --git a/python/nvfuser/nvfuser_version.py b/python/nvfuser/nvfuser_version.py deleted file mode 100644 index 3c48e45e95f..00000000000 --- a/python/nvfuser/nvfuser_version.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -from typing import Any -from .version import _version_str - -__all__ = ["NvfuserVersion", "Version"] - - -class _LazyImport: - """Wraps around classes lazy imported from packaging.version - Output of the function v in following snippets are identical: - from packaging.version import Version - def v(): - return Version('1.2.3') - and - Version = _LazyImport('Version') - def v(): - return Version('1.2.3') - The difference here is that in later example imports - do not happen until v is called - """ - - def __init__(self, cls_name: str) -> None: - self._cls_name = cls_name - - def get_cls(self): - try: - import packaging.version # type: ignore[import] - except ImportError: - # If packaging isn't installed, try and use the vendored copy - # in pkg_resources - from pkg_resources import packaging # type: ignore[attr-defined, no-redef] - return getattr(packaging.version, self._cls_name) - - def __call__(self, *args, **kwargs): - return self.get_cls()(*args, **kwargs) - - def __instancecheck__(self, obj): - return isinstance(obj, self.get_cls()) - - -Version = _LazyImport("Version") - - -class NvfuserVersion(str): - @classmethod - def _convert_to_version(cls, ver: Any) -> Version: - if isinstance(ver, str): - return Version(ver.split("+")[0]) - elif isinstance(ver, Version.get_cls()): - return ver - else: - raise ValueError("can't convert {} to Version".format(ver)) - - def _cmp_version(self, other: Any, method: str) -> Version: - return getattr(NvfuserVersion._convert_to_version(self), method)( - NvfuserVersion._convert_to_version(other) - ) - - -for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]: - setattr( - NvfuserVersion, - cmp_method, - lambda x, y, method=cmp_method: x._cmp_version(y, method), - ) - -__version__ = NvfuserVersion(_version_str) diff --git a/python/nvfuser/pytorch_utils.py b/python/nvfuser/pytorch_utils.py deleted file mode 100644 index afe7cc0c4a5..00000000000 --- a/python/nvfuser/pytorch_utils.py +++ /dev/null @@ -1,190 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -import torch - -from ._C import DataType - -import ctypes -import functools -import gc -from typing import Type, Union, Tuple - -NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]] - -_torch_dtype_to_nvfuser_dtype_map = { - torch.cdouble: DataType.ComplexDouble, - torch.cfloat: DataType.ComplexFloat, - torch.double: DataType.Double, - torch.float: DataType.Float, - torch.half: DataType.Half, - torch.bfloat16: DataType.BFloat16, - torch.float8_e4m3fn: DataType.Float8_e4m3fn, - torch.float8_e5m2: DataType.Float8_e5m2, - torch.float8_e8m0fnu: DataType.Float8_e8m0fnu, - torch.long: DataType.Int, - torch.int: DataType.Int32, - torch.bool: DataType.Bool, - # Python scalars - complex: DataType.ComplexDouble, - float: DataType.Double, - int: DataType.Int, - bool: DataType.Bool, -} - -if hasattr(torch, "float4_e2m1fn_x2"): - _torch_dtype_to_nvfuser_dtype_map[ - torch.float4_e2m1fn_x2 - ] = DataType.Float4_e2m1fn_x2 - - -def python_scalar_to_nvfuser_dtype(a: Union[int, float, complex, bool]): - return _torch_dtype_to_nvfuser_dtype_map[type(a)] - - -def torch_dtype_to_nvfuser_dtype(dtype: Union[torch.dtype, NumberTypeType]): - """ - Translates from torch.dtype to nvFuser's DataType enum - """ - return _torch_dtype_to_nvfuser_dtype_map[dtype] - - -def get_device_properties() -> Tuple[int, float]: - """ - Computes device properties using ctypes and cuda. - Note: Consider using CUDA-Python when CUDA support >= 12.0. - """ - libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll", "cuda.dll") - for libname in libnames: - try: - cuda = ctypes.CDLL(libname) - except OSError: - continue - else: - break - else: - raise OSError("could not load any of: " + " ".join(libnames)) - - # Device attribute enums (taken from cuda.h) - # https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1ge12b8a782bebe21b1ac0091bf9f4e2a3 - - CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 1 - CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK = 8 - CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK = 12 - CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13 - CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36 - CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH = 37 - CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE = 38 - CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39 - - device_properties = {} - device = torch.cuda.current_device() - cuda_properties = torch.cuda.get_device_properties(device) - - device_properties["gpu_name"] = cuda_properties.name - device_properties["gpu_compute_capability_major"] = cuda_properties.major - device_properties["gpu_compute_capability_minor"] = cuda_properties.minor - device_properties["gpu_gmem_bytes"] = cuda_properties.total_memory - device_properties["gpu_sm_count"] = cuda_properties.multi_processor_count - - max_threads_per_block = ctypes.c_int() - cuda.cuDeviceGetAttribute( - ctypes.byref(max_threads_per_block), - CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK, - device, - ) - device_properties["gpu_max_threads_per_block"] = max_threads_per_block.value - - smem_per_block = ctypes.c_int() - cuda.cuDeviceGetAttribute( - ctypes.byref(smem_per_block), - CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, - device, - ) - device_properties["gpu_smem_bytes_per_block"] = smem_per_block.value - - max_reg_per_block = ctypes.c_int() - cuda.cuDeviceGetAttribute( - ctypes.byref(max_reg_per_block), - CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, - device, - ) - device_properties["gpu_regs_per_block"] = max_reg_per_block.value - - max_clock_khz = ctypes.c_int() - cuda.cuDeviceGetAttribute( - ctypes.byref(max_clock_khz), - CU_DEVICE_ATTRIBUTE_CLOCK_RATE, - device, - ) - device_properties["gpu_clock_rate_khz"] = max_clock_khz.value - - l2_cache_size = ctypes.c_int() - cuda.cuDeviceGetAttribute( - ctypes.byref(l2_cache_size), CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, device - ) - device_properties["gpu_l2_bytes"] = l2_cache_size.value - - memory_clock_rate = ctypes.c_int() - cuda.cuDeviceGetAttribute( - ctypes.byref(memory_clock_rate), CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device - ) - device_properties["gpu_mem_clock_khz"] = memory_clock_rate.value - - memory_bus_width = ctypes.c_int() - cuda.cuDeviceGetAttribute( - ctypes.byref(memory_bus_width), - CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, - device, - ) - device_properties["gpu_mem_bus_width"] = memory_bus_width.value - - max_threads_per_sm = ctypes.c_int() - cuda.cuDeviceGetAttribute( - ctypes.byref(max_threads_per_sm), - CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR, - device, - ) - device_properties["gpu_max_threads_per_sm"] = max_threads_per_sm.value - - # Compute peak bandwidth in GBps - peak_bandwidth = (2 * memory_bus_width.value * memory_clock_rate.value) / (1e6 * 8) - device_properties["gpu_peak_bandwidth_gbps"] = peak_bandwidth - - return device_properties - - -DEVICE_PROPERTIES = None -if torch.cuda.is_available(): - # Loading libraries will raise errors on non-CUDA machines. - DEVICE_PROPERTIES = get_device_properties() - - -def retry_on_oom_or_skip_test(func): - """Decorator: upon torch.OutOfMemoryError clear the cache and retry test""" - - @functools.wraps(func) - def retried_func(*args, **kwargs): - try: - output = func(*args, **kwargs) - except torch.OutOfMemoryError: - pass - else: - return output - - # We have hit an OOM error, so clear the cache and retry - gc.collect() - torch.cuda.empty_cache() - - try: - output = func(*args, **kwargs) - except torch.OutOfMemoryError as e: - # If we hit an OOM this time, then skip the test - import pytest - - pytest.skip(f"Test failed due to OutOfMemoryError: {e}") - return - - return output - - return retried_func diff --git a/python/python_frontend/fusion_cache.cpp b/python/python_frontend/fusion_cache.cpp deleted file mode 100644 index 26bc0e59680..00000000000 --- a/python/python_frontend/fusion_cache.cpp +++ /dev/null @@ -1,953 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include - -#include -#include -#include -#include -#include -#include -#include "base.h" - -#include -namespace fs = std::filesystem; - -#ifdef _WIN32 -#include -#else -#include -#include -#endif - -namespace nvfuser::python_frontend { - -namespace { -using BinaryBuffer = std::vector; - -// Generate temporary file for this FusionCacheBuffer -std::string getSerdeTmpFile() { -#ifdef _WIN32 - const unsigned int pid = GetCurrentProcessId(); -#else - const unsigned int pid = getpid(); -#endif // _WIN32 - std::stringstream ss; - ss << "nvf_serde_tmp_" << pid; - return ss.str(); -} - -std::string getSerdeFile(std::optional device_id) { - auto device_prop = (device_id.has_value()) - ? at::cuda::getDeviceProperties( - static_cast(device_id.value())) - : at::cuda::getCurrentDeviceProperties(); - int cuda_major = 0; - int cuda_minor = 0; - NVFUSER_NVRTC_SAFE_CALL(nvrtcVersion(&cuda_major, &cuda_minor)); - - std::stringstream ss; - ss << "nvf_serde"; - if (device_id.has_value()) { - ss << "_rank" << device_id.value(); - } - ss << "_device" << device_prop->major << "_" << device_prop->minor; - ss << "_cuda" << cuda_major << "_" << cuda_minor; - return ss.str(); -} - -// Get std::filesystem::path to specified file in nvfuser kernel database -// directory. -fs::path getSerdeFilePath(const std::string& file_name) { - fs::path kernel_db_path = fs::temp_directory_path() / "nvfuser_kernel_db"; - if (!fs::is_directory(kernel_db_path)) { - try { - fs::create_directory(kernel_db_path); - } catch (const std::exception& e) { - NVF_ERROR( - "Unable to create nvFuser Kernel DB directory! ", - kernel_db_path.string(), - e.what()); - } - } - return kernel_db_path / file_name; -} - -BinaryBuffer openFusionCache(std::string filename) { - FUSER_PERF_SCOPE("Flatbuffers::openFusionCache"); - auto file_handle = std::fopen(filename.c_str(), "rb"); - NVF_CHECK(file_handle != nullptr, "Failed to open FusionCache buffer."); - - auto file_path = fs::path(filename.c_str()); - auto file_size = fs::file_size(file_path); - NVF_CHECK(file_size > 0, "FusionCache buffer is empty."); - - BinaryBuffer buffer(file_size); - size_t read_status = - std::fread(buffer.data(), sizeof(uint8_t), file_size, file_handle); - NVF_CHECK( - read_status == file_size, "Failed to read entire FusionCache buffer.\n"); - std::fclose(file_handle); - return buffer; -} - -// This check function only throws errors if strict flag is enabled. -const serde::FusionCache* verifyFusionCache( - const BinaryBuffer& buffer, - std::optional device_id) { - FUSER_PERF_SCOPE("Flatbuffers::verifyFusionCache"); - auto fusion_cache_buffer = serde::GetFusionCache(buffer.data()); - - // Check flatbuffer integrity - flatbuffers::Verifier v(buffer.data(), buffer.size()); - NVF_CHECK( - fusion_cache_buffer->Verify(v), - "Failed to verify the integrity of FusionCache buffer."); - - // Check schema version - NVF_CHECK( - serde::FusionCacheBufferHasIdentifier(buffer.data()), - "Failed to verify the schema version of the FusionCache buffer"); - - // Check device major and minor versions - auto device_prop = (device_id.has_value()) - ? at::cuda::getDeviceProperties( - static_cast(device_id.value())) - : at::cuda::getCurrentDeviceProperties(); - NVF_CHECK( - device_prop->major == fusion_cache_buffer->device_major() && - device_prop->minor == fusion_cache_buffer->device_minor(), - false, - "Expected cuda version ", - device_prop->major, - ".", - device_prop->minor, - " but flatbuffer has cuda version ", - fusion_cache_buffer->device_major(), - ".", - fusion_cache_buffer->device_minor()); - - // Check cuda installation - int cuda_major = 0; - int cuda_minor = 0; - NVFUSER_NVRTC_SAFE_CALL(nvrtcVersion(&cuda_major, &cuda_minor)); - NVF_CHECK( - cuda_major == fusion_cache_buffer->cuda_major() && - cuda_minor == fusion_cache_buffer->cuda_minor(), - "Expected cuda version ", - cuda_major, - ".", - cuda_minor, - " but flatbuffer has cuda version ", - fusion_cache_buffer->cuda_major(), - ".", - fusion_cache_buffer->cuda_minor()); - - return fusion_cache_buffer; -} - -} // namespace - -void serialize() { - auto tmp_file_path = getSerdeFilePath(getSerdeTmpFile()); - FusionCache::get()->serialize(tmp_file_path); - - // Save to a per-process temporary file to avoid multi-process contention. - // Then, rename the temporary file to the actual file. If the actual file - // already exists, then the rename may fail or replace the actual file. - // Files replaced through this process should remain extant if they are being - // read because of UNIX filesystem properties, but this behavior is - // unverified. - auto file_path = - getSerdeFilePath(getSerdeFile(FusionCache::get()->deviceId())); - std::error_code rename_ec; - fs::rename(tmp_file_path, file_path, rename_ec); - - // Failed to replace common workspace, so remove the temporary file. - if (rename_ec) { - try { - fs::remove(tmp_file_path); - std::cout << "Removed temporary file because we could not replace common " - "workspace. Exception:\t" - << rename_ec.message() << std::endl; - } catch (const std::exception& e) { - std::cout << "Failed to delete temporary file. Exception:\t" << e.what() - << std::endl; - } - } -} - -// FusionCache static data member definitions for singleton usage -std::mutex FusionCache::singleton_lock_; -FusionCache* FusionCache::singleton_ = nullptr; - -UserSchedule::UserSchedule(int64_t fusion_id, int64_t device_id) - : scheduled_fusion(nullptr), - executor(nullptr), - fusion_id_(fusion_id), - device_id_(device_id) { - scheduled_fusion = std::make_unique(); - executor = - std::make_unique(fusion_id, /*concrete_id=*/device_id); -} - -bool UserSchedule::canSchedule(const SchedulerType& scheduler_type) { - return Schedule::canSchedule(scheduler_type, fusion(), *runtimeInfo()); -} - -std::tuple UserSchedule::canScheduleDebug( - const SchedulerType& scheduler_type) { - // Enable collection of messages from canScheduleRejectReason - DebugDumpOptionsGuard debug_dump_options_guard; - DebugDumpOptionsGuard::getCurOptions().set( - DebugDumpOption::FusionSegmenterLog); - - // Send debug messages to stringstream - std::stringstream ss; - DebugStreamGuard dsg(ss); - - bool can_schedule = canSchedule(scheduler_type); - return std::make_tuple(can_schedule, ss.str()); -} - -HeuristicParams* UserSchedule::computeHeuristics(SchedulerType scheduler_type) { - NVF_CHECK( - scheduler == nullptr, - "Scheduler is already defined for this UserSchedule"); - scheduler = SchedulerEntry::makeSchedulerInstance(scheduler_type); - SchedulerRuntimeInfo& runtime_info_ref = *runtimeInfo(); - - NVF_ERROR( - scheduler->canScheduleCompileTime(fusion()) && - scheduler->canScheduleRunTime(fusion(), runtime_info_ref), - "Could not schedule fusion with ", - scheduler_type, - " scheduler."); - - NVF_CHECK( - heuristic_params == nullptr, - "Heuristic Scheduler is already defined for this UserSchedule"); - - // Set scheduler hyperparameters if available for InnerOuterPersistent - // scheduler - // TODO:: extend to other schedulers if necessary - if (scheduler_type == SchedulerType::InnerOuterPersistent && - scheduler_hyperparams) { - scheduler->setSchedulerHyperParameters(scheduler_hyperparams.get()); - } - - heuristic_params = scheduler->computeHeuristics( - fusion(), runtime_info_ref, data_cache.get()); - return heuristic_params.get(); -} - -void UserSchedule::schedule() { - NVF_CHECK( - scheduler != nullptr, "Scheduler is not defined for this UserSchedule"); - NVF_CHECK( - heuristic_params != nullptr, - "Heuristic Scheduler is not defined for this UserSchedule"); - scheduler->schedule(fusion(), heuristic_params.get()); -} - -void UserSchedule::scheduleWithType(SchedulerType scheduler_type) { - // Get default heuristics for scheduler and then schedule fusion. - computeHeuristics(scheduler_type); - schedule(); -} - -FusionSchedules::FusionSchedules(int64_t fusion_id) - : auto_gen_schedules(nullptr), - user_def_schedules(), - last_user_def_scheduled_ir(nullptr), - last_user_def_executor(nullptr), - scheds_lock(), - fusion_id_(fusion_id) { - presched_fusion_ = std::make_unique(); -} - -Fusion* FusionSchedules::preschedFusion() { - if (presched_fusion_ != nullptr) { - return presched_fusion_.get(); - } - - // Ideally, we shouldn't have to access FusionExecutorCache::fusion() so - // FusionExecutorCache has the flexibility to modify it in place or even - // delete it. Currently, this is only needed for cloning an - // nvfuser.FusionDefinition. See exec_nvfuser's is_clonable parameter. After - // FusionDefinition.__exit__, FusionSchedules.presched_fusion_ is moved to - // FusionExecutorCache and therefore becomes null. - if (auto_gen_schedules != nullptr) { - return auto_gen_schedules->fusion(); - } - - NVF_THROW("Prescheduled Fusion is unexpectedly null!"); -} - -void FusionSchedules::createExecutorIfNotExists() { - if (auto_gen_schedules == nullptr) { - auto_gen_schedules = std::make_unique( - std::move(presched_fusion_), fusion_id_); - presched_fusion_ = nullptr; - } -} - -TrieNode::TrieNode(RecordFunctor* rec, TrieNode* _parent, size_t _fusion_id) - : record(rec), - children(), - fusion_id(_fusion_id), - visits(0), - parent(_parent), - trie_node_lock() {} - -bool TrieNode::isTerminal() const { - return (record.get()->recordType() == serde::RecordType::End); -} - -void TrieNode::setException(const char* e) { - std::lock_guard guard(trie_node_lock); - exception = e; -} - -std::optional TrieNode::getException() { - std::lock_guard guard(trie_node_lock); - return exception; -} - -flatbuffers::Offset TrieNode::serialize( - flatbuffers::FlatBufferBuilder& builder, - const std::map& - map_record_functor_to_trie_node_id) { - // Map children TrieNode to its corresponding Integer index - std::vector children_trie_node_ids; - children_trie_node_ids.reserve(children.size()); - for (auto&& c : children) { - size_t id = map_record_functor_to_trie_node_id.at(c.first); - children_trie_node_ids.push_back(id); - } - - return serde::CreateTrieNodeDirect( - builder, - record->serialize(builder), - &children_trie_node_ids, - fusion_id, - visits, - isTerminal()); -} - -FusionCache* FusionCache::get( - size_t max_fusions, - std::optional selected_device, - bool load_from_default_workspace) { - FUSER_PERF_SCOPE("FusionCache::get"); - std::lock_guard guard(singleton_lock_); - if (singleton_ == nullptr) { - singleton_ = new FusionCache(max_fusions, selected_device); - - // Deserialize cache hierarchy from common workspace automatically - auto file_path = - getSerdeFilePath(getSerdeFile(singleton_->deviceId())).native(); - if (load_from_default_workspace && fs::exists(file_path)) { - try { - singleton_->deserialize(file_path); - - // Check if deserialized cache exceeds max_fusions limit - if (singleton_->fusions_.size() > max_fusions) { - std::cout - << "Warning: Deserialized cache contains " - << singleton_->fusions_.size() - << " fusions, which exceeds the requested max_fusions limit of " - << max_fusions << ". Resetting cache." << std::endl; - - // Delete incompatible workspace - std::error_code remove_ec; - fs::remove(file_path, remove_ec); - if (remove_ec) { - std::cout << "Failed to delete common workspace. Exception:\t" - << remove_ec.message() << std::endl; - } - - // Reset FusionCache - delete singleton_; - singleton_ = new FusionCache(max_fusions, selected_device); - } - } catch (const std::exception& deserialize_exception) { - // The saved workspace can become out-of-date between nvfuser updates. - // Send warning and delete the incompatible workspace. - // A new workspace will be saved upon program exit. - std::cout << "Warning: Failed to deserialize common workspace.\n" - << "A new workspace will be saved upon program exit after " - "deleting incompatible workspace." - << std::endl; - - // Hide exception message by default because it should be resolved by - // saving a new workspace. - if (isOptionEnabled(EnableOption::ParallelSerde)) { - std::cout << "Remove `parallel_serde` from NVFUSER_ENABLE " - "environment variable to print exception message." - << std::endl; - } else { - std::cout << deserialize_exception.what() << std::endl; - } - - // Delete incompatible workspace - std::error_code remove_ec; - fs::remove(file_path, remove_ec); - if (remove_ec) { - std::cout << "Failed to delete common workspace. Exception:\t" - << remove_ec.message() << std::endl; - } - - // Reset FusionCache if there is an issue with the current workspace. - delete singleton_; - singleton_ = new FusionCache(max_fusions, selected_device); - } - } - } - NVF_CHECK( - max_fusions >= singleton_->fusions_.size(), - "The max fusions is set less than the number of fusions in the cache."); - singleton_->max_fusions_ = max_fusions; - return singleton_; -} - -size_t FusionCache::numFusions() const { - return fusions_.size(); -} - -std::optional FusionCache::deviceId() const { - return device_id_; -} - -void FusionCache::print(std::ostream& os) const { - os << "Fusions by id:" << std::endl; - std::vector stack; - stack.push_back(root_.get()); - - while (!stack.empty()) { - TrieNode* node = stack.back(); - stack.pop_back(); - - if (node->isTerminal()) { - std::vector rev_fusion_records; - TrieNode* end = node->parent; - while (end) { - if (end->record->recordType() != serde::RecordType::Start) { - rev_fusion_records.emplace_back(end); - } - end = end->parent; - } - - os << node->fusion_id << ":" << std::endl; - std::for_each( - rev_fusion_records.rbegin(), - rev_fusion_records.rend(), - [&os](const auto elem) { - os << " "; - elem->record->print(os); - os << std::endl; - }); - } else { - for (auto& iter : node->children) { - stack.push_back(iter.second.get()); - } - } - } -} - -void FusionCache::stats(std::ostream& os) const { - os << "Total Fusions: " << fusions_.size() << "\n"; - - // Does not make sense to print stats if the cache is disabled. - if (!fusions_.empty()) { - os << "Cache Hits by Fusion Id:\n"; - size_t total_cache_hits = 0; - for (size_t i = 0; i < terminal_nodes_.size(); ++i) { - // The first visit is a miss! - auto visits = terminal_nodes_[i]->visits - 1; - total_cache_hits += visits; - os << "\t" << i << " -> " << visits << " hits\n"; - } - - auto hit_rate = static_cast(total_cache_hits) / - static_cast(root_->visits) * 100.0; - os << "Cache Lookups: " << root_->visits; - os << " Cache Hits: " << total_cache_hits; - os << " Hit Rate: " << hit_rate << "%\n"; - } -} - -void FusionCache::reset() { - std::lock_guard guard(singleton_lock_); - if (singleton_ != nullptr) { - size_t max_fusions = singleton_->max_fusions_; - std::optional device_id = singleton_->device_id_; - delete singleton_; - singleton_ = new FusionCache(max_fusions, device_id); - } -} - -FusionCache::FusionCache( - size_t max_fusions, - std::optional selected_device) - : max_fusions_(max_fusions), - device_id_(selected_device), - root_(nullptr), - fusions_(), - terminal_nodes_(), - user_def_input_encodings_() { - RecordFunctor* start = new StartRecord(); - root_ = std::make_unique(start); -} - -// In order to keep queries fast, this method does not lock. -// In the worst case, the query should fail and if you try to create a child, -// it should give you back an already created child if two threads are walking -// the trie at the same time with the same definition. -std::optional FusionCache::queryChildren( - TrieNode* node, - RecordFunctor* rec) const { - NVF_CHECK( - !node->isTerminal(), "There should be no children from a Terminal Node!"); - NVF_CHECK(rec, "Record is null!"); - auto trie_node = node->children.find(rec); - if (trie_node == std::end(node->children)) { - return std::nullopt; - } else { - ++(trie_node->second.get()->visits); - return std::optional(trie_node->second.get()); - } -} - -FusionSchedules* FusionCache::queryFusionSchedules(size_t fusion_id) const { - NVF_CHECK( - fusion_id < fusions_.size(), - "Invalid scheduler query for id: ", - fusion_id); - FusionSchedules* ptr = fusions_.at(fusion_id).get(); - NVF_CHECK(ptr != nullptr, "Unexpected null FusionSchedules object."); - return ptr; -} - -std::optional FusionCache::queryUserScheduleId( - const FusionSchedules* scheds, - const KernelArgumentHolder& args) { - std::optional result = std::nullopt; - - auto& user_scheds = scheds->user_def_schedules; - if (!user_scheds.empty()) { - auto input_id = user_def_input_encodings_.lookupId(args); - auto user_sched = user_scheds.find(input_id.id); - if (user_sched != user_scheds.end()) { - return std::optional(user_sched->first); - } - } - return result; -} - -const UserSchedule& FusionCache::queryUserSchedule( - const FusionSchedules* scheds, - size_t id, - int device) const { - auto& user_scheds = scheds->user_def_schedules; - NVF_CHECK( - !user_scheds.empty(), - "Expecting there to be at least one user schedule!"); - auto user_sched = user_scheds.find(id); - NVF_CHECK( - user_sched != user_scheds.end(), "Lookup of non-existent user schedule!"); - return user_sched->second.at(device); -} - -bool FusionCache::existUserSchedule( - const FusionSchedules* scheds, - KernelArgumentHolder args, - int device) { - // Short-Circuit: No user schedules - if (scheds->user_def_schedules.empty()) { - return false; - } - args.setDeviceIndex(device); - // Short-Circuit: User schedule does not exist for fusion and args. - InputsIdLookup::IdLookupReturn input_id = - user_def_input_encodings_.lookupId(args); - auto user_sched_iter = scheds->user_def_schedules.find(input_id.id); - if (user_sched_iter == scheds->user_def_schedules.end()) { - return false; - } - - // A vector of user schedules exists for fusion and inputs. - // Now, check that user schedule exists for specified device. - return device < (int)user_sched_iter->second.size(); -} - -TrieNode* FusionCache::createChild(TrieNode* node, RecordFunctor* rec) { - FUSER_PERF_SCOPE("FusionCache::createChild"); - TrieNode* child = nullptr; - NVF_CHECK( - !node->isTerminal(), "Cannot create a trie node from a terminal node!"); - NVF_CHECK(rec, "Record is null!"); - - std::lock_guard guard(node->trie_node_lock); - - // As a thread-safety compromise for fast queries, the node is re-queried - // prior to child creation incase another thread slipped in the node. - auto child_node = queryChildren(node, rec); - if (child_node.has_value()) { - child = child_node.value(); - } else { - size_t fusion_id = 0; - if (rec->recordType() == serde::RecordType::End) { - NVF_CHECK( - (fusions_.size() + 1) <= max_fusions_, - "The number of fusions in nvfuser has exceeded ", - max_fusions_, - "fusions. The max_fusions for the FusionCache might need to be ", - "increased if the max number is not being exceeded due to an error."); - fusion_id = fusions_.size(); - fusions_.emplace_back(std::make_unique(fusion_id)); - } - - // Copying the record owned by the FusionDefinition that calls this function - // so the trie owns a copy when the FusionDefinition gets destroyed rather - // than managing a shared pointer that would only share with - // FusionDefinition that creates a trie node but not cache lookups - RecordFunctor* new_rec = rec->clone(); - node->children[new_rec] = - std::make_unique(new_rec, node, fusion_id); - child = node->children[new_rec].get(); - NVF_CHECK(child, "Created child of TrieNode should not be null!"); - ++(child->visits); - if (rec->recordType() == serde::RecordType::End) { - terminal_nodes_.push_back(node->children[new_rec].get()); - } - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - std::stringstream ss; - new_rec->print(ss); - debug() << "\nFusionDefinition: Create new trie node for: " << ss.str() - << "\n"; - } - } - return child; -} - -UserSchedule* FusionCache::createUserSchedule( - FusionSchedules* scheds, - KernelArgumentHolder args, - int device, - bool overwrite_existing_schedule) { - FUSER_PERF_SCOPE("FusionCache::createUserSchedule"); - std::lock_guard guard(scheds->scheds_lock); - args.setDeviceIndex(device); - auto& user_scheds = scheds->user_def_schedules; - auto input_id = user_def_input_encodings_.lookupId(args); - - // Create UserSchedule for device - if (user_scheds[input_id.id].count(device) == 0) { - user_scheds[input_id.id].emplace( - device, UserSchedule(scheds->fusion_id_, device)); - } else { - if (!overwrite_existing_schedule) { - TORCH_WARN( - "You are overwriting the current user schedule for a definition!"); - } - user_scheds[input_id.id].at(device) = - UserSchedule(scheds->fusion_id_, device); - } - - return &user_scheds[input_id.id].at(device); -} - -TrieNode* FusionCache::rootTriePtr() { - ++(root_.get()->visits); - return root_.get(); -} - -void FusionCache::serialize(std::string filename) const { - FUSER_PERF_SCOPE("FusionCache::serialize"); - flatbuffers::FlatBufferBuilder builder(1024); - // TODO: Serialize Fusion IR containers - - // 1. Flattened the TrieStructure using breadth-first search - // 2. Map RecordFunctor pointer to its position in flattened order - std::map map_record_functor_to_trie_node_id; - std::vector bfs_order; - std::deque queue = {root_.get()}; - while (!queue.empty()) { - TrieNode* current_node = queue.front(); - queue.pop_front(); - - map_record_functor_to_trie_node_id.emplace( - current_node->record.get(), bfs_order.size()); - bfs_order.push_back(current_node); - - for (auto&& child : current_node->children) { - queue.push_back(child.second.get()); - } - } - - // 3. Serialize TrieNode in Breadth-First Search (BFS) order - // - // Note 1) All TrieNode pointers are mapped to their corresponding index in - // BFS traversal order. - // - // Note 2) We cannot create nested Flatbuffer objects. e.g., All Flatbuffer - // objects MUST be created before the start of the table they are referenced - // in. - // - // Thus, it is simplier to get the entire BFS order first, and then serialize - // the flattened Trie structure. - std::vector> fb_nodes; - for (TrieNode* node : bfs_order) { - auto serialized_trie_node = - node->serialize(builder, map_record_functor_to_trie_node_id); - fb_nodes.push_back(serialized_trie_node); - } - - // 4. Map the terminal nodes to their BFS positions. - // 5. Serialize each FusionExecutorCache for each fusion. - std::vector terminal_node_idx; - terminal_node_idx.reserve(terminal_nodes_.size()); - - using fb_fusion_executor_cache = - flatbuffers::Offset; - std::vector fb_auto_gen_schedules; - fb_auto_gen_schedules.reserve(terminal_nodes_.size()); - - for (TrieNode* node : terminal_nodes_) { - if (node->getException().has_value()) { - // Skip error nodes, which don't map to any FusionSchedules in the cache. - // Without this, queryFusionSchedules creates an empty FusionSchedules - // that's not executable. - continue; - } - - FusionSchedules* schedule = queryFusionSchedules(node->fusion_id); - if (schedule->auto_gen_schedules == nullptr) { - // This fusion has been created but never executed. It doesn't save us - // anything to serialize that. - continue; - } - - terminal_node_idx.push_back( - map_record_functor_to_trie_node_id.at(node->record.get())); - - fb_auto_gen_schedules.push_back( - schedule->auto_gen_schedules->serialize(builder)); - } - - auto device_prop = at::cuda::getCurrentDeviceProperties(); - int cuda_major = 0; - int cuda_minor = 0; - NVFUSER_NVRTC_SAFE_CALL(nvrtcVersion(&cuda_major, &cuda_minor)); - - // 6. Build FusionCache flatbuffer object - // See table definition for FusionCache in serde/fusion_cache.fbs - auto fusion_cache = serde::CreateFusionCacheDirect( - builder, - max_fusions_, - &fb_nodes, - &terminal_node_idx, - &fb_auto_gen_schedules, - KernelExecutor::getGlobalFusionCount(), - device_prop->major, - device_prop->minor, - cuda_major, - cuda_minor); - builder.Finish(fusion_cache, /*file_identifier=*/"NV01"); - - // 6. Write flatbuffer binary to file - auto fb = builder.GetBufferSpan(); - auto file_handle = std::fopen(filename.c_str(), "wb"); - size_t write_status = - std::fwrite(fb.data(), sizeof(uint8_t), fb.size(), file_handle); - NVF_ERROR( - write_status == fb.size(), - "Failed to write entire FusionCache Flatbuffer.\n"); - std::fclose(file_handle); -} - -void FusionCache::deserialize(std::string filename) { - // See table definition for FusionCache in serde/fusion_cache.fbs - // 0. Load flatbuffer binary from file - FUSER_PERF_SCOPE("FusionCache::deserialize"); - NVF_CHECK( - fusions_.empty(), - "Deserialization is prohibited if FusionCache is already populated."); - const BinaryBuffer& buffer = openFusionCache(filename); - const serde::FusionCache* fusion_cache_buffer = - verifyFusionCache(buffer, device_id_); - - // See table definition for FusionCache in serde/fusion_cache.fbs - FUSER_PERF_SCOPE("FusionCache::deserialize"); - NVF_CHECK(fusion_cache_buffer != nullptr, "Fusion Cache buffer is invalid."); - - // 0. Set static fusion count in Fusion Executor - KernelExecutor::setGlobalFusionCount( - fusion_cache_buffer->global_fusion_count()); - - // 1. Deserialize max_fusions field - max_fusions_ = fusion_cache_buffer->max_fusions(); - - // 2. Deserialize fusions: (Fusion) and structure: (TrieNode) fields - int64_t num_fusions = 0; - for (const auto i : - arange(fusion_cache_buffer->auto_gen_schedules()->size())) { - num_fusions = std::max( - num_fusions, - fusion_cache_buffer->auto_gen_schedules()->Get(i)->fusion_id() + 1); - } - std::generate_n(std::back_inserter(fusions_), num_fusions, [] { - return std::make_unique(); - }); - - serde::RecordFunctorFactory record_functor_factory; - - using BfsState = std::pair; - std::deque queue = { - {root_.get() /* TrieNode pointer */, 0 /* structure_idx */}}; - - // state_queue holds the FusionState for each BfsState in the queue. - std::deque> state_queue; - - // Create empty fusion container for root node - state_queue.emplace_back(std::make_unique()); - - // bfs_order is used to map indices in the structure field to their - // corresponding TrieNode pointers. It is used to reconstruct the - // terminal_nodes vector. - std::vector bfs_order; - - // Starting from the root node, we build the Trie structure in breadth-first - // (BFS) order. - while (!queue.empty()) { - auto& [trie_ptr, structure_idx] = queue.front(); - - // Update BFS order - bfs_order.push_back(trie_ptr); - - // Get corresponding flatbuffer object for current TrieNode - auto fb_trie_node = fusion_cache_buffer->structure()->Get(structure_idx); - - // While traversing the Trie Structure, build the Fusion Container by - // adding the TrieNode's RecordFunctor - auto state = state_queue.front().get(); - state->addRecord(trie_ptr->record.get()->clone()); - - // Deserialize Table TrieNode => Field: visits (ulong) - trie_ptr->visits = fb_trie_node->visits(); - - // Build fusion container if current node is a terminal node - if (fb_trie_node->is_terminal()) { - NVF_CHECK( - fb_trie_node->children()->size() == 0, - "This terminal node should not have any children.") - NVF_CHECK( - fb_trie_node->record()->type() == serde::RecordType::End, - "This terminal node should have an EndRecord RecordFunctor") - NVF_CHECK( - trie_ptr->fusion_id == fb_trie_node->fusion_id(), - "The fusion id for this TrieNode should already be set.") - FusionSchedules* fs = queryFusionSchedules(fb_trie_node->fusion_id()); - Fusion* fusion = fs->preschedFusion(); - try { - // There could be bad fusion in the serialization. - state->buildFusionIr(fusion); - } catch (const std::exception& e) { - // catch exception and setException for the terminal node - trie_ptr->setException(e.what()); - } - // The FusionState creates a mapping from CPP Fusion to its State objects. - // Since the CPP Fusion is cached in FusionCache and the FusionState is - // temporary, the information linking CPP Fusion and Python - // FusionDefinition is stored in FusionCache. - fs->inputs_fid_ = state->inputs(); - fs->outputs_fid_ = state->outputs(); - fs->extents_fid_ = state->extents(); - fs->map_value_to_fid_ = state->getValueMap(); - } - - // Table TrieNode => Field: children: [ulong] - // Create Children TrieNode - for (auto child_bfs_idx : *fb_trie_node->children()) { - auto fb_child_trie_node = - fusion_cache_buffer->structure()->Get(child_bfs_idx); - - // Create child RecordFunctor - auto serde_buffer = fb_child_trie_node->record(); - auto rec = - record_functor_factory.parse(serde_buffer->type(), serde_buffer); - - // Deserialize the record and fusion id fields in the TrieNode table - auto status = trie_ptr->children.emplace( - rec, - std::make_unique( - rec, trie_ptr, fb_child_trie_node->fusion_id())); - NVF_CHECK( - status.second, - "Fusion-Cache Deserialization: Failed to add child to the current " - "TrieNode."); - - // Add child TrieNode to BFS queue - queue.emplace_back( - status.first->second.get() /* TrieNode pointer */, child_bfs_idx); - state_queue.emplace_back(state->clone()); - } - - // Destroy current fusion state - queue.pop_front(); - state_queue.pop_front(); - } - - std::atomic detect_exception_in_thread_pool{false}; - // Deserialize terminal_nodes field in the FusionCache table - for (auto idx : arange(fusion_cache_buffer->terminal_nodes()->size())) { - auto node_idx = fusion_cache_buffer->terminal_nodes()->Get(idx); - auto trie_node = bfs_order.at(node_idx); - terminal_nodes_.push_back(trie_node); - - auto fb_fec_node = fusion_cache_buffer->auto_gen_schedules()->Get(idx); - FusionSchedules* fusion_schedule = - queryFusionSchedules(trie_node->fusion_id); - // Create an executor so the following code can deserialize it. - fusion_schedule->createExecutorIfNotExists(); - - if (isOptionEnabled(EnableOption::ParallelSerde)) { - // Parallelize the deserialization of each FusionExecutorCache. - getThreadPool()->run([=, &detect_exception_in_thread_pool]() { - FUSER_PERF_SCOPE("FusionCache::deserializeFusionParallel"); - try { - fusion_schedule->auto_gen_schedules->deserialize( - fb_fec_node, (int64_t)trie_node->fusion_id); - } catch (const std::exception& e) { - // Set flag inside lambda so we can throw an exception after thread - // pool completes its work. - detect_exception_in_thread_pool.store(true); - } - }); - } else { - FUSER_PERF_SCOPE("FusionCache::deserializeFusionSerial"); - fusion_schedule->auto_gen_schedules->deserialize( - fb_fec_node, (int64_t)trie_node->fusion_id); - } - } - - if (isOptionEnabled(EnableOption::ParallelSerde)) { - // Wait until all fusion executor caches are deserialized - getThreadPool()->waitWorkComplete(); - NVF_ERROR( - !detect_exception_in_thread_pool.load(), - "Detected exception while deserializing fusions in parallel.\n", - "Print the exception message by disabling parallel serialization. " - "i.e., Remove `parallel_serde` from NVFUSER_ENABLE environment " - "variable."); - } -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/fusion_cache.h b/python/python_frontend/fusion_cache.h deleted file mode 100644 index aa551ae6196..00000000000 --- a/python/python_frontend/fusion_cache.h +++ /dev/null @@ -1,320 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once -#include -#include - -#include -#include -#include -#include -#include - -#include -#include - -namespace nvfuser::python_frontend { - -//! \struct UserSchedule -//! \brief A container to hold a scheduled Fusion IR as well as an executor -//! to contain the corresponding generated kernel. -struct UserSchedule { - UserSchedule(int64_t fusion_id, int64_t device_id); - - //! Runtime information for schedulers - std::unique_ptr runtime_info; - - //! The scheduler heuristic for this UserSchedule - std::unique_ptr scheduler; - - //! The parameters for scheduler heuristic. - std::unique_ptr heuristic_params; - - //! The compile-time data cache. - std::unique_ptr data_cache; - - //! Scheduler hyperparameters for normalization schedulers - std::unique_ptr - scheduler_hyperparams; - - //! Concretized, Scheduled Fusion IR - std::unique_ptr scheduled_fusion; - - //! Generated kernel container - std::unique_ptr executor; - - //! ID of fusion in python frontend fusion cache - int64_t fusion_id_ = -1; - - //! device ID for this user schedule - int64_t device_id_ = -1; - - //! Get scheduler runtime info for UserSchedule - SchedulerRuntimeInfo* runtimeInfo() { - NVF_ERROR( - runtime_info != nullptr, - "Requires SchedulerRuntimeInfo to use heuristic schedulers"); - return runtime_info.get(); - } - - //! Get Fusion for UserSchedule - Fusion* fusion() { - NVF_ERROR( - scheduled_fusion != nullptr, - "Requires Fusion to use heuristic schedulers"); - return scheduled_fusion.get(); - } - - //! Return if we can schedule FusionDefinition with heuristic. - NVF_API bool canSchedule(const SchedulerType& heuristic); - - //! Return if we can schedule FusionDefinition with heuristic along with any - //! debug messages from canScheduleRejectReason. - NVF_API std::tuple canScheduleDebug( - const SchedulerType& scheduler_type); - - //! Create scheduler and get heuristic parameters for fusion. - NVF_API HeuristicParams* computeHeuristics(SchedulerType scheduler_type); - - //! Schedule fusion with selected heuristics and scheduler. - NVF_API void schedule(); - - //! Schedule fusion with heuristic. - NVF_API void scheduleWithType(SchedulerType scheduler_type); -}; - -//! \struct FusionSchedules -//! \brief A container for auto generated and user defined schedules -//! that correspond to compiled kernels for each complete Fusion Definition. -class FusionSchedules { - public: - explicit FusionSchedules(int64_t fusion_id = 0); - - Fusion* preschedFusion(); - - //! Called during execution to create a FusionExecutorCache. It's created - //! during execution instead of by finalizeDefinition because - //! finalizeDefinition may be followed by finalizeMultideviceSchedule which - //! can modify presched_fusion_. The if-not-exists check is necessary because - //! multiple FusionDefinitions may map to the same FusionSchedules. In that - //! case, we want to reuse the same executor. - void createExecutorIfNotExists(); - - //! Schedules Automatically generated by nvFuser for dynamic inputs. (default) - //! NOTE: The FusionExecutorCache also holds the Unscheduled Fusion IR - std::unique_ptr auto_gen_schedules; - //! Schedules defined by the user for specific input sizes. - //! They are also generated per device as all devices may not be the same. - //! Key: Input Encoding hash of Fusion inputs as is created by the - //! InputsIdLookup struct found inside of the FusionCache. - //! Value: A vector based on device_id of User Defined Fusion Schedules. - std::unordered_map> - user_def_schedules; - //! Keeps a pointer to the last scheduled Fusion IR for printing - Fusion* last_user_def_scheduled_ir; - //! Keeps a pointer to the last executed executor for printing its cuda kernel - KernelExecutor* last_user_def_executor; - //! For thread-Safe locking of Fusion Schedules - std::mutex scheds_lock; - //! ID of fusion in python frontend fusion cache - int64_t fusion_id_ = -1; - //! Fusion IDs of input arguments for FusionState - std::vector inputs_fid_; - //! IDs for Extents for TensorView input arguments for FusionState - std::vector extents_fid_; - //! Fusion IDs of output arguments for FusionState - std::vector outputs_fid_; - //! Map Fusion Val to its corresponding FusionDefinition index - std::unordered_map map_value_to_fid_; - //! stores the executor if FusionDefinition::use_multidevice_executor_ is true - std::unique_ptr multi_device_executor; - - private: - //! Holds the presched fusion that will be `std::move`d to a - //! FusionExecutorCache or MultiDeviceExecutor at first execution. - std::unique_ptr presched_fusion_; -}; - -//! \struct TrieNode -//! \brief Is the container for a Node in a prefix tree or trie -//! where each node represents a statement in a fusion definition and -//! the leaf Nodes represent a complete Fusion that is cached. - -struct TrieNode { - TrieNode( - RecordFunctor* rec, - TrieNode* _parent = nullptr, - size_t _fusion_id = 0); - - // Queries whether the entry denotes a leaf node which also represents - // a the end of Fusion entry in the cache. - bool isTerminal() const; - //! getException returns the cached Exception raise during construction of - //! Fusion. It returns std::nullopt if the no error thrown. This function is - //! called at the end of FusionDefinition::finalizeDefinition to avoid - //! silently using a bad FusionDefinition cached in FusionCache. - std::optional getException(); - //! setException is called to record exception message thrown during - //! construction of Fusion. - void setException(const char* e); - //! Serialize TrieNode using flatbuffers - NVF_API flatbuffers::Offset serialize( - flatbuffers::FlatBufferBuilder& builder, - const std::map& - map_record_functor_to_trie_node_id); - - //! An entry's primary data is the record it holds - std::unique_ptr record; - //! A hash map of the children for the current node. - //! The hash map hashes a pointer to a RecordFunctor because - //! the hash function is virtual. - std::unordered_map> children; - //! An index into FusionCache's vector of nvFuser object that holds an - //! unscheduled Fusion. The id is only valid if the entry is terminal. - size_t fusion_id; - //! Count of times the Entry is traversed - size_t visits; - //! Parent node for printing - TrieNode* parent; - //! For thread-Safe locking of a node - std::mutex trie_node_lock; - //! exception is used to track if we failed to create a valid fusion for - //! FusionDefinition at this given TrieNode - std::optional exception = std::nullopt; -}; - -//! \class FusionCache -//! \brief A singleton class used in the nvFuser python interface -//! to manage the caching of fusions. -//! -//! The fusion cache implements a prefix tree (trie) of records in order to -//! cache fusions. A leaf of the tree with a terminal node contains a -//! container for caching the kernels generated for specific fusions. -//! -//! \todo -//! Add the ability to evict a fusion. There is currently a max number -//! of fusions that is checked to prevent a runaway case. -//! -//! \note -//! Thread-Safety is assured by the Python GIL. If a no-GIL python is used -//! then further scrutiny needs to be applied to the mutexes used to limit -//! acccess to the singleton pointer, node creation, and user schedule -//! creation. Otherwise, the Python GIL provides a natural thread based mutex -//! that does not allow for multiple threads to interact. - -class FusionCache { - //! The constructor is private given the FusionCache is only constructed - //! as a singleton. - FusionCache(size_t max_fusions, std::optional selected_device); - - public: - //! Copy and Assignment of the FusionCache is not supported - //! clang-tidy: deleted member function should be public - FusionCache(const FusionCache&) = delete; - FusionCache& operator=(const FusionCache&) = delete; - - //! The next 4 public methods are the python interface methods - - //! Gets a pointer to the singleton and creates a new one if necessary - NVF_API static FusionCache* get( - size_t max_fusions = 16384, - std::optional selected_device = std::nullopt, - bool load_from_default_workspace = true); - //! Number of fusions cached - NVF_API size_t numFusions() const; - //! Get device associated with this FusionCache - NVF_API std::optional deviceId() const; - //! print cache contents - NVF_API void print(std::ostream& os) const; - //! print cache stats - NVF_API void stats(std::ostream& os) const; - //! Reset Cache to an empty state - NVF_API static void reset(); - - //! Serialize Fusion Cache using flatbuffers - NVF_API void serialize(std::string filename) const; - //! Deserialize Fusion Cache using flatbuffers - NVF_API void deserialize(std::string filename); - - //! The rest of the public methods are only used in C++ - - //! Thread-Unsafe: Queries the current trie node to see if a record matches - //! one of its children - NVF_API std::optional queryChildren( - TrieNode* node, - RecordFunctor* rec) const; - //! Query a Fusion's Schedules based on fusion id or cache id - FusionSchedules* queryFusionSchedules(size_t fusion_id) const; - //! Determine if a user schedule exists for given inputs. - bool existUserSchedule( - const FusionSchedules* scheds, - KernelArgumentHolder args, - int device); - //! Lookup the User Schedule Id and return null if one does not exist. - //! NOTE: this method cannot be const because the InputsIdLookup can - //! cause a modification to that data member for cache eviction. - std::optional queryUserScheduleId( - const FusionSchedules* scheds, - const KernelArgumentHolder& args); - //! Lookup the User Schedule based on Id - const UserSchedule& queryUserSchedule( - const FusionSchedules* scheds, - size_t id, - int device) const; - //! Thread-Safe: Creates a child node for the current cache entry and an - //! optional fusion_id is returned if the new entry is terminal - NVF_API TrieNode* createChild(TrieNode* node, RecordFunctor* rec); - //! Lookup the User Schedule based on Id - UserSchedule* createUserSchedule( - FusionSchedules* scheds, - KernelArgumentHolder args, - int device, - bool overwrite_existing_schedule = false); - //! Get the root Trie ptr - NVF_API TrieNode* rootTriePtr(); - - private: - //! The static pointer to the FusionCache - static FusionCache* singleton_; - //! Lock for accessing the singleton by multiple threads - static std::mutex singleton_lock_; - - //! The max allowed number of fusions in the cache - size_t max_fusions_; - //! A separate process is created for each device in a distributed setting. - //! Each FusionCache becomes associated with a device. - std::optional device_id_; - //! The root (start) of the prefix tree to start a cache look up of a given - //! fusion definition. - std::unique_ptr root_; - //! A vector of nvFuser Fusion IR fusions. - std::vector> fusions_; - //! A vector of Terminal trie nodes for Stats collection - std::vector terminal_nodes_; - - //! Items specifically to aid user defined schedules these data members - //! are for the mechanics of user schedule usage and don't make sense as - //! part of an abstraction - - // Inputs for user defined schedules are encoded into an integer Id - // NOTE: I would prefer this be per FusionSchedules object but the container - // is not allowed to be copied or moved. - InputsIdLookup user_def_input_encodings_; -}; - -//! Serialize Fusion Cache to common workspace -//! /tmp/nvfuser_kernel_db/nvf_serde_[cuda_major]_[cuda_minor]_[nvrtc_major]_[nvrtc_minor] -//! -//! '''python -//! # Use atexit to automatically call serialize on program exit -//! import atexit -//! atexit.register(nvfuser.serialize) -//! ''' -NVF_API void serialize(); - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/fusion_definition.cpp b/python/python_frontend/fusion_definition.cpp deleted file mode 100644 index afe3fea460c..00000000000 --- a/python/python_frontend/fusion_definition.cpp +++ /dev/null @@ -1,769 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "base.h" - -// Require namespace for perf scope instrumentation -using namespace nvfuser::inst; - -namespace nvfuser::python_frontend { - -FusionDefinition::FusionDefinition( - std::optional id, - size_t max_length, - bool use_multidevice_executor, - CommunicatorBackend backend_type) - : FusionState(), - max_length_(max_length), - fusion_id_(id), - fusion_cache_(FusionCache::get()), - trie_node_(nullptr), - prev_fusion_(nullptr), - user_sched_(nullptr), - ops(this), - sched(this), - use_multidevice_executor_(use_multidevice_executor), - backend_type_(backend_type) {} - -FusionCache* FusionDefinition::fusionCache() const { - NVF_ERROR(fusion_cache_ != nullptr, "FusionCache pointer is null!"); - return fusion_cache_; -} - -FusionDefinition* FusionDefinition::setupDefinition() { - NVF_CHECK(max_length_ > 0, "Can't make a FusionDefinition with 0 records!"); - NVF_CHECK(!id().has_value(), "Fusion Schedule is already found!"); - trie_node_ = fusionCache()->rootTriePtr(); - return this; -} - -void FusionDefinition::finalizeDefinition() { - FUSER_PERF_SCOPE("FusionDefinition::finalizeDefinition"); - auto child_node = fusionCache()->queryChildren(trie_node_, end_record_.get()); - if (!child_node.has_value()) { - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << "\nFusionDefinition: Terminal Node not found.\n"; - } - trie_node_ = fusionCache()->createChild(trie_node_, end_record_.get()); - fusion_id_ = std::optional(trie_node_->fusion_id); - try { - NVF_CHECK(id().has_value(), "Invalid fusion id!"); - - if (isDebugDumpEnabled(DebugDumpOption::PythonDefinition)) { - print(debug()); - } - - buildFusionIr(preschedFusion()); - verifyTensorDimensions(); - } catch (const std::exception& e) { - // Exception thrown after fusionCache()->createChild wouldn't be visible - // by fusion cache, if the exception is suppressed on the python side. We - // explicitly set the exception message on the terminal trie node, so - // we'll be able to throw the same exception again when user tries to - // create the same fusion entry. - trie_node_->setException(e.what()); - fusion_id_ = std::nullopt; - throw; - } - - // The FusionState creates a mapping from CPP Fusion to its State objects. - // Since the CPP Fusion is cached in FusionCache and the FusionState is - // temporary, the information linking CPP Fusion and Python - // FusionDefinition is stored in FusionCache. - FusionSchedules* fs = - fusionCache()->queryFusionSchedules(fusion_id_.value()); - fs->inputs_fid_ = inputs(); - fs->outputs_fid_ = outputs(); - fs->extents_fid_ = extents(); - fs->map_value_to_fid_ = getValueMap(); - - if (isDebugDumpEnabled(DebugDumpOption::FusionIrOriginal)) { - printIr(); - } - } else { - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << "\nFusionDefinition: Terminal Node found!\n"; - } - trie_node_ = child_node.value(); - std::optional opt_e = trie_node_->getException(); - // rethrow the exception message if the cached FusionDefinition fails to - // build a proper fusion earlier. - NVF_CHECK(!opt_e.has_value(), opt_e.value()); - fusion_id_ = std::optional(trie_node_->fusion_id); - - // A CPP fusion already exists in the FusionCache for this FusionDefinition. - // In this case, a new CPP Fusion is not created, so the mapping from CPP - // fusion to Python FusionDefinition is not initialized. This state is - // stored within FusionSchedules and is retrieved for this FusionDefinition. - FusionSchedules* fs = - fusionCache()->queryFusionSchedules(fusion_id_.value()); - inputs_fid_ = fs->inputs_fid_; - outputs_fid_ = fs->outputs_fid_; - extents_fid_ = fs->extents_fid_; - map_value_to_fid_ = fs->map_value_to_fid_; - } - - NVF_ERROR( - num_recording_states_presched_ == 0, - "Expected number of recording states for prescheduled fusion to be " - "uninitialized."); - num_recording_states_presched_ = (int64_t)recording_state_.size(); -} - -void FusionDefinition::findHiddenTensorViews(Fusion* fusion) { - NVF_ERROR(fusion != nullptr); - - // Filter Tensor states - std::vector tensor_states; - std::copy_if( - recording_state_.begin(), - recording_state_.end(), - std::back_inserter(tensor_states), - [](const State& s) { return s.stype == serde::StateType::Tensor; }); - - // Get corresponding CPP values and add to set for membership check. - std::unordered_set known_tensor_vals; - std::transform( - tensor_states.begin(), - tensor_states.end(), - std::inserter(known_tensor_vals, known_tensor_vals.end()), - [this](State s) { return getFusionState(s.index); }); - - // Get set difference between CPP Fusion and Python FusionDefinition - std::vector all_vals = fusion->usedMathVals(); - std::vector new_fusion_tvs; - std::copy_if( - all_vals.begin(), - all_vals.end(), - std::back_inserter(new_fusion_tvs), - [&](Val* v) { - return v->isA() && known_tensor_vals.count(v) == 0; - }); - - // Short-Circuit: No new TensorViews found - if (new_fusion_tvs.empty()) { - return; - } - - // Add missing TensorViews to FusionDefinition - for (Val* v : new_fusion_tvs) { - addTensor(v->as()); - } -} - -void FusionDefinition::updateSymbolicStates( - const std::unordered_map& symbolic_to_concretized_map) { - for (const State& s : recording_state_) { - // Only update Tensor and Scalar states - if (s.stype != serde::StateType::Tensor && - s.stype != serde::StateType::Scalar) { - continue; - } - - Val* old_value = getFusionState(s.index); - - // Skip replacement if unnecessary - if (symbolic_to_concretized_map.count(old_value) == 0) { - continue; - } - - // Update symbolic states with new concretized values - setFusionState(s.index, symbolic_to_concretized_map.at(old_value)); - } -} - -void FusionDefinition::verifyTensorDimensions() { - NVF_CHECK(id().has_value(), "Invalid fusion id!"); - - std::vector all_tensors = tensors(); - for (const Tensor& t : all_tensors) { - Val* v = getFusionState(t.index); - NVF_ERROR(v->isA(), v->toString()); - const int64_t tv_ndims = v->as()->nDims(); - NVF_ERROR( - tv_ndims == (int64_t)t.dims, - "Expected TensorView to have same number of dimensions as Tensor but " - "got: ", - tv_ndims, - " and ", - t.dims); - } -} - -bool FusionDefinition::existSchedule(const KernelArgumentHolder& args) { - FUSER_PERF_SCOPE("FusionDefinition::existsSchedule"); - NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!"); - FusionSchedules* scheds = fusionCache()->queryFusionSchedules(id().value()); - int8_t device = getCommonDeviceCUDA(args); - NVF_CHECK( - args.empty() || device > -1, "Inputs are not all on the same device!"); - return fusionCache()->existUserSchedule(scheds, args, device); -} - -void FusionDefinition::setupSchedule( - const KernelArgumentHolder& args, - bool overwrite_existing_schedule) { - FUSER_PERF_SCOPE("FusionDefinition::setupSchedule"); - NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!"); - FusionSchedules* scheds = fusionCache()->queryFusionSchedules(id().value()); - int8_t device = getCommonDeviceCUDA(args); - NVF_CHECK( - args.empty() || (device > -1 && device == args.getDeviceIndex()), - "Inputs are not all on the same device!"); - - // NOTE: Clear user schedule state in setupSchedule. - // Scheduling the fusion can add states to recording_state. - // Remove any schedule-only states before applying new schedule. - size_t num_states_to_remove = - recording_state_.size() - num_recording_states_presched_; - for (size_t rnd = 0; rnd < num_states_to_remove; ++rnd) { - recording_state_.pop_back(); - } - - user_sched_ = fusionCache()->createUserSchedule( - scheds, args, device, overwrite_existing_schedule); - - // Create scheduler data cache - user_sched_->data_cache = std::make_unique(); - - // Building a new Fusion container for scheduling with definition such that - // the definition's tensor data members refer to the corresponding IR objects - // needed for scheduling. A simple copy of the container would mean the data - // members that represent tensors would refer to the IR objects in the - // original and not the copy needed for scheduling. - buildFusionIr(user_sched_->scheduled_fusion.get()); - - // Add TensorViews created by composite operations to Python FusionDefinition. - findHiddenTensorViews(user_sched_->scheduled_fusion.get()); - - // Concretize fusion - std::unordered_map symbolic_to_concrete_map = - DynamicTransform::concretizeFusion( - user_sched_->scheduled_fusion.get(), args); - - // Update symbolic values to their new concretized values. - // Users will access concretized values in schedule function. - updateSymbolicStates(symbolic_to_concrete_map); - - // Create runtime info for schedulers - Fusion* user_schedule_fusion = user_sched_->scheduled_fusion.get(); - user_sched_->runtime_info = std::make_unique( - user_schedule_fusion, - args, - /*precomuted_values=*/nullptr, - user_schedule_fusion->allTvs()); - - // Manually setting the fusion guard as there is not a good way of using a - // guard in a local scope across the schedule function - prev_fusion_ = FusionGuard::getCurFusion(); - FusionGuard::setCurFusion(user_sched_->scheduled_fusion.get()); -} - -void FusionDefinition::finalizeSchedule(const KernelArgumentHolder& args) { - FUSER_PERF_SCOPE("FusionDefinition::finalizeSchedule"); - - FusionGuard::setCurFusion(prev_fusion_); - user_sched_->runtime_info.reset(); - prev_fusion_ = nullptr; - - // NOTE: Clear user schedule state in setupSchedule. - // Users can access schedule objects after scheduling the fusion. -} - -void FusionDefinition::setupMultideviceSchedule() { - // FusionDefinition.multidevice_schedule may create new Exprs (e.g. DID - // splits), which will be added to the presched fusion. - prev_fusion_ = FusionGuard::getCurFusion(); - FusionGuard::setCurFusion(preschedFusion()); -} - -void FusionDefinition::finalizeMultideviceSchedule() { - FusionGuard::setCurFusion(prev_fusion_); -} - -void FusionDefinition::print(std::ostream& os) const { - if (id().has_value()) { - os << "\ndef nvfuser_fusion_id" << id().value(); - } else { - os << "\ndef nvfuser_incomplete_fusion"; - } - os << "(fd : FusionDefinition) -> None :\n"; - os << std::dec; - for (auto& rec : recording_) { - // Skip inline defined records - if (!rec.get()->inlineDef()) { - os << " "; - rec->print(os); - os << "\n"; - } - } - os << std::endl; -} - -std::pair> FusionDefinition:: - execute( - KernelArgumentHolder args, - std::optional selected_device, - bool override_user_schedule, - bool capture_debug_output, - bool profile, - std::vector _enable_options, - std::vector _disable_options) const { - debug_output_ = std::nullopt; - std::stringstream debug_ss; - DebugStreamGuard dsg(capture_debug_output ? debug_ss : std::cout); - args.setDeviceIndex(selected_device); - NVF_CHECK(id().has_value(), "Valid fusion schedule is not available!"); - - auto scheds = fusionCache()->queryFusionSchedules(id().value()); - - if (profile) { - ProfilerOptionsGuard::getCurOptions().set(ProfilerOption::Enable); - } - - EnableOptionsGuard enable_opt_guard; - for (const auto& _enable_option : _enable_options) { - std::optional opt = stringToEnableOption(_enable_option); - NVF_CHECK(opt.has_value(), "Unrecognized enable_option: ", _enable_option); - EnableOptionsGuard::getCurOptions().set(opt.value()); - } - - DisableOptionsGuard disable_opt_guard; - for (const auto& _disable_option : _disable_options) { - std::optional opt = stringToDisableOption(_disable_option); - NVF_CHECK( - opt.has_value(), "Unrecognized disable_option: ", _disable_option); - DisableOptionsGuard::getCurOptions().set(opt.value()); - } - - auto find_user_schedule = [&]() -> const UserSchedule* { - if (override_user_schedule) { - return nullptr; - } - - auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, args); - if (!user_sched_id.has_value()) { - return nullptr; - } - - NVF_CHECK( - args.empty() || args.getDeviceIndex() > -1, - "Inputs are not all on the same device or don't match selection!"); - const UserSchedule& user_sched = fusionCache()->queryUserSchedule( - scheds, user_sched_id.value(), args.getDeviceIndex()); - return &user_sched; - }; - const auto* user_sched = find_user_schedule(); - - KernelArgumentHolder outputs; - if (user_sched == nullptr) { - if (use_multidevice_executor_) { - if (scheds->multi_device_executor == nullptr) { - MultiDeviceExecutorParams params; - params.lower.communicator_backend = backend_type_; - scheds->multi_device_executor = std::make_unique( - std::make_unique(*scheds->preschedFusion()), - Communicator::getInstance(), - std::move(params)); - } - outputs = scheds->multi_device_executor->runWithInput(args); - } else { - scheds->createExecutorIfNotExists(); - outputs = scheds->auto_gen_schedules->runFusionWithInputs( - args, std::nullopt, args.getDeviceIndex()); - } - } else { - NVF_ERROR( - !use_multidevice_executor_, - "multidevice_executor is not supported " - "for user-defined schedules."); - if (isProfilerEnabledWithCupti()) { - FusionProfiler::start(); - FusionProfiler::createSegments(1); - } - - scheds->last_user_def_scheduled_ir = user_sched->scheduled_fusion.get(); - scheds->last_user_def_executor = user_sched->executor.get(); - - if (user_sched->heuristic_params == nullptr) { - // Manual schedule - if (!user_sched->executor->isCompiled()) { - user_sched->executor->compile(user_sched->scheduled_fusion.get(), args); - } - outputs = user_sched->executor->run(args); - } else { - // Automatic scheduler was used for UserSchedule. - // Pass launch and compile params to compileFusion and runFusion. - if (!user_sched->executor->isCompiled()) { - user_sched->executor->compile( - user_sched->scheduled_fusion.get(), - args, - user_sched->heuristic_params->lparams, - user_sched->heuristic_params->cparams, - user_sched->heuristic_params->scheduler_type); - } - outputs = user_sched->executor->run( - args, - {}, - user_sched->heuristic_params->lparams, - user_sched->heuristic_params->cparams); - } - - if (isProfilerEnabledWithCupti()) { - FusionProfiler::segment(0).scheduler("user"); - FusionProfiler::stop(); - if (isProfilerPrintingEnabled()) { - debug() << FusionProfiler::profile(); - } - } - } - - if (profile) { - ProfilerOptionsGuard::getCurOptions().unset(ProfilerOption::Enable); - } - - if (capture_debug_output) { - debug_output_ = debug_ss.str(); - } - - std::vector output_shardings; - if (user_sched == nullptr) { - Fusion* fusion = use_multidevice_executor_ - ? scheds->preschedFusion() - : scheds->auto_gen_schedules->getMostRecentKernelRuntime() - ->fusionSegments() - ->completeFusion(); - output_shardings = getOutputShardings(fusion); - NVF_ERROR( - output_shardings.empty() || - std::ssize(output_shardings) == outputs.size(), - "Found ", - std::ssize(output_shardings), - " output shardings but expected ", - outputs.size(), - " or 0."); - } - - return std::make_pair(std::move(outputs), std::move(output_shardings)); -} - -std::string FusionDefinition::fusionIr() { - NVF_CHECK(id().has_value(), "Invalid fusion definition!"); - std::stringstream ss; - preschedFusion()->print(ss, false); - return ss.str(); -} - -UserSchedule* FusionDefinition::userSchedule() { - NVF_CHECK(id().has_value(), "Invalid fusion definition!"); - - if (user_sched_ == nullptr) { - NVF_THROW("User schedule is not defined."); - } - return user_sched_; -} - -std::string FusionDefinition::userScheduleIr() { - NVF_CHECK(id().has_value(), "Invalid fusion definition!"); - - if (user_sched_ == nullptr) { - return "User schedule is not defined."; - } - - std::stringstream ss; - user_sched_->scheduled_fusion->print(ss, false); - return ss.str(); -} - -std::string FusionDefinition::lastCudaCode( - bool intrinsic_code, - bool override_user_schedule) const { - std::string result; - NVF_CHECK(id().has_value(), "Invalid fusion definition!"); - auto scheds = fusionCache()->queryFusionSchedules(id().value()); - auto user_exec = scheds->last_user_def_executor; - - if (!override_user_schedule && (user_exec != nullptr)) { - if (intrinsic_code) { - result = user_exec->compiledKernel()->getStructuredCode(); - } else { - result = user_exec->compiledKernel()->kernelString(); - } - } else { - NVF_CHECK( - scheds->auto_gen_schedules != nullptr, - "Fusion ", - *id(), - " has never been executed via FusionExecutorCache."); - result = scheds->auto_gen_schedules->getMostRecentCode(intrinsic_code); - } - return result; -} - -std::string FusionDefinition::cudaCodeFor( - KernelArgumentHolder args, - bool intrinsic_code, - bool override_user_schedule) const { - NVF_CHECK(id().has_value(), "Invalid fusion definition!"); - auto scheds = fusionCache()->queryFusionSchedules(id().value()); - - if (!override_user_schedule) { - auto device = getCommonDeviceCUDA(args); - NVF_CHECK( - args.empty() || device > -1, "Inputs are not all on the same device!"); - auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, args); - if (user_sched_id.has_value()) { - auto& user_sched = fusionCache()->queryUserSchedule( - scheds, user_sched_id.value(), device); - auto user_exec = user_sched.executor.get(); - if (intrinsic_code) { - return user_exec->compiledKernel()->getStructuredCode(); - } else { - return user_exec->compiledKernel()->kernelString(); - } - } - } - NVF_CHECK( - scheds->auto_gen_schedules != nullptr, - "Fusion ", - *id(), - " has never been executed via FusionExecutorCache."); - return scheds->auto_gen_schedules->getCodeFor(args, intrinsic_code); -} - -std::string FusionDefinition::lastScheduledFusionIr( - bool tensor_transforms, - bool override_user_schedule) const { - std::string result; - NVF_CHECK(id().has_value(), "Invalid fusion definition!"); - auto scheds = fusionCache()->queryFusionSchedules(id().value()); - auto user_sched_ir = scheds->last_user_def_scheduled_ir; - - if (!override_user_schedule && (user_sched_ir != nullptr)) { - std::stringstream ss; - user_sched_ir->print(ss, tensor_transforms); - result = ss.str(); - } else { - NVF_CHECK( - scheds->auto_gen_schedules != nullptr, - "Fusion ", - *id(), - " has never been executed via FusionExecutorCache."); - result = - scheds->auto_gen_schedules->getMostRecentScheduledIr(tensor_transforms); - } - return result; -} - -std::string FusionDefinition::scheduledFusionIrFor( - const KernelArgumentHolder& args, - bool tensor_transforms, - bool override_user_schedule) const { - NVF_CHECK(id().has_value(), "Invalid fusion definition!"); - auto scheds = fusionCache()->queryFusionSchedules(id().value()); - - if (!override_user_schedule) { - auto device = getCommonDeviceCUDA(args); - NVF_CHECK( - args.empty() || (device > -1 && device == args.getDeviceIndex()), - "Inputs are not all on the same device!"); - auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, args); - if (user_sched_id.has_value()) { - auto& user_sched = fusionCache()->queryUserSchedule( - scheds, user_sched_id.value(), device); - auto user_sched_ir = user_sched.scheduled_fusion.get(); - std::stringstream ss; - user_sched_ir->print(ss, tensor_transforms); - return ss.str(); - } - } - NVF_CHECK( - scheds->auto_gen_schedules != nullptr, - "Fusion ", - *id(), - " has never been executed via FusionExecutorCache."); - return scheds->auto_gen_schedules->getScheduledIrFor(args, tensor_transforms); -} - -std::optional FusionDefinition::id() const { - return fusion_id_; -} - -Scalar FusionDefinition::defineScalar() { - FUSER_PERF_SCOPE("FusionDefinition::defineScalar"); - NVF_CHECK( - trie_node_ != nullptr, - "define_scalar() must be called from an initialized definition via a " - "python context manager or a child class' definition() method."); - Scalar out(recording_state_.size(), this); - recording_state_.emplace_back(out(), serde::StateType::Scalar); - return out; -} - -Tensor FusionDefinition::addTensor(TensorView* tv) { - FUSER_PERF_SCOPE("FusionDefinition::addTensor"); - NVF_CHECK( - trie_node_ != nullptr, - "addTensor() must be called from an initialized definition via a python " - "context manager or a child class' definition() method."); - Tensor output = defineTensor(tv->nDims()); - NVF_CHECK( - output.index == numFusionStates(), - "Fusion State index does not match the size!"); - addFusionState(tv); - return output; -} - -Tensor FusionDefinition::defineTensor(size_t dims) { - FUSER_PERF_SCOPE("FusionDefinition::defineTensor"); - NVF_CHECK( - trie_node_ != nullptr, - "define_tensor() must be called from an initialized definition via a " - "python context manager or a child class' definition() method."); - Tensor out(recording_state_.size(), dims, this); - recording_state_.emplace_back(out(), serde::StateType::Tensor); - return out; -} - -Vector FusionDefinition::defineVector(size_t size) { - FUSER_PERF_SCOPE("FusionDefinition::defineVector"); - NVF_CHECK( - trie_node_ != nullptr, - "define_vector() must be called from an initialized definition via a " - "python context manager or a child class' definition() method."); - Vector out(recording_state_.size(), size, this); - recording_state_.emplace_back(out(), serde::StateType::Vector); - return out; -} - -void FusionDefinition::defineRecord(RecordFunctor* record) { - FUSER_PERF_SCOPE("FusionDefinition::defineRecord"); - NVF_CHECK( - trie_node_ != nullptr, - "defineRecord() must be called from an initialized definition via a " - "python context manager or a child class' definition() method."); - NVF_CHECK( - (recording_.size() + 1) <= max_length_, - "The fusion definition has exceeded ", - max_length_, - "operations. The max_length for FusionDefintion's might need to be ", - "increased if the definition is created as expected."); - addRecord(record); - auto child_node = - fusionCache()->queryChildren(trie_node_, recording_.back().get()); - // If the Record is found in the cache, the FusionDefinition and the Cache - // will not share Record given the Record had to be created in order to - // match it but it also already existed in the cache. - if (child_node.has_value()) { - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << "\nFusionDefinition: Record (hash: 0x" << std::hex - << record->hash() << ") hit in Fusion Cache.\n"; - } - trie_node_ = child_node.value(); - // The FusionDefinition and the Cache will share the Record - } else { - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << "\nFusionDefinition: Record (hash: 0x" << std::hex - << record->hash() << ") missed in Fusion Cache.\n"; - } - trie_node_ = - fusionCache()->createChild(trie_node_, recording_.back().get()); - } -} - -Fusion* FusionDefinition::preschedFusion() { - NVF_CHECK( - fusion_id_.has_value(), - "FusionDefinition does not contain a definition, yet!"); - return fusionCache() - ->queryFusionSchedules(fusion_id_.value()) - ->preschedFusion(); -} - -void FusionDefinition::printMathIr() { - return preschedFusion()->printMath(); -} - -State FusionDefinition::recordingState(size_t index) const { - return recording_state_.at(index); -} - -std::vector FusionDefinition::tensors() { - // Filter TensorView states - std::vector tensor_states; - std::copy_if( - recording_state_.begin(), - recording_state_.end(), - std::back_inserter(tensor_states), - [](const State& s) { return s.stype == serde::StateType::Tensor; }); - - // Reconstruct Tensors - std::vector all_tensors; - all_tensors.reserve(tensor_states.size()); - std::transform( - tensor_states.begin(), - tensor_states.end(), - std::back_inserter(all_tensors), - [this](const State& s) { - return Tensor( - s.index, getFusionState(s.index)->as()->nDims(), this); - }); - return all_tensors; -} - -std::vector> FusionDefinition::getValTolerances( - const KernelArgumentHolder& args) { - return nvfuser::getValTolerances(preschedFusion(), args); -} - -void FusionDefinition::validate_with_auto_inferred_outputs( - const KernelArgumentHolder& fusion_outputs, - const KernelArgumentHolder& args) { - return testValidate(preschedFusion(), fusion_outputs, args); -} - -int64_t FusionDefinition::setupSegmentation(const KernelArgumentHolder& args) { - NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!"); - NVF_ERROR( - segmentation_state_ == nullptr, "SegmentationState already exists!"); - segmentation_state_ = std::make_unique(); - return segmentation_state_->setupSegmentation( - preschedFusion(), map_value_to_fid_, args); -} - -std::unordered_map FusionDefinition::buildSegment( - FusionDefinition& segment_fd, - int64_t segment_id) { - NVF_CHECK(id().has_value(), "FusionDefinition does not exist!"); - NVF_CHECK( - segmentation_state_ != nullptr, - "Run setupSegmentation first before trying to build segments!"); - return segmentation_state_->buildSegment(segment_fd, segment_id); -} - -void FusionDefinition::finalizeSegmentation() { - // Destroy SegmentedState - segmentation_state_.reset(); -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/fusion_definition.h b/python/python_frontend/fusion_definition.h deleted file mode 100644 index 9008a7b5f1a..00000000000 --- a/python/python_frontend/fusion_definition.h +++ /dev/null @@ -1,389 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace nvfuser::python_frontend { - -class FusionCache; -class FusionDefinition; -class FusionInterface; -class FusionState; -struct RecordFunctor; -class SegmentationState; -struct TrieNode; -struct UserSchedule; - -//! The Tensor and Scalar classes are used to define separate function -//! signatures in the FusionDefinition to identify the appropriate Operator -//! function. -//! -//! Example: -//! -//! add(Tensor* arg1, Tensor* arg2) -> Tensor* -//! add(Tensor* arg1, Val* arg2) -> Tensor* -//! add(Val* arg1, Val* arg2) -> Val* -struct Tensor { - Tensor(size_t _index, size_t _dims, FusionDefinition* _fd) - : index(_index), dims(_dims), fusion_definition(_fd) {} - - size_t operator()() const { - return index; - } - - bool operator==(const Tensor& other) const { - if (index != other.index) { - return false; - } - - if (dims != other.dims) { - return false; - } - - if (fusion_definition != other.fusion_definition) { - return false; - } - return true; - } - - bool operator!=(const Tensor& other) const { - return !(*this == other); - } - - //! A unique index to identifiy each recorded state item. - size_t index; - size_t dims; - - //! Pointer to the FusionDefinition used to create this tensor - //! The FusionDefinition pointer is necessary to enable special - //! dunder operations (ie __add__()) from the python API. - FusionDefinition* fusion_definition; -}; - -struct Scalar { - Scalar(size_t _index, FusionDefinition* _fd) - : index(_index), fusion_definition(_fd) {} - - size_t operator()() const { - return index; - } - - bool operator==(const Scalar& other) const { - if (index != other.index) { - return false; - } - - if (fusion_definition != other.fusion_definition) { - return false; - } - return true; - } - - bool operator!=(const Scalar& other) const { - return !(*this == other); - } - - //! A unique index to identifiy each recorded state item. - size_t index; - - //! Pointer to the FusionDefinition used to create this scalar - //! The FusionDefinition pointer is necessary to enable special - //! dunder operations (ie __add__()) from the python API. - FusionDefinition* fusion_definition; -}; - -struct Vector { - Vector(size_t _index, size_t _size, FusionDefinition* _fd) - : index(_index), size(_size), fusion_definition(_fd) {} - - size_t operator()() const { - return index; - } - - bool operator==(const Vector& other) const { - if (index != other.index) { - return false; - } - - if (size != other.size) { - return false; - } - - if (fusion_definition != other.fusion_definition) { - return false; - } - return true; - } - - bool operator!=(const Vector& other) const { - return !(*this == other); - } - - //! A unique index to identifiy each recorded state item. - size_t index; - //! Elements in the vector - size_t size; - - //! Pointer to the FusionDefinition used to create this scalar - FusionDefinition* fusion_definition; -}; - -//! FusionDefinition defines the C++ side of a Python Context manager to -//! encapsulate the definition of fusion operations. -//! -//! The FusionDefinition records the state definitions and operations prior -//! to exiting the context manager. Upon exit, the operations are queried -//! in a cache and the recorded records are used to build an nvFuser Fusion -//! object if the definition missed in the cache. -//! -//! The nested Operators class was designed to allow the user to query all the -//! available Operators in the FusionDefinition via python help. -//! -//! Example: -//! help(FusionDefinition.Operators) -//! -//! (Experimental) `use_multidevice_executor` toggles using MultiDeviceExecutor -//! directly instead of the main stack -//! -//! (Experimental) `backend_type` selects the communicator backend for -//! MultiDeviceExecutor if `use_multidevice_executor` is true -class NVF_API FusionDefinition : public FusionState { - public: - FusionDefinition( - std::optional id, - size_t max_length = 256, - bool use_multidevice_executor = false, - CommunicatorBackend backend_type = CommunicatorBackend::kNccl); - - // The copy/move/assign constructors/operators are removed - FusionDefinition(const FusionDefinition& fd) = delete; - FusionDefinition(FusionDefinition&& fd) = delete; - FusionDefinition& operator=(const FusionDefinition& fd) = delete; - FusionDefinition& operator=(FusionDefinition&& fd) = delete; - - //! Enter Python Context Manager -- Reset trie for new cache lookup - NVF_API FusionDefinition* setupDefinition(); - //! Exit Python Context Manager -- Triggers Fusion IR build if it is not - //! cached - NVF_API void finalizeDefinition(); - //! Check that a user schedule exists for FusionDefinition and input - //! arguments on device. - NVF_API bool existSchedule(const KernelArgumentHolder& args); - //! Setup user scheduling of a fusion - //! Copies fusion object and sets up FusionGuard - NVF_API void setupSchedule( - const KernelArgumentHolder& args, - bool overwrite_existing_schedule = false); - //! Finalized use scheduling of a fusion - //! resets FusionGuard, lowers IR to a kernel, compiles kernel - NVF_API void finalizeSchedule(const KernelArgumentHolder& args); - //! A hook that gets called right before - //! FusionDefinition.multidevice_schedule. - NVF_API void setupMultideviceSchedule(); - //! A hook that gets called right after FusionDefinition.multidevice_schedule. - NVF_API void finalizeMultideviceSchedule(); - //! Prints a python function representing the definition - NVF_API void print(std::ostream& os) const; - //! Executes a fusion if a valid definition or cache lookup occurred prior. - //! - //! This method returns a KernelArgumentHolder for output tensors and a list - //! of output shardings. If it was a single-GPU execution, output shardings - //! will be empty. - //! - //! Alternatives considered: - //! 1. Return std::vector>. - //! Because DistributedTensor can also represent a non-distributed tensor, I - //! chose the current API for simplicity -- C++ is more verbose than Python - //! when dealing with dynamic types. - //! 2. Return std::variant, - //! std::vector>. Same reason. - //! 3. Store output shardings (i.e. the mesh and the mesh-to-tenseor-axis - //! mapping) to a field of FusionDefinition and retrieve it using another - //! method. This would be similar to getDebugOutput. I didn't choose that - //! because it introduced a new state in the class that could get out of sync. - //! 4. Return a list of `DistributedTensor`s. Each - //! `DistributedTensor` is either the local view of a distributed tensor - //! (when the mesh is non-empty) or a non-distributed tensor - //! (when the mesh is empty). This enforces Python to unpack - //! DistributedTensor, which is confirmed to be slow. - NVF_API std::pair> execute( - KernelArgumentHolder inputs, - std::optional device, - bool override_user_schedule, - bool capture_debug_output, - bool profile, - std::vector _enable_options, - std::vector _disable_options) const; - - //! Return debugging output captured through exeuction with - //! capture_debug_output=true - std::optional getDebugOutput() const { - return debug_output_; - } - // Returns the tolerances values based on reduction sizes. - NVF_API std::vector> getValTolerances( - const KernelArgumentHolder& inputs); - - // Validate the fusion outputs against auto inferred outputs. - NVF_API void validate_with_auto_inferred_outputs( - const KernelArgumentHolder& fusion_outputs, - const KernelArgumentHolder& inputs); - - //! Return the unscheduled Fusion IR - NVF_API std::string fusionIr(); - //! Return the user scheduled FusionIR; - NVF_API std::string userScheduleIr(); - //! Return the Cuda code for the last executed set of inputs - NVF_API std::string lastCudaCode( - bool intrinsic_code, - bool override_user_schedule) const; - //! Return the Cuda code for the given inputs - NVF_API std::string cudaCodeFor( - KernelArgumentHolder inputs, - bool intrinsic_code, - bool override_user_schedule) const; - //! Return the Cuda code for the last executed set of inputs - NVF_API std::string lastScheduledFusionIr( - bool tensor_transforms, - bool override_user_schedule) const; - //! Return the Cuda code for the given inputs - NVF_API std::string scheduledFusionIrFor( - const KernelArgumentHolder& inputs, - bool tensor_transforms, - bool override_user_schedule) const; - //! Return fusion id of defined FusionDefinition - NVF_API std::optional id() const; - //! Prints the Prescheduled Fusion IR representation - void printMathIr(); - - bool completed() { - return id().has_value(); - } - - //! Return a prescheduled Fusion object - Fusion* preschedFusion(); - - //! Return UserSchedule struct if it exists - UserSchedule* userSchedule(); - - //! These methods are used to record the FusionDefinition for cache lookup - - //! Defines a Tensor State Record - NVF_API Tensor addTensor(TensorView* tv); - //! Defines a Scalar State Record - NVF_API Scalar defineScalar(); - //! Defines a Tensor State Record - NVF_API Tensor defineTensor(size_t dims); - //! Defines a Vector State Record - NVF_API Vector defineVector(size_t size); - //! Defines a Record that records the operation required to - //! build the corresponding Fusion IR operation on cache miss. - NVF_API void defineRecord(RecordFunctor* record); - //! Gets a Record State object - NVF_API State recordingState(size_t index) const; - //! Get all Tensors in FusionState. - NVF_API std::vector tensors(); - - //! Run segmentation algorithm on FusionDefinition. Returns the number of - //! segments. - NVF_API int64_t setupSegmentation(const KernelArgumentHolder& inputs); - //! Given an empty FusionDefinition and a segment id, buildSegment creates the - //! CPP Fusion, translates it to the python FusionDefinition, then return a - //! mapping from segment fusion state indices to the original fusion state - //! indices. - NVF_API std::unordered_map buildSegment( - FusionDefinition& segment_fd, - int64_t segment_id); - //! After creating segments, destroy SegmentationState. - NVF_API void finalizeSegmentation(); - - private: - //! Returns the FusionCache Ptr that holds the cache of Fusions - FusionCache* fusionCache() const; - //! Composite operations can create hidden TensorViews in the CPP fusion - //! These TensorViews are not visible from python definition. This function - //! finds and adds them to FusionDefinition - void findHiddenTensorViews(Fusion* fusion); - //! Update Symbolic FusionStates after DynamicTransform pass - void updateSymbolicStates( - const std::unordered_map& symbolic_to_concretized_map); - // Check that the NvFuser TensorView and the Python Tensor dimensions match. - // Apply after buildFusionIr - void verifyTensorDimensions(); - - //! Holds the defined maximum length of a FusionDefinition in order to - //! prevent a run away error. The user should feel free to increase this - //! number as appropriate. - size_t max_length_; - //! Fusion Cache Id for Scheduled Fusion. - std::optional fusion_id_; - //! A pointer to the FusionCache. - FusionCache* fusion_cache_; - //! Current pointer to node in FusionCache. - TrieNode* trie_node_; - - // Book keeping data members for user created schedules - - //! Data member for holding previous fusion container when manually setting - //! the fusion guard. - Fusion* prev_fusion_; - //! Data member for holding the current user schedule object - UserSchedule* user_sched_; - //! Number of recording_states_ before applying user schedule - int64_t num_recording_states_presched_ = 0; - //! Data member that creates SegmentedFusion from cloned, prescheduled Fusion - //! then translates the segments to python FusionDefinitions. - std::unique_ptr segmentation_state_; - - public: - //! The Operators are not directly defined in this header. They are defined - //! in the python bindings through lambda functions so the user only needs to - //! define new operators in one place. - //! Operators define what operations are fused. - struct Operators { - Operators(FusionDefinition* fd) : fusion_definition(fd) {} - bool validUse() const { - return !fusion_definition->completed(); - } - - FusionDefinition* fusion_definition; - }; - - //! The SchedOperators are not directly defined in this header. They are - //! defined in the python bindings through lambda functions so the user only - //! needs to define new operators in one place. - //! SchedOperators allow the user to define how a fusion should be blocked - //! for execution. - struct SchedOperators { - SchedOperators(FusionDefinition* fd) : fusion_definition(fd) {} - bool validUse() const { - return fusion_definition->completed(); - } - - FusionDefinition* fusion_definition; - }; - - Operators ops; - SchedOperators sched; - - private: - mutable std::optional debug_output_ = std::nullopt; - const bool use_multidevice_executor_; - const CommunicatorBackend backend_type_; -}; - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/fusion_record.h b/python/python_frontend/fusion_record.h deleted file mode 100644 index 0a2e84aecb4..00000000000 --- a/python/python_frontend/fusion_record.h +++ /dev/null @@ -1,3675 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "base.h" - -namespace nvfuser::python_frontend { - -//! RecordFunctor is the base class record for operations recorded by -//! the FusionState. It is, in essence, a node in the graph with -//! input edges, args, and output edges where the stored -//! values are indices into the recorded state. -//! -//! The virtual functor operator is executed on a cache miss to build the -//! appropriate part of the nvFuser Fusion IR for a given record. -//! -//! The hash and equality operators are used to facilitate the hashing of -//! RecordFunctors in a hash map given those operators need to be -//! specified for custom objects. -//! -//! The print function is used to print the given Record as a statement -//! in a python formated function. - -struct RecordFunctor { - RecordFunctor( - std::vector _args, - std::vector _outputs, - std::string _name, - serde::RecordType _record_type, - bool _inline_def = false) - : args_(std::move(_args)), - arg_names_(args_.size()), - outputs_(std::move(_outputs)), - name_(std::move(_name)), - record_type_(_record_type), - inline_def_( - _inline_def && - !isOptionDisabled(DisableOption::PythonInlineDefinitions)) { - // Set this Record as the parent of each output - if (inline_def_) { - for (auto& out : outputs_) { - out.setInlineDefRecord(this); - } - } - } - RecordFunctor(const RecordFunctor& other) - : args_(other.args_), - arg_names_(other.arg_names_), - outputs_(other.outputs_), - name_(other.name_), - record_type_(other.record_type_), - inline_def_(other.inline_def_) { - // Set this Record as the parent of each output - if (inline_def_) { - for (auto& out : outputs_) { - out.setInlineDefRecord(this); - } - } - } - virtual ~RecordFunctor() = default; - //! Allows for copying of Child Class objects with RecordFunctor pointers. - virtual RecordFunctor* clone() = 0; - - //! The base class is placing the type, outputs, and args hashed as follows: - //! | 63 - 56 | 55 - 48 | 47 ----------- 32 | 32 ------------------------ 0 | - //! | Type | Outputs | Args | Child Class Specified | - virtual size_t hash() const { - size_t arg_hash = 0; - for (auto arg : args_) { - arg_hash ^= ((arg.index << 1) ^ static_cast(arg.stype)); - } - size_t output_hash = 0; - for (auto output : outputs_) { - output_hash ^= ((output.index << 1) ^ static_cast(output.stype)); - } - // NOTE: The inline_def is not part of the hash as it is not used for - // comparison - return ((static_cast(record_type_) & 0xff) << 56) | - ((output_hash & 0xff) << 48) | ((arg_hash & 0xffff) << 32); - } - - //! The base virtual equality operator is defined so all child - //! classes can utilize the check for the same args and outputs. - virtual bool operator==(const RecordFunctor& other) const { - auto result = (record_type_ == other.record_type_); - result = result && (args_.size() == other.args_.size()) && - (outputs_.size() == other.outputs_.size()); - result = result && (arg_names_ == other.arg_names_); - if (result) { - for (size_t i = 0; i < args_.size(); ++i) { - if ((args_[i].index != other.args_[i].index) || - (args_[i].stype != other.args_[i].stype)) { - result = false; - break; - } - } - } - if (result) { - for (size_t i = 0; i < outputs_.size(); ++i) { - if ((outputs_[i].index != other.outputs_[i].index) || - (outputs_[i].stype != other.outputs_[i].stype)) { - result = false; - break; - } - } - } - // NOTE: The inline_def is not part of the equality operator as it is not - // used for comparison - return result; - } - - //! Abstraction for an operation to build this record's nvFuser Fusion IR - //! piece if the recording has a cache miss. - virtual void operator()(FusionState& fd) = 0; - - //! Abstraction for storing data specific to a record functor. - virtual std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const { - return {serde::RecordData::NONE, flatbuffers::Offset()}; - } - - //! The base serialize function that handles args, outputs, name and - //! recordType. Child recordFunctors should overload the recordData function - //! if has supplementary attributes. - virtual flatbuffers::Offset serialize( - flatbuffers::FlatBufferBuilder& builder) const { - // See table definition for RecordFunctor in serde/fusion_cache.fbs - - std::vector fb_args; - fb_args.reserve(args_.size()); - for (auto& it : args_) { - fb_args.emplace_back(it.index, it.stype); - } - auto args_fb = - builder.CreateVectorOfStructs(fb_args.data(), fb_args.size()); - - std::vector fb_outputs; - fb_outputs.reserve(outputs_.size()); - for (auto& it : outputs_) { - fb_outputs.emplace_back(it.index, it.stype); - } - auto outputs_fb = - builder.CreateVectorOfStructs(fb_outputs.data(), fb_outputs.size()); - - auto&& [record_data_type, record_data] = recordData(builder); - - return serde::CreateRecordFunctor( - builder, - args_fb, - outputs_fb, - builder.CreateString(name_), - recordType(), - record_data_type, - record_data); - } - - //! The base print function when printing Record for a given FusionState - //! in python formated code. - virtual void print(std::ostream& os, bool close_function = true) const { - NVF_ERROR( - !inline_def_, - "The default print function does not handle inline definitions!"); - bool first_output = true; - for (auto& output : outputs_) { - if (first_output) { - first_output = false; - } else { - os << ", "; - } - if (output.stype == serde::StateType::None) { - os << "_"; - } else { - os << output; - } - } - if (always_returns_tuple_) { - os << ","; - } - if (!outputs_.empty()) { - os << " = " << "fd." << name_ << "("; - } else { - os << "fd." << name_ << "("; - } - bool first_arg = true; - size_t idx = 0; - for (auto& arg : args_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - if (!arg_names_[idx].empty()) { - os << arg_names_[idx] << "="; - } - ++idx; - os << arg; - } - if (close_function) { - os << ")"; - } - } - - size_t numOutputs() const { - return outputs_.size(); - } - - const std::vector& outputs() const { - return outputs_; - } - std::vector& args() { - return args_; - } - - serde::RecordType recordType() const { - return record_type_; - } - - bool inlineDef() const { - return inline_def_; - } - - //! Set the name of an argument. If given, it will be listed as a keyword - //! argument during printing using the given name as the key. Unnamed - //! arguments are the default, and are listed as positional arguments before - //! any named arguments. - void setArgName(size_t pos, std::string name) { - arg_names_.at(pos) = name; - } - - protected: - //! Inputs that are indices into the FusionState's Recorded State. - std::vector args_; - //! String name to print for arg in Python, if any. Defaults to empty. - std::vector arg_names_; - //! Outputs that are indices into the FusionState's Recorded State. - std::vector outputs_; - //! Record Name - std::string name_; - //! Record Type of child class used for hashing - //! enum class RecordType is defined in flatbuffer schema - serde::RecordType record_type_; - //! Indicates if a record was defined inline with another record for printing - bool inline_def_; - //! Whether this record type returns a tuple of unknown length. This is only - //! used for TensorSizesRecord. - bool always_returns_tuple_ = false; -}; - -//! The OpRecord RecordFunctor is the most widely used child class because -//! it utilizes varidiac template arguments to represent unary, binary, -//! ternary, and other similar flavors of operations in nvFuser that have -//! a mix of Tensor and Scalar arguments only. -//! -//! The additional data memeber of this child class records the function -//! signature of the nvFuser Arith Operation to be replayed upon a cache -//! miss by the functor operator() call. - -template -struct OpRecord : RecordFunctor { - OpRecord( - std::vector _args, - std::vector _outputs, - std::string _name, - serde::RecordType record_type, - std::function fusion_op) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - _name, - record_type), - fusion_op_(fusion_op) {} - ~OpRecord() override = default; - RecordFunctor* clone() final { - return new OpRecord(*this); - } - - //! Child specific hash function in lower 32 bits.= at::Symbol - //! | 31 ------------------------------------- 0 | - //! | Arith Function Sigs hash code | - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (fusion_op_.target_type().hash_code() & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - // A succesfull cast indicates a RecordFunctor of the same child class - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - // Match the nvFuser arith function types - result = result && - (fusion_op_.target_type() == child_ptr->fusion_op_.target_type()); - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << "\nOpRecord: " << name_ << " Target Type [self: 0x" - << fusion_op_.target_type().name() << "] [other: 0x" - << child_ptr->fusion_op_.target_type().name() << "] "; - } - // Match the nvFuser arith function pointers - // IMPORTANT! you need to dereference the target pointer in order - // to match the function - result = result && - (*fusion_op_.template target() == - *child_ptr->fusion_op_ - .template target()); - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() - << "Target Ptr [self: 0x" << std::hex - << (size_t)*fusion_op_.template target() - << "] [other: 0x" << std::hex - << (size_t)*child_ptr->fusion_op_ - .template target() - << "]\n"; - } - } - } - return result; - } - - //! The variadic set of indices for the number of args for this op are - //! deduced by providing the index_sequence as a parameter. Similarly, - //! the tuple type is also deduced. - //! - //! The tuple type is used to decide whether to cast the input argument - //! to a Fusion IR TensorView or leave it as a Fusion IR Val (Scalar). - //! - //! A deduced binary op could look like: - //! OutType opFunc, 0, 1> - //! A deduced ternary op could look like: - //! OutTupe opFunc, 0, 1, 2> - template - OutType opFunc(FusionState& fd, TupleType& tp, std::index_sequence) { - return fusion_op_( - dynamic_cast::type>( - fd.getFusionState(args_.at(Is).index))...); - } - - void operator()(FusionState& fd) final { - using arg_tuple_t = std::tuple; - auto indices = - std::make_index_sequence::value>(); - // The tuple variable is never populated, it is passed for its type. - arg_tuple_t inputs; - auto output = opFunc(fd, inputs, indices); - fd.setFusionState(outputs_.at(0).index, output); - } - - private: - //! An nvFuser Arith Operation function signature - std::function fusion_op_; -}; - -struct SliceOpRecord : RecordFunctor { - SliceOpRecord( - std::vector _args, - std::vector _outputs, - bool manual_normalization) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.slice", - serde::RecordType::SliceOp), - manual_normalization_(manual_normalization) { - arg_names_[1] = "start_indices"; - arg_names_[2] = "end_indices"; - arg_names_[3] = "strides"; - } - ~SliceOpRecord() override = default; - RecordFunctor* clone() final { - return new SliceOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 | 30 ------------------------ 0 | - //! | manual_normalization? | other | - size_t hash() const final { - auto result = RecordFunctor::hash(); - result |= ((static_cast(manual_normalization_) & 0x1) << 31); - return result; - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = - result && (manual_normalization_ == child_ptr->manual_normalization_); - } - return result; - } - - void operator()(FusionState& fd) final { - auto* arg = fd.getFusionState(args_.at(0).index)->as(); - const std::vector& start = fd.getFusionStateVector(args_.at(1).index); - const std::vector& end = fd.getFusionStateVector(args_.at(2).index); - const std::vector& stride = - fd.getFusionStateVector(args_.at(3).index); - std::vector vec_slice; - for (auto [start_idx, end_idx, stride_idx] : zip(start, end, stride)) { - // NOTE: there's an extra move, we can use emplace_back if we go write - // some constructors for Slice. - NVF_CHECK( - !start_idx->isConstInt() || start_idx->evaluate().as() >= 0, - "Slice operation start_indices must be greater than or equal to 0. " - "Start Indices: ", - start_idx->evaluate().as()); - NVF_CHECK( - !start_idx->isConstInt() || !end_idx->isConstInt() || - end_idx->evaluate().as() >= - start_idx->evaluate().as(), - "Slice operation end_indices must be greater than or equal to " - "start_indices. Start Indices: ", - start_idx->evaluate().as(), - " End Indices: ", - end_idx->evaluate().as()); - NVF_CHECK( - stride_idx->isConstInt() && stride_idx->evaluate().as() == 1, - "nvFuser Limitation: All slice operation strides must be of const " - "size 1."); - vec_slice.push_back({start_idx, end_idx, stride_idx}); - } - auto output = slice(arg, vec_slice, manual_normalization_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", manual_normalization=" << manual_normalization_; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Slice, - serde::CreateSlice(builder, manual_normalization_).Union()}; - } - - private: - //! A flag to skip slice normalization step in composite operation. - bool manual_normalization_; -}; - -struct ReshapeOpRecord : RecordFunctor { - ReshapeOpRecord(std::vector _args, std::vector _outputs) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.reshape", - serde::RecordType::ReshapeOp) { - arg_names_[1] = "new_shape"; - } - ~ReshapeOpRecord() override = default; - RecordFunctor* clone() final { - return new ReshapeOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto* arg = fd.getFusionState(args_.at(0).index)->as(); - const std::vector& new_shape = - fd.getFusionStateVector(args_.at(1).index); - auto output = reshape(arg, new_shape); - fd.setFusionState(outputs_.at(0).index, output); - } -}; - -struct PadOpRecord : RecordFunctor { - PadOpRecord(std::vector _args, std::vector _outputs) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.pad", - serde::RecordType::PadOp) {} - ~PadOpRecord() override = default; - RecordFunctor* clone() final { - return new PadOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - const std::vector& val_widths = - fd.getFusionStateVector(args_.at(1).index); - - TensorView* output = nullptr; - if (args_.at(2).stype == serde::StateType::Scalar) { - output = pad(arg, val_widths, fd.getFusionState(args_.at(2).index)); - } else { // default: None - NVF_ERROR(args_.at(2).stype == serde::StateType::None); - output = pad(arg, val_widths); - } - - fd.setFusionState(outputs_.at(0).index, output); - } -}; - -template -struct DimsOpRecord : RecordFunctor { - DimsOpRecord( - std::vector _args, - std::vector _outputs, - std::vector dims, - std::string name) - : RecordFunctor(std::move(_args), std::move(_outputs), name, op_type) { - int64_t rank = (int64_t)dims.size(); - dims_.reserve(rank); - std::unordered_set dims_set; - for (auto dim : dims) { - dims_set.insert(dim); - if (dim < 0) { - NVF_CHECK( - dim >= -rank, - name + " dims argument is out of range, expects >= -" + - std::to_string(rank) + ", but got: " + std::to_string(dim)); - dim += rank; - } else { - NVF_CHECK( - dim < rank, - name + " dims argument is out of range, expects < " + - std::to_string(rank) + ", but got: " + std::to_string(dim)); - } - dims_.push_back(dim); - } - NVF_CHECK( - dims_set.size() == dims.size(), - name + " got duplicated dimension entries: " + toDelimitedString(dims)); - } - ~DimsOpRecord() override = default; - RecordFunctor* clone() final { - return new DimsOpRecord(*this); - } - - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t dims_hash = 0; - for (auto dim : dims_) { - hashCombine(dims_hash, static_cast(dim)); - } - return result | (dims_hash & 0xffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - result = (dims_.size() == child_ptr->dims_.size()); - if (result) { - for (size_t i = 0; i < dims_.size(); ++i) { - if (dims_[i] != child_ptr->dims_[i]) { - result = false; - break; - } - } - } - } - } - return result; - } - - void operator()(FusionState& fd) final { - if constexpr (op_type == serde::RecordType::PermuteOp) { - auto arg = - fd.getFusionState(args_.at(0).index)->template as(); - auto output = permute(arg, dims_); - fd.setFusionState(outputs_.at(0).index, output); - } else if constexpr (op_type == serde::RecordType::StrideOrderOp) { - auto arg = - fd.getFusionState(args_.at(0).index)->template as(); - auto output = set(arg); - std::vector allocation_domain = - ir_utils::strideOrderToAllocation(output->getLogicalDomain(), dims_); - output->setAllocationDomain(allocation_domain, true); - fd.setFusionState(outputs_.at(0).index, output); - } else { - NVF_THROW("op_type is not recognized by dims operator."); - } - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - if constexpr (op_type == serde::RecordType::PermuteOp) { - os << ", dims=["; - } else if constexpr (op_type == serde::RecordType::StrideOrderOp) { - os << ", stride_order=["; - } else { - NVF_THROW("op_type is not recognized by dims operator."); - } - bool first_arg = true; - for (auto dim : dims_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << dim; - } - os << "]"; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Dims, - serde::CreateDimsDirect(builder, &dims_).Union()}; - } - - private: - //! Represents the mapping from the original shape to the new shape - std::vector dims_; -}; - -struct SqueezeOpRecord : RecordFunctor { - SqueezeOpRecord( - std::vector _args, - std::vector _outputs, - std::vector dims, - bool squeeze_expanded = false) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.squeeze", - serde::RecordType::SqueezeOp), - dims_(std::move(dims)), - squeeze_expanded_(squeeze_expanded) {} - ~SqueezeOpRecord() override = default; - RecordFunctor* clone() final { - return new SqueezeOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 | 30 -------------------------------- 0 | - //! | squeeze_expanded? | Squeeze Dim hash | - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t squeeze_dims_hash = 0; - for (auto dim : dims_) { - squeeze_dims_hash ^= static_cast(dim); - } - result = result | (squeeze_dims_hash & 0x7fffffff); - result |= ((static_cast(squeeze_expanded_) & 0x1) << 31); - return result; - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && (dims_ == child_ptr->dims_); - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - // In pytorch, the squeeze operation cannot remove expanded dimensions. - // In nvfuser, for reduction operations, we apply squeeze to remove - // broadcast and expanded iterDomains. The squeeze_expanded_ flag bypasses - // assertion used to match pytorch's behavior. - auto output = squeeze(arg, dims_, squeeze_expanded_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dims=["; - bool first_arg = true; - for (auto dim : dims_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << dim; - } - os << "], squeeze_expanded=" << (squeeze_expanded_ ? "True" : "False"); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Squeeze, - serde::CreateSqueezeDirect(builder, &dims_, squeeze_expanded_).Union()}; - } - - private: - //! Dimension to squeeze. - std::vector dims_; - //! Option to remove expanded dimensions - bool squeeze_expanded_; -}; - -//! Specialized Record Functor for the FusionState's broadcast_in_dim op. -// NOTE: output_ndims gives the rank of the output tensor. This size can be -// found from the State after the definition is read and the Fusion IR is in the -// process of being created. However, pior to that point, the size is needed -// for matching a Fusion Record node in the Trie used to cache definitions. -struct BroadcastInDimOpRecord : RecordFunctor { - BroadcastInDimOpRecord( - std::vector _args, - std::vector _outputs, - size_t output_ndims, - std::vector broadcast_dims) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.broadcast_in_dim", - serde::RecordType::BroadcastInDim), - output_ndims_(output_ndims), - broadcast_dims_(std::move(broadcast_dims)) { - arg_names_[1] = "shape"; - } - ~BroadcastInDimOpRecord() override = default; - RecordFunctor* clone() final { - return new BroadcastInDimOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 ------------------------------------- 0 | - //! | broadcast_dims hash | - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t broadcast_dims_hash = 0; - for (auto dim : broadcast_dims_) { - broadcast_dims_hash |= 1 << ((output_ndims_ - 1) - dim); - } - return result | (broadcast_dims_hash & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - result = - ((output_ndims_ == child_ptr->output_ndims_) && - (broadcast_dims_.size() == child_ptr->broadcast_dims_.size())); - if (result) { - for (size_t i = 0; i < broadcast_dims_.size(); ++i) { - if (broadcast_dims_[i] != child_ptr->broadcast_dims_[i]) { - result = false; - break; - } - } - } - } - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - const std::vector& output_shape = - fd.getFusionStateVector(args_.at(1).index); - - const auto arg_ndims = std::ranges::distance( - arg->getLoopDomain() | TensorDomain::kNoReductions); - NVF_CHECK( - static_cast(output_ndims_) >= arg_ndims, - "The new shape is expected to be greater-then-or-equal to the input: ", - output_ndims_, - " vs ", - arg_ndims); - NVF_CHECK( - arg_ndims == std::ssize(broadcast_dims_), - "The broadcast dimensions should match the input dimensions: ", - arg_ndims, - " vs ", - broadcast_dims_.size(), - ". arg = ", - arg->toString()); - - std::vector is_broadcast_dim(output_ndims_, true); - for (const auto idx : arange(broadcast_dims_.size())) { - if (idx > 0) { - NVF_CHECK( - broadcast_dims_[idx - 1] < broadcast_dims_[idx], - "Broadcast dimension is not greater than the previous value."); - } - NVF_CHECK( - broadcast_dims_[idx] < static_cast(output_ndims_), - "Invalid broadcast_dims value."); - is_broadcast_dim.at(broadcast_dims_[idx]) = false; - } - - auto output = broadcast(arg, is_broadcast_dim); - auto expanded_output = expand(output, output_shape); - - fd.setFusionState(outputs_.at(0).index, expanded_output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", broadcast_dims=["; - bool first_arg = true; - for (auto dim : broadcast_dims_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << dim; - } - os << "]"; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::BroadcastInDim, - serde::CreateBroadcastInDimDirect( - builder, output_ndims_, &broadcast_dims_) - .Union()}; - }; - - private: - //! Number of dims of shape Vector used to communicate the output tensor shape - size_t output_ndims_; - //! Communicates which dimensions of the output the input tensor maps. - //! For instance, for output [2, 3, 4] and input [3]. This vector would - //! contain [1]. - std::vector broadcast_dims_; -}; - -//! Specialized Record Functor for the FusionState's broadcast op. - -struct BroadcastOpRecord : RecordFunctor { - BroadcastOpRecord( - std::vector _args, - std::vector _outputs, - std::string _name, - std::vector is_broadcast_dim) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - _name, - serde::RecordType::BroadcastOp), - is_broadcast_dim_(std::move(is_broadcast_dim)) {} - ~BroadcastOpRecord() override = default; - RecordFunctor* clone() final { - return new BroadcastOpRecord(*this); - } - - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t is_broadcast_dim_hash = 0; - for (size_t i = 0; i < is_broadcast_dim_.size(); ++i) { - is_broadcast_dim_hash |= - (is_broadcast_dim_[i] << (is_broadcast_dim_.size() - 1 - i)); - } - return result | (is_broadcast_dim_hash & 0xfff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && - std::equal( - is_broadcast_dim_.begin(), - is_broadcast_dim_.end(), - child_ptr->is_broadcast_dim_.begin()); - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - auto output = broadcast(arg, is_broadcast_dim_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", is_broadcast_dim=["; - bool first_arg = true; - for (auto dim : is_broadcast_dim_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << (dim ? "True" : "False"); - } - os << "]"; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - auto fb_broadcast_dims = builder.CreateVector(is_broadcast_dim_); - - serde::BroadcastBuilder bcast_builder(builder); - bcast_builder.add_broadcast_dims(fb_broadcast_dims); - auto expr_data = bcast_builder.Finish(); - return {serde::RecordData::Broadcast, expr_data.Union()}; - } - - private: - //! Communicates which dimensions in the output are broadcasted. - std::vector is_broadcast_dim_; -}; - -//! Specialized Record Functor for the FusionState's expand op. -struct ExpandOpRecord : RecordFunctor { - ExpandOpRecord(std::vector _args, std::vector _outputs) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.expand", - serde::RecordType::ExpandOp) { - arg_names_[1] = "shape"; - } - ~ExpandOpRecord() override = default; - RecordFunctor* clone() final { - return new ExpandOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 --------------------------------------- 0 | - //! | None | - size_t hash() const final { - return RecordFunctor::hash(); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - const std::vector& output_shape = - fd.getFusionStateVector(args_.at(1).index); - - const auto arg_ndims = std::ranges::distance( - arg->getLoopDomain() | TensorDomain::kNoReductions); - NVF_CHECK( - std::ssize(output_shape) == arg_ndims, - "The new shape is expected to be equal to the input: ", - output_shape.size(), - " vs ", - arg_ndims); - auto expanded_output = expand(arg, output_shape); - - fd.setFusionState(outputs_.at(0).index, expanded_output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - if (close_function) { - os << ")"; - } - } -}; - -template -struct CastOpRecord : RecordFunctor { - CastOpRecord( - std::vector _args, - std::vector _outputs, - std::string _name, - serde::RecordType record_type, - std::function fusion_op, - PrimDataType dtype) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - _name, - record_type), - fusion_op_(fusion_op), - dtype_(dtype) {} - ~CastOpRecord() override = default; - RecordFunctor* clone() final { - return new CastOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 --- 24 | 23 -------------------------- 0 | - //! | Dtype | Arith Function Sig hash code | - size_t hash() const final { - auto result = RecordFunctor::hash(); - result |= ((static_cast(dtype_) & 0xff) << 24); - result |= (fusion_op_.target_type().hash_code() & 0xffffff); - return result; - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - result = result && - (fusion_op_.target_type() == child_ptr->fusion_op_.target_type()); - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << "\nCastOpRecord: " << name_ << " Target Type [self: 0x" - << fusion_op_.target_type().name() << "] [other: 0x" - << child_ptr->fusion_op_.target_type().name() << "]"; - } - // IMPORTANT! you need to dereference the target pointer in order - // to match the function - result = result && - (*fusion_op_.template target() == - *child_ptr->fusion_op_ - .template target()); - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << " Target Ptr [self: 0x" << std::hex - << (size_t)*fusion_op_ - .template target() - << "] [other: 0x" << std::hex - << (size_t)*child_ptr->fusion_op_ - .template target() - << "]\n"; - } - result = result && (dtype_ == child_ptr->dtype_); - } - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = dynamic_cast(fd.getFusionState(args_.at(0).index)); - auto output = fusion_op_(dtype_, arg); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dtype=" << dtypeToPyString(dtype_); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Dtype, - serde::CreateDtype(builder, nvfuser::toUnderlying(dtype_)).Union()}; - } - - private: - //! nvFuser arith function signature - std::function fusion_op_; - //! Type to cast to. - PrimDataType dtype_; -}; - -struct CatOpRecord : RecordFunctor { - CatOpRecord( - std::vector _args, - std::vector _outputs, - int64_t dim, - bool manual_padding) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.cat", - serde::RecordType::CatOp), - dim_(dim), - manual_padding_(manual_padding) {} - ~CatOpRecord() override = default; - RecordFunctor* clone() final { - return new CatOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 | 30 ------------------------ 0 | - //! | manual_padding? | dim | - size_t hash() const final { - auto result = RecordFunctor::hash(); - result |= ((static_cast(manual_padding_) & 0x1) << 31); - result |= (static_cast(dim_) & 0x7fff); - return result; - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (dim_ == child_ptr->dim_); - result = result && (manual_padding_ == child_ptr->manual_padding_); - } - return result; - } - - void operator()(FusionState& fd) final { - std::vector input_tvs; - input_tvs.reserve(args_.size()); - for (auto& a : args_) { - input_tvs.push_back( - fd.getFusionState(a.index)->template as()); - } - auto output = - cat(input_tvs, dim_, /*iter_type_opt=*/std::nullopt, manual_padding_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - // Similar to RecordFunctor::print(os, false), but don't print args - bool first_output = true; - for (auto& output : outputs_) { - if (first_output) { - first_output = false; - } else { - os << ", "; - } - os << output; - } - if (always_returns_tuple_) { - os << ","; - } - if (!outputs_.empty()) { - os << " = " << "fd." << name_ << "("; - } else { - os << "fd." << name_ << "("; - } - os << "["; - bool first_arg = true; - for (auto& arg : args_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << arg; - } - os << "], dim=" << dim_; - os << ", manual_padding=" << manual_padding_; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Cat, - serde::CreateCat(builder, dim_, manual_padding_).Union()}; - } - - private: - //! The dimension along which we will concatenate - int64_t dim_; - //! A flag to skip the pad operation in the cat composite operation. - bool manual_padding_; -}; - -//! Specialized Record Functor for recording FusionState End. -//! The accompanying Fusion Cache Entry holds a Fusion Object. - -struct EndRecord : RecordFunctor { - EndRecord() : RecordFunctor({}, {}, "end", serde::RecordType::End) {} - ~EndRecord() override = default; - RecordFunctor* clone() final { - return new EndRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 --------------------------------------- 0 | - //! | None | - size_t hash() const final { - return RecordFunctor::hash(); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - } - return result; - } - - void operator()(FusionState& fd) final {} -}; - -//! Specialized Record Functor for recording FusionState input tensors. - -struct TensorRecord : RecordFunctor { - TensorRecord( - std::vector _outputs, - std::vector _shape, - std::vector> _contiguity, - PrimDataType _dtype, - bool _is_cpu = false, - std::vector _stride_order = {}) - : RecordFunctor( - {}, - std::move(_outputs), - "define_tensor", - serde::RecordType::Tensor), - shape_(std::move(_shape)), - contiguity_(std::move(_contiguity)), - stride_order_(std::move(_stride_order)), - dtype_(_dtype), - is_cpu_(_is_cpu) { - normalizeStrideOrder(stride_order_); - } - ~TensorRecord() override = default; - RecordFunctor* clone() final { - return new TensorRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 | 30 --- 24 | 23 --------- 12 | 11 ------------------------ 0 | - //! | CPU? | Dtype | Symbolic Sizes | Contiguous Info & stride_order | - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t ssize_hash = 0; - for (size_t i = 0; i < shape_.size(); ++i) { - size_t ssize = 0; - if (shape_[i] == -1) { - ssize = 1; - } - ssize_hash |= (ssize << (shape_.size() - 1 - i)); - } - size_t contig_stride_hash = 0; - for (size_t i = 0; i < contiguity_.size(); ++i) { - auto contiguity_value = contiguity_[i]; - contig_stride_hash |= - ((contiguity_value.has_value() && contiguity_value.value()) - << (contiguity_.size() - 1 - i)); - } - for (size_t i = 0; i < stride_order_.size(); ++i) { - contig_stride_hash ^= (stride_order_[i] << i); - } - - result |= ((static_cast(is_cpu_) & 0x1) << 31); - result |= ((static_cast(dtype_) & 0x7f) << 24); - return result | ((ssize_hash & 0xfff) << 12) | (contig_stride_hash & 0xfff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (dtype_ == child_ptr->dtype_); - result = result && (is_cpu_ == child_ptr->is_cpu_); - if (result) { - result = - ((shape_.size() == child_ptr->shape_.size()) && - (stride_order_.size() == child_ptr->stride_order_.size()) && - (contiguity_.size() == child_ptr->contiguity_.size())); - if (result) { - for (size_t i = 0; i < shape_.size(); ++i) { - if (shape_[i] != child_ptr->shape_[i]) { - result = false; - break; - } - } - } - if (result) { - for (size_t i = 0; i < stride_order_.size(); ++i) { - if (stride_order_[i] != child_ptr->stride_order_[i]) { - result = false; - break; - } - } - } - if (result) { - for (size_t i = 0; i < contiguity_.size(); ++i) { - if (contiguity_[i] != child_ptr->contiguity_[i]) { - result = false; - break; - } - } - } - } - } - return result; - } - - void operator()(FusionState& fd) final { - TensorView* tv = - TensorViewBuilder() - .contiguity(contiguity_) - .shape(shape_) - .dtype(dtype_) - .expanded(getExpanded(shape_, contiguity_, stride_order_)) - .strideOrder(stride_order_) - .build(); - - if (shape_.empty() && is_cpu_) { - tv->setCpuScalar(true); - } else { - NVF_CHECK(!is_cpu_, "CPU non-scalar tensor is not supported!"); - } - - fd.setFusionState(outputs_.at(0).index, tv); - fd.addInput(tv, outputs_.at(0).index); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << "shape=["; - bool first_arg = true; - for (auto ss : shape_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << ss; - } - os << "], contiguity=["; - first_arg = true; - for (auto ci : contiguity_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - if (!ci.has_value()) { - os << "None"; - } else { - if (*ci) { - os << "True"; - } else { - os << "False"; - } - } - } - os << "], dtype=" << dtypeToPyString(dtype_); - os << ", is_cpu=" << (is_cpu_ ? "True" : "False"); - if (!stride_order_.empty()) { - os << ", stride_order=["; - bool first_arg = true; - for (auto item : stride_order_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << item; - } - os << "]"; - } - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - auto fb_sizes = builder.CreateVector(shape_); - - auto mapOptionalToEnum = [](std::optional v) -> serde::Contiguity { - if (!v.has_value()) { - return serde::Contiguity::None; - } else if (v.value()) { - return serde::Contiguity::Contiguous; - } else { - return serde::Contiguity::Strided; - } - }; - std::vector contiguity_enum; - std::transform( - contiguity_.cbegin(), - contiguity_.cend(), - std::back_inserter(contiguity_enum), - mapOptionalToEnum); - auto fb_contiguity_enum = builder.CreateVector(contiguity_enum); - auto fb_stride_order = builder.CreateVector(stride_order_); - - serde::TensorBuilder tensor_builder(builder); - tensor_builder.add_sizes(fb_sizes); - tensor_builder.add_contiguity(fb_contiguity_enum); - tensor_builder.add_stride_order(fb_stride_order); - tensor_builder.add_dtype(toUnderlying(dtype_)); - tensor_builder.add_is_cpu(is_cpu_); - auto expr_data = tensor_builder.Finish(); - return {serde::RecordData::Tensor, expr_data.Union()}; - } - - private: - //! A vector of tensor dimension sizes. - //! This vector only captures sizes of -1 or 1 to indicate a symbolic - //! dimension (-1) or a broadcast dimension (1). - std::vector shape_; - //! A vector to indicate whether the a tensor dimension is contiguous - //! with the dimension just to its right. - std::vector> contiguity_; - //! A vector to indicate stride order of tensor - std::vector stride_order_; - //! Tensor data type. - PrimDataType dtype_; - //! Notes a scalar CPU Tensor - bool is_cpu_; -}; - -//! Specialized Record Functor for recording FusionState outputs. - -template -struct OutputRecord : RecordFunctor { - OutputRecord( - std::vector _args, - serde::RecordType record_type, - std::vector stride_order = {}) - : RecordFunctor(std::move(_args), {}, "add_output", record_type) { - if (!stride_order.empty()) { - stride_order_ = stride_order; - } - } - ~OutputRecord() override = default; - RecordFunctor* clone() final { - return new OutputRecord(*this); - } - - //! Nothing extra necessary in hash - //! Child specific hash function in lower 32 bits. - //! | 31 ---------------------------------------- 0 | - //! | stride_order hash | - size_t hash() const final { - size_t stride_order_hash = 0; - for (auto i : arange(stride_order_.size())) { - stride_order_hash = (stride_order_hash << 4) | stride_order_[i]; - } - return RecordFunctor::hash() | (stride_order_hash & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - result = (stride_order_.size() == child_ptr->stride_order_.size()); - if (result) { - for (size_t i = 0; i < stride_order_.size(); ++i) { - if (stride_order_[i] != child_ptr->stride_order_[i]) { - result = false; - break; - } - } - } - } - } - return result; - } - - void operator()(FusionState& fd) final { - auto output = fd.getFusionState(args_.at(0).index); - Val* alias_input = nullptr; - if (args_.size() == 2) { - alias_input = fd.getFusionState(args_.at(1).index); - } - - if (alias_input) { - NVF_CHECK( - stride_order_.empty(), - "stride_order can't be dictated for aliased outputs."); - if constexpr (std::is_same_v) { - fd.aliasOutputToInput(output, alias_input); - } else { - NVF_THROW("Scalar outputs should not alias inputs."); - } - } else { - if constexpr (std::is_same_v) { - auto tv_output = output->template as(); - if (!stride_order_.empty()) { - auto logical_domain = tv_output->getLogicalDomain(); - std::vector allocation_domain = - ir_utils::strideOrderToAllocation(logical_domain, stride_order_); - tv_output->setAllocationDomain(allocation_domain, true); - } - fd.addOutput(tv_output, args_.at(0).index); - } else { - NVF_CHECK( - stride_order_.empty(), - "stride_order can't be dictated for scalar outputs."); - fd.addOutput(output, args_.at(0).index); - } - } - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - if (!stride_order_.empty()) { - os << ", stride_order=["; - bool first_arg = true; - for (auto item : stride_order_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << item; - } - os << "]"; - } - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Output, - serde::CreateOutputDirect(builder, &stride_order_).Union()}; - } - - private: - //! The tensor dimensions to reduce - std::vector stride_order_; -}; - -//! Specialized Record Functor for the FusionState's sum/min/max ops. - -struct ReductionOpRecord : RecordFunctor { - ReductionOpRecord( - std::vector _args, - std::vector _outputs, - std::string _name, - serde::RecordType record_type, - std::function< - TensorView*(TensorView*, const std::vector&, bool, DataType)> - fusion_op, - std::vector axes, - bool keep_dim, - PrimDataType dtype) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - _name, - record_type), - fusion_op_(std::move(fusion_op)), - axes_(std::move(axes)), - keep_dim_(keep_dim), - dtype_(dtype) {} - ~ReductionOpRecord() override = default; - RecordFunctor* clone() final { - return new ReductionOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 -- 28 | 27 --- 20 | 19 ----------------- 0 | - //! | keep_dim | Dtype | Axes Hash | - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t axes_hash = 0; - // Normally I would make a little endian hash of the axes but I do not - // know the size of the tensor based on just the record information. - for (auto i : arange(axes_.size())) { - axes_hash |= (1 << axes_[i]); - } - - return result | (static_cast(keep_dim_) << 28) | - ((static_cast(dtype_) & 0xff) << 20) | (axes_hash & 0xfffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - result = result && - (fusion_op_.target_type() == child_ptr->fusion_op_.target_type()); - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << "\nReductionOpRecord: " << name_ - << " Target Type [self: 0x" << fusion_op_.target_type().name() - << "] [other: 0x" - << child_ptr->fusion_op_.target_type().name() << "]"; - } - // IMPORTANT! you need to dereference the target pointer in order - // to match the function - result = result && - (*fusion_op_.template target< - - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() == - *child_ptr->fusion_op_.template target< - - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>()); - if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << " Target Ptr [self: 0x" << std::hex - << (size_t)*fusion_op_.template target< - - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() - << "] [other: 0x" << std::hex - << (size_t)*child_ptr->fusion_op_.template target< - - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() - << "]\n"; - } - result = result && (keep_dim_ == child_ptr->keep_dim_); - result = result && (dtype_ == child_ptr->dtype_); - if (result) { - result = (axes_.size() == child_ptr->axes_.size()); - if (result) { - for (size_t i = 0; i < axes_.size(); ++i) { - if (axes_[i] != child_ptr->axes_[i]) { - result = false; - break; - } - } - } - } - } - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - auto output = fusion_op_(arg, axes_, keep_dim_, dtype_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dims=["; - bool first_arg = true; - for (auto axis : axes_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << axis; - } - os << "]"; - os << ", keepdim=" << (keep_dim_ ? "True" : "False"); - os << ", dtype=" << dtypeToPyString(dtype_); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - // TODO add dtype - return { - serde::RecordData::Reduction, - serde::CreateReductionDirect( - builder, &axes_, keep_dim_, toUnderlying(dtype_)) - .Union()}; - } - - private: - //! nvFuser arith function signature for a given reduction operation - std::function< - TensorView*(TensorView*, const std::vector&, bool, DataType)> - fusion_op_; - //! The tensor dimensions to reduce - std::vector axes_; - //! Indicates whether to keep the reduced dimension(s). - bool keep_dim_; - //! The output data type. - PrimDataType dtype_; -}; - -struct ScanOpRecord : RecordFunctor { - ScanOpRecord( - std::vector _args, - std::vector _outputs, - std::string _name, - serde::RecordType record_type, - std::function fusion_op, - int64_t dim, - BinaryOpType op_type) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - _name, - record_type), - fusion_op_(std::move(fusion_op)), - dim_(dim), - op_type_(op_type) {} - ~ScanOpRecord() override = default; - RecordFunctor* clone() final { - return new ScanOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 7 --- 4 | 3 --- 0 | - //! | op_type | dim | - size_t hash() const final { - auto result = RecordFunctor::hash(); - result |= ((static_cast(op_type_) & 0xf) << 4); - return result | (static_cast(dim_) & 0xf); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - result = result && - (*fusion_op_ - .template target() == - *child_ptr->fusion_op_ - .template target()); - result = result && (dim_ == child_ptr->dim_); - result = result && (op_type_ == child_ptr->op_type_); - } - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - auto output = fusion_op_(arg, dim_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::ScanOp, serde::CreateScanOp(builder, dim_).Union()}; - } - - private: - std::function fusion_op_; - int64_t dim_; - BinaryOpType op_type_; -}; - -struct IndexSelectOpRecord : RecordFunctor { - IndexSelectOpRecord( - std::vector _args, - std::vector _outputs, - int64_t dim) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.index_select", - serde::RecordType::IndexSelectOp), - dim_(dim) {} - ~IndexSelectOpRecord() override = default; - RecordFunctor* clone() final { - return new IndexSelectOpRecord(*this); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_; - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg1 = fd.getFusionState(args_.at(0).index)->template as(); - auto arg3 = fd.getFusionState(args_.at(1).index)->template as(); - - Val* output = indexSelect(arg1, dim_, arg3); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Dimension, - serde::CreateDimension(builder, dim_).Union()}; - } - - private: - //! Dimension to select. - int64_t dim_; -}; - -// TODO Merge IndexSelectOpRecord and SelectOpRecord for cleaner interface. -// If the index TensorView is a scalar, then use select operation. -struct SelectOpRecord : RecordFunctor { - SelectOpRecord( - std::vector _args, - std::vector _outputs, - int64_t dim) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.select", - serde::RecordType::SelectOp), - dim_(dim) {} - ~SelectOpRecord() override = default; - RecordFunctor* clone() final { - return new SelectOpRecord(*this); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_; - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg1 = fd.getFusionState(args_.at(0).index)->template as(); - auto arg3 = fd.getFusionState(args_.at(1).index); - - Val* output = select(arg1, dim_, arg3); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Dimension, - serde::CreateDimension(builder, dim_).Union()}; - } - - private: - //! Dimension to select. - int64_t dim_; -}; - -struct ScatterOpRecord : RecordFunctor { - ScatterOpRecord( - std::vector _args, - std::vector _outputs, - int64_t dim) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.scatter", - serde::RecordType::ScatterOp), - dim_(dim) {} - ~ScatterOpRecord() override = default; - RecordFunctor* clone() final { - return new ScatterOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto arg1 = fd.getFusionState(args_.at(0).index)->template as(); - auto arg3 = fd.getFusionState(args_.at(1).index)->template as(); - auto arg4 = fd.getFusionState(args_.at(2).index)->template as(); - - Val* output = scatter(arg1, dim_, arg3, arg4); - fd.setFusionState(outputs_.at(0).index, output); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_; - } - return result; - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Dimension, - serde::CreateDimension(builder, dim_).Union()}; - } - - private: - //! Dimension to select. - int64_t dim_; -}; - -struct GatherOpRecord : RecordFunctor { - GatherOpRecord( - std::vector _args, - std::vector _outputs, - int64_t dim) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.gather", - serde::RecordType::GatherOp), - dim_(dim) {} - ~GatherOpRecord() override = default; - RecordFunctor* clone() final { - return new GatherOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto arg1 = fd.getFusionState(args_.at(0).index)->template as(); - auto arg3 = fd.getFusionState(args_.at(1).index)->template as(); - - Val* output = gather(arg1, dim_, arg3); - fd.setFusionState(outputs_.at(0).index, output); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_; - } - return result; - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Dimension, - serde::CreateDimension(builder, dim_).Union()}; - } - - private: - //! Dimension to select. - int64_t dim_; -}; - -//! Similar to GatherOpRecord but enforces that non-index dimension -//! extents match between index tensor and value tensor. -struct TakeAlongAxisOpRecord : RecordFunctor { - TakeAlongAxisOpRecord( - std::vector _args, - std::vector _outputs, - int64_t dim) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.take_along_axis", - serde::RecordType::TakeAlongAxisOp), - dim_(dim) {} - ~TakeAlongAxisOpRecord() override = default; - RecordFunctor* clone() final { - return new TakeAlongAxisOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto arg1 = fd.getFusionState(args_.at(0).index)->template as(); - auto arg3 = fd.getFusionState(args_.at(1).index)->template as(); - - Val* output = takeAlongAxis(arg1, arg3, dim_); - fd.setFusionState(outputs_.at(0).index, output); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && dim_ == child_ptr->dim_; - } - return result; - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Dimension, - serde::CreateDimension(builder, dim_).Union()}; - } - - private: - //! Dimension to select. - int64_t dim_; -}; - -//! Specialized Record Functor for recording FusionState scalars for both -//! inputs and constants. - -struct ScalarRecord : RecordFunctor { - ScalarRecord( - std::vector _outputs, - PolymorphicValue value, - std::optional dtype, - bool inline_def = false) - : RecordFunctor( - {}, - std::move(_outputs), - "define_scalar", - serde::RecordType::Scalar, - inline_def), - value_( - dtype.has_value() ? castToDtype(std::move(value), dtype.value()) - : std::move(value)), - dtype_( - dtype.has_value() - ? dtype.value() - : std::get(getDataType(value_).type)) {} - ~ScalarRecord() override = default; - RecordFunctor* clone() final { - return new ScalarRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 --------------------------------------- 0 | - //! | Dtype | - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(dtype_) & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - if (auto child_ptr = dynamic_cast(&other)) { - if (RecordFunctor::operator==(other)) { - if (value_.hasValue() != child_ptr->value_.hasValue() || - dtype_ != child_ptr->dtype_) { - return false; - } - if (value_.hasValue()) { - if (value_.is() && std::isnan(value_.as()) && - std::isnan(child_ptr->value_.as())) { - return true; - } else { - return value_ == child_ptr->value_; - } - } else { - return true; - } - } - } - return false; - } - - void operator()(FusionState& fd) final { - Val* output = IrBuilder::create(value_, dtype_); - if (!value_.hasValue()) { - fd.addInput(output, outputs_.at(0).index); - } - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - if (inline_def_) { - NVF_CHECK( - value_.hasValue(), - "Only ScalarRecords with values support inline definitions!"); - if (value_.is()) { - NVF_CHECK( - dtype_ == PrimDataType::Bool, - "A ScalarRecord for Bool inline definition not have a matching " - "data type!"); - os << ((bool)value_ ? "True" : "False"); - } else if (value_.is()) { - NVF_CHECK( - dtype_ == PrimDataType::Double, - "A ScalarRecord for Double inline definition not have a matching " - "data type!"); - if (std::isinf(value_.as())) { - if (std::signbit(value_.as())) { - os << "float(\"-inf\")"; - } else { - os << "float(\"inf\")"; - } - } else if (std::isnan(value_.as())) { - os << "float(\"nan\")"; - } else { - os << std::showpoint << value_.as(); - } - } else if (value_.is()) { - NVF_CHECK( - dtype_ == PrimDataType::Int, - "A ScalarRecord for Int inline definition not have a matching data " - "type!"); - os << value_; - } else { - NVF_THROW("A ScalarRecord with an unsupported inline definition type!"); - } - // NOTE: close_function is not relevant for the inline definition as the - // printing is specific to each operator and not partially done with the - // base class print method. - } else { - RecordFunctor::print(os, false); - if (value_.hasValue()) { - if (value_.is()) { - os << ((bool)value_ ? "True" : "False"); - } else if (value_.is>()) { - os << std::showpoint << std::real(value_.as>()) - << "+" << std::showpoint - << std::imag(value_.as>()) << "j"; - } else if (value_.is()) { - if (std::isinf(value_.as())) { - if (std::signbit(value_.as())) { - os << "float(\"-inf\")"; - } else { - os << "float(\"inf\")"; - } - } else if (std::isnan(value_.as())) { - os << "float(\"nan\")"; - } else { - os << std::showpoint << value_.as(); - } - } else if (value_.is()) { - os << value_; - } else { - NVF_CHECK(false, "Unsupported dtype."); - } - } else { - os << "None"; - } - - os << ", dtype=" << dtypeToPyString(dtype_); - - if (close_function) { - os << ")"; - } - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Scalar, - serde::serializeScalar(builder, value_, dtype_).Union()}; - } - - inline std::pair> valueRecordData( - flatbuffers::FlatBufferBuilder& builder, - PolymorphicValue value) const; - - private: - //! The scalar's value, an input is a nullopt - PolymorphicValue value_; - //! Scalar data type. - PrimDataType dtype_; -}; - -//! Specialized Record Functor for recording FusionDefinition Start. -//! There should only ever be one instance of this Record in the -//! Fusion Cache. - -struct StartRecord : RecordFunctor { - StartRecord() : RecordFunctor({}, {}, "start", serde::RecordType::Start) {} - ~StartRecord() override = default; - RecordFunctor* clone() final { - return new StartRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 --------------------------------------- 0 | - //! | None | - size_t hash() const final { - return RecordFunctor::hash(); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - } - return result; - } - - void operator()(FusionState& fd) final {} -}; - -//! Specialized Record Functors for Normalization based ops. - -struct NormOpRecord : RecordFunctor { - NormOpRecord( - std::vector args, - std::vector outputs, - std::string name, - serde::RecordType type, - std::vector axes, - int64_t correction, - bool keep_dim) - : RecordFunctor(std::move(args), std::move(outputs), name, type), - axes_(std::move(axes)), - correction_(correction), - keep_dim_(keep_dim) {} - ~NormOpRecord() override = default; - RecordFunctor* clone() override = 0; - - // I am skipping the bassel's correction value in the hash because - // I suspect we might change it to a bool from a 64-bit value - //! Child specific hash function in lower 32 bits. - //! | 31 -- 28 | 27 ----------------------------- 0 | - //! | keep_dim | Axes Hash | - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t axes_hash = 0; - // Normally I would make a little endian hash of the axes but I do not - // know the size of the tensor based on just the record information. - for (auto i : arange(axes_.size())) { - axes_hash |= (1 << axes_[i]); - } - return result | (static_cast(keep_dim_) << 28) | - (axes_hash & 0xfffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (correction_ == child_ptr->correction_); - result = result && (keep_dim_ == child_ptr->keep_dim_); - if (result) { - result = (axes_.size() == child_ptr->axes_.size()); - if (result) { - for (size_t i = 0; i < axes_.size(); ++i) { - if (axes_[i] != child_ptr->axes_[i]) { - result = false; - break; - } - } - } - } - } - return result; - } - - //! Each NormOp Child should define the operator() to build the IR - void operator()(FusionState& fd) override = 0; - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dims=["; - bool first_arg = true; - for (auto axis : axes_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << axis; - } - os << "]"; - os << ", correction=" << correction_; - os << ", keepdim=" << (keep_dim_ ? "True" : "False"); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Norm, - serde::CreateNormDirect(builder, &axes_, correction_, keep_dim_) - .Union()}; - } - - protected: - //! Dimensions of tensor to reduce for variance calculation - std::vector axes_; - //! Bessel's correction value - int64_t correction_; - //! Indicates whether to keep the reduced dimension(s). - bool keep_dim_; -}; - -struct VarianceOpRecord : NormOpRecord { - VarianceOpRecord( - std::vector args, - std::vector outputs, - std::vector axes, - int64_t correction, - bool keep_dim) - : NormOpRecord( - std::move(args), - std::move(outputs), - "ops.var", - serde::RecordType::VarianceOp, - std::move(axes), - correction, - keep_dim) {} - ~VarianceOpRecord() override = default; - RecordFunctor* clone() final { - return new VarianceOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->as(); - auto output = variance(arg, axes_, correction_, keep_dim_); - fd.setFusionState(outputs_.at(0).index, output); - } -}; - -//! VarianceMean requires a separate Record because nvFuser defines the output -//! of var_mean as a custom struct. -struct VarianceMeanOpRecord : NormOpRecord { - VarianceMeanOpRecord( - std::vector args, - std::vector outputs, - std::vector axes, - int64_t correction, - bool keep_dim) - : NormOpRecord( - std::move(args), - std::move(outputs), - "ops.var_mean", - serde::RecordType::VarianceMeanOp, - std::move(axes), - correction, - keep_dim) {} - ~VarianceMeanOpRecord() override = default; - RecordFunctor* clone() final { - return new VarianceMeanOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->as(); - auto output = variance_mean(arg, axes_, correction_, keep_dim_); - fd.setFusionState(outputs_.at(0).index, output.var); - fd.setFusionState(outputs_.at(1).index, output.mean); - } -}; - -struct WelfordOpRecord : RecordFunctor { - WelfordOpRecord( - std::vector _args, - std::vector _outputs, - std::vector axes) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.welford", - serde::RecordType::WelfordOp), - axes_(std::move(axes)) {} - ~WelfordOpRecord() override = default; - RecordFunctor* clone() final { - return new WelfordOpRecord(*this); - } - - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t axes_hash = 0; - for (auto axis : axes_) { - hashCombine(axes_hash, static_cast(axis)); - } - return result | (axes_hash & 0xffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - result = (axes_.size() == child_ptr->axes_.size()); - if (result) { - for (size_t i = 0; i < axes_.size(); ++i) { - if (axes_[i] != child_ptr->axes_[i]) { - result = false; - break; - } - } - } - } - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - auto output = WelfordRaw(arg, axes_); - fd.setFusionState(outputs_.at(0).index, output.avg); - fd.setFusionState(outputs_.at(1).index, output.var_sum); - fd.setFusionState(outputs_.at(2).index, output.n); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dims=["; - bool first_arg = true; - for (auto axis : axes_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << axis; - } - os << "]"; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Welford, - serde::CreateWelfordDirect(builder, &axes_).Union()}; - } - - private: - //! The tensor dimensions to reduce - std::vector axes_; -}; - -struct BatchNormOpRecord : RecordFunctor { - BatchNormOpRecord( - std::vector args, - std::vector outputs, - bool training, - bool channels_last) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.batch_norm", - serde::RecordType::BatchNormOp), - training_(training), - channels_last_(channels_last) {} - ~BatchNormOpRecord() override = default; - RecordFunctor* clone() final { - return new BatchNormOpRecord(*this); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (training_ == child_ptr->training_); - result = result && (channels_last_ == child_ptr->channels_last_); - } - return result; - } - - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(training_) << 28) | - (static_cast(channels_last_) << 29); - } - - void operator()(FusionState& fd) final { - auto x = fd.getFusionState(args_.at(0).index)->as(); - auto weight = (args_.at(1).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(1).index)->as() - : nullptr; - auto bias = (args_.at(2).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(2).index)->as() - : nullptr; - auto running_mean = (args_.at(3).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(3).index)->as() - : nullptr; - auto running_var = (args_.at(4).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(4).index)->as() - : nullptr; - auto momentum = fd.getFusionState(args_.at(5).index)->as(); - auto eps = fd.getFusionState(args_.at(6).index)->as(); - auto output = batch_norm( - x, - weight, - bias, - running_mean, - running_var, - training_, - momentum, - eps, - channels_last_); - fd.setFusionState(outputs_.at(0).index, output.output); - fd.setFusionState(outputs_.at(1).index, output.mean); - fd.setFusionState(outputs_.at(2).index, output.invstd); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", training=" << (training_ ? "True" : "False"); - os << ", channels_last=" << (channels_last_ ? "True" : "False"); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::BatchNorm, - serde::CreateBatchNorm(builder, training_, channels_last_).Union()}; - } - - private: - bool training_; - bool channels_last_; -}; - -//! Specialized Record Functor for the FusionState's tensor_size op. -//! Uses the default hash() and print() methods of Record Functor - -struct TensorSizesRecord : RecordFunctor { - TensorSizesRecord(std::vector args, std::vector outputs) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.tensor_sizes", - serde::RecordType::TensorSizes) { - always_returns_tuple_ = true; - } - ~TensorSizesRecord() override = default; - RecordFunctor* clone() final { - return new TensorSizesRecord(*this); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->as(); - auto sizes = shape(arg); - for (const auto idx : arange(sizes.size())) { - fd.setFusionState(outputs_.at(idx).index, sizes[idx]); - } - } -}; - -//! Specialized Record Functor for the shape op. -//! Uses the default hash() and print() methods of Record Functor - -struct ShapeOpRecord : RecordFunctor { - ShapeOpRecord(std::vector args, std::vector outputs) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.shape", - serde::RecordType::ShapeOp) {} - ~ShapeOpRecord() override = default; - RecordFunctor* clone() final { - return new ShapeOpRecord(*this); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->as(); - auto result = shape(arg); - fd.setFusionStateVector(outputs_.at(0).index, result); - } -}; - -//! Specialized Record Functor for the size op. -//! Uses the default hash() and print() methods of Record Functor - -struct SizeOpRecord : RecordFunctor { - SizeOpRecord(std::vector args, std::vector outputs, int64_t dim) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.size", - serde::RecordType::SizeOp), - dim_(dim) {} - ~SizeOpRecord() override = default; - RecordFunctor* clone() final { - return new SizeOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 -------------------------------------- 0 | - //! | dim | - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(dim_) & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (dim_ == child_ptr->dim_); - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->as(); - auto result = size(arg, dim_); - fd.setFusionState(outputs_.at(0).index, result); - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return {serde::RecordData::Size, serde::CreateSize(builder, dim_).Union()}; - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_; - if (close_function) { - os << ")"; - } - } - - private: - int64_t dim_; -}; - -//! Specialized Record Functor for the at() op. -//! Uses the default hash() and print() methods of Record Functor - -struct AtOpRecord : RecordFunctor { - AtOpRecord(std::vector args, std::vector outputs, int64_t index) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.at", - serde::RecordType::AtOp), - index_(index) {} - ~AtOpRecord() override = default; - RecordFunctor* clone() final { - return new AtOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 -------------------------------------- 0 | - //! | index | - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(index_) & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (index_ == child_ptr->index_); - } - return result; - } - - void operator()(FusionState& fd) final { - NVF_CHECK( - args_.at(0).stype == serde::StateType::Vector, - "Expected Vector State!"); - const std::vector& arg = fd.getFusionStateVector(args_.at(0).index); - auto result = at(arg, index_); - fd.setFusionState(outputs_.at(0).index, result); - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return {serde::RecordData::At, serde::CreateAt(builder, index_).Union()}; - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", index=" << index_; - if (close_function) { - os << ")"; - } - } - - private: - int64_t index_; -}; - -struct FullOpRecord : RecordFunctor { - FullOpRecord( - std::vector _args, - std::vector _outputs, - PrimDataType dtype) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.full", - serde::RecordType::FullOp), - dtype_(dtype) { - setArgName(0, "shape"); - setArgName(1, "fill_value"); - } - ~FullOpRecord() override = default; - RecordFunctor* clone() final { - return new FullOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 -------------------------------------- 0 | - //! | Dtype | - size_t hash() const final { - auto result = RecordFunctor::hash(); - result |= (static_cast(dtype_) & 0xffffffff); - return result; - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && dtype_ == child_ptr->dtype_; - } - return result; - } - - void operator()(FusionState& fd) final { - const std::vector& shape = fd.getFusionStateVector(args_.at(0).index); - auto fill_value = fd.getFusionState(args_.at(1).index); - - auto output = full(shape, fill_value, dtype_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const override { - RecordFunctor::print(os, false); - os << ", dtype=" << dtypeToPyString(dtype_); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::TensorCreationSymbolic, - serde::CreateTensorCreationSymbolic(builder, toUnderlying(dtype_)) - .Union()}; - } - - private: - //! Type of output - PrimDataType dtype_; -}; - -struct IotaOpRecord : RecordFunctor { - IotaOpRecord( - std::vector _args, - std::vector _outputs, - PrimDataType dtype) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.iota", - serde::RecordType::IotaOp), - dtype_(dtype) {} - ~IotaOpRecord() override = default; - RecordFunctor* clone() final { - return new IotaOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 -------------------------------------- 0 | - //! | Dtype | - size_t hash() const final { - return RecordFunctor::hash() | static_cast(dtype_); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && dtype_ == child_ptr->dtype_; - } - return result; - } - - void operator()(FusionState& fd) final { - auto length = fd.getFusionState(args_.at(0).index); - auto start = (args_.at(1).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(1).index)->as() - : nullptr; - auto step = (args_.at(2).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(2).index)->as() - : nullptr; - auto output = iota(length, start, step, dtype_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const override { - RecordFunctor::print(os, false); - os << ", dtype=" << dtypeToPyString(dtype_); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Dtype, - serde::CreateDtype(builder, nvfuser::toUnderlying(dtype_)).Union()}; - } - - private: - //! Type of output - PrimDataType dtype_; -}; - -//! Specialized Record Functors for random ops. -template -struct RandomDistOpRecord : RecordFunctor { - RandomDistOpRecord( - std::vector _args, - std::vector _outputs, - PrimDataType dtype) - : RecordFunctor(std::move(_args), std::move(_outputs), "", RType), - dtype_(dtype) { - if constexpr (RType == serde::RecordType::UniformDistOp) { - name_ = "ops.uniform"; - } else if constexpr (RType == serde::RecordType::NormalDistOp) { - name_ = "ops.normal"; - } else { - static_assert( - (RType == serde::RecordType::NormalDistOp) || - (RType == serde::RecordType::UniformDistOp)); - } - setArgName(2, "shape"); - if (args_.size() == 5) { - setArgName(3, "rng_seed"); - setArgName(4, "rng_offset"); - } - } - ~RandomDistOpRecord() override = default; - RecordFunctor* clone() final { - return new RandomDistOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 --------------------------------------- 0 | - //! | Dtype | - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(dtype_) & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (dtype_ == child_ptr->dtype_); - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg1 = fd.getFusionState(args_.at(0).index); - auto arg2 = fd.getFusionState(args_.at(1).index); - const std::vector& output_shape = - fd.getFusionStateVector(args_.at(2).index); - - Val* output = nullptr; - if constexpr (RType == serde::RecordType::UniformDistOp) { - if (args_.size() == 3) { // stochastic uniform - output = uniform(output_shape, arg1, arg2, dtype_); - } else if (args_.size() == 5) { // provided seed and offset - auto seed = fd.getFusionState(args_.at(3).index); - auto offset = fd.getFusionState(args_.at(4).index); - output = uniform(output_shape, arg1, arg2, dtype_, seed, offset); - } - } else if constexpr (RType == serde::RecordType::NormalDistOp) { - if (args_.size() == 3) { // stochastic normal - output = normal(output_shape, arg1, arg2, dtype_); - } else if (args_.size() == 5) { // provided seed and offset - auto seed = fd.getFusionState(args_.at(3).index); - auto offset = fd.getFusionState(args_.at(4).index); - output = normal(output_shape, arg1, arg2, dtype_, seed, offset); - } - } else { - static_assert( - (RType == serde::RecordType::NormalDistOp) || - (RType == serde::RecordType::UniformDistOp)); - } - - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dtype=" << dtypeToPyString(dtype_); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::TensorCreationSymbolic, - serde::CreateTensorCreationSymbolic(builder, toUnderlying(dtype_)) - .Union()}; - } - - private: - //! DataType of output - PrimDataType dtype_; -}; - -//! Specialized Record Functor for recording Vector of Scalars - -struct VectorRecord : RecordFunctor { - VectorRecord( - std::vector _args, - std::vector _outputs, - PrimDataType dtype, - bool inline_def = false) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "define_vector", - serde::RecordType::Vector, - inline_def), - dtype_(dtype) {} - ~VectorRecord() override = default; - RecordFunctor* clone() final { - return new VectorRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 --------------------------------------- 0 | - //! | Dtype | - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(dtype_) & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (dtype_ == child_ptr->dtype_); - } - return result; - } - - void operator()(FusionState& fd) final { - std::vector output(args_.size(), nullptr); - NVF_CHECK( - dtype_ == DataType::Int, - "Only Int Dtype is not supported by a vector of sizes: ", - dtype_); - for (size_t i = 0; i < args_.size(); ++i) { - NVF_CHECK( - args_.at(i).stype == serde::StateType::Scalar, - "Unsupported State type!"); - output.at(i) = fd.getFusionState(args_.at(i).index); - } - fd.setFusionStateVector(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - if (inline_def_) { - bool first_arg = true; - NVF_CHECK(outputs_.size() == 1, "VectorRecord's does not have 1 output!"); - os << "["; - for (auto& arg : args_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << arg; - } - os << "]"; - } else { - bool first_output = true; - for (auto& output : outputs_) { - if (first_output) { - first_output = false; - } else { - os << ", "; - } - os << output; - } - os << " = fd." << name_ << "(["; - bool first_arg = true; - for (auto& arg : args_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << arg; - } - os << "], dtype=" << dtypeToPyString(dtype_); - if (close_function) { - os << ")"; - } - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Vector, - serde::CreateVector(builder, nvfuser::toUnderlying(dtype_)).Union()}; - }; - - private: - //! Scalar data type. - PrimDataType dtype_; -}; - -struct SdpaFwdOpRecord : RecordFunctor { - SdpaFwdOpRecord(std::vector args, std::vector outputs) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.sdpfa_fwd", - serde::RecordType::SdpaFwdOp) {} - ~SdpaFwdOpRecord() override = default; - RecordFunctor* clone() final { - return new SdpaFwdOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto query = fd.getFusionState(args_.at(0).index)->as(); - auto key = fd.getFusionState(args_.at(1).index)->as(); - auto value = fd.getFusionState(args_.at(2).index)->as(); - auto bias = (args_.at(3).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(3).index)->as() - : nullptr; - auto mask = (args_.at(4).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(4).index)->as() - : nullptr; - auto dropout_p = (args_.at(5).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(5).index)->as() - : nullptr; - auto is_causal = (args_.at(6).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(6).index)->as() - : nullptr; - auto scale = (args_.at(7).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(7).index)->as() - : nullptr; - auto output = - sdpfa_fwd(query, key, value, bias, mask, dropout_p, is_causal, scale); - fd.setFusionState(outputs_.at(0).index, output.output); - fd.setFusionState(outputs_.at(1).index, output.logsumexp); - fd.setFusionState(outputs_.at(2).index, output.philox_seed); - fd.setFusionState(outputs_.at(3).index, output.philox_offset); - } -}; - -struct SdpaBwdOpRecord : RecordFunctor { - SdpaBwdOpRecord(std::vector args, std::vector outputs) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.sdpfa_bwd", - serde::RecordType::SdpaBwdOp) {} - ~SdpaBwdOpRecord() override = default; - RecordFunctor* clone() final { - return new SdpaBwdOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto grad_output = fd.getFusionState(args_.at(0).index)->as(); - auto query = fd.getFusionState(args_.at(1).index)->as(); - auto key = fd.getFusionState(args_.at(2).index)->as(); - auto value = fd.getFusionState(args_.at(3).index)->as(); - auto output = fd.getFusionState(args_.at(4).index)->as(); - auto logsumexp = fd.getFusionState(args_.at(5).index)->as(); - - auto dropout_p = (args_.at(6).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(6).index)->as() - : nullptr; - auto is_causal = (args_.at(7).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(7).index)->as() - : nullptr; - - auto philox_seed = fd.getFusionState(args_.at(8).index)->as(); - auto philox_offset = fd.getFusionState(args_.at(9).index)->as(); - - auto scale = (args_.at(10).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(10).index)->as() - : nullptr; - - auto grad = sdpfa_bwd( - grad_output, - query, - key, - value, - output, - logsumexp, - dropout_p, - is_causal, - philox_seed, - philox_offset, - scale); - fd.setFusionState(outputs_.at(0).index, grad.grad_query); - fd.setFusionState(outputs_.at(1).index, grad.grad_key); - fd.setFusionState(outputs_.at(2).index, grad.grad_value); - } -}; - -struct EmbeddingFwdOpRecord : RecordFunctor { - EmbeddingFwdOpRecord(std::vector args, std::vector outputs) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.embedding_fwd", - serde::RecordType::EmbeddingFwdOp) {} - ~EmbeddingFwdOpRecord() override = default; - RecordFunctor* clone() final { - return new EmbeddingFwdOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto input = fd.getFusionState(args_.at(0).index)->as(); - auto weight = fd.getFusionState(args_.at(1).index)->as(); - auto padding_idx = (args_.at(2).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(2).index)->as() - : nullptr; - auto max_norm = (args_.at(3).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(3).index)->as() - : nullptr; - auto norm_type = (args_.at(4).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(4).index)->as() - : nullptr; - auto scale_grad_by_freq = (args_.at(5).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(5).index)->as() - : nullptr; - auto sparse = (args_.at(6).stype == serde::StateType::Scalar) - ? fd.getFusionState(args_.at(6).index)->as() - : nullptr; - - auto output = embedding_fwd( - input, - weight, - padding_idx, - max_norm, - norm_type, - scale_grad_by_freq, - sparse); - fd.setFusionState(outputs_.at(0).index, output); - } -}; - -struct IndexPutAccumulateOpRecord : RecordFunctor { - IndexPutAccumulateOpRecord( - std::vector args, - std::vector outputs) - : RecordFunctor( - std::move(args), - std::move(outputs), - "ops.index_put_accumulate", - serde::RecordType::IndexPutAccumulateOp) {} - ~IndexPutAccumulateOpRecord() override = default; - RecordFunctor* clone() final { - return new IndexPutAccumulateOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto acc = fd.getFusionState(args_.at(0).index)->as(); - auto index = fd.getFusionState(args_.at(1).index)->as(); - auto value = fd.getFusionState(args_.at(2).index)->as(); - - auto output = indexPutAccumulate(acc, index, value); - fd.setFusionState(outputs_.at(0).index, output); - } -}; - -struct ArgsortOpRecord : RecordFunctor { - ArgsortOpRecord( - std::vector _args, - std::vector _outputs, - int64_t dim, - bool descending, - bool stable) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.argsort", - serde::RecordType::ArgsortOp), - dim_(dim), - descending_(descending), - stable_(stable) {} - ~ArgsortOpRecord() override = default; - RecordFunctor* clone() final { - return new ArgsortOpRecord(*this); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto other_argsort = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && - dim_ == other_argsort->dim_ && - descending_ == other_argsort->descending_ && - stable_ == other_argsort->stable_; - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - Val* output = argsort(arg, dim_, descending_, stable_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_ - << ", descending=" << (descending_ ? "True" : "False") - << ", stable=" << (stable_ ? "True" : "False"); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Sort, - serde::CreateSort(builder, dim_, descending_, stable_).Union()}; - } - - private: - int64_t dim_; - bool descending_; - bool stable_; -}; - -//! Record for TopK operation in fusion cache and Python frontend -//! -//! Stores the parameters needed to recreate a TopK operation: -//! - dim: dimension along which to find top-k elements -//! - largest: whether to find largest (true) or smallest (false) elements -//! - sorted: whether the output should be sorted -//! -//! The operation takes two inputs: the tensor and k (number of elements) -//! and produces two outputs: values and indices tensors. -struct TopKOpRecord : RecordFunctor { - TopKOpRecord( - std::vector _args, - std::vector _outputs, - int64_t dim, - bool largest, - bool sorted) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.topk", - serde::RecordType::TopKOp), - dim_(dim), - largest_(largest), - sorted_(sorted) {} - ~TopKOpRecord() override = default; - RecordFunctor* clone() final { - return new TopKOpRecord(*this); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto other_topk = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && dim_ == other_topk->dim_ && - largest_ == other_topk->largest_ && sorted_ == other_topk->sorted_; - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->template as(); - auto k = fd.getFusionState(args_.at(1).index); - auto output = topk(arg, k, dim_, largest_, sorted_); - fd.setFusionState(outputs_.at(0).index, output.values); - fd.setFusionState(outputs_.at(1).index, output.indices); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dim=" << dim_ << ", largest=" << (largest_ ? "True" : "False") - << ", sorted=" << (sorted_ ? "True" : "False"); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::TopK, - serde::CreateTopK(builder, dim_, largest_, sorted_).Union()}; - } - - private: - int64_t dim_; - bool largest_; - bool sorted_; -}; - -struct ScaledGroupedMmaOpRecord : RecordFunctor { - ScaledGroupedMmaOpRecord( - std::vector _args, - std::vector _outputs, - PrimDataType dtype, - int64_t out_block_scale_size, - PrimDataType out_block_scale_dtype, - bool out_gamma) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.grouped_mm", - serde::RecordType::ScaledGroupedMmaOp), - dtype_(dtype), - out_block_scale_size_(out_block_scale_size), - out_block_scale_dtype_(out_block_scale_dtype), - out_gamma_(out_gamma) {} - ~ScaledGroupedMmaOpRecord() override = default; - RecordFunctor* clone() final { - return new ScaledGroupedMmaOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 --------------------------------------- 0 | - //! | Dtype | - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(dtype_) & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = - dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - result = result && (dtype_ == child_ptr->dtype_); - } - return result; - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dtype=" << dtypeToPyString(dtype_); - os << ", output_block_scale_size=" << out_block_scale_size_; - os << ", output_block_scale_dtype=" - << dtypeToPyString(out_block_scale_dtype_); - os << ", output_gamma=" << (out_gamma_ ? "True" : "False"); - if (close_function) { - os << ")"; - } - } - - void operator()(FusionState& fd) final { - auto mat1 = fd.getFusionState(args_.at(0).index)->template as(); - auto mat2 = fd.getFusionState(args_.at(1).index)->template as(); - auto offsets = - fd.getFusionState(args_.at(2).index)->template as(); - auto scale1 = - fd.getFusionState(args_.at(3).index)->template as(); - auto scale2 = - fd.getFusionState(args_.at(4).index)->template as(); - auto alpha = (args_.at(5).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(5).index)->as() - : nullptr; - auto bias = (args_.at(6).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(6).index)->as() - : nullptr; - auto beta = (args_.at(7).stype == serde::StateType::Tensor) - ? fd.getFusionState(args_.at(7).index)->as() - : nullptr; - auto [output_mat, output_scale, output_gamma] = grouped_mm( - mat1, - mat2, - offsets, - scale1, - scale2, - alpha, - bias, - beta, - dtype_, - out_block_scale_size_, - out_block_scale_dtype_, - out_gamma_); - fd.setFusionState(outputs().at(0).index, output_mat); - if (out_block_scale_size_ > 0) { - NVF_CHECK(output_scale != nullptr, "Output scale is null"); - NVF_CHECK( - outputs().at(1).stype != serde::StateType::None, - "Output scale is expected but is null"); - fd.setFusionState(outputs().at(1).index, output_scale); - } - if (out_gamma_) { - NVF_CHECK(output_gamma != nullptr, "Output gamma is null"); - NVF_CHECK( - outputs().at(2).stype != serde::StateType::None, - "Output gamma is expected but is null"); - fd.setFusionState(outputs().at(2).index, output_gamma); - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::ScaledOp, - serde::CreateScaledOp( - builder, - nvfuser::toUnderlying(dtype_), - out_block_scale_size_, - nvfuser::toUnderlying(out_block_scale_dtype_), - out_gamma_) - .Union()}; - }; - - PrimDataType dtype_; - int64_t out_block_scale_size_; - PrimDataType out_block_scale_dtype_; - bool out_gamma_; -}; - -struct ScaledMmaOpRecord : RecordFunctor { - ScaledMmaOpRecord( - std::vector _args, - std::vector _outputs, - PrimDataType dtype, - int64_t output_block_scale_size, - PrimDataType output_block_scale_dtype, - bool output_gamma) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.scaled_mm", - serde::RecordType::ScaledMmaOp), - dtype_(dtype), - out_block_scale_size_(output_block_scale_size), - out_block_scale_dtype_(output_block_scale_dtype), - out_gamma_(output_gamma) {} - ~ScaledMmaOpRecord() override = default; - RecordFunctor* clone() final { - return new ScaledMmaOpRecord(*this); - } - - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(dtype_) & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - if (!RecordFunctor::operator==(other)) { - return false; - } - auto other_scaled_mma = static_cast(other); - return (dtype_ == other_scaled_mma.dtype_) && - (out_block_scale_size_ == other_scaled_mma.out_block_scale_size_) && - (out_block_scale_dtype_ == other_scaled_mma.out_block_scale_dtype_) && - (out_gamma_ == other_scaled_mma.out_gamma_); - } - - void operator()(FusionState& fd) final { - auto mat1 = fd.getFusionState(args_[0].index)->template as(); - auto mat2 = fd.getFusionState(args_[1].index)->template as(); - auto scale1 = fd.getFusionState(args_[2].index)->template as(); - auto scale2 = fd.getFusionState(args_[3].index)->template as(); - auto alpha = args_[4].stype == serde::StateType::None - ? nullptr - : fd.getFusionState(args_[4].index)->template as(); - auto bias = args_[5].stype == serde::StateType::None - ? nullptr - : fd.getFusionState(args_[5].index)->template as(); - auto beta = args_[6].stype == serde::StateType::None - ? nullptr - : fd.getFusionState(args_[6].index)->template as(); - - auto [output_mat, output_scale, output_gamma] = scaled_mm( - mat1, - mat2, - scale1, - scale2, - alpha, - bias, - beta, - dtype_, - out_block_scale_size_, - out_block_scale_dtype_, - out_gamma_); - - fd.setFusionState(outputs().at(0).index, output_mat); - if (out_block_scale_size_ > 0) { - NVF_CHECK(output_scale != nullptr, "Output scale is null"); - NVF_CHECK( - outputs().at(1).stype != serde::StateType::None, - "Output scale is expected but is null"); - fd.setFusionState(outputs().at(1).index, output_scale); - } - if (out_gamma_) { - NVF_CHECK(output_gamma != nullptr, "Output gamma is null"); - NVF_CHECK( - outputs().at(2).stype != serde::StateType::None, - "Output gamma is expected but is null"); - fd.setFusionState(outputs().at(2).index, output_gamma); - } - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dtype=" << dtypeToPyString(dtype_); - os << ", output_block_scale_size=" << out_block_scale_size_; - os << ", output_block_scale_dtype=" - << dtypeToPyString(out_block_scale_dtype_); - os << ", output_gamma=" << (out_gamma_ ? "True" : "False"); - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::ScaledOp, - serde::CreateScaledOp( - builder, - nvfuser::toUnderlying(dtype_), - out_block_scale_size_, - nvfuser::toUnderlying(out_block_scale_dtype_), - out_gamma_) - .Union()}; - }; - - PrimDataType dtype_; - int64_t out_block_scale_size_; - PrimDataType out_block_scale_dtype_; - bool out_gamma_; -}; - -struct CutlassNvfp4GroupedMmaOpRecord : RecordFunctor { - CutlassNvfp4GroupedMmaOpRecord( - std::vector _args, - std::vector _outputs, - PrimDataType dtype) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.cutlass_nvfp4_grouped_mm", - serde::RecordType::CutlassNvfp4GroupedMmaOp), - dtype_(dtype) {} - - ~CutlassNvfp4GroupedMmaOpRecord() override = default; - - size_t hash() const final { - auto result = RecordFunctor::hash(); - return result | (static_cast(dtype_) & 0xffffffff); - } - - bool operator==(const RecordFunctor& other) const final { - if (!RecordFunctor::operator==(other)) { - return false; - } - auto other_cutlass_nvfp4_grouped_mma = - static_cast(other); - return (dtype_ == other_cutlass_nvfp4_grouped_mma.dtype_); - } - - RecordFunctor* clone() final { - return new CutlassNvfp4GroupedMmaOpRecord(*this); - } - - void operator()(FusionState& fd) final { - auto mat1 = fd.getFusionState(args_[0].index)->template as(); - auto mat2 = fd.getFusionState(args_[1].index)->template as(); - auto scale1 = fd.getFusionState(args_[2].index)->template as(); - auto scale2 = fd.getFusionState(args_[3].index)->template as(); - auto alpha = fd.getFusionState(args_[4].index)->template as(); - auto problem_sizes = - fd.getFusionState(args_[5].index)->template as(); - auto expert_offsets = - fd.getFusionState(args_[6].index)->template as(); - auto sf_offsets = - fd.getFusionState(args_[7].index)->template as(); - - auto result = cutlass_nvfp4_grouped_mm( - mat1, - mat2, - scale1, - scale2, - alpha, - problem_sizes, - expert_offsets, - sf_offsets, - dtype_); - fd.setFusionState(outputs().at(0).index, result); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", dtype=" << dtypeToPyString(dtype_); - if (close_function) { - os << ")"; - } - } - - PrimDataType dtype_; -}; - -} // namespace nvfuser::python_frontend - -//! Creating the template specialized hash and equal_to functions for a -//! RecordFunctor object in order to use hash maps (unordered_maps) in STL. -namespace std { -using namespace nvfuser::python_frontend; - -template <> -struct hash { - size_t operator()(const RecordFunctor* p) const { - NVF_CHECK(p, "The RecordFunctor Pointer for hashing is null!"); - return p->hash(); - } -}; -template <> -struct equal_to { - bool operator()(const RecordFunctor* p, const RecordFunctor* q) const { - NVF_CHECK( - p, - "The RecordFunctor Pointer on the lhs of an equality check is null!"); - NVF_CHECK( - q, - "The RecordFunctor Pointer on the rhs of an equality check is null!"); - return p->operator==(*q); - } -}; -} // namespace std diff --git a/python/python_frontend/fusion_state.cpp b/python/python_frontend/fusion_state.cpp deleted file mode 100644 index 0d6dea03781..00000000000 --- a/python/python_frontend/fusion_state.cpp +++ /dev/null @@ -1,297 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include -#include -#include "base.h" - -// Require namespace for perf scope instrumentation -using namespace nvfuser::inst; - -namespace nvfuser::python_frontend { - -bool State::inlineDef() const { - return inline_def_record_.has_value(); -} -void State::setInlineDefRecord(const RecordFunctor* record) { - NVF_CHECK( - record, "Attemped to set the record for an inline definition as Null!"); - inline_def_record_ = std::optional(record); -} -const RecordFunctor* State::inlineDefRecord() const { - NVF_CHECK( - inlineDef(), - "Attempting to query the inline definition Record State that is not " - "inline defined!"); - NVF_CHECK(inline_def_record_.value(), "Inline definition Record is Null!"); - return inline_def_record_.value(); -} - -bool State::operator==(const State& other) const { - NVF_ERROR( - (index == other.index ? (stype == other.stype) : true), - "State indices should not match with different State Types!"); - return (index == other.index) && (stype == other.stype); -} - -bool State::operator!=(const State& other) const { - NVF_ERROR( - (index == other.index ? (stype == other.stype) : true), - "State indices should not match with different State Types!"); - return (index != other.index) || (stype != other.stype); -} - -// Generalized printing of State -std::ostream& operator<<(std::ostream& os, const State& state) { - if (state.inlineDef()) { - NVF_CHECK( - state.inlineDefRecord()->inlineDef(), - "The State Object's definition record is not set with an inline " - "definition!"); - state.inlineDefRecord()->print(os); - } else { - if (state.stype == serde::StateType::Scalar) { - os << "S" << state.index; - } else if (state.stype == serde::StateType::Tensor) { - os << "T" << state.index; - } else if (state.stype == serde::StateType::Vector) { - os << "V" << state.index; - } else if (state.stype == serde::StateType::None) { - os << "None"; - } else { - NVF_THROW("Unsupported StateType"); - } - } - return os; -} - -std::vector getExtents(Fusion* fusion) { - NVF_CHECK(fusion != nullptr, "Fusion is undefined."); - - std::vector extents; - for (Val* v : fusion->inputs()) { - // short-circuit: skip if not TensorView - if (!v->isA()) { - continue; - } - auto* tv = v->as(); - std::vector logical_dom = - TensorDomain::noReductions(tv->getLogicalDomain()); - std::transform( - logical_dom.begin(), - logical_dom.end(), - std::back_inserter(extents), - [](IterDomain* id) { return id->getMaybeExpandedExtent(); }); - } - return extents; -} - -FusionState::FusionState() - : end_record_(new EndRecord()), - recording_(), - recording_state_(), - fusion_(nullptr), - fusion_state_(), - num_recording_states_(0) {} - -std::unique_ptr FusionState::clone() { - auto state = std::make_unique(); - for (auto&& rf : recording_) { - state->recording_.emplace_back(rf->clone()); - } - state->fusion_ = fusion_; - state->fusion_state_.insert( - state->fusion_state_.end(), fusion_state_.begin(), fusion_state_.end()); - state->num_recording_states_ = num_recording_states_; - std::copy( - inputs_fid_.begin(), - inputs_fid_.end(), - std::back_inserter(state->inputs_fid_)); - std::copy( - outputs_fid_.begin(), - outputs_fid_.end(), - std::back_inserter(state->outputs_fid_)); - std::copy( - extents_fid_.begin(), - extents_fid_.end(), - std::back_inserter(state->extents_fid_)); - std::copy( - map_value_to_fid_.begin(), - map_value_to_fid_.end(), - std::inserter(state->map_value_to_fid_, state->map_value_to_fid_.end())); - return state; -} - -void FusionState::buildFusionIr(Fusion* fusion) { - FUSER_PERF_SCOPE("FusionContainer::buildFusionIr"); - NVF_CHECK(fusion != nullptr, "Fusion is undefined."); - resetFusionState(fusion, num_recording_states_); - auto fusion_guard = FusionGuard(fusion); - for (auto& record : recording_) { - auto functor = record.get(); - try { - (*functor)(*this); - } catch (const std::exception& e) { - std::stringstream ss; - record->print(ss); - - NVF_THROW( - "\nDetected exception while building Fusion Ir. The failing " - "RecordFunctor is: ", - ss.str(), - "\nNvFuser error message is: ", - e.what()); - } - } - addExtents(); -} - -void FusionState::addRecord(RecordFunctor* record) { - FUSER_PERF_SCOPE("FusionContainer::addRecord"); - recording_.emplace_back(record); - num_recording_states_ += record->numOutputs(); - RecordFunctor* state_record = recording_.back().get(); - - // NOTE: when the outputs are added to the Record constructor, - // the Record is not constructed. Therefore, the information has to be - // propagated when the Record is added to the FusionState. - for (const auto& out : state_record->outputs()) { - if (state_record->inlineDef()) { - NVF_CHECK( - out.index < recording_state_.size(), - "Output state is not found in recording_state! Index: ", - out.index, - " Size: ", - recording_state_.size()); - recording_state_.at(out.index).setInlineDefRecord(state_record); - } - } -} - -Fusion* FusionState::fusion() { - NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); - return fusion_; -} - -void FusionState::printIr() const { - NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); - fusion_->printMath(); -} - -void FusionState::resetFusionState(Fusion* fusion, size_t size) { - NVF_CHECK(fusion != nullptr, "Fusion is undefined."); - fusion_ = fusion; - fusion_state_.clear(); - fusion_state_.resize(size, {}); - inputs_fid_.clear(); - outputs_fid_.clear(); - extents_fid_.clear(); - map_value_to_fid_.clear(); -} - -void FusionState::addFusionState(Val* val) { - fusion_state_.push_back({val}); -} - -void FusionState::addFusionStateVector(std::vector val) { - for (auto v : val) { - NVF_CHECK( - !v->isA(), - "TensorViews should not be added to State Vectors!"); - } - fusion_state_.push_back(val); -} - -Val* FusionState::getFusionState(size_t index) const { - const auto& ret = fusion_state_.at(index); - NVF_CHECK(ret.size() == 1, "Expecting to return only one Val*."); - return ret.front(); -} - -const std::vector& FusionState::getFusionStateVector(size_t index) const { - return fusion_state_.at(index); -} - -size_t FusionState::numFusionStates() const { - return fusion_state_.size(); -} - -void FusionState::setFusionState(size_t index, Val* val) { - fusion_state_.at(index) = {val}; - map_value_to_fid_.emplace(val, (int64_t)index); -} - -void FusionState::setFusionStateVector(size_t index, std::vector val) { - for (auto v : val) { - NVF_CHECK( - !v->isA(), - "TensorViews should not be added to State Vectors!"); - } - fusion_state_.at(index) = {val}; -} - -void FusionState::addInput(Val* input, size_t index) { - NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); - fusion_->addInput(input); - map_value_to_fid_.emplace(input, (int64_t)index); - inputs_fid_.push_back((int64_t)index); -} - -void FusionState::addOutput(Val* output, size_t index) { - NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); - fusion_->addOutput(output); - map_value_to_fid_.emplace(output, (int64_t)index); - outputs_fid_.push_back((int64_t)index); -} - -void FusionState::aliasOutputToInput(Val* output, Val* input) { - NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); - // We haven't exposed AllocationType to Python API. For now, use - // ReuseBuffer to preserve the old behavior. - fusion_->aliasOutputToInput(output, input, AllocationType::ReuseBuffer); -} - -const std::unordered_map& FusionState::getValueMap() - const { - return map_value_to_fid_; -} - -const std::vector& FusionState::inputs() const { - return inputs_fid_; -} - -const std::vector& FusionState::outputs() const { - return outputs_fid_; -} - -const std::vector& FusionState::extents() const { - return extents_fid_; -} - -void FusionState::addExtents() { - NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); - - // The size of the tensor dimensions can be used as an input of the - // segments. NvFuser does not support returning scalar values. Segmentation - // must pass those sizes as segment arguments manually. - std::vector extents = getExtents(fusion_); - for (Val* extent : extents) { - int64_t num_extents = (int64_t)extents_fid_.size(); - // Use negative numbers to represent extent of iterDomains to avoid conflict - // with non-negative numbers used for scalars, vectors, and tensors. - // The extents are ordered based on the order of the fusion's inputs. - int64_t extent_fid = -num_extents - 1; - extents_fid_.push_back(extent_fid); - // The extent can already exist in the fusion. However, since scalars cannot - // be passed between segments, always overwrited existing fids. The original - // fusion definition will provide scalar extents. - map_value_to_fid_[extent] = extent_fid; - } -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/fusion_state.h b/python/python_frontend/fusion_state.h deleted file mode 100644 index 9e512a0e2a6..00000000000 --- a/python/python_frontend/fusion_state.h +++ /dev/null @@ -1,143 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once -#include -#include -#include -#include - -namespace nvfuser::python_frontend { - -struct RecordFunctor; - -struct State { - State() - : index(0), - stype(serde::StateType::None), - inline_def_record_(std::nullopt) {} - State( - size_t _index, - serde::StateType _stype, - std::optional inline_def_record = std::nullopt) - : index(_index), stype(_stype), inline_def_record_(inline_def_record) {} - - bool inlineDef() const; - NVF_API void setInlineDefRecord(const RecordFunctor* record); - const RecordFunctor* inlineDefRecord() const; - - bool operator==(const State& other) const; - bool operator!=(const State& other) const; - - //! A unique index to identifiy each recorded state item. - size_t index; - //! StateType is either: Tensor, Scalar, or Vector - serde::StateType stype; - - private: - // This data member is only set if this state is inline defined! - std::optional inline_def_record_; -}; - -NVF_API std::ostream& operator<<(std::ostream& os, const State& state); - -//! Get extents for TensorView inputs in Fusion -std::vector getExtents(Fusion* fusion); - -//! FusionState contains the information used to build a new cpp Fusion object. -//! Unlike FusionDefinition, it does not modify the FusionCache Trie structure. -class FusionState { - public: - FusionState(); - - // The copy/move/assign constructors/operators are removed - FusionState(const FusionState& other) = delete; - FusionState(FusionState&& other) noexcept = delete; - FusionState& operator=(const FusionState& other) = delete; - FusionState& operator=(FusionState&& other) noexcept = delete; - virtual ~FusionState() = default; - - //! Get fusion object - Fusion* fusion(); - //! Prints the Fusion IR representation - void printIr() const; - - //! Adds a Fusion IR Tensor/Scalar object - NVF_API void addFusionState(Val* val); - //! Adds a Fusion IR Vector of Scalars - void addFusionStateVector(std::vector val); - //! Gets a Fusion IR Tensor/Scalar object - NVF_API Val* getFusionState(size_t index) const; - //! Gets a Fusion IR Vector of Scalars - NVF_API const std::vector& getFusionStateVector(size_t index) const; - //! Number of fusion states - NVF_API size_t numFusionStates() const; - //! Sets a Fusion IR Tensor/Scalar object - NVF_API void setFusionState(size_t index, Val* val); - //! Sets a Fusion IR Vector of Scalars - NVF_API void setFusionStateVector(size_t index, std::vector val); - - //! Adds a Tensor/Scalar input to the Fusion object - NVF_API void addInput(Val* input, size_t index); - //! Adds a Tensor/Scalar output to the Fusion object - NVF_API void addOutput(Val* output, size_t index); - //! Alias an Output to Input in the Fusion object - NVF_API void aliasOutputToInput(Val* output, Val* input); - - //! Get map between CPP Fusion and Python FusionDefinition - NVF_API const std::unordered_map& getValueMap() const; - //! Get indicies for the inputs of FusionState - NVF_API const std::vector& inputs() const; - //! Get indicies for the outputs of FusionState - NVF_API const std::vector& outputs() const; - //! Get indicies for the extents of TensorView inputs of FusionState - NVF_API const std::vector& extents() const; - - //! Add a Record - void addRecord(RecordFunctor* record); - //! Builds an nvFuser Fusion IR object - void buildFusionIr(Fusion* fusion); - - //! Create clone of FusionState - std::unique_ptr clone(); - - private: - //! Add extents of TensorView inputs to FusionState - void addExtents(); - //! Change the fusion ptr and reset its state - void resetFusionState(Fusion* fusion, size_t size); - - protected: - //! Holds an End Record - std::unique_ptr end_record_; - //! A vector of record operations in the FusionDefintion - std::vector> recording_; - //! A vector of state that represents Tensors/Vectors/Scalars - std::vector recording_state_; - //! Input arguments for FusionState - std::vector inputs_fid_; - //! Output arguments for FusionState - std::vector outputs_fid_; - //! Extents for TensorView input arguments for FusionState - std::vector extents_fid_; - //! Map Fusion Val to its corresponding FusionDefinition index - std::unordered_map map_value_to_fid_; - - private: - //! A ptr to the container used when building the Fusion IR from a definition - Fusion* fusion_ = nullptr; - //! A vector of nvFuser Fusion IR TensorViews/Vectors/Scalars for building the - //! Fusion IR graph. - //! NOTE: Vectors are represented by a vector. This could - //! be another child class of Val in the IR, similar to TensorView. - std::vector> fusion_state_; - //! The number of states in Fusion Container - //! A sum of all outputs for each RecordFunctor - size_t num_recording_states_; -}; - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/multidevice_bindings.cpp b/python/python_frontend/multidevice_bindings.cpp deleted file mode 100644 index 5dbe199101f..00000000000 --- a/python/python_frontend/multidevice_bindings.cpp +++ /dev/null @@ -1,103 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include - -#include -#include -#include -#include - -namespace nvfuser::python_frontend { - -namespace { -void bindCommunicator(py::module& nvfuser) { - // py::nodelete is necessary because Communicator doesn't have a destructor: - // https://pybind11.readthedocs.io/en/stable/advanced/classes.html#non-public-destructors - py::class_> - communicator(nvfuser, "Communicator"); - communicator.def( - "instance", - &Communicator::getInstance, - "Returns the singleton communicator instance.", - py::return_value_policy::reference); - communicator.def( - "size", - &Communicator::size, - "Returns the number of processes in the communicator."); - communicator.def( - "rank", - &Communicator::deviceId, - "Returns the device ID associated with the current process."); - communicator.def( - "local_size", - &Communicator::local_size, - "Returns the number of processes within the node."); - communicator.def( - "local_rank", - &Communicator::local_rank, - "Returns the in-node rank associated with the current process."); - communicator.def( - "barrier", - [](Communicator& self) { - // Communicator::barrier takes an optional backend argument, which we - // don't use yet. - self.barrier(); - }, - "Performs a blocking barrier across all ranks."); -} - -void bindDeviceMesh(py::module& nvfuser) { - py::class_ device_mesh(nvfuser, "DeviceMesh", py::module_local()); - device_mesh.def(py::init([](const std::vector& devices) { - return new DeviceMesh(at::tensor(devices)); - })); - device_mesh.def("__repr__", [](const DeviceMesh& self) { - std::stringstream ss; - ss << self; - return ss.str(); - }); - device_mesh.def_property_readonly( - "size", - [](const DeviceMesh& self) -> int64_t { return self.size(); }, - "Returns the size of the mesh."); - device_mesh.def( - "shard_tensor", - [](const DeviceMesh& self, at::Tensor tensor, const int64_t axis) - -> at::Tensor { return shardTensor1D(tensor, axis, self); }, - py::arg("tensor"), - py::arg("axis")); -} - -void bindDistributedTensor(py::module& nvfuser) { - py::class_ distributed_tensor( - nvfuser, "Sharding", py::module_local()); - distributed_tensor.def_property_readonly( - "mesh", - &Sharding::mesh, - "Returns the device mesh.", - py::return_value_policy::reference); - distributed_tensor.def( - "axis_sharded_on", - &Sharding::axisShardedOn, - R"( - Returns the axis sharded on the given parallel type. - - If the distributed tensor is replicated on that parallel type, returns -1. - )", - py::arg("parallel_type")); -} - -} // namespace - -void bindMultidevice(py::module& nvfuser) { - bindCommunicator(nvfuser); - bindDeviceMesh(nvfuser); - bindDistributedTensor(nvfuser); -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/python_bindings.cpp b/python/python_frontend/python_bindings.cpp deleted file mode 100644 index 791f19f9ea2..00000000000 --- a/python/python_frontend/python_bindings.cpp +++ /dev/null @@ -1,4196 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include - -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace nvfuser::python_frontend { - -// Set of local functions that are used to compose python FusionDefinition -// bindings. Ideally, these would be templated lambda functions but those -// are not available without C++20. -namespace { -Vector define_vector_base_fn( - FusionDefinition& fd, - std::vector& args, - bool inline_def = false) { - FUSER_PERF_SCOPE("python_frontend::define_vector_base_fn"); - NVF_CHECK(!fd.completed(), "Attempting to add to a completed definition!"); - std::vector inputs; - inputs.reserve(args.size()); - for (const auto& arg : args) { - inputs.push_back(fd.recordingState(arg())); - } - Vector out = fd.defineVector(inputs.size()); - fd.defineRecord(new VectorRecord( - inputs, {fd.recordingState(out())}, DataType::Int, inline_def)); - return out; -} - -template -Vector define_vector_fn( - FusionDefinition& self, - ITERABLE& values, - bool inline_def, - bool shape_check) { - FUSER_PERF_SCOPE("python_frontend::define_vector_fn"); - std::vector args; - size_t idx = 0; - for (const auto& item : values) { - if (py::isinstance(item)) { - auto int_value = py::cast(item); - NVF_CHECK( - !shape_check || int_value >= -1, - "The value ", - int_value, - " at index ", - idx, - " was neither symbolic(-1), zero_element(0), broadcast(1), or " - "static(>1)."); - Scalar out = self.defineScalar(); - self.defineRecord(new ScalarRecord( - {self.recordingState(out())}, - py::cast(item), - DataType::Int, - /*inline_def=*/true)); - args.emplace_back(out); - } else if (py::isinstance(item)) { - args.emplace_back(py::cast(item)); - } else { - NVF_CHECK( - false, - "Unsupported iterable object type for define_vector! Index:", - idx); - } - ++idx; - } - return define_vector_base_fn(self, args, inline_def); -} - -template -Vector define_vector_explicit_fn( - FusionDefinition& self, - ITERABLE& values, - PrimDataType dtype = DataType::Int) { - return define_vector_fn( - self, values, /*inline_def=*/false, /*shape_check=*/true); -} - -template -Vector SequenceAsVector( - ShapeType shape, - FusionDefinition& fd, - bool shape_check = true) { - static_assert( - std::is_same_v || - std::is_same_v || - std::is_same_v); - if constexpr (std::is_same_v) { - return shape; - } else { - // It's important to call define_vector_fn in the if-else branch. - // - // ``` - // if constexpr (std::is_same_v) { - // return shape; - // } - // return define_vector_fn(fd, shape); - // ``` - // would not work because the compiler would try to instantiate - // define_vector_fn and fail. - return define_vector_fn( - fd, shape, /*inline_def=*/true, /*shape_check=*/shape_check); - } -} - -template -Tensor broadcast_in_dim_fn( - FusionDefinition::Operators& op, - Tensor arg, - ShapeType generic_output_shape, - std::vector& broadcast_dims) { - FUSER_PERF_SCOPE("Operators.broadcast_in_dim"); - FusionDefinition* fd = op.fusion_definition; - NVF_CHECK(op.validUse(), "Attempting to add to a completed definition!"); - Vector output_shape = SequenceAsVector(generic_output_shape, *fd); - NVF_CHECK( - output_shape.size >= broadcast_dims.size(), - "broadcast_dims vector size is too big for output shape!"); - - Tensor output = fd->defineTensor(output_shape.size); - fd->defineRecord(new BroadcastInDimOpRecord( - {fd->recordingState(arg()), fd->recordingState(output_shape())}, - {fd->recordingState(output())}, - output_shape.size, - broadcast_dims)); - return output; -} - -template -Tensor expand_fn( - FusionDefinition::Operators& op, - Tensor arg, - ShapeType generic_output_shape) { - FUSER_PERF_SCOPE("Operators.expand"); - FusionDefinition* fd = op.fusion_definition; - NVF_CHECK(op.validUse(), "Attempting to add to a completed definition!"); - Vector output_shape = SequenceAsVector(generic_output_shape, *fd); - - Tensor output = fd->defineTensor(output_shape.size); - fd->defineRecord(new ExpandOpRecord( - {fd->recordingState(arg()), fd->recordingState(output_shape())}, - {fd->recordingState(output())})); - return output; -} - -template -Tensor full_op_fn( - FusionDefinition::Operators& self, - ShapeType generic_output_shape, - Scalar fill_value, - PrimDataType dtype) { - NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Vector output_shape = SequenceAsVector(generic_output_shape, *fd); - Tensor output = fd->defineTensor(output_shape.size); - fd->defineRecord(new FullOpRecord( - {fd->recordingState(output_shape()), fd->recordingState(fill_value())}, - {fd->recordingState(output())}, - dtype)); - return output; -} - -template -Tensor reshape_fn( - FusionDefinition::Operators& self, - Tensor arg, - ShapeType generic_new_shape) { - NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); - - FusionDefinition* fd = self.fusion_definition; - Vector new_shape = SequenceAsVector(generic_new_shape, *fd); - - Tensor output = fd->defineTensor(new_shape.size); - fd->defineRecord(new ReshapeOpRecord( - {fd->recordingState(arg()), fd->recordingState(new_shape())}, - {fd->recordingState(output())})); - return output; -} - -template -Tensor pad_fn( - FusionDefinition::Operators& self, - Tensor arg, - ShapeType generic_pad_widths, - std::optional value) { - NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); - - FusionDefinition* fd = self.fusion_definition; - Vector pad_widths = - SequenceAsVector(generic_pad_widths, *fd, /*shape_check=*/false); - - NVF_CHECK( - pad_widths.size <= 2 * arg.dims, - "Number of pad widths must be at most twice the input dimension"); - - State value_state = value.has_value() ? fd->recordingState(value.value()()) - : State(0, serde::StateType::None); - - Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new PadOpRecord( - {fd->recordingState(arg()), - fd->recordingState(pad_widths()), - value_state}, - {fd->recordingState(output())})); - return output; -} - -template -Tensor random_dist_op_fn( - FusionDefinition::Operators& self, - Scalar arg1, - Scalar arg2, - ShapeType generic_new_shape, - std::optional rng_seed, - std::optional rng_offset, - PrimDataType dtype) { - static_assert( - (RType == serde::RecordType::NormalDistOp) || - (RType == serde::RecordType::UniformDistOp)); - NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); - NVF_CHECK( - isFloatingPointType(dtype), - "Random distributions only create floating point types! ", - dtype); - FusionDefinition* fd = self.fusion_definition; - Vector new_shape = SequenceAsVector(generic_new_shape, *fd); - - Tensor output = fd->defineTensor(new_shape.size); - std::vector arg_states = { - fd->recordingState(arg1()), - fd->recordingState(arg2()), - fd->recordingState(new_shape()), - }; - if (rng_seed.has_value() && rng_offset.has_value()) { - arg_states.push_back(fd->recordingState(rng_seed.value()())); - arg_states.push_back(fd->recordingState(rng_offset.value()())); - } else { - NVF_CHECK( - !rng_seed.has_value() && !rng_offset.has_value(), - "rng_seed and rng_offset must be provided together!"); - } - - fd->defineRecord(new RandomDistOpRecord( - arg_states, {fd->recordingState(output())}, dtype)); - - return output; -} - -template -Tensor slice_fn( - FusionDefinition::Operators& self, - Tensor arg, - ShapeType start, - ShapeType end, - std::optional strides, - bool manual_normalization) { - NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); - - FusionDefinition* fd = self.fusion_definition; - Vector new_start = SequenceAsVector(start, *fd, /*shape_check=*/false); - Vector new_end = SequenceAsVector(end, *fd, /*shape_check=*/false); - size_t stride_index = 0; - - if (strides.has_value()) { - Vector new_stride = - SequenceAsVector(strides.value(), *fd, /*shape_check=*/false); - NVF_CHECK( - new_start.size == new_stride.size, - "Slice start_indices and strides don't match! Start Indices: ", - new_start.size, - " Strides: ", - new_stride.size); - stride_index = new_stride(); - } else { - // set stride with default value; - std::vector stride_vec; - stride_vec.reserve(new_start.size); - // Note: we cannot re-use the same ScalarRecord, otherwise, serialized - // python program uses `define_vector`, which would create multiple - // ScalarRecord, causing a cache miss. - for (auto i : arange(new_start.size)) { - (void)i; // Supress unused variable warning - Scalar out = fd->defineScalar(); - fd->defineRecord(new ScalarRecord( - {fd->recordingState(out())}, - 1, - DataType::Int, - /*inline_def=*/true)); - stride_vec.push_back(out); - } - // Cannot inline definition with `Vector` here, since - // `FusionDefinition.ops.slice` expects start/end/stride to have the same - // type. - Vector default_stride = define_vector_base_fn( - *fd, stride_vec, !std::is_same_v); - stride_index = default_stride(); - } - - NVF_CHECK( - arg.dims == new_start.size, - "Number of tensor dimensions does not match slice dimensions! " - "Tensor-dims: ", - arg.dims, - " Slice-dims: ", - new_start.size); - NVF_CHECK( - new_start.size == new_end.size, - "Slice indexing attribute dimensions don't match! Start Indices: ", - new_start.size, - " End Indices: ", - new_end.size); - - Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new SliceOpRecord( - {fd->recordingState(arg()), - fd->recordingState(new_start()), - fd->recordingState(new_end()), - fd->recordingState(stride_index)}, - {fd->recordingState(output())}, - manual_normalization)); - return output; -} - -} // namespace - -std::vector> computeContiguity( - const std::vector& sizes, - const std::vector& strides) { - NVF_CHECK( - sizes.size() == strides.size(), - "compute_contiguity: Sizes and strides must have the same number of " - "dimensions"); - // Not a broadcast means neither the stride == 0 (size can be non-zero) - // or the size == 1 that each can indicate a broadcast - auto not_broadcast = [&](auto i) { return strides[i] != 0 && sizes[i] != 1; }; - // Contiguity defaults to vector of all None's - std::vector> contiguity(sizes.size(), std::nullopt); - if (contiguity.empty()) { // zero-dim tensor - return contiguity; - } - int64_t last = (int64_t)sizes.size() - 1; // inner most dimension - // Contiguity normallly is determined by the current dimension and one - // dimension to the right. The innermost dimension, that is not broadcasted, - // does not have any dimension to it's right and needs to be specially marked - // contiguous. - for (; last >= 0; --last) { - if (not_broadcast(last)) { - contiguity[last] = (strides.at(last) == 1); - break; - } - } - // Dimensions are marked contiguous by inspecting the current dimension and - // one to the right towards the inner dimension while skipping over broadcast - // dimensions. - for (int64_t i = 0; i < last;) { - if (not_broadcast(i)) { - auto l = i++; - for (; i <= last; i++) { - if (not_broadcast(i)) { - break; - } - } - contiguity[l] = (strides[l] == strides[i] * sizes[i]); - } else { - i++; - } - } - return contiguity; -} - -// Copy definition from a FusionDefinion's pre-scheduled CPP fusion to a blank -// FusionDefinition. Primarily for testing purposes to check that the -// translation from CPP fusion is correct. -void clone(FusionDefinition& from, FusionDefinition& to) { - NVF_CHECK(from.completed(), "FusionDefinition definition does not exist!"); - NVF_ERROR( - !to.completed(), "Expected an incomplete definition before translation."); - translate(from.preschedFusion(), &to); -} - -namespace { -void defineHeuristicParamBindings(py::module& nvfuser) { - py::class_ launch_parameters( - nvfuser, "LaunchParams", py::module_local()); - launch_parameters.def( - py::init()); - launch_parameters.def( - "__repr__", [](const LaunchParams& self) { return self.toString(); }); - launch_parameters.def_property( - "bdimx", - [](LaunchParams& self) { return self.bdimx(); }, - [](LaunchParams& self, int64_t val) { - self.bindUnsafe(val, ParallelType::TIDx); - }); - launch_parameters.def_property( - "bdimy", - [](LaunchParams& self) { return self.bdimy(); }, - [](LaunchParams& self, int64_t val) { - self.bindUnsafe(val, ParallelType::TIDy); - }); - launch_parameters.def_property( - "bdimz", - [](LaunchParams& self) { return self.bdimz(); }, - [](LaunchParams& self, int64_t val) { - self.bindUnsafe(val, ParallelType::TIDz); - }); - launch_parameters.def_property( - "gdimx", - [](LaunchParams& self) { return self.gdimx(); }, - [](LaunchParams& self, int64_t val) { - self.bindUnsafe(val, ParallelType::BIDx); - }); - launch_parameters.def_property( - "gdimy", - [](LaunchParams& self) { return self.gdimy(); }, - [](LaunchParams& self, int64_t val) { - self.bindUnsafe(val, ParallelType::BIDy); - }); - launch_parameters.def_property( - "gdimz", - [](LaunchParams& self) { return self.gdimz(); }, - [](LaunchParams& self, int64_t val) { - self.bindUnsafe(val, ParallelType::BIDz); - }); - -#define DEFINECLASS(type) py::class_(nvfuser, #type, py::module_local()) - -#define TOSTRINGTOPLEVEL(type) \ - def("__repr__", [](const type& self) { return toString(self); }) -#define TOSTRINGMETHOD(type) \ - def("__repr__", [](const type& self) { return self.toString(); }) - -#define PARAM(internal_type, name) def_readwrite(#name, &internal_type::name) - - DEFINECLASS(CompileParams) - .PARAM(CompileParams, index_type) - .PARAM(CompileParams, maxrregcount) - .PARAM(CompileParams, enable_magic_zero) - .PARAM(CompileParams, enable_ptxas_verbose) - .TOSTRINGMETHOD(CompileParams); - - DEFINECLASS(GemmTile) - .def(py::init()) - .PARAM(GemmTile, m) - .PARAM(GemmTile, n) - .PARAM(GemmTile, k) - .TOSTRINGTOPLEVEL(GemmTile); - - DEFINECLASS(MatMulTileOptions) - .def(py::init()) - .PARAM(MatMulTileOptions, cta_tile) - .PARAM(MatMulTileOptions, warp_tile) - .TOSTRINGTOPLEVEL(MatMulTileOptions); - - py::class_( - nvfuser, "CircularBufferOptions") - .def(py::init()) - .PARAM(MatmulParams::CircularBufferOptions, circular_buffer_smem_read) - .PARAM(MatmulParams::CircularBufferOptions, circular_buffer_smem_write) - .PARAM(MatmulParams::CircularBufferOptions, smem_circular_buffer_stage) - .PARAM( - MatmulParams::CircularBufferOptions, - smem_circular_buffer_prefetch_gap) - .TOSTRINGMETHOD(MatmulParams::CircularBufferOptions); - - py::class_( - nvfuser, "SupportedVectorization") - .def(py::init()) - .PARAM(MatmulParams::SupportedVectorization, a) - .PARAM(MatmulParams::SupportedVectorization, b) - .PARAM(MatmulParams::SupportedVectorization, epilogue) - .TOSTRINGMETHOD(MatmulParams::SupportedVectorization); - - py::enum_( - nvfuser, "MatmulTileRasterizationOrder") - .value("column_major", MatmulParams::TileRasterizationOrder::ColumnMajor) - .value("row_major", MatmulParams::TileRasterizationOrder::RowMajor); - - py::class_(nvfuser, "ClusterDims") - .def(py::init()) - .PARAM(MatmulParams::ClusterDims, m) - .PARAM(MatmulParams::ClusterDims, n) - .TOSTRINGMETHOD(MatmulParams::ClusterDims); - - py::enum_(nvfuser, "MmaMacroArch") - .value("no_mma", MmaMacroEncode::Arch::NoMma) - .value("volta", MmaMacroEncode::Arch::Volta) - .value("turing", MmaMacroEncode::Arch::Turing) - .value("ampere", MmaMacroEncode::Arch::Ampere) - .value("hopper", MmaMacroEncode::Arch::Hopper); - - DEFINECLASS(MmaMacroEncode) - .def(py::init()) - .def("mma_macro", &MmaMacroEncode::operator MmaMacro) - .PARAM(MmaMacroEncode, arch) - .PARAM(MmaMacroEncode, m) - .PARAM(MmaMacroEncode, n) - .PARAM(MmaMacroEncode, k); - - // NOTE: MmaMacro is a uint64_t. To modify it, we convert to and from - // MmaMacroEncode -#define MMAMACROPROP(prop, type) \ - def_property( \ - #prop, \ - [](const MmaMacro& self) { return MmaMacroEncode(self).prop; }, \ - [](MmaMacro& self, type x) { \ - auto enc = MmaMacroEncode(self); \ - enc.prop = x; \ - self = enc; \ - }) - DEFINECLASS(MmaMacro) - .MMAMACROPROP(arch, MmaMacroEncode::Arch) - .MMAMACROPROP(m, uint16_t) - .MMAMACROPROP(n, uint16_t) - .MMAMACROPROP(k, uint16_t) - .TOSTRINGTOPLEVEL(MmaMacro); -#undef MMAMACROPROP - - py::enum_(nvfuser, "MatmulTilingStrategy") - .value("one_tile_per_cta", MatmulParams::TilingStrategy::OneTilePerCTA) - .value( - "distribute_tiles_across_sms", - MatmulParams::TilingStrategy::DistributeTilesAcrossSMs) - .value( - "distribute_stages_across_sms", - MatmulParams::TilingStrategy::DistributeStagesAcrossSMs); - py::enum_( - nvfuser, "MatmulBufferingLoopLevel") - .value("cta_tiles", MatmulParams::BufferingLoopLevel::CTATiles) - .value("warp_tiles", MatmulParams::BufferingLoopLevel::WarpTiles); - py::enum_( - nvfuser, "MatmulCircularBufferingStrategy") - .value("pipelined", MatmulParams::CircularBufferingStrategy::Pipelined) - .value( - "warp_specialized", - MatmulParams::CircularBufferingStrategy::WarpSpecialized); - - // Base class for scheduler parameters - DEFINECLASS(HeuristicParams) - .TOSTRINGMETHOD(HeuristicParams) - .PARAM(HeuristicParams, lparams) - .PARAM(HeuristicParams, cparams); - -#define INITHEURISTICPARAMS(internal_type) \ - py::class_(nvfuser, #internal_type) \ - .def(py::init()) \ - .def("__repr__", [](const internal_type& self) { \ - return self.toString(); \ - }) - - // Pointwise scheduler parameters - INITHEURISTICPARAMS(PointwiseParams) - .PARAM(PointwiseParams, break_point) - .PARAM(PointwiseParams, split_block) - .PARAM(PointwiseParams, split_grid_y_dim) - .PARAM(PointwiseParams, flip_grid_binding) - .PARAM(PointwiseParams, vectorization_factor) - .PARAM(PointwiseParams, unroll_factor_inner) - .PARAM(PointwiseParams, unroll_factor_outer); - - // Reduction scheduler parameters - INITHEURISTICPARAMS(ReductionParams) - .PARAM(ReductionParams, fastest_dim) - .PARAM(ReductionParams, persistent_kernel) - .PARAM(ReductionParams, project_persistent_buffers) - .PARAM(ReductionParams, schedule_3d) - .PARAM(ReductionParams, flip_grid) - .PARAM(ReductionParams, cross_block_inner_reduction) - .PARAM(ReductionParams, cross_grid_inner_reduction) - .PARAM(ReductionParams, unroll_factor_inner_reduction) - .PARAM(ReductionParams, unroll_factor_top_of_vectorization) - .PARAM(ReductionParams, vectorize_inner_reduction) - .PARAM(ReductionParams, split_grid_dim_inner_reduction) - .PARAM(ReductionParams, pad_inner_reduction_to_warp) - .PARAM(ReductionParams, batches_per_block_inner_reduction) - .PARAM(ReductionParams, block_dim_inner_reduction) - .PARAM(ReductionParams, grid_dim_inner_reduction) - .PARAM(ReductionParams, multiple_reds_per_blk) - .PARAM(ReductionParams, unroll_factor_iter_dom) - .PARAM(ReductionParams, vectorize_iter_dom) - .PARAM(ReductionParams, split_grid_dim_iter_dom_inner) - .PARAM(ReductionParams, split_grid_dim_iter_dom_outer) - .PARAM(ReductionParams, block_dim_iter_dom) - .PARAM(ReductionParams, grid_dim_iter_dom) - .PARAM(ReductionParams, cross_block_outer_reduction) - .PARAM(ReductionParams, cross_grid_outer_reduction) - .PARAM(ReductionParams, batches_per_block_outer_reduction) - .PARAM(ReductionParams, unroll_factor_outer_reduction) - .PARAM(ReductionParams, block_dim_outer_reduction) - .PARAM(ReductionParams, grid_dim_outer_reduction) - .PARAM(ReductionParams, compute_persistent_buffer_with_first_consumer) - .PARAM(ReductionParams, static_bdimx) - .PARAM(ReductionParams, static_bdimy) - .PARAM(ReductionParams, combined_inner_outer) - .PARAM(ReductionParams, tidx_for_outer_reduction) - .PARAM(ReductionParams, pad_outer_reduction_to_warp) - .PARAM(ReductionParams, combined_split_grid_inner_dim) - .PARAM(ReductionParams, vectorization_factor_outer) - .PARAM(ReductionParams, vectorization_factor_tmp_gmem_write) - .PARAM(ReductionParams, block_dim_inner_reduction_extra); - - // Matmul scheduler parameters - INITHEURISTICPARAMS(MatmulParams) - .PARAM(MatmulParams, tile_sizes) - .PARAM(MatmulParams, circular_buffer_options) - .PARAM(MatmulParams, supported_vec_size) - .PARAM(MatmulParams, async_gmem_load_operands) - .PARAM(MatmulParams, grid_traversal_factor) - .PARAM(MatmulParams, use_smem_epilogue) - .PARAM(MatmulParams, use_ldst_matrix) - .PARAM(MatmulParams, promote_prologue_smem_reuse) - .PARAM(MatmulParams, splitk_factor) - .PARAM(MatmulParams, tiling_strategy) - .PARAM(MatmulParams, buffering_loop_level) - .PARAM(MatmulParams, circular_buffering_strategy) - .PARAM(MatmulParams, cta_order) - .PARAM(MatmulParams, cluster_dims) - .PARAM(MatmulParams, mma_macro); - -#undef PARAM -#undef INITPARAMS -} - -} // namespace - -void initNvFuserPythonBindings(PyObject* module) { - auto nvfuser = py::handle(module).cast(); - - nvfuser.def("clone", clone); - - //! DataTypes supported by nvFuser in the FusionDefinition - py::enum_(nvfuser, "DataType", py::module_local()) - .value("Double", DataType::Double) - .value("Float", DataType::Float) - .value("Half", DataType::Half) - .value("Int", DataType::Int) - .value("Int32", DataType::Int32) - .value("UInt64", DataType::UInt64) - .value("Bool", DataType::Bool) - .value("BFloat16", DataType::BFloat16) - .value("Float8_e4m3fn", DataType::Float8_e4m3fn) - .value("Float8_e5m2", DataType::Float8_e5m2) - .value("Float8_e8m0fnu", DataType::Float8_e8m0fnu) - .value("Float4_e2m1fn", DataType::Float4_e2m1fn) - .value("Float4_e2m1fn_x2", DataType::Float4_e2m1fn_x2) - .value("ComplexFloat", DataType::ComplexFloat) - .value("ComplexDouble", DataType::ComplexDouble) - .value("Null", DataType::Null); - - //! ParallelType used for scheduling - py::enum_(nvfuser, "ParallelType", py::module_local()) - .value("mesh_x", ParallelType::DIDx) - .value("grid_x", ParallelType::BIDx) - .value("grid_y", ParallelType::BIDy) - .value("grid_z", ParallelType::BIDz) - .value("block_x", ParallelType::TIDx) - .value("block_y", ParallelType::TIDy) - .value("block_z", ParallelType::TIDz) - .value("mma", ParallelType::Mma) - .value("serial", ParallelType::Serial) - .value("tma", ParallelType::Bulk) - .value("unroll", ParallelType::Unroll) - .value("unswitch", ParallelType::Unswitch) - .value("vectorize", ParallelType::Vectorize) - .value("stream", ParallelType::Stream); - - //! LoadStoreOpType used for scheduling - py::enum_(nvfuser, "LoadStoreOpType") - .value("set", LoadStoreOpType::Set) - .value("load_matrix", LoadStoreOpType::LdMatrix) - .value("cp_async", LoadStoreOpType::CpAsync) - .value("tma", LoadStoreOpType::CpAsyncBulkTensorTile); - - //! CacheOp used for scheduling - py::enum_(nvfuser, "CacheOp") - .value("unspecified", CacheOp::Unspecified) - .value("all_levels", CacheOp::AllLevels) - .value("streaming", CacheOp::Streaming) - .value("global", CacheOp::Global); - - //! MemoryType used for scheduling - py::enum_(nvfuser, "MemoryType") - .value("local", MemoryType::Local) - .value("shared", MemoryType::Shared) - .value("global", MemoryType::Global); - - //! Scheduler Type for scheduling - py::enum_(nvfuser, "SchedulerType", py::module_local()) - .value("none", SchedulerType::None) - .value("no_op", SchedulerType::NoOp) - .value("pointwise", SchedulerType::PointWise) - .value("matmul", SchedulerType::Matmul) - .value("reduction", SchedulerType::Reduction) - .value("inner_persistent", SchedulerType::InnerPersistent) - .value("inner_outer_persistent", SchedulerType::InnerOuterPersistent) - .value("outer_persistent", SchedulerType::OuterPersistent) - .value("transpose", SchedulerType::Transpose) - .value("expr_eval", SchedulerType::ExprEval) - .value("resize", SchedulerType::Resize); - - py::enum_( - nvfuser, "CommunicatorBackend", py::module_local()) - .value("nccl", CommunicatorBackend::kNccl) - .value("ucc", CommunicatorBackend::kUcc); - - nvfuser.def("compute_contiguity", computeContiguity); - nvfuser.def("compute_tensor_descriptor", computeTensorDescriptor); - nvfuser.def("serialize", serialize); - - //! Binding the FusionCache that holds a cache of Fusions - //! This is only bound to provide an interface to get the number of fusions - //! that are cached. - py::class_ fusion_cache(nvfuser, "FusionCache"); - fusion_cache - .def_static( - "get", - &FusionCache::get, - py::arg("max_fusions") = int(16384), - py::arg("selected_device") = py::none(), - py::arg("load_from_default_workspace") = true, - py::return_value_policy::reference) - .def("num_fusions", &FusionCache::numFusions) - .def_static( - "reset", &FusionCache::reset, py::return_value_policy::reference) - .def( - "serialize", - [](FusionCache& self, std::string filename) { - FUSER_PERF_SCOPE("FusionCache.serialize (string)"); - self.serialize(filename); - }, - py::arg("filename")) - .def( - "deserialize", - [](FusionCache& self, std::string filename) { - FUSER_PERF_SCOPE("FusionCache.deserialize (string)"); - self.deserialize(filename); - }, - py::arg("filename")) - .def( - "__repr__", - [](FusionCache& self) { - std::stringstream ss; - self.print(ss); - return ss.str(); - }) - .def("stats", [](FusionCache& self) { - std::stringstream ss; - self.stats(ss); - return ss.str(); - }); - - defineHeuristicParamBindings(nvfuser); - - py::class_ hyperparameters( - nvfuser, "SchedulerHyperParameters"); - hyperparameters.def(py::init()); - hyperparameters.def_property( - "vectorize_factor", - [](scheduler_utils::SchedulerHyperParameters& self) { - return self.vectorize_factor; - }, - [](scheduler_utils::SchedulerHyperParameters& self, - int64_t vectorize_factor_) { - self.vectorize_factor = vectorize_factor_; - }); - hyperparameters.def_property( - "unroll_factor", - [](scheduler_utils::SchedulerHyperParameters& self) { - return self.unroll_factor; - }, - [](scheduler_utils::SchedulerHyperParameters& self, - int64_t unroll_factor_) { self.unroll_factor = unroll_factor_; }); - hyperparameters.def_property( - "threads_per_block_min", - [](scheduler_utils::SchedulerHyperParameters& self) { - return self.threads_per_block_min; - }, - [](scheduler_utils::SchedulerHyperParameters& self, - int64_t threads_per_block_min_) { - self.threads_per_block_min = threads_per_block_min_; - }); - hyperparameters.def_property( - "threads_per_block_max", - [](scheduler_utils::SchedulerHyperParameters& self) { - return self.threads_per_block_max; - }, - [](scheduler_utils::SchedulerHyperParameters& self, - int64_t threads_per_block_max_) { - self.threads_per_block_max = threads_per_block_max_; - }); - hyperparameters.def_property( - "is_warp_specialized", - [](scheduler_utils::SchedulerHyperParameters& self) { - return self.is_warp_specialized; - }, - [](scheduler_utils::SchedulerHyperParameters& self, - int64_t is_warp_specialized_) { - self.is_warp_specialized = is_warp_specialized_; - }); - //! KernelProfiles are encapsulated in FusionProfiles where each KP - //! is associated with a segment. - py::class_ kernel_prof(nvfuser, "KernelProfile"); - kernel_prof.def_property_readonly( - "name", [](KernelProfile& self) { return self.name; }); - kernel_prof.def_property_readonly( - "segment_id", [](KernelProfile& self) { return self.segment_id; }); - kernel_prof.def_property_readonly( - "device", [](KernelProfile& self) { return self.device; }); - kernel_prof.def_property_readonly( - "stream", [](KernelProfile& self) { return self.stream; }); - kernel_prof.def_property_readonly("correlation_id", [](KernelProfile& self) { - return self.correlation_id; - }); - kernel_prof.def_property_readonly("compile_time_ms", [](KernelProfile& self) { - return self.compile_time_ms; - }); - kernel_prof.def_property_readonly( - "time_ms", [](KernelProfile& self) { return self.time_ms; }); - kernel_prof.def_property_readonly( - "effective_bandwidth_gbs", - [](KernelProfile& self) { return self.effective_bandwidth_gbs; }); - kernel_prof.def_property_readonly( - "percentage_peak_bandwidth", - [](KernelProfile& self) { return self.percentage_peak_bandwidth; }); - kernel_prof.def_property_readonly( - "grid_str", [](KernelProfile& self) { return self.grid_str; }); - kernel_prof.def_property_readonly( - "block_str", [](KernelProfile& self) { return self.block_str; }); - kernel_prof.def_property_readonly( - "cluster_str", [](KernelProfile& self) { return self.cluster_str; }); - kernel_prof.def_property_readonly("shared_mem_str", [](KernelProfile& self) { - return self.shared_mem_str; - }); - kernel_prof.def_property_readonly( - "registers", [](KernelProfile& self) { return self.registers; }); - kernel_prof.def_property_readonly( - "input_bytes", [](KernelProfile& self) { return self.input_bytes; }); - kernel_prof.def_property_readonly( - "output_bytes", [](KernelProfile& self) { return self.output_bytes; }); - kernel_prof.def_property_readonly( - "scheduler", [](KernelProfile& self) { return self.scheduler; }); - - //! A fusion profile is generated for FusionDefinition. - py::class_ fusion_prof(nvfuser, "FusionProfile"); - fusion_prof.def_property_readonly( - "verbose", [](FusionProfile& self) { return self.verbose; }); - fusion_prof.def_property_readonly( - "fusion_id", [](FusionProfile& self) { return self.fusion_id; }); - fusion_prof.def_property_readonly( - "segments", [](FusionProfile& self) { return self.segments; }); - fusion_prof.def_property_readonly( - "cuda_evt_time_ms", - [](FusionProfile& self) { return self.cuda_evt_time_ms; }); - fusion_prof.def_property_readonly( - "host_time_ms", [](FusionProfile& self) { return self.host_time_ms; }); - fusion_prof.def_property_readonly("compile_time_ms", [](FusionProfile& self) { - return self.compile_time_ms; - }); - fusion_prof.def_property_readonly("kernel_time_ms", [](FusionProfile& self) { - return self.kernel_time_ms; - }); - fusion_prof.def_property_readonly( - "effective_bandwidth_gbs", - [](FusionProfile& self) { return self.effective_bandwidth_gbs; }); - fusion_prof.def_property_readonly( - "percentage_peak_bandwith", - [](FusionProfile& self) { return self.percentage_peak_bandwidth; }); - fusion_prof.def_property_readonly( - "input_bytes", [](FusionProfile& self) { return self.input_bytes; }); - fusion_prof.def_property_readonly( - "output_bytes", [](FusionProfile& self) { return self.output_bytes; }); - fusion_prof.def_property_readonly("kernel_profiles", [](FusionProfile& self) { - return self.kernel_profiles; - }); - - //! These are the FusionDefinition supported object types that are either - //! defined as inputs or the output of an operation. - py::class_ tensor_class(nvfuser, "Tensor"); - tensor_class.def("__repr__", [](Tensor& self) { - std::stringstream ss; - ss << "Tensor(index=" << self.index << ", ndim=" << self.dims << ")"; - return ss.str(); - }); - tensor_class.def_property_readonly( - "ndim", - [](Tensor& self) { return self.dims; }, - "Returns the rank of the tensor."); - tensor_class.def_property_readonly( - "index", - [](Tensor& self) { return self.index; }, - "Returns the index of the tensor as in " - "FusionDefinition.sched.tensors()."); - tensor_class.def("_get_fusion_definition", [](Tensor& self) { - return self.fusion_definition; - }); - tensor_class.def(pybind11::self == pybind11::self); - tensor_class.def(pybind11::self != pybind11::self); - - py::class_ scalar_class(nvfuser, "Scalar"); - scalar_class.def("__repr__", [](Scalar& self) { - std::stringstream ss; - ss << "Scalar(index=" << self.index << ")"; - return ss.str(); - }); - scalar_class.def(pybind11::self == pybind11::self); - scalar_class.def(pybind11::self != pybind11::self); - - py::class_ vector_class(nvfuser, "Vector"); - vector_class.def("__repr__", [](Vector& self) { - std::stringstream ss; - ss << "Vector(index=" << self.index << ", size=" << self.size << ")"; - return ss.str(); - }); - vector_class.def_property_readonly( - "size", [](Vector& self) { return self.size; }); - vector_class.def(pybind11::self == pybind11::self); - vector_class.def(pybind11::self != pybind11::self); - - //! The FusionDefinition is a context manager in Python where the user will - //! define the set the operations and connections between operations for - //! nvFuser to create. - py::class_ fusion_def(nvfuser, "_FusionDefinition"); - fusion_def - .def( - py::init, size_t, bool, CommunicatorBackend>(), - py::arg("id") = py::none(), - py::arg("max_length") = int(1024), - py::arg("use_multidevice_executor") = false, - py::arg("backend_type") = CommunicatorBackend::kNccl) - .def_readwrite("ops", &FusionDefinition::ops) - .def_readwrite("sched", &FusionDefinition::sched) - .def( - "_setup_definition", - [](FusionDefinition& self) -> FusionDefinition* { - // Instrumentation to mark the beginning of a FusionDefinition - inst::Trace::instance()->beginEvent("FusionDefinition Definition"); - return self.setupDefinition(); - }) - .def( - "_finalize_definition", - [](FusionDefinition& self) { - self.finalizeDefinition(); - // Mark the end of a definition - inst::Trace::instance()->endEvent("FusionDefinition Definition"); - }) - .def( - "_exist_schedule", - [](FusionDefinition& self, const py::iterable& iter) { - KernelArgumentHolder args; - for (py::handle obj : iter) { - args.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - return self.existSchedule(args); - }) - .def( - "_setup_schedule", - [](FusionDefinition& self, - const py::iterable& iter, - bool overwrite_existing_schedule) { - // Instrumentation to mark the beginning of a schedule - inst::Trace::instance()->beginEvent("FusionDefinition Schedule"); - KernelArgumentHolder args; - for (py::handle obj : iter) { - args.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - self.setupSchedule(args, overwrite_existing_schedule); - }, - py::arg("inputs"), - py::kw_only(), - py::arg("overwrite_existing_schedule") = false) - .def( - "_finalize_schedule", - [](FusionDefinition& self, const py::iterable& iter) { - KernelArgumentHolder args; - for (py::handle obj : iter) { - args.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - self.finalizeSchedule(args); - // Mark the end of a schedule - inst::Trace::instance()->endEvent(nullptr); - }) - .def( - "_setup_multidevice_schedule", - [](FusionDefinition& self) { self.setupMultideviceSchedule(); }) - .def( - "_finalize_multidevice_schedule", - [](FusionDefinition& self) { self.finalizeMultideviceSchedule(); }) - .def("inputs", [](FusionDefinition& self) { return self.inputs(); }) - .def("outputs", [](FusionDefinition& self) { return self.outputs(); }) - .def("extents", [](FusionDefinition& self) { return self.extents(); }) - .def( - "_setup_segmentation", - [](FusionDefinition& self, const py::iterable& iter) { - // Instrumentation to mark the beginning of segmentation - inst::Trace::instance()->beginEvent( - "FusionDefinition Segmentation"); - KernelArgumentHolder args; - for (py::handle obj : iter) { - // Allows for a Vector of Sizes to be inputed as a list/tuple - if (py::isinstance(obj) || - py::isinstance(obj)) { - for (py::handle item : obj) { - args.push(torch::jit::toIValue(item, c10::AnyType::get())); - } - } else { - args.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - } - return self.setupSegmentation(args); - }) - .def( - "_build_segment", - [](FusionDefinition& self, - FusionDefinition& other, - int64_t segment_id) { - return self.buildSegment(other, segment_id); - }) - .def( - "_finalize_segmentation", - [](FusionDefinition& self) { - self.finalizeSegmentation(); - // Mark the end of segmentation - inst::Trace::instance()->endEvent(nullptr); - }) - .def("inputs", [](FusionDefinition& self) { return self.inputs(); }) - .def("outputs", [](FusionDefinition& self) { return self.outputs(); }) - .def("extents", [](FusionDefinition& self) { return self.extents(); }) - .def( - "__repr__", - [](FusionDefinition& self) { - std::stringstream ss; - self.print(ss); - return ss.str(); - }) - .def( - "_execute", - [](FusionDefinition& self, - const py::iterable& iter, - std::optional device, - bool override_user_schedule, - bool capture_debug_output, - bool profile, - std::vector _enable_options, - std::vector _disable_options) - -> std::pair, std::vector> { - KernelArgumentHolder ins; - for (py::handle obj : iter) { - // Allows for a Vector of Sizes to be inputed as a list/tuple - if (py::isinstance(obj) || - py::isinstance(obj)) { - for (py::handle item : obj) { - ins.push(torch::jit::toIValue(item, c10::AnyType::get())); - } - } else { - ins.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - } - std::optional int8_device = std::nullopt; - if (device.has_value()) { - NVF_CHECK(device.value() < 256, "Maximum device index is 255"); - int8_device = (int8_t)device.value(); - } - auto&& [outs, out_shardings] = self.execute( - ins, - int8_device, - override_user_schedule, - capture_debug_output, - profile, - _enable_options, - _disable_options); - - std::vector out_tensors; - out_tensors.reserve(outs.size()); - for (const auto& out : outs) { - // Should we append toIValue(out) instead? - out_tensors.push_back(out.as()); - } - return std::make_pair( - std::move(out_tensors), std::move(out_shardings)); - }, - py::arg("inputs"), - py::kw_only(), - py::arg("device") = py::none(), - py::arg("override_user_schedule") = false, - py::arg("capture_debug_output") = false, - py::arg("profile") = false, - py::arg("_enable_options") = py::none(), - py::arg("_disable_options") = py::none(), - py::return_value_policy::reference) - .def_static( - "_profile", - &FusionProfiler::profile, - py::return_value_policy::reference) - .def( - "_debug_output", - [](FusionDefinition& self) { return self.getDebugOutput(); }, - py::return_value_policy::reference) - .def( - "_fusion_ir", - [](FusionDefinition& self) { return self.fusionIr(); }, - py::return_value_policy::reference) - .def( - "_user_schedule_ir", - [](FusionDefinition& self) { return self.userScheduleIr(); }, - py::return_value_policy::reference) - .def( - "_last_cuda_code", - [](FusionDefinition& self, - bool intrinsic_code, - bool override_user_schedule) { - return self.lastCudaCode(intrinsic_code, override_user_schedule); - }, - py::arg("intrinsic_code") = false, - py::arg("override_user_schedule") = false, - py::return_value_policy::reference) - .def( - "_cuda_code_for", - [](FusionDefinition& self, - const py::iterable& iter, - bool intrinsic_code, - bool override_user_schedule) { - KernelArgumentHolder args; - for (py::handle obj : iter) { - args.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - return self.cudaCodeFor( - args, intrinsic_code, override_user_schedule); - }, - py::arg("inputs"), - py::arg("intrinsic_code") = false, - py::arg("override_user_schedule") = false, - py::return_value_policy::reference) - .def( - "_last_scheduled_fusion_ir", - [](FusionDefinition& self, - bool tensor_transforms, - bool override_user_schedule) { - return self.lastScheduledFusionIr( - tensor_transforms, override_user_schedule); - }, - py::arg("tensor_transforms") = false, - py::arg("override_user_schedule") = false, - py::return_value_policy::reference) - .def( - "_scheduled_fusion_ir_for", - [](FusionDefinition& self, - const py::iterable& iter, - bool tensor_transforms, - bool override_user_schedule) { - KernelArgumentHolder args; - for (py::handle obj : iter) { - args.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - return self.scheduledFusionIrFor( - args, tensor_transforms, override_user_schedule); - }, - py::arg("inputs"), - py::arg("tensor_transforms") = false, - py::arg("override_user_schedule") = false, - py::return_value_policy::reference) - .def( - "id", - [](FusionDefinition& self) -> std::optional { - return self.id(); - }) - .def( - "add_output", - [](FusionDefinition& self, Scalar output) { - FUSER_PERF_SCOPE("FusionDefinition.add_output (scalar)"); - NVF_CHECK( - !self.completed(), - "Attempting to add to a completed definition!"); - self.defineRecord(new OutputRecord( - {self.recordingState(output())}, serde::RecordType::OutputVal)); - }, - py::arg("output")) - .def( - "add_output", - [](FusionDefinition& self, - Tensor output, - std::optional alias_input = std::nullopt) { - FUSER_PERF_SCOPE("FusionDefinition.add_output (tensor)"); - NVF_CHECK( - !self.completed(), - "Attempting to add to a completed definition!"); - if (alias_input.has_value()) { - self.defineRecord(new OutputRecord( - {self.recordingState(output()), - self.recordingState(alias_input.value()())}, - serde::RecordType::OutputTv)); - } else { - self.defineRecord(new OutputRecord( - {self.recordingState(output())}, - serde::RecordType::OutputTv)); - } - }, - py::arg("output"), - py::arg("alias_input") = py::none()) - .def( - "add_output", - [](FusionDefinition& self, - Tensor output, - std::vector stride_order) { - FUSER_PERF_SCOPE("FusionDefinition.add_output (tensor)"); - NVF_CHECK( - !self.completed(), - "Attempting to add to a completed definition!"); - NVF_CHECK( - stride_order.empty() || output.dims == stride_order.size(), - "stride_order needs to be either empty or the same length of " - "Tensor `output`"); - int64_t duplicate_check = 0; - for (const auto& v : stride_order) { - NVF_CHECK( - v >= 0 && v < (int64_t)stride_order.size(), - "stride_order elements need to be within [0, " - "stride_order.size())"); - duplicate_check |= 1 << v; - } - NVF_CHECK( - duplicate_check == (1 << stride_order.size()) - 1, - "duplicated elements in stride_order detected!"); - self.defineRecord(new OutputRecord( - {self.recordingState(output())}, - serde::RecordType::OutputTv, - stride_order)); - }, - py::arg("output"), - py::arg("stride_order")) - // This version of define_tensor is the canonical version - // that displays the values as they are passed to the IR's - // TensorViewBuilder. - // Each dimension can be of value: - // -1 : Symbolic for Dynamic usage - // 0 : Zero-element - // 1 : Broadcast - // >1 : Static size - // NOTE: A Tensor defined for dynamic shape usage should only - // contain either symbolic(-1) or broadcast(1) defined dimensions. - .def( - "define_tensor", - [](FusionDefinition& self, - const std::vector& shape, - const std::vector>& contiguity, - const PrimDataType dtype = DataType::Float, - const bool is_cpu = false, - const std::vector& stride_order = {}) -> Tensor { - FUSER_PERF_SCOPE( - "FusionDefinition.define_tensor (contiguity as vector)"); - NVF_CHECK( - !self.completed(), - "Attempting to add to a completed definition!"); - - verifyShape(shape); - - Tensor out = self.defineTensor(shape.size()); - self.defineRecord(new TensorRecord( - {self.recordingState(out())}, - shape, - contiguity, - dtype, - is_cpu, - stride_order)); - - return out; - }, - py::arg("shape"), - py::arg("contiguity"), - py::arg("dtype") = DataType::Float, - py::arg("is_cpu") = false, - py::arg("stride_order") = py::list(), - py::return_value_policy::reference) - .def( - "define_tensor", - [](FusionDefinition& self, - const std::vector& shape, - // Contiguity for non-broadcast dimensions. - const bool contiguity = false, - const PrimDataType dtype = DataType::Float, - const bool is_cpu = false, - const std::vector& stride_order = {}) -> Tensor { - FUSER_PERF_SCOPE( - "FusionDefinition.define_tensor (contiguity as bool)"); - NVF_CHECK( - !self.completed(), - "Attempting to add to a completed definition!"); - - verifyShape(shape); - Tensor out = self.defineTensor(shape.size()); - self.defineRecord(new TensorRecord( - {self.recordingState(out())}, - shape, - getContiguityVec(shape, stride_order, contiguity), - dtype, - is_cpu, - stride_order)); - - return out; - }, - py::arg("shape"), - py::arg("contiguity") = false, - py::arg("dtype") = DataType::Float, - py::arg("is_cpu") = false, - py::arg("stride_order") = py::list(), - py::return_value_policy::reference) - .def( - "define_tensor", - [](FusionDefinition& self, - const std::vector& sizes, - const std::vector& strides, - const PrimDataType dtype = DataType::Float, - const bool static_sizes = false, - const bool is_cpu = false) -> Tensor { - FUSER_PERF_SCOPE("FusionDefinition.define_tensor (integration)"); - NVF_CHECK( - !self.completed(), - "Attempting to add to a completed definition!"); - NVF_CHECK( - sizes.size() == strides.size(), - "The number of sizes does not match the number of strides.", - sizes.size(), - strides.size()); - Tensor out = self.defineTensor(sizes.size()); - std::vector> contiguity; - std::vector stride_order; - std::tie(contiguity, stride_order) = - computeTensorDescriptor(sizes, strides); - self.defineRecord(new TensorRecord( - {self.recordingState(out())}, - getTensorViewBuilderSizes(sizes, static_sizes), - contiguity, - dtype, - is_cpu, - stride_order)); - return out; - }, - py::arg("sizes"), - py::arg("strides"), - py::arg("dtype") = DataType::Float, - py::arg("static_sizes") = false, - py::arg("is_cpu") = false, - py::return_value_policy::reference) - .def( - "define_scalar", - [](FusionDefinition& self, - PrimDataType dtype = DataType::Double) -> Scalar { - FUSER_PERF_SCOPE("FusionDefinition.define_scalar (input_specific)"); - NVF_CHECK( - !self.completed(), - "Attempting to add to a completed definition!"); - Scalar out = self.defineScalar(); - self.defineRecord(new ScalarRecord( - {self.recordingState(out())}, std::monostate{}, dtype)); - return out; - }, - py::arg("dtype") = DataType::Double, - py::return_value_policy::reference); - fusion_def.def( - "define_scalar", - [](FusionDefinition& self, - PolymorphicValue::VariantType value, - std::optional dtype) -> Scalar { - FUSER_PERF_SCOPE("FusionDefinition.define_scalar"); - Scalar out = self.defineScalar(); - self.defineRecord( - new ScalarRecord({self.recordingState(out())}, value, dtype)); - return out; - }, - py::arg("value"), - py::arg("dtype") = std::nullopt, - py::return_value_policy::reference); - fusion_def.def( - "define_constant", - [](FusionDefinition& self, - PolymorphicValue::VariantType value, - std::optional dtype) -> Scalar { - FUSER_PERF_SCOPE("FusionDefinition.define_contant"); - TORCH_WARN_ONCE( - "Deprecating define_constant functions in favor of define_scalar " - "for constants."); - Scalar out = self.defineScalar(); - self.defineRecord( - new ScalarRecord({self.recordingState(out())}, value, dtype)); - return out; - }, - py::arg("value"), - py::arg("dtype") = std::nullopt, - py::return_value_policy::reference); - - // This is the input version of define_vector - fusion_def.def( - "define_vector", - [](FusionDefinition& self, size_t size) -> Vector { - std::vector args; - args.reserve(size); - for (size_t i = 0; i < size; ++i) { - Scalar out = self.defineScalar(); - self.defineRecord(new ScalarRecord( - {self.recordingState(out())}, std::monostate{}, DataType::Int)); - args.emplace_back(out); - } - return define_vector_base_fn(self, args); - }, - py::arg("size"), - py::return_value_policy::reference); - // This is the constant version of define_vector when given a vector - // of constant values. - fusion_def.def( - "define_vector", - define_vector_explicit_fn, - py::arg("values"), - py::arg("dtype") = DataType::Int, - py::return_value_policy::reference); - fusion_def.def( - "define_vector", - define_vector_explicit_fn, - py::arg("values"), - py::arg("dtype") = DataType::Int, - py::return_value_policy::reference); - - fusion_def.def( - "getValTolerances", - [](FusionDefinition& self, const py::iterable& input_iter) { - KernelArgumentHolder args; - for (py::handle obj : input_iter) { - args.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - return self.getValTolerances(args); - }, - py::return_value_policy::reference); - - fusion_def.def( - "validate_with_auto_inferred_outputs", - [](FusionDefinition& self, - const py::iterable& fusion_outputs, - const py::iterable& inputs) { - KernelArgumentHolder fusion_outputs_holder; - for (py::handle obj : fusion_outputs) { - fusion_outputs_holder.push( - torch::jit::toIValue(obj, c10::AnyType::get())); - } - KernelArgumentHolder inputs_holder; - for (py::handle obj : inputs) { - inputs_holder.push(torch::jit::toIValue(obj, c10::AnyType::get())); - } - return self.validate_with_auto_inferred_outputs( - fusion_outputs_holder, inputs_holder); - }, - py::return_value_policy::reference, - R"doc( - Validates the fusion outputs against the inputs with auto-inferred outputs. - - Parameters - ---------- - fusion_outputs : iterable - The outputs of the fusion to validate. - inputs : iterable - The inputs to the fusion. - Example - ------- - >>> fd.validate_with_auto_inferred_outputs(fusion_outputs, inputs) - )doc"); - - //! The Operators class is a nested class of FusionDefinition to allow the - //! user to query the class for the list of operators. - //! - //! Example: - //! help(FusionDefinition.Operators) - //! - //! Additional operators are expected to be defined below as needed. They - //! may require defining a new RecordFunctor child class if they are unique. - py::class_ nvf_ops(fusion_def, "Operators"); - nvf_ops.def(py::init()); - -// ******************** INSERT OP BINDINGS BELOW HERE ******************** -#define OP_PREFIX "Operators." -#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, Tensor input) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(input.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(input())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Unary_TV, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, Scalar input) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(input())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Unary_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_UNARY_OP("abs", abs) - NVFUSER_PYTHON_BINDING_UNARY_OP("acos", acos) - NVFUSER_PYTHON_BINDING_UNARY_OP("acosh", acosh) - NVFUSER_PYTHON_BINDING_UNARY_OP("asin", asin) - NVFUSER_PYTHON_BINDING_UNARY_OP("asinh", asinh) - NVFUSER_PYTHON_BINDING_UNARY_OP("atan", atan) - NVFUSER_PYTHON_BINDING_UNARY_OP("atanh", atanh) - NVFUSER_PYTHON_BINDING_UNARY_OP("ceil", ceil) - NVFUSER_PYTHON_BINDING_UNARY_OP("cos", cos) - NVFUSER_PYTHON_BINDING_UNARY_OP("cosh", cosh) - NVFUSER_PYTHON_BINDING_UNARY_OP("exp", exp) - NVFUSER_PYTHON_BINDING_UNARY_OP("exp2", exp2) - NVFUSER_PYTHON_BINDING_UNARY_OP("expm1", expm1) - NVFUSER_PYTHON_BINDING_UNARY_OP("erf", erf) - NVFUSER_PYTHON_BINDING_UNARY_OP("erfc", erfc) - NVFUSER_PYTHON_BINDING_UNARY_OP("erfinv", erfinv) - NVFUSER_PYTHON_BINDING_UNARY_OP("erfcinv", erfcinv) - NVFUSER_PYTHON_BINDING_UNARY_OP("floor", floor) - NVFUSER_PYTHON_BINDING_UNARY_OP("frac", frac) - NVFUSER_PYTHON_BINDING_UNARY_OP("lgamma", lgamma) - NVFUSER_PYTHON_BINDING_UNARY_OP("log", log) - NVFUSER_PYTHON_BINDING_UNARY_OP("log10", log10) - NVFUSER_PYTHON_BINDING_UNARY_OP("log1p", log1p) - NVFUSER_PYTHON_BINDING_UNARY_OP("log2", log2) - NVFUSER_PYTHON_BINDING_UNARY_OP("neg", neg) - NVFUSER_PYTHON_BINDING_UNARY_OP("logical_not", logical_not) - NVFUSER_PYTHON_BINDING_UNARY_OP("bitwise_not", bitwise_not) - NVFUSER_PYTHON_BINDING_UNARY_OP("relu", relu) - NVFUSER_PYTHON_BINDING_UNARY_OP("rand_like", rand_like) - NVFUSER_PYTHON_BINDING_UNARY_OP("randn_like", randn_like) - NVFUSER_PYTHON_BINDING_UNARY_OP("reciprocal", reciprocal) - NVFUSER_PYTHON_BINDING_UNARY_OP("round", round) - NVFUSER_PYTHON_BINDING_UNARY_OP("rsqrt", rsqrt) - NVFUSER_PYTHON_BINDING_UNARY_OP("set", set) - NVFUSER_PYTHON_BINDING_UNARY_OP("segment_set", segment_set) - NVFUSER_PYTHON_BINDING_UNARY_OP("sign", sign) - NVFUSER_PYTHON_BINDING_UNARY_OP("sigmoid", sigmoid) - NVFUSER_PYTHON_BINDING_UNARY_OP("signbit", signbit) - NVFUSER_PYTHON_BINDING_UNARY_OP("silu", silu) - NVFUSER_PYTHON_BINDING_UNARY_OP("sin", sin) - NVFUSER_PYTHON_BINDING_UNARY_OP("sinh", sinh) - NVFUSER_PYTHON_BINDING_UNARY_OP("sqrt", sqrt) - NVFUSER_PYTHON_BINDING_UNARY_OP("tan", tan) - NVFUSER_PYTHON_BINDING_UNARY_OP("tanh", tanh) - NVFUSER_PYTHON_BINDING_UNARY_OP("trunc", trunc) - NVFUSER_PYTHON_BINDING_UNARY_OP("isfinite", isfinite) - NVFUSER_PYTHON_BINDING_UNARY_OP("isinf", isinf) - NVFUSER_PYTHON_BINDING_UNARY_OP("isnan", isnan) - NVFUSER_PYTHON_BINDING_UNARY_OP("isneginf", isneginf) - NVFUSER_PYTHON_BINDING_UNARY_OP("isposinf", isposinf) - NVFUSER_PYTHON_BINDING_UNARY_OP("isreal", isreal) - NVFUSER_PYTHON_BINDING_UNARY_OP("real", real) - NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag) -#undef NVFUSER_PYTHON_BINDING_UNARY_OP - - nvf_ops.def( - "triu", - [](FusionDefinition::Operators& self, - Tensor input, - int64_t diagonal) -> Tensor { - FUSER_PERF_SCOPE("Operators.triu"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(input.dims); - - auto diagonal_ = fd->defineScalar(); - fd->defineRecord(new ScalarRecord( - {fd->recordingState(diagonal_())}, diagonal, DataType::Int, true)); - - fd->defineRecord(new OpRecord( - {fd->recordingState(input()), fd->recordingState(diagonal_())}, - {fd->recordingState(output())}, - ("ops.triu"), - serde::RecordType::Binary_TV_VAL, - static_cast(triu))); - - return output; - }, - py::arg("input"), - py::arg("diagonal") = 0, - py::return_value_policy::reference, - R"doc( - Returns the upper triangular part of a 2+D tensor. - - Parameters - ---------- - input : Tensor - The input tensor. - diagonal : int, optional - The diagonal to consider. Default is 0. - - Returns - ------- - Tensor - The upper triangular part of the input tensor. - - >>> a = torch.randn(3, 3) - >>> fd.ops.triu(a) - )doc"); - - // overload to - nvf_ops.def( - "stride_order", - [](FusionDefinition::Operators& self, - Tensor arg, - std::vector& stride_order) -> Tensor { - FUSER_PERF_SCOPE("Operators.stride_order"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - NVF_CHECK( - arg.dims == stride_order.size(), - "Operator stride_order expects `stride_order` argument to have the " - "same length as input!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new DimsOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - std::move(stride_order), - "ops.stride_order")); - return output; - }, - py::arg("arg"), - py::arg("stride_order"), - py::return_value_policy::reference); - -// rand_like and randn_like are normally used with a single TensorView argument, -// like a UnaryOp. However, they also take an optional pair (rng_seed, -// rng_offset) which converts them to deterministic ops. When those args are -// provided, and they must both be provided if either is, then the op behaves -// like a ternary op. We handle the UnaryOp case above and the TernaryOp case -// here. -#define NVFUSER_PYTHON_BINDING_TERNARY_RANDOM_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor input, \ - Scalar rng_seed, \ - Scalar rng_offset) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(input.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(input()), \ - fd->recordingState(rng_seed()), \ - fd->recordingState(rng_offset())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_TV_VAL_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::arg("arg"), \ - py::kw_only(), \ - py::arg("rng_seed"), \ - py::arg("rng_offset"), \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_TERNARY_RANDOM_OP("rand_like", rand_like) - NVFUSER_PYTHON_BINDING_TERNARY_RANDOM_OP("randn_like", randn_like) - -#undef NVFUSER_PYTHON_BINDING_UNARY_RANDOM_OP - -#define NVFUSER_PYTHON_BINDING_UNARY_OP_SPECIAL(op_str, op_name) \ - tensor_class.def( \ - "__" op_str "__", \ - [](Tensor input) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - FusionDefinition* fd = input.fusion_definition; \ - NVF_CHECK( \ - !fd->completed(), "Attempting to add to a completed definition!"); \ - Tensor output = fd->defineTensor(input.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(input())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Unary_TV, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - scalar_class.def( \ - "__" op_str "__", \ - [](Scalar input) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - FusionDefinition* fd = input.fusion_definition; \ - NVF_CHECK( \ - !fd->completed(), "Attempting to add to a completed definition!"); \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(input())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Unary_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - NVFUSER_PYTHON_BINDING_UNARY_OP_SPECIAL("abs", abs) - NVFUSER_PYTHON_BINDING_UNARY_OP_SPECIAL("neg", neg) -#undef NVFUSER_PYTHON_BINDING_UNARY_OP_SPECIAL - -#define NVFUSER_PYTHON_BINDING_MATMUL_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Tensor arg2) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - /* Per https://pytorch.org/docs/stable/generated/torch.matmul.html */ \ - size_t out_ndims; \ - if (arg1.dims <= 2 && arg2.dims <= 2) { \ - out_ndims = arg1.dims + arg2.dims - 2; \ - } else { \ - /* batch matmul */ \ - out_ndims = std::max(arg1.dims, arg2.dims); \ - } \ - Tensor output = fd->defineTensor(out_ndims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_TV, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - NVFUSER_PYTHON_BINDING_MATMUL_OP("matmul", matmul) -#undef NVFUSER_PYTHON_BINDING_MATMUL_OP - - nvf_ops.def( - "linear", - [](FusionDefinition::Operators& self, - Tensor arg1, - Tensor arg2, - std::optional bias = std::nullopt) -> Tensor { - FUSER_PERF_SCOPE("Operators.linear"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - // See newForLinear for how the output rank is computed. - Tensor output = fd->defineTensor(arg1.dims + arg2.dims - 2); - - if (bias.has_value()) { - fd->defineRecord( - new OpRecord( - {fd->recordingState(arg1()), - fd->recordingState(arg2()), - fd->recordingState(bias.value()())}, - {fd->recordingState(output())}, - ("ops.linear"), - serde::RecordType::Ternary_TV, - static_cast< - TensorView* (*)(TensorView*, TensorView*, TensorView*)>( - linear))); - } else { - fd->defineRecord(new OpRecord( - {fd->recordingState(arg1()), fd->recordingState(arg2())}, - {fd->recordingState(output())}, - ("ops.linear"), - serde::RecordType::Binary_TV, - static_cast(linear))); - } - return output; - }, - py::arg("arg1"), - py::arg("arg2"), - py::arg("bias") = std::nullopt, - py::return_value_policy::reference); - -#define NVFUSER_PYTHON_BINDING_BINARY_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Tensor arg2) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_TV, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Scalar arg2) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_TV_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Tensor arg2) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg2.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_VAL_TV, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Scalar arg2) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_BINARY_OP("add", add) - NVFUSER_PYTHON_BINDING_BINARY_OP("atan2", atan2) - NVFUSER_PYTHON_BINDING_BINARY_OP("div", div) - NVFUSER_PYTHON_BINDING_BINARY_OP("truediv", truediv) - NVFUSER_PYTHON_BINDING_BINARY_OP("fmod", fmod) - NVFUSER_PYTHON_BINDING_BINARY_OP("mul", mul) - NVFUSER_PYTHON_BINDING_BINARY_OP("nextafter", nextafter) - NVFUSER_PYTHON_BINDING_BINARY_OP("pow", pow) - NVFUSER_PYTHON_BINDING_BINARY_OP("remainder", remainder) - NVFUSER_PYTHON_BINDING_BINARY_OP("sub", sub) - NVFUSER_PYTHON_BINDING_BINARY_OP("minimum", minimum) - NVFUSER_PYTHON_BINDING_BINARY_OP("maximum", maximum) - NVFUSER_PYTHON_BINDING_BINARY_OP("mod", mod) - NVFUSER_PYTHON_BINDING_BINARY_OP("eq", eq) - NVFUSER_PYTHON_BINDING_BINARY_OP("ge", ge) - NVFUSER_PYTHON_BINDING_BINARY_OP("gt", gt) - NVFUSER_PYTHON_BINDING_BINARY_OP("le", le) - NVFUSER_PYTHON_BINDING_BINARY_OP("lt", lt) - NVFUSER_PYTHON_BINDING_BINARY_OP("ne", ne) - NVFUSER_PYTHON_BINDING_BINARY_OP("logical_and", logical_and) - NVFUSER_PYTHON_BINDING_BINARY_OP("logical_or", logical_or) - NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_and", bitwise_and) - NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_or", bitwise_or) - NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_xor", bitwise_xor) - NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_left_shift", bitwise_left_shift) - NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_right_shift", bitwise_right_shift) - NVFUSER_PYTHON_BINDING_BINARY_OP("logical_right_shift", logical_right_shift) - NVFUSER_PYTHON_BINDING_BINARY_OP("gcd", gcd) -#undef NVFUSER_PYTHON_BINDING_BINARY_OP - -#define NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL(py_op, op_str, op_name) \ - tensor_class.def( \ - py_op, \ - [](Tensor arg1, Tensor arg2) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - FusionDefinition* fd = arg1.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_TV, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - tensor_class.def( \ - py_op, \ - [](Tensor arg1, Scalar arg2) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - FusionDefinition* fd = arg1.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_TV_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - scalar_class.def( \ - py_op, \ - [](Scalar arg1, Tensor arg2) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - FusionDefinition* fd = arg1.fusion_definition; \ - Tensor output = fd->defineTensor(arg2.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_VAL_TV, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - scalar_class.def( \ - py_op, \ - [](Scalar arg1, Scalar arg2) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - FusionDefinition* fd = arg2.fusion_definition; \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Binary_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__add__", "add", add) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__mul__", "mul", mul) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__pow__", "pow", pow) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__sub__", "sub", sub) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__mod__", "mod", mod) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__eq__", "eq", eq) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__ge__", "ge", ge) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__gt__", "gt", gt) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__le__", "le", le) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__lt__", "lt", lt) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__ne__", "ne", ne) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL( - "__and__", "bitwise_and", bitwise_and) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__or__", "bitwise_or", bitwise_or) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL( - "__xor__", "bitwise_xor", bitwise_xor) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL( - "__lshift__", "bitwise_left_shift", bitwise_left_shift) - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL( - "__rshift__", "bitwise_right_shift", bitwise_right_shift) - // In python, __truediv__ (/) always returns a float regardless of whether - // the input arguments are float or integer. __truediv__ (/) corresponds with - // pytorch torch.true_divide(a, b). The __div__ operator is deprecated in - // python 3. - // - // In nvfuser, truediv function in csrc/ops/arith.h has the same semantics as - // python's operator __truediv__ (/). The div function in csrc/ops/arith.h - // truncates the result instead of promoting it to float. It has the same - // semantics as the C++'s (/) operator. In pytorch, - // torch.div(a, b, rounding_mode='trunc') corresponds C-style integer - // division. - // - // Hence, in the python frontend, the __truediv__ (/) python operator maps to - // trunc division. - NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL("__truediv__", "div", div) -#undef NVFUSER_PYTHON_BINDING_BINARY_OP_SPECIAL - -#define NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Tensor arg2, \ - Scalar arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_TV_TV_VAL, \ - static_cast( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Scalar arg2, \ - Scalar arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_TV_VAL_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Tensor arg2, \ - Scalar arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg2.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_VAL_TV_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Scalar arg2, \ - Scalar arg3) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("add_alpha", add_alpha) - NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("sub_alpha", sub_alpha) -#undef NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP - -#define NVFUSER_PYTHON_BINDING_TERNARY_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Scalar arg2, \ - Scalar arg3) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Tensor arg2, \ - Tensor arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_TV, \ - static_cast< \ - TensorView* (*)(TensorView*, TensorView*, TensorView*)>( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Tensor arg2, \ - Scalar arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_TV_TV_VAL, \ - static_cast( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Scalar arg2, \ - Tensor arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_TV_VAL_TV, \ - static_cast( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Tensor arg2, \ - Tensor arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg2.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_VAL_TV_TV, \ - static_cast( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Scalar arg2, \ - Tensor arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg3.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_VAL_VAL_TV, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Scalar arg2, \ - Scalar arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_TV_VAL_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Tensor arg2, \ - Scalar arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg2.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_VAL_TV_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_TERNARY_OP("lerp", lerp) - NVFUSER_PYTHON_BINDING_TERNARY_OP("where", where) -#undef NVFUSER_PYTHON_BINDING_TERNARY_OP - -#define NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Scalar arg2, \ - Scalar arg3) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - !self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Scalar arg2, \ - Scalar arg3) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - !self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_TV_VAL_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("clamp", clamp) - NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("threshold", threshold) -#undef NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP - -#define NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Scalar arg2, \ - Scalar arg3, \ - Scalar arg4) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3()), \ - fd->recordingState(arg4())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_Alpha_VAL, \ - static_cast(op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Tensor arg2, \ - Tensor arg3, \ - Scalar arg4) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord(new OpRecord< \ - TensorView*, \ - TensorView*, \ - TensorView*, \ - TensorView*, \ - Val*>( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3()), \ - fd->recordingState(arg4())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_Alpha_TV, \ - static_cast< \ - TensorView* (*)(TensorView*, TensorView*, TensorView*, Val*)>( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Tensor arg2, \ - Scalar arg3, \ - Scalar arg4) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3()), \ - fd->recordingState(arg4())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_Alpha_TV_TV_VAL, \ - static_cast< \ - TensorView* (*)(TensorView*, TensorView*, Val*, Val*)>( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Scalar arg2, \ - Tensor arg3, \ - Scalar arg4) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3()), \ - fd->recordingState(arg4())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_Alpha_TV_VAL_TV, \ - static_cast< \ - TensorView* (*)(TensorView*, Val*, TensorView*, Val*)>( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Tensor arg2, \ - Tensor arg3, \ - Scalar arg4) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg2.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3()), \ - fd->recordingState(arg4())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_Alpha_VAL_TV_TV, \ - static_cast< \ - TensorView* (*)(Val*, TensorView*, TensorView*, Val*)>( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Scalar arg2, \ - Tensor arg3, \ - Scalar arg4) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg3.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3()), \ - fd->recordingState(arg4())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_Alpha_VAL_VAL_TV, \ - static_cast( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg1, \ - Scalar arg2, \ - Scalar arg3, \ - Scalar arg4) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg1.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3()), \ - fd->recordingState(arg4())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_Alpha_TV_VAL_VAL, \ - static_cast( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg1, \ - Tensor arg2, \ - Scalar arg3, \ - Scalar arg4) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg2.dims); \ - fd->defineRecord( \ - new OpRecord( \ - {fd->recordingState(arg1()), \ - fd->recordingState(arg2()), \ - fd->recordingState(arg3()), \ - fd->recordingState(arg4())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::Ternary_Alpha_VAL_TV_VAL, \ - static_cast( \ - op_name))); \ - return output; \ - }, \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul) -#undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP - -#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = 0; \ - std::vector dims(arg.dims); \ - std::iota(dims.begin(), dims.end(), 0); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - dims, \ - false, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - int dim, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - {dim}, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dim"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - const std::vector& dims, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - dims.size()); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - dims, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dims"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_REDUCTION_OP( - "max", max, serde::RecordType::ReductionMax) - NVFUSER_PYTHON_BINDING_REDUCTION_OP( - "min", min, serde::RecordType::ReductionMin) - NVFUSER_PYTHON_BINDING_REDUCTION_OP( - "prod", prod, serde::RecordType::ReductionProd) - NVFUSER_PYTHON_BINDING_REDUCTION_OP( - "sum", sum, serde::RecordType::ReductionSum) -#undef NVFUSER_PYTHON_BINDING_REDUCTION_OP - -#define NVFUSER_PYTHON_BINDING_CAST_OP(op_str, op_name) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Tensor output = fd->defineTensor(arg.dims); \ - fd->defineRecord(new CastOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::CastTv, \ - static_cast(op_name), \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dtype"), \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Scalar arg, \ - PrimDataType dtype) -> Scalar { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - Scalar output = fd->defineScalar(); \ - fd->defineRecord(new CastOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - serde::RecordType::CastVal, \ - static_cast(op_name), \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dtype"), \ - py::return_value_policy::reference); - - NVFUSER_PYTHON_BINDING_CAST_OP("cast", castOp) -#undef NVFUSER_PYTHON_BINDING_CAST_OP - -#define NVFUSER_ALL_VECTOR_TYPES(fn, ...) \ - fn(Vector, __VA_ARGS__); \ - fn(py::list, __VA_ARGS__); \ - fn(py::tuple, __VA_ARGS__); - -#define NVFUSER_RANDOM_DIST_OP_HELPER( \ - vec_type, op_str, op_type, arg1_str, arg2_str) \ - nvf_ops.def( \ - op_str, \ - random_dist_op_fn, \ - py::arg(arg1_str), \ - py::arg(arg2_str), \ - py::arg("shape"), \ - py::kw_only(), \ - py::arg("rng_seed") = py::none(), \ - py::arg("rng_offset") = py::none(), \ - py::arg("dtype") = DataType::Float, \ - py::return_value_policy::reference); - -#define NVFUSER_PYTHON_BINDING_RANDOM_DIST_OP(...) \ - NVFUSER_ALL_VECTOR_TYPES(NVFUSER_RANDOM_DIST_OP_HELPER, __VA_ARGS__) - - NVFUSER_PYTHON_BINDING_RANDOM_DIST_OP( - "normal", serde::RecordType::NormalDistOp, "mean", "std") - NVFUSER_PYTHON_BINDING_RANDOM_DIST_OP( - "uniform", serde::RecordType::UniformDistOp, "minval", "maxval") -#undef NVFUSER_PYTHON_BINDING_RANDOM_DIST_OP -#undef NVFUSER_RANDOM_DIST_OP_HELPER - -#define NVFUSER_FULL_OP_HELPER(vec_type, ...) \ - nvf_ops.def( \ - "full", \ - full_op_fn, \ - py::arg("shape"), \ - py::arg("fill_value"), \ - py::arg("dtype"), \ - py::return_value_policy::reference); - - // NOTE: The second argument is a dummy to satisfy the macro - NVFUSER_ALL_VECTOR_TYPES(NVFUSER_FULL_OP_HELPER, false) -#undef NVFUSER_FULL_OP_HELPER - - nvf_ops.def( - "batch_norm", - [](FusionDefinition::Operators& self, - Tensor arg, - std::optional weight, - std::optional bias, - std::optional running_mean, - std::optional running_var, - Scalar momentum, - Scalar eps, - bool training, - bool channels_last) -> decltype(auto) { - FUSER_PERF_SCOPE("Operators.batch_norm"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims); - Tensor mean = fd->defineTensor(1); - Tensor invstd = fd->defineTensor(1); - auto weight_state = weight.has_value() - ? fd->recordingState(weight.value()()) - : State(0, serde::StateType::None); - auto bias_state = bias.has_value() ? fd->recordingState(bias.value()()) - : State(0, serde::StateType::None); - auto running_mean_state = running_mean.has_value() - ? fd->recordingState(running_mean.value()()) - : State(0, serde::StateType::None); - auto running_var_state = running_var.has_value() - ? fd->recordingState(running_var.value()()) - : State(0, serde::StateType::None); - fd->defineRecord(new BatchNormOpRecord( - {fd->recordingState(arg()), - weight_state, - bias_state, - running_mean_state, - running_var_state, - fd->recordingState(momentum()), - fd->recordingState(eps())}, - {fd->recordingState(output()), - fd->recordingState(mean()), - fd->recordingState(invstd())}, - training, - channels_last)); - return std::make_tuple(output, mean, invstd); - }, - py::arg("arg"), - py::arg("weight").none(true), - py::arg("bias").none(true), - py::arg("running_mean").none(true), - py::arg("running_var").none(true), - py::arg("momentum"), - py::arg("eps"), - py::arg("training"), - py::arg("channels_last") = false, - py::return_value_policy::reference); - nvf_ops.def( - "broadcast_in_dim", - broadcast_in_dim_fn, - py::arg("arg"), - py::arg("shape"), - py::arg("broadcast_dims"), - py::return_value_policy::reference); - nvf_ops.def( - "broadcast_in_dim", - broadcast_in_dim_fn, - py::arg("arg"), - py::arg("shape"), - py::arg("broadcast_dims"), - py::return_value_policy::reference); - // NOTE: Tuple support was added to facilitate the direct usage of Pytorch's - // Tensor.size() function that returns a child class of a Tuple. - nvf_ops.def( - "broadcast_in_dim", - broadcast_in_dim_fn, - py::arg("arg"), - py::arg("shape"), - py::arg("broadcast_dims"), - py::return_value_policy::reference); - nvf_ops.def( - "broadcast", - [](FusionDefinition::Operators& self, - Tensor arg, - std::vector& is_broadcast_dim) -> Tensor { - FUSER_PERF_SCOPE("Operators.broadcast"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(is_broadcast_dim.size()); - fd->defineRecord(new BroadcastOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - "ops.broadcast", - std::move(is_broadcast_dim))); - return output; - }, - py::arg("arg"), - py::arg("is_broadcast_dim"), - py::return_value_policy::reference); - nvf_ops.def( - "cat", - [](FusionDefinition::Operators& self, - std::vector tensors, - int64_t dim, - bool manual_padding) -> Tensor { - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - NVF_CHECK( - !tensors.empty(), "Attempting to concatenate empty list of tensors") - Tensor output = fd->defineTensor(tensors[0].dims); - std::vector tensor_states; - tensor_states.reserve(tensors.size()); - for (auto& t : tensors) { - tensor_states.push_back(fd->recordingState(t())); - } - self.fusion_definition->defineRecord(new CatOpRecord( - tensor_states, - {fd->recordingState(output())}, - dim, - manual_padding)); - return output; - }, - py::arg("tensors"), - py::arg("dim") = 0, - py::arg("manual_padding") = false, - py::return_value_policy::reference); - nvf_ops.def( - "expand", - expand_fn, - py::arg("arg"), - py::arg("shape"), - py::return_value_policy::reference); - nvf_ops.def( - "expand", - expand_fn, - py::arg("arg"), - py::arg("shape"), - py::return_value_policy::reference); - // NOTE: Tuple support was added to facilitate the direct usage of Pytorch's - // Tensor.size() function that returns a child class of a Tuple. - nvf_ops.def( - "expand", - expand_fn, - py::arg("arg"), - py::arg("shape"), - py::return_value_policy::reference); - nvf_ops.def( - "index_select", - [](FusionDefinition::Operators& self, - Tensor arg, - Tensor index, - int64_t dim) -> Tensor { - FUSER_PERF_SCOPE("Operators.index_select"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new IndexSelectOpRecord( - { - fd->recordingState(arg()), - fd->recordingState(index()), - }, - {fd->recordingState(output())}, - dim)); - return output; - }, - py::arg("arg"), - py::arg("index"), - py::arg("dim"), - py::return_value_policy::reference); - nvf_ops.def( - "index_put_accumulate", - [](FusionDefinition::Operators& self, - Tensor acc, - Tensor index, - Tensor value) -> Tensor { - FUSER_PERF_SCOPE("Operators.index_put_accumulate"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(acc.dims); - fd->defineRecord(new IndexPutAccumulateOpRecord( - { - fd->recordingState(acc()), - fd->recordingState(index()), - fd->recordingState(value()), - }, - {fd->recordingState(output())})); - return output; - }, - py::arg("acc"), - py::arg("index"), - py::arg("value"), - py::return_value_policy::reference, - R"doc( - Accumulates values into a tensor at specified indices. - - This function performs a restricted version of `torch.index_put`. - It adds the values from `value_tv` to the elements of `acc_tv` at the indices - specified by `index_tv`. - - acc_tv: The tensor to accumulate into (in-place modification). - index_tv: The tensor containing the indices. - value_tv: The tensor containing the values to accumulate. - - Returns: - An alias to the modified `acc_tv` tensor. - - Note: - This is a restricted version and may not support all features of the - full `torch.index_put(..., accumulate=true)` function. - )doc"); - nvf_ops.def( - "select", - [](FusionDefinition::Operators& self, - Tensor arg, - Scalar index, - int64_t dim) -> Tensor { - FUSER_PERF_SCOPE("Operators.select"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new SelectOpRecord( - { - fd->recordingState(arg()), - fd->recordingState(index()), - }, - {fd->recordingState(output())}, - dim)); - return output; - }, - py::arg("arg"), - py::arg("index"), - py::arg("dim"), - py::return_value_policy::reference); - nvf_ops.def( - "scatter", - [](FusionDefinition::Operators& self, - Tensor arg1, - Tensor index, - Tensor src, - int64_t dim) -> Tensor { - FUSER_PERF_SCOPE("Operators.scatter"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - NVF_CHECK( - arg1.dims == index.dims && arg1.dims == src.dims, - "Tensor arguments have different dimensions ", - arg1.dims, - ", ", - index.dims, - " and ", - src.dims); - auto num_dims = (int64_t)arg1.dims; - NVF_CHECK( - dim >= -num_dims && dim < num_dims, - "Tensor arguments have dimension ", - num_dims, - " so dim argument must satisfy ", - -num_dims, - " <= dim < ", - num_dims, - ", but received ", - dim); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(num_dims); - fd->defineRecord(new ScatterOpRecord( - { - fd->recordingState(arg1()), - fd->recordingState(index()), - fd->recordingState(src()), - }, - {fd->recordingState(output())}, - dim)); - return output; - }, - py::arg("arg1"), - py::arg("index"), - py::arg("src"), - py::arg("dim"), - py::return_value_policy::reference); - nvf_ops.def( - "scatter", - [](FusionDefinition::Operators& self, - Tensor arg1, - Tensor index, - Scalar src, - int64_t dim) -> Tensor { - FUSER_PERF_SCOPE("Operators.scatter"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - auto num_dims = (int64_t)arg1.dims; - NVF_CHECK( - dim >= -num_dims && dim < num_dims, - "Tensor arguments have dimension ", - num_dims, - " so dim argument must satisfy ", - -num_dims, - " <= dim < ", - num_dims, - ", but received ", - dim); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(num_dims); - fd->defineRecord(new ScatterOpRecord( - { - fd->recordingState(arg1()), - fd->recordingState(index()), - fd->recordingState(src()), - }, - {fd->recordingState(output())}, - dim)); - return output; - }, - py::arg("arg1"), - py::arg("index"), - py::arg("src"), - py::arg("dim"), - py::return_value_policy::reference); - nvf_ops.def( - "gather", - [](FusionDefinition::Operators& self, - Tensor arg1, - Tensor index, - int64_t dim) -> Tensor { - FUSER_PERF_SCOPE("Operators.gather"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - NVF_CHECK( - arg1.dims == index.dims, - "Tensor arguments have different dimensions ", - arg1.dims, - " and ", - index.dims); - auto num_dims = (int64_t)arg1.dims; - NVF_CHECK( - dim >= -num_dims && dim < num_dims, - "Tensor arguments have dimension ", - num_dims, - " so dim argument must satisfy ", - -num_dims, - " <= dim < ", - num_dims, - ", but received ", - dim); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg1.dims); - fd->defineRecord(new GatherOpRecord( - { - fd->recordingState(arg1()), - fd->recordingState(index()), - }, - {fd->recordingState(output())}, - dim)); - return output; - }, - R"pbdoc( - Index arg1 in dim at positions given by index. - - The dimension of arg1 and index must match. For all axes other than dim - the extent of index in that axis need not be equal to its counterpart - in arg1 but must not be greater than it. - - Args: - arg1 (Tensor): Tensor of shape `(Ni...,M,Nk...)` where `M` is the - extent of `arg1` in the dimension `dim`. - index (Tensor): Tensor of dtype `DataType::Int` of shape - `(Mi...,J,Mk...)` where all the extents other than `J` are less - than or equal to their counterparts in `arg1`; for example `Mk - <= Nk`. - dim (int): Which position to index along. - - Returns: - (Tensor): Tensor of same dtype as `arg1` and of shape - `(Mi...,J,Mk...)` where the element at position `(i...,j,k...)` - is equal to `arg1[i,...,index[i,...,j,k,...],k,...]`. - )pbdoc", - py::arg("arg1"), - py::arg("index"), - py::arg("dim"), - py::return_value_policy::reference); - nvf_ops.def( - "pad", - pad_fn, - py::arg("arg"), - py::arg("pad_widths"), - py::arg("value") = py::none(), - py::return_value_policy::reference); - nvf_ops.def( - "pad", - pad_fn, - py::arg("arg"), - py::arg("pad_widths"), - py::arg("value") = py::none(), - py::return_value_policy::reference); - nvf_ops.def( - "pad", - pad_fn, - py::arg("arg"), - py::arg("pad_widths"), - py::arg("value") = py::none(), - py::return_value_policy::reference); - nvf_ops.def( - "take_along_axis", - [](FusionDefinition::Operators& self, - Tensor arg1, - Tensor index, - int64_t dim) -> Tensor { - FUSER_PERF_SCOPE("Operators.take_along_axis"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - NVF_CHECK( - arg1.dims == index.dims, - "Tensor arguments have different dimensions ", - arg1.dims, - " and ", - index.dims); - auto num_dims = (int64_t)arg1.dims; - NVF_CHECK( - dim >= -num_dims && dim < num_dims, - "Tensor arguments have dimension ", - num_dims, - " so dim argument must satisfy ", - -num_dims, - " <= dim < ", - num_dims, - ", but received ", - dim); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg1.dims); - fd->defineRecord(new TakeAlongAxisOpRecord( - { - fd->recordingState(arg1()), - fd->recordingState(index()), - }, - {fd->recordingState(output())}, - dim)); - return output; - }, - R"pbdoc( - Index arg1 in dim at positions given by index. - - This operation is very similar to :meth:'gather' but enforces that all - dimensions other than dim must be equal between arg1 and index. - - Args: - arg1 (Tensor): Tensor of shape `(Ni...,M,Nk...)` where `M` is the - extent of `arg1` in the dimension `dim`. - index (Tensor): Tensor of dtype `DataType::Int` of shape - `(Ni...,J,Nk...)`. - dim (int): Which position to index along. - - Returns: - (Tensor): Tensor of same dtype as `arg1` and of shape - `(Ni...,J,Nk...)` where the element at position `(i...,j,k...)` - is equal to `arg1[i,...,index[i,...,j,k,...],k,...]`. - )pbdoc", - py::arg("arg1"), - py::arg("index"), - py::arg("dim"), - py::return_value_policy::reference); - nvf_ops.def( - "permute", - [](FusionDefinition::Operators& self, - Tensor arg, - std::vector& dims) -> Tensor { - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - NVF_CHECK( - arg.dims == dims.size(), - "Operator permute expects `dims` argument to have the same length " - "as input!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims); - self.fusion_definition->defineRecord( - new DimsOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - std::move(dims), - "ops.permute")); - return output; - }, - py::arg("arg"), - py::arg("dims"), - py::return_value_policy::reference); - - auto shape_def = [](Tensor arg) -> Vector { - FUSER_PERF_SCOPE("Operators.shape"); - auto fd = arg.fusion_definition; - NVF_CHECK( - fd->ops.validUse(), "Attempting to add to a completed definition!"); - Vector output = fd->defineVector(arg.dims); - fd->defineRecord(new ShapeOpRecord( - {fd->recordingState(arg())}, {fd->recordingState(output())})); - return output; - }; - - tensor_class.def( - "shape", - [&shape_def](Tensor arg) -> Vector { return shape_def(arg); }, - py::return_value_policy::reference); - nvf_ops.def( - "shape", - [&shape_def](FusionDefinition::Operators& self, Tensor arg) -> Vector { - return shape_def(arg); - }, - py::arg("arg"), - py::return_value_policy::reference); - - auto size_def = [](Tensor arg, int64_t dim) -> Scalar { - FUSER_PERF_SCOPE("Operators.size"); - auto fd = arg.fusion_definition; - NVF_CHECK( - fd->ops.validUse(), "Attempting to add to a completed definition!"); - Scalar output = fd->defineScalar(); - fd->defineRecord(new SizeOpRecord( - {fd->recordingState(arg())}, {fd->recordingState(output())}, dim)); - return output; - }; - - tensor_class.def( - "size", - [&size_def](Tensor arg, int64_t dim) -> Scalar { - return size_def(arg, dim); - }, - py::return_value_policy::reference); - nvf_ops.def( - "size", - [&size_def](FusionDefinition::Operators& self, Tensor arg, int64_t dim) - -> Scalar { return size_def(arg, dim); }, - py::arg("arg"), - py::arg("dim"), - py::return_value_policy::reference); - - auto at_def = [](Vector arg, int64_t index) -> Scalar { - FUSER_PERF_SCOPE("Operators.at"); - auto fd = arg.fusion_definition; - NVF_CHECK( - fd->ops.validUse(), "Attempting to add to a completed definition!"); - Scalar output = fd->defineScalar(); - fd->defineRecord(new AtOpRecord( - {fd->recordingState(arg())}, {fd->recordingState(output())}, index)); - return output; - }; - - vector_class.def( - "at", - [&at_def](Vector arg, int64_t index) -> Scalar { - return at_def(arg, index); - }, - py::return_value_policy::reference); - vector_class.def( - "__getitem__", - [&at_def](Vector arg, int64_t index) -> Scalar { - return at_def(arg, index); - }, - py::return_value_policy::reference); - nvf_ops.def( - "at", - [&at_def](FusionDefinition::Operators& self, Vector arg, int64_t index) - -> Scalar { return at_def(arg, index); }, - py::arg("arg"), - py::arg("index"), - py::return_value_policy::reference); - - nvf_ops.def( - "slice", - slice_fn, - py::arg("arg"), - py::arg("start_indices"), - py::arg("end_indices"), - py::arg("strides") = py::none(), - py::arg("manual_normalization") = false, - py::return_value_policy::reference); - nvf_ops.def( - "slice", - slice_fn, - py::arg("arg"), - py::arg("start_indices"), - py::arg("end_indices"), - py::arg("strides") = py::none(), - py::arg("manual_normalization") = false, - py::return_value_policy::reference); - nvf_ops.def( - "slice", - slice_fn, - py::arg("arg"), - py::arg("start_indices"), - py::arg("end_indices"), - py::arg("strides") = py::none(), - py::arg("manual_normalization") = false, - py::return_value_policy::reference); - nvf_ops.def( - "squeeze", - [](FusionDefinition::Operators& self, - Tensor arg, - std::vector dims, - const bool squeeze_expanded) -> Tensor { - FUSER_PERF_SCOPE("Operators.squeeze"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims - dims.size()); - fd->defineRecord(new SqueezeOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - std::move(dims), - squeeze_expanded)); - return output; - }, - py::arg("arg"), - py::arg("dims"), - py::arg("squeeze_expanded") = false, - py::return_value_policy::reference); - nvf_ops.def( - "tensor_sizes", - [](FusionDefinition::Operators& self, Tensor arg) -> std::vector { - FUSER_PERF_SCOPE("Operators.tensor_sizes"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - std::vector outputs; - std::vector output_state; - for (const auto idx : arange(arg.dims)) { - outputs.push_back(fd->defineScalar()); - output_state.push_back(fd->recordingState(outputs[idx]())); - } - fd->defineRecord( - new TensorSizesRecord({fd->recordingState(arg())}, output_state)); - return outputs; - }, - py::arg("arg"), - py::return_value_policy::reference); - nvf_ops.def( - "reshape", - reshape_fn, - py::arg("arg"), - py::arg("new_shape"), - py::return_value_policy::reference); - nvf_ops.def( - "reshape", - reshape_fn, - py::arg("arg"), - py::arg("new_shape"), - py::return_value_policy::reference); - nvf_ops.def( - "reshape", - reshape_fn, - py::arg("arg"), - py::arg("new_shape"), - py::return_value_policy::reference); - nvf_ops.def( - "iota", - [](FusionDefinition::Operators& self, - Scalar length, - std::optional start, - std::optional step, - PrimDataType dtype) -> Tensor { - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(1); - auto start_state = start.has_value() - ? fd->recordingState(start.value()()) - : State(0, serde::StateType::None); - auto step_state = step.has_value() ? fd->recordingState(step.value()()) - : State(0, serde::StateType::None); - fd->defineRecord(new IotaOpRecord( - {fd->recordingState(length()), start_state, step_state}, - {fd->recordingState(output())}, - dtype)); - return output; - }, - py::arg("length"), - py::arg("start").none(true), - py::arg("step").none(true), - py::arg("dtype") = DataType::Int, - py::return_value_policy::reference); - nvf_ops.def( - "var", - [](FusionDefinition::Operators& self, - Tensor arg, - std::vector& dims, - int64_t correction, - bool keepdim) -> Tensor { - FUSER_PERF_SCOPE("Operators.var"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - size_t ndims = keepdim ? arg.dims : (arg.dims - dims.size()); - Tensor output = fd->defineTensor(ndims); - fd->defineRecord(new VarianceOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - std::move(dims), - correction, - keepdim)); - return output; - }, - py::arg("arg"), - py::arg("dims"), - py::arg("correction"), - py::arg("keepdim") = false, - py::return_value_policy::reference); - nvf_ops.def( - "var_mean", - [](FusionDefinition::Operators& self, - Tensor arg, - std::vector& dims, - int64_t correction, - bool keepdim) -> decltype(auto) { - FUSER_PERF_SCOPE("Operators.var_mean"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - size_t ndims = keepdim ? arg.dims : (arg.dims - dims.size()); - Tensor var = fd->defineTensor(ndims); - Tensor mean = fd->defineTensor(ndims); - fd->defineRecord(new VarianceMeanOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(var()), fd->recordingState(mean())}, - std::move(dims), - correction, - keepdim)); - return std::make_tuple(var, mean); - }, - py::arg("arg"), - py::arg("dims"), - py::arg("correction") = 1, - py::arg("keepdim") = false, - py::return_value_policy::reference); - nvf_ops.def( - "welford", - [](FusionDefinition::Operators& self, - Tensor arg, - const std::vector& dims) -> decltype(auto) { - FUSER_PERF_SCOPE("Operators.welford"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - size_t ndims = (arg.dims - dims.size()); - Tensor avg = fd->defineTensor(ndims); - Tensor var_sum = fd->defineTensor(ndims); - Tensor n = fd->defineTensor(ndims); - fd->defineRecord(new WelfordOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(avg()), - fd->recordingState(var_sum()), - fd->recordingState(n())}, - dims)); - return std::make_tuple(avg, var_sum, n); - }, - py::arg("arg"), - py::arg("dims"), - py::return_value_policy::reference); - nvf_ops.def( - "sdpfa_bwd", - [](FusionDefinition::Operators& self, - Tensor grad_output, - Tensor query, - Tensor key, - Tensor value, - Tensor output, - Tensor logsumexp, - std::optional dropout_p, - std::optional is_causal, - Tensor philox_seed, - Tensor philox_offset, - std::optional scale) -> decltype(auto) { - FUSER_PERF_SCOPE("Operators.sdpfa_bwd"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - size_t ndims = query.dims; - Tensor grad_query = fd->defineTensor(/*dims=*/ndims); - Tensor grad_key = fd->defineTensor(/*dims=*/ndims); - Tensor grad_value = fd->defineTensor(/*dims=*/ndims); - - auto dropout_p_state = dropout_p.has_value() - ? fd->recordingState(dropout_p.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto is_causal_state = is_causal.has_value() - ? fd->recordingState(is_causal.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto scale_state = scale.has_value() - ? fd->recordingState(scale.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - fd->defineRecord(new SdpaBwdOpRecord( - {fd->recordingState(grad_output()), - fd->recordingState(query()), - fd->recordingState(key()), - fd->recordingState(value()), - fd->recordingState(output()), - fd->recordingState(logsumexp()), - dropout_p_state, - is_causal_state, - fd->recordingState(philox_seed()), - fd->recordingState(philox_offset()), - scale_state}, - {fd->recordingState(grad_query()), - fd->recordingState(grad_key()), - fd->recordingState(grad_value())})); - return std::make_tuple(grad_query, grad_key, grad_value); - }, - py::arg("grad_output"), - py::arg("query"), - py::arg("key"), - py::arg("value"), - py::arg("output"), - py::arg("logsumexp"), - py::arg("dropout_p").none(true) = py::none(), - py::arg("is_causal").none(true) = py::none(), - py::arg("philox_seed"), - py::arg("philox_offset"), - py::arg("scale").none(true) = py::none(), - py::return_value_policy::reference); - - nvf_ops.def( - "sdpfa_fwd", - [](FusionDefinition::Operators& self, - Tensor query, - Tensor key, - Tensor value, - std::optional bias, - std::optional mask, - std::optional dropout_p, - std::optional is_causal, - std::optional scale) -> decltype(auto) { - FUSER_PERF_SCOPE("Operators.sdpfa_fwd"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - size_t ndims = query.dims; - Tensor output = fd->defineTensor(/*dims=*/ndims); - Tensor logsumexp = fd->defineTensor(/*dims=*/ndims - 1); -#if NVF_TORCH_VERSION_NO_LESS(2, 7, 0) - int64_t philox_ndims = 1; -#else - int64_t philox_ndims = 0; -#endif - Tensor philox_seed = fd->defineTensor(philox_ndims); - Tensor philox_offset = fd->defineTensor(/*dims=*/0); - - auto bias_state = bias.has_value() - ? fd->recordingState(bias.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto mask_state = mask.has_value() - ? fd->recordingState(mask.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto dropout_p_state = dropout_p.has_value() - ? fd->recordingState(dropout_p.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto is_causal_state = is_causal.has_value() - ? fd->recordingState(is_causal.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto scale_state = scale.has_value() - ? fd->recordingState(scale.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - fd->defineRecord(new SdpaFwdOpRecord( - {fd->recordingState(query()), - fd->recordingState(key()), - fd->recordingState(value()), - bias_state, - mask_state, - dropout_p_state, - is_causal_state, - scale_state}, - {fd->recordingState(output()), - fd->recordingState(logsumexp()), - fd->recordingState(philox_seed()), - fd->recordingState(philox_offset())})); - return std::make_tuple(output, logsumexp, philox_seed, philox_offset); - }, - py::arg("query"), - py::arg("key"), - py::arg("value"), - py::arg("bias").none(true) = py::none(), - py::arg("mask").none(true) = py::none(), - py::arg("dropout_p").none(true) = py::none(), - py::arg("is_causal").none(true) = py::none(), - py::arg("scale").none(true) = py::none(), - py::return_value_policy::reference); - - nvf_ops.def( - "embedding_fwd", - [](FusionDefinition::Operators& self, - Tensor input, - Tensor weight, - std::optional padding_idx, - std::optional max_norm, - std::optional norm_type, - std::optional scale_grad_by_freq, - std::optional sparse) -> decltype(auto) { - FUSER_PERF_SCOPE("Operators.embedding_fwd"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - size_t ndims = input.dims + 1; - Tensor output = fd->defineTensor(/*dims=*/ndims); - - auto padding_idx_state = padding_idx.has_value() - ? fd->recordingState(padding_idx.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto max_norm_state = max_norm.has_value() - ? fd->recordingState(max_norm.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto norm_type_state = norm_type.has_value() - ? fd->recordingState(norm_type.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto scale_grad_by_freq_state = scale_grad_by_freq.has_value() - ? fd->recordingState(scale_grad_by_freq.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - auto sparse_state = sparse.has_value() - ? fd->recordingState(sparse.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - fd->defineRecord(new EmbeddingFwdOpRecord( - {fd->recordingState(input()), - fd->recordingState(weight()), - padding_idx_state, - max_norm_state, - norm_type_state, - scale_grad_by_freq_state, - sparse_state}, - {fd->recordingState(output())})); - return output; - }, - py::arg("input"), - py::arg("weight"), - py::arg("padding_idx").none(true) = py::none(), - py::arg("max_norm").none(true) = py::none(), - py::arg("norm_type").none(true) = py::none(), - py::arg("scale_grad_by_freq").none(true) = py::none(), - py::arg("sparse").none(true) = py::none(), - py::return_value_policy::reference); - - nvf_ops.def( - "argsort", - [](FusionDefinition::Operators& self, - Tensor arg, - int64_t dim, - bool descending, - bool stable) -> Tensor { - FUSER_PERF_SCOPE("Operators.argsort"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new ArgsortOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - dim, - descending, - stable)); - return output; - }, - py::arg("arg"), - py::arg("dim"), - py::arg("descending") = false, - py::arg("stable") = false, - py::return_value_policy::reference); - - nvf_ops.def( - "cumsum", - [](FusionDefinition::Operators& self, Tensor arg, int64_t dim) -> Tensor { - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new ScanOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - ("ops.cumsum"), - serde::RecordType::ScanOpCumsum, - static_cast(cumsum), - dim, - BinaryOpType::Add)); - - return output; - }, - py::arg("arg"), - py::arg("dim"), - py::return_value_policy::reference, - R"doc( - Computes the cumulative sum of elements along a given dimension. - Args: - arg (Tensor): Input tensor. - dim (int): Dimension along which to compute the cumulative sum. - Returns: - Tensor: Tensor of the same shape as input with cumulative sums computed along the specified dimension. - Example: - >>> fd.ops.cumsum(tensor, dim=0) - )doc"); - - nvf_ops.def( - "grouped_mm", - [](FusionDefinition::Operators& self, - Tensor mat1, - Tensor mat2, - Tensor offsets) -> Tensor { - FUSER_PERF_SCOPE("Operators.grouped_mm"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - - // Calculate output dimensions based on mat1 & mat2 rank - size_t output_dims = mat1.dims == 2 && mat2.dims == 2 ? 3 : 2; - Tensor output = fd->defineTensor(output_dims); - fd->defineRecord( - new OpRecord( - {fd->recordingState(mat1()), - fd->recordingState(mat2()), - fd->recordingState(offsets())}, - {fd->recordingState(output())}, - ("ops.grouped_mm"), - serde::RecordType::Ternary_TV, - static_cast< - TensorView* (*)(TensorView*, TensorView*, TensorView*)>( - [](TensorView* mat1, - TensorView* mat2, - TensorView* offsets) { - ScaledTensorView scaled_out = - grouped_mm(mat1, mat2, offsets); - return scaled_out.tv; - }))); - return output; - }, - R"( - Grouped matrix multiplication. - - Performs matrix multiplication on grouped sets of matrices using offsets - to define variable-sized groups. - - Args: - mat1 (Tensor): First set of matrices - mat2 (Tensor): Second set of matrices - offsets (Tensor): Offsets tensor defining group boundaries - - Returns: - Tensor: Result of grouped matrix multiplication - )", - py::arg("mat1"), - py::arg("mat2"), - py::arg("offsets"), - py::return_value_policy::reference); - - nvf_ops.def( - "grouped_mm", - [](FusionDefinition::Operators& self, - Tensor mat1, - Tensor mat2, - Tensor offsets, - Tensor scale1, - Tensor scale2, - std::optional alpha, - std::optional bias, - std::optional beta, - PrimDataType dtype, - int64_t output_block_scale_size, - PrimDataType output_block_scale_dtype, - bool output_gamma) - -> std::tuple, std::optional> { - FUSER_PERF_SCOPE("Operators.grouped_mm"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - - // Calculate output dimensions based on mat1 & mat2 rank - size_t output_dims = mat1.dims == 2 && mat2.dims == 2 ? 3 : 2; - Tensor output = fd->defineTensor(output_dims); - std::optional out_scale = std::nullopt; - std::optional out_gamma = std::nullopt; - if (output_block_scale_size > 0) { - out_scale = fd->defineTensor(output_dims); - } - if (output_gamma) { - // TODO: would out_gamma should be a vector when both inputs are 2d. - out_gamma = fd->defineTensor(0); - } - - fd->defineRecord(new ScaledGroupedMmaOpRecord( - {fd->recordingState(mat1()), - fd->recordingState(mat2()), - fd->recordingState(offsets()), - fd->recordingState(scale1()), - fd->recordingState(scale2()), - alpha.has_value() - ? fd->recordingState(alpha.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - bias.has_value() - ? fd->recordingState(bias.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - beta.has_value() - ? fd->recordingState(beta.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None)}, - {fd->recordingState(output()), - out_scale.has_value() - ? fd->recordingState(out_scale.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - out_gamma.has_value() - ? fd->recordingState(out_gamma.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None)}, - dtype, - output_block_scale_size, - output_block_scale_dtype, - output_gamma)); - - if (output_gamma) { - NVF_CHECK( - output_block_scale_size > 0, - "output_block_scale_size must be greater than 0 when " - "output_gamma is " - "true"); - return std::make_tuple(output, out_scale, out_gamma); - } else if (output_block_scale_size > 0) { - return std::make_tuple(output, out_scale, std::nullopt); - } - return std::make_tuple(output, std::nullopt, std::nullopt); - }, - R"( - Scaled Grouped matrix multiplication. - - Performs matrix multiplication on grouped sets of matrices using offsets - to define variable-sized groups. - - Args: - mat1 (Tensor): First set of matrices - mat2 (Tensor): Second set of matrices - offsets (Tensor): Offsets tensor defining group boundaries - scale1 (Tensor): Scale tensor for mat1 - scale2 (Tensor): Scale tensor for mat2 - alpha (Tensor): Alpha tensor [optional] - bias (Tensor): Bias tensor [optional] - beta (Tensor): Beta tensor [optional] - dtype (ScalarType): Output tensor type [optional] - output_block_scale_size (int): Output block scale size [optional] - output_block_scale_dtype (ScalarType): Output block scale dtype [optional] - output_gamma (bool): Output gamma [optional, default: False] - - The math operation is roughly two steps: - out = alpha * grouped_mm(dequant(mat1, scale1), dequant(mat2, scale2), offsets) + beta * bias - - (out_mat, out_scale, out_gamma) = Quantization( - out, - dtype, - output_block_scale_size, - output_block_scale_dtype, - output_gamma) - - Note 1: The post quantization only applies when output_block_scale_size > 0, - which would produce out_scale tensor. Otherwise, None will be returned; - Note 2: When output_gamma is set to True, it should produce global scaling factor out_gamma tensor. - Otherwise, None will be returned. - - Returns: - Tensor: Result of matrix multiplication - Tensor: Output block scale tensor [optional] - Tensor: Output gamma tensor [optional] - )", - py::arg("mat1"), - py::arg("mat2"), - py::arg("offsets"), - py::arg("scale1"), - py::arg("scale2"), - py::arg("alpha") = std::nullopt, - py::arg("bias") = std::nullopt, - py::arg("beta") = std::nullopt, - py::arg("dtype") = DataType::BFloat16, - py::arg("output_block_scale_size") = 0, - py::arg("output_block_scale_dtype") = DataType::BFloat16, - py::arg("output_gamma") = false, - py::return_value_policy::reference); - - nvf_ops.def( - "scaled_mm", - [](FusionDefinition::Operators& self, - Tensor mat1, - Tensor mat2, - Tensor scale1, - Tensor scale2, - std::optional alpha, - std::optional bias, - std::optional beta, - PrimDataType dtype, - int64_t output_block_scale_size, - PrimDataType output_block_scale_dtype, - bool output_gamma) - -> std::tuple, std::optional> { - FUSER_PERF_SCOPE("Operators.scaled_mm"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - - /* Per https://pytorch.org/docs/stable/generated/torch.matmul.html */ - size_t out_ndims; - if (mat1.dims <= 2 && mat2.dims <= 2) { - out_ndims = mat1.dims + mat2.dims - 2; - } else { - /* batch matmul */ - out_ndims = std::max(mat1.dims, mat2.dims); - } - Tensor output = fd->defineTensor(out_ndims); - // - std::optional out_scale = std::nullopt; - std::optional out_gamma = std::nullopt; - if (output_block_scale_size > 0) { - out_scale = fd->defineTensor(out_ndims); - } - if (output_gamma) { - // out_gamma is a scalar tensor - out_gamma = fd->defineTensor(0); - } - - fd->defineRecord(new ScaledMmaOpRecord( - {fd->recordingState(mat1()), - fd->recordingState(mat2()), - fd->recordingState(scale1()), - fd->recordingState(scale2()), - alpha.has_value() - ? fd->recordingState(alpha.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - bias.has_value() - ? fd->recordingState(bias.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - beta.has_value() - ? fd->recordingState(beta.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None)}, - {fd->recordingState(output()), - out_scale.has_value() - ? fd->recordingState(out_scale.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - out_gamma.has_value() - ? fd->recordingState(out_gamma.value()()) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None)}, - dtype, - output_block_scale_size, - output_block_scale_dtype, - output_gamma)); - - if (output_gamma) { - NVF_CHECK( - output_block_scale_size > 0, - "output_block_scale_size must be greater than 0 when " - "output_gamma is " - "true"); - return std::make_tuple(output, out_scale, out_gamma); - } else if (output_block_scale_size > 0) { - return std::make_tuple(output, out_scale, std::nullopt); - } - return std::make_tuple(output, std::nullopt, std::nullopt); - }, - R"( - Scaled matrix multiplication. - - Args: - mat1 (Tensor): First set of matrices - mat2 (Tensor): Second set of matrices - scale1 (Tensor): Scale tensor for mat1 - scale2 (Tensor): Scale tensor for mat2 - alpha (Tensor): Alpha tensor [optional] - bias (Tensor): Bias tensor [optional] - beta (Tensor): Beta tensor [optional] - dtype (ScalarType): Output tensor type [optional] - output_block_scale_size (int): Output block scale size [optional, default 0] - output_block_scale_dtype (ScalarType): Output block scale dtype [optional] - output_gamma (bool): Output gamma [optional, default: False] - - Note 1: The post quantization only applies when output_block_scale_size > 0, - which would produce out_scale tensor. Otherwise, None will be returned; - Note 2: When output_gamma is set to True, it should produce global scaling factor out_gamma tensor. - Otherwise, None will be returned. - - Returns: - Tensor: Result of grouped matrix multiplication - Tensor: Output block scale tensor [optional] - Tensor: Output gamma tensor [optional] - )", - py::arg("mat1"), - py::arg("mat2"), - py::arg("scale1"), - py::arg("scale2"), - py::arg("alpha") = std::nullopt, - py::arg("bias") = std::nullopt, - py::arg("beta") = std::nullopt, - py::arg("dtype") = DataType::BFloat16, - py::arg("output_block_scale_size") = 0, - py::arg("output_block_scale_dtype") = DataType::BFloat16, - py::arg("output_gamma") = false, - py::return_value_policy::reference); - - nvf_ops.def( - "cutlass_nvfp4_grouped_mm", - [](FusionDefinition::Operators& self, - Tensor mat1, - Tensor mat2, - Tensor scale1, - Tensor scale2, - Tensor alpha, - Tensor problem_sizes, - Tensor expert_offsets, - Tensor sf_offsets, - PrimDataType dtype) -> Tensor { - FUSER_PERF_SCOPE("Operators.cutlass_nvfp4_grouped_mm"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - - Tensor output = fd->defineTensor(3); - - fd->defineRecord(new CutlassNvfp4GroupedMmaOpRecord( - {fd->recordingState(mat1()), - fd->recordingState(mat2()), - fd->recordingState(scale1()), - fd->recordingState(scale2()), - fd->recordingState(alpha()), - fd->recordingState(problem_sizes()), - fd->recordingState(expert_offsets()), - fd->recordingState(sf_offsets())}, - {fd->recordingState(output())}, - dtype)); - - return output; - }, - R"( - Cutlass NVFP4 Grouped Matrix Multiplication. - - Args: - mat1 (Tensor): First set of matrices - mat2 (Tensor): Second set of matrices - scale1 (Tensor): Scale tensor for mat1 - scale2 (Tensor): Scale tensor for mat2 - alpha (Tensor): Alpha tensor - problem_sizes (Tensor): Problem sizes tensor - expert_offsets (Tensor): Expert offsets tensor - sf_offsets (Tensor): SF offsets tensor - dtype (ScalarType): Output tensor type - - Returns: - Tensor: Result of grouped matrix multiplication - )", - py::arg("mat1"), - py::arg("mat2"), - py::arg("scale1"), - py::arg("scale2"), - py::arg("alpha"), - py::arg("problem_sizes"), - py::arg("expert_offsets"), - py::arg("sf_offsets"), - py::arg("dtype") = DataType::BFloat16, - py::return_value_policy::reference); - - nvf_ops.def( - "topk", - [](FusionDefinition::Operators& self, - Tensor arg, - Scalar k, - int64_t dim, - bool largest, - bool sorted) -> py::tuple { - FUSER_PERF_SCOPE("Operators.topk"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - - Tensor values = fd->defineTensor(arg.dims); - Tensor indices = fd->defineTensor(arg.dims); - - fd->defineRecord(new TopKOpRecord( - {fd->recordingState(arg()), fd->recordingState(k())}, - {fd->recordingState(values()), fd->recordingState(indices())}, - dim, - largest, - sorted)); - - return py::make_tuple(values, indices); - }, - R"( - Find the k largest or smallest elements along a dimension. - - Args: - arg (Tensor): Input tensor - k (Scalar): Number of elements to return - dim (int, optional): Dimension along which to find top-k. Defaults to -1. - largest (bool, optional): If True, return largest elements. Defaults to True. - sorted (bool, optional): If True, return elements in sorted order. Defaults to False. - - Returns: - tuple[Tensor, Tensor]: A tuple of (values, indices) where values contains - the k largest/smallest elements and indices contains - their positions in the original tensor. - )", - py::arg("arg"), - py::arg("k"), - py::arg("dim") = -1, - py::arg("largest") = true, - py::arg("sorted") = false, - py::return_value_policy::reference); - - bindSchedule(fusion_def); - - bindMultidevice(nvfuser); -} - -void cleanup() { - auto& c = Communicator::getInstance(); - // In the transition period, both nvfuser and nvfuser_direct may be imported - // and share one Communicator singleton. Without the is_available check, - // each tries to call Communicator::cleanup() at process exit. - if (c.is_available()) { - c.cleanup(); - } -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/python_bindings.h b/python/python_frontend/python_bindings.h deleted file mode 100644 index 3bddf48fa92..00000000000 --- a/python/python_frontend/python_bindings.h +++ /dev/null @@ -1,27 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include - -#include -#include - -namespace nvfuser::python_frontend { - -NVF_API void initNvFuserPythonBindings(PyObject* module); - -// Add bindings for multi-GPU capabilities, e.g., DeviceMesh and Communicator. -void bindMultidevice(py::module& nvfuser); - -void bindSchedule(py::class_& fusion_def); - -NVF_API void cleanup(); - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/python_bindings_extension.cpp b/python/python_frontend/python_bindings_extension.cpp deleted file mode 100644 index cafe514e969..00000000000 --- a/python/python_frontend/python_bindings_extension.cpp +++ /dev/null @@ -1,18 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include - -PYBIND11_MODULE(EXTENSION_NAME, m) { - m.doc() = "nvfuser C API python binding"; // optional module docstring - - nvfuser::python_frontend::initNvFuserPythonBindings(m.ptr()); - - auto cleanup = []() -> void { nvfuser::python_frontend::cleanup(); }; - m.add_object("_cleanup", py::capsule(cleanup)); -} diff --git a/python/python_frontend/schedule_bindings.cpp b/python/python_frontend/schedule_bindings.cpp deleted file mode 100644 index 66ff97e5595..00000000000 --- a/python/python_frontend/schedule_bindings.cpp +++ /dev/null @@ -1,517 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace nvfuser::python_frontend { - -void bindSchedule(py::class_& fusion_def) { - //! The SchedOperators class is a nested class of FusionDefinition to allow - //! the user to query the class for the list of schedule operators. - //! - //! Example: - //! help(FusionDefinition.SchedOperators) - //! - //! Additional operators are expected to be defined below as needed. - py::class_ nvf_sched( - fusion_def, "SchedOperators"); - nvf_sched.def(py::init()); - nvf_sched.def( - "to_string", - [](FusionDefinition::SchedOperators& self, Tensor tensor) { - // NOTE: For debugging purposes, print the state of TensorView - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Determine if tensor is a result from a reduction operation. - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - return tv->toString(); - }, - py::arg("tensor")); - nvf_sched.def( - "user_schedule_ir", - [](FusionDefinition::SchedOperators& self) { - return self.fusion_definition->userScheduleIr(); - }, - py::return_value_policy::reference); - //! experimental API for multidevice support - nvf_sched.def( - "_set_device_mesh", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const DeviceMesh& mesh) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto tv = fd->getFusionState(tensor.index)->template as(); - tv->setDeviceMesh(mesh); - }, - py::arg("tensor"), - py::arg("mesh")); - nvf_sched.def( - "parallelize", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int axis, - const ParallelType& parallel_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto tv = fd->getFusionState(tensor.index)->template as(); - tv->axis(axis)->parallelize(parallel_type); - }, - py::arg("tensor"), - py::arg("axis"), - py::arg("parallel_type")); - nvf_sched.def( - "merge", - [](FusionDefinition::SchedOperators& self, Tensor arg, int dim) { - FUSER_PERF_SCOPE("SchedOperators.merge"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->merge(dim); - }, - py::arg("arg"), - py::arg("dim")); - auto reduction_factor_func = [](FusionDefinition::SchedOperators& self, - Tensor arg, - const std::vector& dims) -> Tensor { - FUSER_PERF_SCOPE("SchedOperators.reduction_factor"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(arg.index)->template as(); - TensorView* output_tv = input_tv->rFactor(dims); - return fd->addTensor(output_tv); - }; - nvf_sched.def( - "reduction_factor", - reduction_factor_func, - py::arg("arg"), - py::arg("dims")); - nvf_sched.def( - "rfactor", reduction_factor_func, py::arg("arg"), py::arg("dims")); - nvf_sched.def( - "reorder", - [](FusionDefinition::SchedOperators& self, - Tensor arg, - const std::unordered_map& old2new) { - FUSER_PERF_SCOPE("SchedOperators.reorder"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->reorder(old2new); - }, - py::arg("arg"), - py::arg("old2new")); - nvf_sched.def( - "split", - [](FusionDefinition::SchedOperators& self, - Tensor arg, - int64_t dim, - int64_t factor, - bool inner_split) { - FUSER_PERF_SCOPE("SchedOperators.split"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->split(dim, factor, inner_split); - }, - py::arg("arg"), - py::arg("dim"), - py::arg("factor"), - py::arg("inner_split") = true); - nvf_sched.def( - "set_allocation_as_loop", - [](FusionDefinition::SchedOperators& self, Tensor arg) { - FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto* tv = fd->getFusionState(arg.index)->template as(); - tv->setAllocationDomain(tv->getLoopDomain(), true); - }, - py::arg("arg")); - nvf_sched.def( - "cache_after", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const LoadStoreOpType& op_type, - const CacheOp& cache_op) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheAfter(op_type, cache_op); - return fd->addTensor(output_tv); - }, - py::arg("tensor"), - py::arg("op_type") = LoadStoreOpType::Set, - py::arg("cache_op") = CacheOp::Unspecified); - nvf_sched.def( - "cache_before", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const LoadStoreOpType& op_type) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheBefore(op_type); - return fd->addTensor(output_tv); - }, - py::arg("tensor"), - py::arg("op_type") = LoadStoreOpType::Set); - nvf_sched.def( - "cache_fork", - [](FusionDefinition::SchedOperators& self, Tensor tensor) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheFork(); - return fd->addTensor(output_tv); - }, - py::arg("tensor")); - nvf_sched.def( - "set_memory_type", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const MemoryType& memory_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - tv->setMemoryType(memory_type); - }, - py::arg("tensor"), - py::arg("memory_type")); - nvf_sched.def( - "transform_like", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const std::vector& selected_tensors) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - TransformPropagator propagator(reference_tv); - if (selected_tensors.empty()) { - // Propagate scheduler transformations on reference TensorView to the - // rest of the fusion. - MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); - } else { - // Propagate scheduler transformations on reference TensorView to the - // subset of the fusion. - std::unordered_set selected_tv_set; - selected_tv_set.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::inserter(selected_tv_set, selected_tv_set.end()), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - SetSelector selector( - {selected_tv_set.begin(), selected_tv_set.end()}); - MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) - .traverse(&propagator); - } - }, - py::arg("tensor"), - py::arg("selected_tensors") = std::vector()); - nvf_sched.def( - "parallelize_like", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int64_t pos, - const std::vector& selected_tensors, - const std::unordered_set& selected_parallel_types, - bool propagate_padding) { - // Propagate the parallelization from the selected dimensions of the - // reference tensor to their corresponding dimensions in all selected - // tensors in the DAG. - // - // 1. Position `pos` means selecting all the dimensions - // [0, 1, ..., pos - 1]. pos = -1 means selecting all dimensions. - // 2. `selected_tvs` are selected tensors in the DAG. Empty - // `selected_tvs` means selecting all tensors in the fusion of - // `reference_tv`. - // 3. `selected_parallel_types` are the selected parallel types. Empty - // `selected_parallel_types` means selecting all parallel types. - - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - std::vector selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::back_inserter(selected_tvs), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - - nvfuser::scheduler_utils::parallelizeAllLike( - reference_tv, - pos, - selected_tvs, - selected_parallel_types, - propagate_padding); - }, - py::arg("tensor"), - py::arg("pos") = -1, - py::arg("selected_tensors") = std::vector(), - py::arg("selected_parallel_types") = std::unordered_set(), - py::arg("propagate_padding") = true); - nvf_sched.def( - "inline_most", - [](FusionDefinition::SchedOperators& self, - const std::vector& selected_tensors) { - // Inline to the right most allowed position for the selected tensors in - // the current fusion. - - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - - if (selected_tensors.empty()) { - nvfuser::inlineMost(); - } else { - std::vector selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::back_inserter(selected_tvs), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - nvfuser::inlineMost(selected_tvs); - } - }, - py::arg("selected_tensors") = std::vector()); - nvf_sched.def( - "inline_at", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int64_t pos, - bool best_effort, - const std::vector& selected_tensors) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - if (selected_tensors.empty()) { - // Inline to the position corresponding to the reference position in - // the reference tensor for all tensors in the current fusion. - nvfuser::inlineAllAt(reference_tv, pos, best_effort); - } else { - // Inline to the position corresponding to the reference position in - // the reference tensor for selected tensors in the current fusion. - std::unordered_set selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::inserter(selected_tvs, selected_tvs.end()), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - - nvfuser::inlineSelectedAt( - selected_tvs, reference_tv, pos, best_effort); - } - }, - py::arg("tensor"), - py::arg("pos") = -1, - py::arg("best_effort") = false, - py::arg("selected_tensors") = std::vector()); - nvf_sched.def("tensors", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Return all Tensors in FusionDefinition - return self.fusion_definition->tensors(); - }); - nvf_sched.def( - "is_reduction", - [](FusionDefinition::SchedOperators& self, Tensor tensor) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Determine if tensor is a result from a reduction operation. - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - return ( - !tv->isFusionInput() && - std::any_of( - tv->getMaybeRootDomain().begin(), - tv->getMaybeRootDomain().end(), - [](IterDomain* id) { return id->isReduction(); }) && - !isResharding(tv->definition())); - }, - py::arg("tensor")); - nvf_sched.def( - "can_schedule", - [](FusionDefinition::SchedOperators& self, - const SchedulerType& scheduler_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - return self.fusion_definition->userSchedule()->canScheduleDebug( - scheduler_type); - }, - py::arg("scheduler_type")); - nvf_sched.def( - "find_compatible_schedulers", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - std::vector valid_scheduler_types; - valid_scheduler_types.reserve(all_heuristics_in_priority_order.size()); - std::copy_if( - all_heuristics_in_priority_order.begin(), - all_heuristics_in_priority_order.end(), - std::back_inserter(valid_scheduler_types), - [sched = self.fusion_definition->userSchedule()]( - SchedulerType scheduler_type) { - return sched->canSchedule(scheduler_type); - }); - return valid_scheduler_types; - }); - nvf_sched.def( - "schedule", - [](FusionDefinition::SchedOperators& self, - const SchedulerType& scheduler_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - auto&& [can_schedule, error_msg] = - sched->canScheduleDebug(scheduler_type); - NVF_CHECK(can_schedule, error_msg); - sched->scheduleWithType(scheduler_type); - }, - py::arg("heuristic")); - nvf_sched.def("schedule", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - sched->schedule(); - }); - nvf_sched.def( - "compute_pointwise_heuristics", - [](FusionDefinition::SchedOperators& self) -> PointwiseParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::PointWise); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "compute_reduction_heuristics", - [](FusionDefinition::SchedOperators& self) -> ReductionParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::Reduction); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "compute_matmul_heuristics", - [](FusionDefinition::SchedOperators& self) -> MatmulParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::Matmul); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "schedule_hyperparameters", - [](FusionDefinition::SchedOperators& self) - -> scheduler_utils::SchedulerHyperParameters& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - if (!sched->scheduler_hyperparams) { - sched->scheduler_hyperparams = - std::make_unique( - /*vectorize_factor=*/1, - /*unroll_factor=*/1, - /*threads_per_block_min=*/1, - /*threads_per_block_max=*/1, - /*is_warp_specialized=*/false); - } - return *sched->scheduler_hyperparams; - }, - py::return_value_policy::reference); -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/segmentation.cpp b/python/python_frontend/segmentation.cpp deleted file mode 100644 index 51f5a192a29..00000000000 --- a/python/python_frontend/segmentation.cpp +++ /dev/null @@ -1,369 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include - -namespace nvfuser::python_frontend { - -int64_t SegmentationState::setupSegmentation( - Fusion* fusion, - const std::unordered_map& - map_presched_value_to_original_python_index, - const KernelArgumentHolder& args) { - // Check state - NVF_ERROR(fusion != nullptr); - NVF_ERROR(cloned_original_fusion_ == nullptr); - NVF_ERROR(segmented_fusion_ == nullptr); - NVF_ERROR(group_run_order_.empty()); - NVF_ERROR(map_cloned_concretized_value_to_original_python_index_.empty()); - NVF_ERROR(cloned_original_extents_.empty()); - - // Step 1) Clone preschedFusion CPP Fusion. - cloned_original_fusion_ = std::make_unique(); - - // The IRCloner returned by Fusion::copy acts as map from the original fusion - // to the cloned fusion. - IrCloner original_to_cloned_map = - Fusion::copy(fusion, cloned_original_fusion_.get()); - - // Step 2) Given the map_presched_value_to_original_python_index AND the - // IRCloner returned by Fusion::copy, create a mapping from cloned CPP values - // to original fusion state indices. - std::unordered_map map_cloned_value_to_original_python_index; - map_cloned_value_to_original_python_index.reserve( - map_presched_value_to_original_python_index.size()); - std::transform( - map_presched_value_to_original_python_index.begin(), - map_presched_value_to_original_python_index.end(), - std::inserter( - map_cloned_value_to_original_python_index, - map_cloned_value_to_original_python_index.end()), - [&](const auto& item) { - const Val* original_value = item.first; - int64_t python_index = item.second; - Val* cloned_value = original_to_cloned_map.clone(original_value); - return std::make_pair(cloned_value, python_index); - }); - - // Step 3) Concretize fusion with input arguments. - std::unordered_map symbolic_to_concrete_map = - DynamicTransform::concretizeFusion(cloned_original_fusion_.get(), args); - - // Given the map_cloned_value_to_original_python_index AND the - // symbolic_to_concrete map returned by the concretization pass, create a - // mapping from cloned, concretized CPP values to original fusion state - // indices. - map_cloned_concretized_value_to_original_python_index_.reserve( - map_cloned_value_to_original_python_index.size()); - std::transform( - map_cloned_value_to_original_python_index.begin(), - map_cloned_value_to_original_python_index.end(), - std::inserter( - map_cloned_concretized_value_to_original_python_index_, - map_cloned_concretized_value_to_original_python_index_.end()), - [&](const auto& item) { - Val* maybe_concretized_value = item.first; - int64_t python_index = item.second; - if (symbolic_to_concrete_map.count(maybe_concretized_value) > 0) { - maybe_concretized_value = - symbolic_to_concrete_map.at(maybe_concretized_value); - } - return std::make_pair(maybe_concretized_value, python_index); - }); - - // Track the extents for input TensorViews in cloned CPP Fusion. - cloned_original_extents_ = getExtents(cloned_original_fusion_.get()); - - // Create runtime infomation - SchedulerRuntimeInfo runtime_info( - cloned_original_fusion_.get(), - args, - /*precomputed_values=*/nullptr, - cloned_original_fusion_->allTvs()); - - // Run segmentation algorithm - segmented_fusion_ = SegmentCandidateFinder::segment( - std::move(cloned_original_fusion_), args, runtime_info); - - // Get the order for fusion segments - prepareGroupOrder(); - - // Return the number of segments created by segmentation algorithm. - return (int64_t)segmented_fusion_->groups().size(); -} - -// setupSegmentation transforms the Prescheduled, Symbolic Fusion to Cloned, -// Concretized Fusion. Both CPP fusions corresponds with Original -// FusionDefinition. -// -// The segmentation pass runs on cloned, concretized fusion to create -// SegmentedFusion. Each SegmentedGroup in the SegmentedFusion creates a segment -// CPP fusion that is translated to a python definition. -// -// -// NOTE: Steps 4a through 4d are run for every fusion segment. However, -// sometimes the python definition needs the extents of the original fusion's -// input tensors as extra arguments. Steps 4f to 4l creates mappings for these -// missing extents. -// -// Details: -// 1) Use segment id to get SegmentedGroup from group_run_order_. -// 2) Create CPP Fusion for SegmentedGroup. -// * IrCloner acts as a map from fusion segment to the original fusion. -// 3) Translate CPP Fusion to Python FusionDefinition -// 4) Create map from segment fusion indices to original fusion indices. -// a) Get cloned Vals for SegmentedGroup's inputs and outputs. Map them -// to their original fusion indices. -// b) Map cloned Vals to their segment Vals -// c) Map segment Vals to their fusion indices. -// d) Map original indices to segment indices. -// e) Return map if the number of input arguments for python definition -// matches the number of input arguments for CPP fusion. -// f) Create a map from segment to cloned extents. -// g) Create a map from segment fusion indices to cloned extents. -// h) Find segment inputs that are missing from segment to original -// indices map. -// i) Get segment Vals for the missing segment fusion indices. -// j) Map segment Vals to cloned Vals. -// k) Map cloned Vals to their corresponding fusion indices. -// l) Add missing mappings to segment to original indices map. -// 5) Return the mapping from the segmented FusionDefinition index space to -// original FusionDefinition index space. -std::unordered_map SegmentationState::buildSegment( - FusionDefinition& segment_fd, - int64_t segment_id) { - NVF_ERROR( - !segment_fd.completed(), - "Expected an incomplete definition before translation."); - NVF_ERROR( - segmented_fusion_ != nullptr, - "SegmentedFusion is not initialized. Run setupSegmentation first."); - NVF_ERROR( - segment_id >= 0 && - segment_id < (int64_t)segmented_fusion_->groups().size(), - "The segment id is not valid"); - - // Step 1) Use segment id to get SegmentedGroup from group_run_order_. - SegmentedGroup* sg = group_run_order_.at(segment_id); - NVF_ERROR(sg != nullptr); - - // Step 2) Create CPP Fusion for SegmentedGroup. The IrCloner acts as a map - // from fusion segment to the original fusion. - std::pair> cloner_segment_pair = - segmented_fusion_->makeFusion(sg); - IrCloner cloned_to_segment_map = cloner_segment_pair.first; - std::unique_ptr fusion_segment = - std::move(cloner_segment_pair.second); - - // Step 3) Translate CPP Fusion to Python FusionDefinition - std::unordered_map - map_segment_cpp_value_to_python_index = - translate(fusion_segment.get(), &segment_fd); - - // Step 4) Create map from segment fusion indices to original fusion indices. - // Step 4a) Get FusionDefinition index for cloned inputs and outputs. Map them - // to their original fusion indices. - const std::vector& cloned_inputs = sg->inputs(); - const std::vector& cloned_outputs = sg->outputs(); - - std::vector original_python_index; - original_python_index.reserve(cloned_inputs.size() + cloned_outputs.size()); - - std::transform( - cloned_inputs.begin(), - cloned_inputs.end(), - std::back_inserter(original_python_index), - [&](Val* v) { - return map_cloned_concretized_value_to_original_python_index_.at(v); - }); - - std::transform( - cloned_outputs.begin(), - cloned_outputs.end(), - std::back_inserter(original_python_index), - [&](Val* v) { - return map_cloned_concretized_value_to_original_python_index_.at(v); - }); - - // Step 4b) ir_cloner maps cloned fusion Vals to segment Vals. - std::vector segment_inputs_outputs; - segment_inputs_outputs.reserve(cloned_inputs.size() + cloned_outputs.size()); - - std::transform( - cloned_inputs.begin(), - cloned_inputs.end(), - std::back_inserter(segment_inputs_outputs), - [&](Val* v) { return cloned_to_segment_map.clone(v); }); - - std::transform( - cloned_outputs.begin(), - cloned_outputs.end(), - std::back_inserter(segment_inputs_outputs), - [&](Val* v) { return cloned_to_segment_map.clone(v); }); - - // Step 4c) Map segment Vals to their FusionDefinition index. - std::vector segment_python_index; - segment_python_index.reserve(segment_inputs_outputs.size()); - std::transform( - segment_inputs_outputs.begin(), - segment_inputs_outputs.end(), - std::back_inserter(segment_python_index), - [&](Val* v) { return map_segment_cpp_value_to_python_index.at(v); }); - - // Step 4d) Map original indices to segment indices. - NVF_ERROR(original_python_index.size() == segment_python_index.size()); - std::unordered_map segment_to_original_python_index_map; - for (size_t idx : arange(original_python_index.size())) { - segment_to_original_python_index_map.emplace( - segment_python_index.at(idx), original_python_index.at(idx)); - } - - // Step 4e) short-circuit: No extra extents required for python definition. - if (fusion_segment->inputs().size() == segment_fd.inputs().size()) { - return segment_to_original_python_index_map; - } - - // The python segment can require the size of tensor dimensions from original - // fusion's input arguments, which the CPP segment does not. - - // Step 4f) Create a map from segment to cloned extents. - // Step 4g) Create a map from segment indices to segment extents. - std::unordered_map segment_to_cloned_extents; - std::unordered_map segment_python_index_to_cpp_val; - for (Val* cloned_extent : cloned_original_extents_) { - Val* segment_extent = cloned_to_segment_map.clone(cloned_extent); - - // short-circuit: some extents are not used in segment - if (map_segment_cpp_value_to_python_index.count(segment_extent) == 0) { - continue; - } - - size_t segment_python_index = - map_segment_cpp_value_to_python_index.at(segment_extent); - segment_to_cloned_extents.emplace(segment_extent, cloned_extent); - segment_python_index_to_cpp_val.emplace( - segment_python_index, segment_extent); - } - - // Step 4h) Find the set difference between all segment input indices and - // known input segment indices. - std::vector missing_segment_python_index; - for (int64_t input_python_index : segment_fd.inputs()) { - if (segment_to_original_python_index_map.count(input_python_index) == 0) { - missing_segment_python_index.push_back(input_python_index); - } - } - - // Step 4i) Get segment Val for missing segment input indices. - std::vector missing_segment_val; - missing_segment_val.reserve(missing_segment_python_index.size()); - std::transform( - missing_segment_python_index.begin(), - missing_segment_python_index.end(), - std::back_inserter(missing_segment_val), - [&](int64_t segment_python_index) { - return segment_python_index_to_cpp_val.at(segment_python_index); - }); - - // Step 4j) Map segment Vals to cloned Vals - std::vector missing_cloned_val; - missing_cloned_val.reserve(missing_segment_val.size()); - std::transform( - missing_segment_val.begin(), - missing_segment_val.end(), - std::back_inserter(missing_cloned_val), - [&](Val* segment_val) { - return segment_to_cloned_extents.at(segment_val); - }); - - // Step 4k) Transform cloned Vals to their original fusion indices. - std::vector missing_cloned_python_index; - missing_cloned_python_index.reserve(missing_cloned_val.size()); - std::transform( - missing_cloned_val.begin(), - missing_cloned_val.end(), - std::back_inserter(missing_cloned_python_index), - [&](Val* cloned_val) { - return map_cloned_concretized_value_to_original_python_index_.at( - cloned_val); - }); - - // Step 4l) Add missing mappings from segment to original indices. - for (size_t idx : arange(missing_segment_python_index.size())) { - segment_to_original_python_index_map.emplace( - missing_segment_python_index.at(idx), - missing_cloned_python_index.at(idx)); - } - - // Return the mapping from the index space of segment FusionDefinition to the - // index space of the original FusionDefinition. - return segment_to_original_python_index_map; -} - -void SegmentationState::prepareGroupOrder() { - NVF_ERROR(segmented_fusion_ != nullptr); - - // Gather initial inputs for SegmentedFusion. - std::unordered_set available_input( - segmented_fusion_->inputs().begin(), segmented_fusion_->inputs().end()); - - // The size of the tensor dimensions can be used as an input of the segments. - // NvFuser does not support returning scalar values. Segmentation must pass - // those sizes as segment arguments manually. - std::vector extents = getExtents(segmented_fusion_->completeFusion()); - std::copy( - extents.begin(), - extents.end(), - std::inserter(available_input, available_input.end())); - - // Track the run status of all SegmentedGroups in SegmentedFusion - std::vector group_ran(segmented_fusion_->groups().size(), false); - - // While not all the SegmentedGroups are run: - while (!std::all_of( - group_ran.begin(), group_ran.end(), [](bool b) { return b; })) { - bool ran_any_group = false; - - // Find the first segment with all inputs available to run - for (size_t group_i : arange(segmented_fusion_->groups().size())) { - SegmentedGroup* group = segmented_fusion_->groups().at(group_i); - - // short-circuit: Already ran this segmented group. - if (group_ran.at(group_i)) { - continue; - } - - const std::vector& group_inputs = group->inputs(); - bool ready_to_run = std::all_of( - group_inputs.begin(), - group_inputs.end(), - [&available_input](Val* val) { return available_input.count(val); }); - - // short-circuit: This segmented group is not ready to run. - if (!ready_to_run) { - continue; - } - - // Add SegmentedGroup to group_run_order_. - group_run_order_.push_back(group); - - // Mark all outputs of SegmentedGroup as ready. - const std::vector& group_outputs = group->outputs(); - for (size_t group_out_i : arange(group_outputs.size())) { - available_input.insert(group_outputs.at(group_out_i)); - } - group_ran[group_i] = true; - ran_any_group = true; - } - NVF_ERROR( - ran_any_group, - "Failed to run any group; An error must have occured in segmentation."); - } -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/segmentation.h b/python/python_frontend/segmentation.h deleted file mode 100644 index 22c71aa551d..00000000000 --- a/python/python_frontend/segmentation.h +++ /dev/null @@ -1,246 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include -#include - -namespace nvfuser::python_frontend { - -class FusionDefinition; - -//! Overview: -//! Segmentation decomposes a fusion into a directed acyclic graph (DAG) of -//! sub-fusions. After applying the segmentation algorithm, we can translate -//! the sub-fusions into their corresponding python definitions. Then, given the -//! fusion's input arguments, the segments are run in the correct order to -//! produce the output results. -//! -//! Each FusionDefinition contains a set of states representing tensors, vectors -//! and scalars. Every state has a unique index, which matches the insertion -//! order of the state in the FusionDefinition. These indices form a linear -//! index space for each FusionDefinition. -//! -//! The original FusionDefinition stores the sequence of sub-fusions and acts as -//! an argument manager. It gathers the input arguments before running the -//! sub-fusion and stores its results. To perform this function, it requires a -//! map from the segment index space to the original index space. This mapping -//! is generated while creating the python definition for each sub-fusion. -//! -//! Algorithm: -//! Step 1: setupSegmentation runs the segmentation algorithm on the CPP Fusion -//! to create the SegmentedFusion. Then, sub-fusions are ordered according to -//! their dependencies by the prepareGroupOrder function. It returns the number -//! of segments in SegmentedFusion. -//! -//! Step 2: buildSegment creates the CPP Fusion for a given segment id, -//! translates it to a python FusionDefinition, then returns a mapping from the -//! segment fusion state indices to the original fusion state indices. -//! -//! =========================================================================== -//! -//! Example 1: A simple fusion with two iota operations. -//! -//! Original Fusion: -//! def nvfuser_fusion_id1(fd : FusionDefinition) -> None : -//! S0 = fd.define_scalar(2, dtype=DataType.Int) -//! S1 = fd.define_scalar(0, dtype=DataType.Int) -//! S2 = fd.define_scalar(2, dtype=DataType.Int) -//! T3 = fd.ops.iota(S0, S1, S2, dtype=DataType.Int) -//! S4 = fd.define_scalar(3, dtype=DataType.Int) -//! S5 = fd.define_scalar(100, dtype=DataType.Int32) -//! S6 = fd.define_scalar(1, dtype=DataType.Int32) -//! T7 = fd.ops.iota(S4, S5, S6, dtype=DataType.Int32) -//! fd.add_output(T3) -//! fd.add_output(T7) -//! -//! After Segmentation: -//! The original fusion is divided into two segments. There is no dependencies -//! between either segment so they can run in any order. -//! -//! First Segment: -//! def nvfuser_fusion_id2(fd : FusionDefinition) -> None : -//! S0 = fd.define_scalar(2, dtype=DataType.Int) -//! S1 = fd.define_scalar(0, dtype=DataType.Int) -//! S2 = fd.define_scalar(2, dtype=DataType.Int) -//! T3 = fd.ops.iota(S0, S1, S2, dtype=DataType.Int) -//! fd.add_output(T3) -//! -//! Second Segment: -//! def nvfuser_fusion_id3(fd : FusionDefinition) -> None : -//! S0 = fd.define_scalar(3, dtype=DataType.Int) -//! S1 = fd.define_scalar(100, dtype=DataType.Int32) -//! S2 = fd.define_scalar(1, dtype=DataType.Int32) -//! T3 = fd.ops.iota(S0, S1, S2, dtype=DataType.Int32) -//! fd.add_output(T3) -//! -//! The first segment corresponds with [S0, S1, S2, T3] in the original fusion. -//! The second segment corresponds with [S4, S5, S6, S7] in the original fusion. -//! -//! Neither segment requires any input arguments from the original fusion. -//! -//! For the first segment, the segment's T3 is mapped to the original's T3. -//! Segment Index : Original Index Mapping -//! -------------------------------------- -//! T3 : T3 -//! -//! For the second segment the segment's T3 is mapped to the original's T7. -//! Segment Index : Original Index Mapping -//! -------------------------------------- -//! T3 : T7 -//! -//! =========================================================================== -//! -//! Example 2: A reduction + broadcast + pointwise fusion. -//! -//! Original Fusion: -//! def nvfuser_fusion_id1(fd : FusionDefinition) -> None : -//! T0 = fd.define_tensor(shape=[-1, -1], -//! contiguity=[True, True], -//! dtype=DataType.Float, -//! is_cpu=False) -//! T1 = fd.define_tensor(shape=[-1, -1], -//! contiguity=[True, True], -//! dtype=DataType.Float, -//! is_cpu=False) -//! T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) -//! T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True]) -//! T4 = fd.ops.add(T1, T3) -//! fd.add_output(T4) -//! -//! After Segmentation: -//! The reduction scheduler does not support fusing any operations with an -//! inner reduction, so the original fusion is divided into two segments. -//! Segment 2 depends on Segment 1, so there is a strict segment ordering. -//! -//! First Segment: -//! def nvfuser_fusion_id2(fd : FusionDefinition) -> None : -//! T0 = fd.define_tensor(shape=[-1, -1], -//! contiguity=[True, True], -//! dtype=DataType.Float, -//! is_cpu=False) -//! T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float) -//! T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True]) -//! fd.add_output(T2) -//! -//! Second Segment: -//! def nvfuser_fusion_id3(fd : FusionDefinition) -> None : -//! T0 = fd.define_tensor(shape=[-1, -1], -//! contiguity=[True, True], -//! dtype=DataType.Float, -//! is_cpu=False) -//! T1 = fd.define_tensor(shape=[-1, 1], -//! contiguity=[True, None], -//! dtype=DataType.Float, -//! is_cpu=False) -//! T2 = fd.ops.add(T0, T1) -//! fd.add_output(T2) -//! -//! The first segment contains the reduction and broadcast operations, which -//! corresponds with [T0, T2, T3] in the original fusion. Therefore, the segment -//! index to original index map has two entries. -//! -//! Segment Index : Original Index Mapping -//! -------------------------------------- -//! T0 : T0 --- The first tensor argument for the original fusion. -//! T2 : T3 --- The broadcasted, reduction tensor is this segment's output. -//! -//! The second segment is the pointwise addition with the broadcasted reduction. -//! It corresponds with [T1, T3, T4] in the original fusion. -//! -//! Segment Index : Original Index Mapping -//! -------------------------------------- -//! T0 : T1 --- The second tensor argument for the original fusion. -//! T1 : T3 --- The broadcasted, reduction tensor, which is the output from the -//! first segment. -//! T2 : T4 --- The pointwise addition, which is the output for the original -//! fusion. -//! =========================================================================== -class SegmentationState { - public: - // setupSegmentation runs the segmentation algorithm on CPP Fusion to create - // SegmentedFusion. It returns the number of segments in SegmentedFusion. - // - // Details: - // 1) Clone preschedFusion CPP Fusion. - // 2) Concretize fusion using input arguments. - // 3) Given the map_presched_value_to_original_python_index, the IRCloner - // returned by Fusion::copy, AND symbolic_to_concrete map returned by - // concretization pass, create a mapping from cloned Vals to original fusion - // state indices. - // 4) Get extents for cloned fusion. - // 5) Create SchedulerRuntimeInfo. - // 6) Run segmentation algorithm using cloned fusion, input arguments, and - // scheduler runtime information. - // 7) Get sequential order of fusion segments using prepareGroupOrder. - // 8) Return the number of segments created by segmentation algorithm. - int64_t setupSegmentation( - Fusion* fusion, - const std::unordered_map& - map_presched_value_to_original_python_index, - const KernelArgumentHolder& inputs); - - // Given an empty FusionDefinition and a segment id, buildSegment creates the - // CPP Fusion, translates it to the python FusionDefinition, then returns a - // mapping from segment fusion state indices to the original fusion state - // indices. - // - // The mapping is constructed from the segment's python definition -> - // segment's CPP Fusion -> original's CPP Fusion -> original's python - // definition. - // - // NOTE: Sometimes the python definition requires the extents from the - // original fusion's input tensors as extra arguments. Therefore, the input - // arguments for the python definition and the CPP Fusion may not exactly - // match. - NVF_API std::unordered_map buildSegment( - FusionDefinition& segment_fd, - int64_t segment_id); - - private: - // prepareGroupOrder is similar to prepareRuntimeOrder. It generates the - // topological order of SegmentedGroups in SegmentedFusion. - // - // Details: - // 1) Gather initial inputs for SegmentedFusion. - // 2) Gather IterDomain extents from the tensor input arguments. - // 3) Track the run status of all SegmentedGroups in SegmentedFusion - // 4) While not all the SegmentedGroups are run: - // 5) For each SegmentedGroup: - // 6) Skip SegmentedGroup if it is already run - // 7) Skip SegmentedGroup if inputs are not ready - // 8) Add SegmentedGroup to group_run_order_. Mark all outputs of - // SegmentedGroup as ready. - // 9) End For - // 10) Fail if none of the SegmentedGroups are available to run. - // 11) End While - void prepareGroupOrder(); - - private: - // Clone of original fusion for segmentation - std::unique_ptr cloned_original_fusion_ = nullptr; - - // This FusionDefinition may require multiple kernels if it cannot be handled - // by a single heuristic scheduler. SegmentedFusion takes a fusion and runs - // the segmentation algorithm. - std::unique_ptr segmented_fusion_ = nullptr; - - // Pre-determined order to run the segmented groups - std::vector group_run_order_; - - // Map values from cloned, concretized fusion to the indices of the original - // python definition. - std::unordered_map - map_cloned_concretized_value_to_original_python_index_; - - // Extents for TensorView input arguments for cloned Fusion - std::vector cloned_original_extents_; -}; - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/translation.cpp b/python/python_frontend/translation.cpp deleted file mode 100644 index c6f75871fe2..00000000000 --- a/python/python_frontend/translation.cpp +++ /dev/null @@ -1,1484 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -// -#include -#include -#include -#include -#include -#include -#include "base.h" - -#include - -namespace nvfuser::python_frontend { - -namespace { - -// Given a CPP Fusion and an empty python_frontend FusionDefinition -// FusionTranslator adds the appropriate RecordFunctors corresponding to the -// CPP values and expressions. -// -// Rather than create a new FusionDefinition from the CPP Fusion, we add -// RecordFunctors to a blank FusionDefinition. This is a design decision because -// of the FusionDefinition python class, which inherits from the -// _C._FusionDefinition class created by pybind11. It is easier to operate on -// the child class directly than to create a new child instance from parent -// instance. -// -// How to add support for an expression not yet overriden by FusionTranslator? -// 1. Create handle function for expression. -// a. void handle(const SomeOp* op) final -// -// 2. Add RecordFunctor corresponding to Statement to FusionDefinition. -// a. fd_->defineRecord(new RecordFunctor(inputs, outputs) -// -// 3. If input argument already exists in FusionDefinition, map expressions -// input values to FusionDefinition State. -// a. map_val_to_fd_index_ maps CPP Val to fusion definition index. -// b. fd_->recordingState(map_val_to_fd_index_.at(op->inputs(...))) -// -// 4. If input argument is a vector, use createVector function. -// -// 5. If input argument is a scalar constant, use createScalar function. -// -// 6. Create output states expressions inputs. -// a. Tensor output = fd_->defineTensor(v->as()->nDims()) -// -// 7. Add CPP Val and output state pair to map_val_to_fd_index_. -class FusionTranslator : public OptInConstDispatch { - public: - // Returns a map from the values in the CPP fusion to its corresponding - // FusionDefinition State index. - // - // Why? - // For segmentation, we divide the original FusionDefinition into its - // segments. Each segment has a separate index namespace. To run a segment, - // we need to pass outputs from prior segments as this segment's input - // arguments. The original FusionDefinition coordinates this argument passing. - // The map returned by this function is used to a global mapping from the - // original FusionDefinition's indicies to this segment's indicies. - static std::unordered_map translate( - Fusion* fusion, - FusionDefinition* fd) { - NVF_ERROR( - !fd->completed(), - "Expected an incomplete definition before fusion translation!"); - FusionTranslator translator(fusion, fd); - translator.translate(); - return translator.map_val_to_fd_index_; - } - - private: - FusionTranslator(Fusion* fusion, FusionDefinition* fd) - : fusion_(fusion), fd_(fd) {} - - bool isScheduledTensorView(TensorView* tv) const { - NVF_ERROR(tv != nullptr); - const std::vector& logical = tv->domain()->logical(); - const std::vector& loop = tv->domain()->loop(); - // short-circuit: check same length - if (logical.size() != loop.size()) { - return true; - } - - if (tv->definition() != nullptr && !tv->definition()->isA()) { - for (size_t idx : arange(logical.size())) { - if (logical.at(idx) != loop.at(idx)) { - return true; - } - } - } - return false; - } - - // The new shape for view operation can be dynamic. Check that all dynamic - // scalar dependencies are handled before the ReshapeOp. - bool checkViewShapeDependency(const ReshapeOp* vop) { - const std::vector& logical_out_domain = - vop->out()->as()->domain()->logical(); - std::vector logical_domain_extents; - std::transform( - logical_out_domain.begin(), - logical_out_domain.end(), - std::back_inserter(logical_domain_extents), - [](IterDomain* id) { return id->getMaybeExpandedExtent(); }); - return std::all_of( - logical_domain_extents.begin(), - logical_domain_extents.end(), - [&](Val* v) { - return v->definition() == nullptr || - map_val_to_fd_index_.count(v) > 0; - }); - } - - // Gather the expressions necessary to create a scalar value. - std::vector gatherScalarExpressions(Val* v) { - NVF_ERROR(v != nullptr); - NVF_ERROR(v->isScalar()); - - // short-circuit: v does not have a definition. - if (v->definition() == nullptr) { - return {}; - } - - std::vector expression_chain; - std::unordered_set visited; - std::vector to_visit = {v->definition()}; - while (!to_visit.empty()) { - Expr* e = to_visit.back(); - to_visit.pop_back(); - - expression_chain.push_back(e); - visited.insert(e); - - for (Val* input : e->inputs()) { - // short-circuit: input does not have a definition. - if (input->definition() == nullptr) { - continue; - } - - // short-circuit: input definition is already visited. - if (visited.count(input->definition()) > 0) { - continue; - } - - to_visit.push_back(input->definition()); - } - } - return expression_chain; - } - - // Gather the scalar expressions necessary to create the logical domain for a - // TensorView. - std::vector gatherScalarExpressions(TensorView* tv) { - NVF_ERROR(tv != nullptr); - std::vector logical_domain_expressions; - const std::vector& logical_out_domain = - tv->domain()->logical(); - for (IterDomain* id : logical_out_domain) { - std::vector extent_definitions = - gatherScalarExpressions(id->getMaybeExpandedExtent()); - logical_domain_expressions.insert( - logical_domain_expressions.end(), - extent_definitions.begin(), - extent_definitions.end()); - } - return logical_domain_expressions; - } - - // Check that all of the expression's inputs are defined in FusionDefinition. - bool checkExpressionDependencies(Expr* e) { - bool check_view_dependency = - !e->isA() || checkViewShapeDependency(e->as()); - return check_view_dependency && - std::all_of(e->inputs().begin(), e->inputs().end(), [&](const Val* v) { - return map_val_to_fd_index_.count(v) > 0; - }); - } - - void translate() { - fd_->setupDefinition(); - - // Add Fusion inputs to FusionDefinition - for (nvfuser::Val* v : fusion_->inputs()) { - dispatch(v); - } - - // Gather all expressions in CPP Fusion. - const std::vector fusion_exprs = fusion_->exprs(); - std::deque to_visit( - fusion_exprs.begin(), fusion_exprs.end()); - - // Scalar expressions are not handled by Fusion::exprs, so gather them - // manually. - for (Expr* e : to_visit) { - if (e->isA() || e->isA() || e->isA()) { - std::vector extent_definitions = - gatherScalarExpressions(e->output(0)->as()); - to_visit.insert( - to_visit.end(), - extent_definitions.begin(), - extent_definitions.end()); - } - } - - // Topological search of Fusion expressions - size_t skip_count = 0; - std::unordered_set visited; - while (!to_visit.empty()) { - Expr* e = to_visit.front(); - to_visit.pop_front(); - - NVF_ERROR( - skip_count <= to_visit.size(), - "Cycle detected: None of the expressions can be processed!"); - - // short-circuit: skip if already visited - if (visited.count(e) > 0) { - continue; - } - - // short-circuit: skip Split and Merge expressions created by Reshape - // short-circuit: skip Resize expressions created by Slice - if (e->isA() || e->isA() || e->isA()) { - visited.insert(e); - continue; - } - - bool is_expr_inputs_valid = - std::all_of(e->inputs().begin(), e->inputs().end(), [this](Val* v) { - return !v->isA() || - !isScheduledTensorView(v->as()); - }); - NVF_ERROR( - is_expr_inputs_valid, - "Found a TensorView with scheduled loop domain."); - - // Handle scalars and constants not generated by separate expression. - std::vector scalars; - std::copy_if( - e->inputs().begin(), - e->inputs().end(), - std::back_inserter(scalars), - [](Val* v) { return v->isScalar(); }); - std::for_each(scalars.begin(), scalars.end(), [this](const Val* v) { - dispatch(v); - }); - - // short-circuit: add to back of stack if not all of the expression's - // dependencies are satisfied. - if (!checkExpressionDependencies(e)) { - ++skip_count; - to_visit.push_back(e); - continue; - } - - // Create RecordFunctor given inputs, outputs, and attributes. - visited.insert(e); - dispatch(e); - skip_count = 0; - } - - // Add tensor outputs and handle aliased outputs - std::unordered_set visited_alias_output; - for (nvfuser::Val* v : fusion_->outputs()) { - NVF_ERROR(v->isA()); - const AliasInfo& alias_info = fusion_->getOutputAlias(v); - switch (alias_info.type) { - case AllocationType::New: { - handleOutput(v->as()); - break; - } - case AllocationType::ReuseBuffer: { - size_t num_visited = visited_alias_output.count(v); - if (num_visited == 0) { - visited_alias_output.insert(v); - handleOutput(v->as(), alias_info); - } - // An alias output can also be returned as a fusion output - // if it is already aliased or if the output is visible. - if (num_visited > 0 || - alias_info.visibility == OutputVisibility::kVisible) { - handleOutput(v->as()); - } - break; - } - default: - NVF_ERROR(false, "Unsupported AllocationType"); - } - } - - fd_->finalizeDefinition(); - } - - // ================================================================================= - // Filter Functions - - // Gather all TensorViews and FusionDefinition indices - std::vector> tensors() { - std::vector> tensors; - std::copy_if( - map_val_to_fd_index_.begin(), - map_val_to_fd_index_.end(), - std::back_inserter(tensors), - [](std::pair&& kv) { - return kv.first->isA(); - }); - return tensors; - } - - // ================================================================================= - // Handle define_scalar and define_tensor variants - - // Create scalar for given nvfuser value. The nvfuser value must not already - // exist and have a definition. It can be a fusion input, a constant, or a - // tensor's extent. - Scalar createScalar(const Val* v) { - NVF_ERROR( - v->definition() == nullptr, - "Value has a definition and should not be created directly."); - - // short-circuit: value already exists in FusionDefinition - if (map_val_to_fd_index_.count(v) > 0) { - return Scalar(map_val_to_fd_index_.at(v), fd_); - } - - Scalar output = fd_->defineScalar(); - map_val_to_fd_index_.emplace(v, output()); - - // Since scalars can come from TensorView dimension sizes, search through - // all TensorViews for an iterDomain whose extent matches the desired - // value and then create SizeOpRecord. - for (auto& kv : tensors()) { - const TensorView* key_tv = kv.first->as(); - - std::vector filtered_logical_domain = - TensorDomain::noReductions(key_tv->domain()->logical()); - // Get extents for each IterDomain - std::vector extents; - extents.reserve(filtered_logical_domain.size()); - std::transform( - filtered_logical_domain.begin(), - filtered_logical_domain.end(), - std::back_inserter(extents), - [](IterDomain* id) { return id->getMaybeExpandedExtent(); }); - - auto iter = std::find(extents.begin(), extents.end(), v); - // Check if value matches iterdomain extent - if (iter == extents.end()) { - continue; - } - - int64_t dim = std::distance(extents.begin(), iter); - fd_->defineRecord(new SizeOpRecord( - {fd_->recordingState(kv.second)}, - {fd_->recordingState(output())}, - dim)); - return output; - } - - // DataType::Index does not exist in python_frontend, so convert to - // DataType::Int - DataType scalar_dtype = - (v->dtype() == DataType::Index) ? DataType::Int : v->dtype(); - - fd_->defineRecord(new ScalarRecord( - {fd_->recordingState(output())}, - v->value(), - std::get(scalar_dtype.type))); - return output; - } - - // Add scalar value to Fusion Definition - void handle(const Val* v) final { - // short-circuit: scalar definition has a definition - if (v->definition() != nullptr) { - return; - } - createScalar(v); - } - - // Create python_frontend Vector from a vector of CPP scalar values. - Vector createVector(std::vector scalars) { - // Add CPP values to Fusion Definition if necessary - std::for_each(scalars.begin(), scalars.end(), [this](const Val* v) { - OptOutConstDispatch::dispatch(v); - }); - - // Get corresponding index for CPP values - std::vector inputs; - std::transform( - scalars.begin(), - scalars.end(), - std::back_inserter(inputs), - [&](Val* v) { - return fd_->recordingState(map_val_to_fd_index_.at(v)); - }); - - // NOTE There is not an equivalent CPP class for python-frontend vector, - // so we do not add it to map_val_to_fd_index_. - Vector output = fd_->defineVector(inputs.size()); - fd_->defineRecord(new VectorRecord( - inputs, {fd_->recordingState(output())}, DataType::Int)); - return output; - } - - // Add Tensor value to Fusion Definition - void handle(const TensorView* tv) final { - // short-circuit: value already exists in FusionDefinition - if (map_val_to_fd_index_.count(tv) > 0) { - return; - } - - Tensor output = fd_->defineTensor(tv->nDims()); - map_val_to_fd_index_.emplace(tv, output()); - - std::vector shape; - std::transform( - tv->domain()->logical().begin(), - tv->domain()->logical().end(), - std::back_inserter(shape), - [](IterDomain* id) { - return (id->getMaybeExpandedExtent()->isConstScalar()) - ? id->getMaybeExpandedExtent()->evaluate().as() - : -1; - }); - - fd_->defineRecord(new TensorRecord( - {fd_->recordingState(output())}, - shape, - tv->domain()->contiguity(), - std::get(tv->dtype().type), - tv->isCpuScalar(), - tv->domain()->strideOrder())); - } - - // ================================================================================= - // Utility functions - - // Create a vector for the logical domain of TensorView. - // Used with ReshapeOp and ExpandOp handlers - Vector getShape(TensorView* tv) { - const std::vector& logical_out_domain = - tv->domain()->logical(); - std::vector logical_domain_extents; - // Use expanded extent if available for IterDomain. - std::transform( - logical_out_domain.begin(), - logical_out_domain.end(), - std::back_inserter(logical_domain_extents), - [](IterDomain* id) { return id->getMaybeExpandedExtent(); }); - return createVector(logical_domain_extents); - } - - // Find integer index corresponding with reduction iterDomains - std::vector getReductionAxes(TensorView* tv) { - std::vector axes; - const std::vector& logical_domain = tv->domain()->logical(); - for (int64_t dim : arange((int64_t)logical_domain.size())) { - if (logical_domain.at(dim)->isReduction()) { - axes.push_back(dim); - } - } - return axes; - } - - // ================================================================================= - // Handle add_output variants - - // Add Tensor output to FusionDefinition - void handleOutput(const TensorView* tv) { - size_t output_index = map_val_to_fd_index_.at(tv); - fd_->defineRecord(new OutputRecord( - {fd_->recordingState(output_index)}, - serde::RecordType::OutputTv, - tv->domain()->strideOrder())); - } - - // Alias output Tensor with input tensor - void handleOutput(const TensorView* tv, const AliasInfo& alias_info) { - size_t output_index = map_val_to_fd_index_.at(tv); - size_t input_index = map_val_to_fd_index_.at(alias_info.aliased_io); - fd_->defineRecord(new OutputRecord( - {fd_->recordingState(output_index), fd_->recordingState(input_index)}, - serde::RecordType::OutputTv)); - } - - // ================================================================================= - // Map CPP Expression classes to corresponding RecordFunctors in - // python_frontend - - // A generic function to map UnaryOp, BinaryOp, and TernaryOp to - // python_frontend OpRecord - template - void handleOpRecord( - const Expr* e, - serde::RecordType record_type, - ResultType result, - ArgTypes... args) { - NVF_ERROR(e->isA()); - std::vector argument_states; - std::transform( - e->inputs().begin(), - e->inputs().end(), - std::back_inserter(argument_states), - [&](auto arg) { - return fd_->recordingState(map_val_to_fd_index_.at(arg)); - }); - - fd_->defineRecord(new OpRecord( - argument_states, - {fd_->recordingState(map_val_to_fd_index_.at(result))}, - "ops." + python::toString(e->as()), - record_type, - getFunction(e->as()))); - } - - // Map UnaryOp to python_frontend OpRecord - void handle(const UnaryOp* uop) final { - // short-circuit: Handle cast operation separately - if (uop->getUnaryOpType() == UnaryOpType::Cast) { - return handleCastOp(uop); - } - - // Map remaining UnaryOp to python_frontend OpRecord - if (uop->in()->isA()) { - Tensor output = fd_->defineTensor(uop->out()->as()->nDims()); - map_val_to_fd_index_.emplace(uop->out(), output()); - handleOpRecord( - uop, - serde::RecordType::Unary_TV, - uop->out()->as(), - uop->in()->as()); - } else { - Scalar output = fd_->defineScalar(); - map_val_to_fd_index_.emplace(uop->out(), output()); - handleOpRecord( - uop, serde::RecordType::Unary_VAL, uop->out(), uop->in()); - } - } - - // Map cast UnaryOp to CastOpRecord - void handleCastOp(const Expr* op) { - bool is_cast_op = op->isA() && - op->as()->getUnaryOpType() == UnaryOpType::Cast; - NVF_ERROR(is_cast_op); - - size_t input_fd_index = map_val_to_fd_index_.at(op->input(0)); - - // DataType::Index does not exist in python_frontend, so convert to - // DataType::Int - DataType scalar_dtype = op->output(0)->dtype(); - if (scalar_dtype == DataType::Index) { - scalar_dtype = DataType::Int; - } - - if (op->input(0)->isA()) { - Tensor output = - fd_->defineTensor(op->output(0)->as()->nDims()); - map_val_to_fd_index_.emplace(op->output(0), output()); - fd_->defineRecord(new CastOpRecord( - {fd_->recordingState(input_fd_index)}, - {fd_->recordingState(output())}, - "ops.cast", - serde::RecordType::CastTv, - static_cast(castOp), - std::get(scalar_dtype.type))); - } else { - Scalar output = fd_->defineScalar(); - map_val_to_fd_index_.emplace(op->output(0), output()); - fd_->defineRecord(new CastOpRecord( - {fd_->recordingState(input_fd_index)}, - {fd_->recordingState(output())}, - "ops.cast", - serde::RecordType::CastVal, - static_cast(castOp), - std::get(scalar_dtype.type))); - } - } - - // Map BinaryOp to python_frontend OpRecord - void handle(const BinaryOp* bop) final { - bool is_lhs_tv = bop->lhs()->isA(); - bool is_rhs_tv = bop->rhs()->isA(); - - if (is_lhs_tv || is_rhs_tv) { - Tensor output = fd_->defineTensor(bop->out()->as()->nDims()); - map_val_to_fd_index_.emplace(bop->out(), output()); - - if (is_lhs_tv && is_rhs_tv) { - handleOpRecord( - bop, - serde::RecordType::Binary_TV, - bop->out()->as(), - bop->lhs()->as(), - bop->rhs()->as()); - } else if (is_lhs_tv && !is_rhs_tv) { - handleOpRecord( - bop, - serde::RecordType::Binary_TV_VAL, - bop->out()->as(), - bop->lhs()->as(), - bop->rhs()); - } else { - handleOpRecord( - bop, - serde::RecordType::Binary_VAL_TV, - bop->out()->as(), - bop->lhs(), - bop->rhs()->as()); - } - } else { - Scalar output = fd_->defineScalar(); - map_val_to_fd_index_.emplace(bop->out(), output()); - handleOpRecord( - bop, - serde::RecordType::Binary_VAL, - bop->out(), - bop->lhs(), - bop->rhs()); - } - } - - // Map TernaryOp to python frontend - void handle(const TernaryOp* top) final { - bool is_in1_tv = top->in1()->isA(); - bool is_in2_tv = top->in2()->isA(); - bool is_in3_tv = top->in3()->isA(); - - if (is_in1_tv || is_in2_tv || is_in3_tv) { - Tensor output = fd_->defineTensor(top->out()->as()->nDims()); - map_val_to_fd_index_.emplace(top->out(), output()); - - if (is_in1_tv && is_in2_tv && is_in3_tv) { - handleOpRecord( - top, - serde::RecordType::Ternary_TV, - top->out()->as(), - top->in1()->as(), - top->in2()->as(), - top->in3()->as()); - } else if (is_in1_tv && is_in2_tv && !is_in3_tv) { - handleOpRecord( - top, - serde::RecordType::Ternary_TV_TV_VAL, - top->out()->as(), - top->in1()->as(), - top->in2()->as(), - top->in3()); - } else if (is_in1_tv && !is_in2_tv && is_in3_tv) { - handleOpRecord( - top, - serde::RecordType::Ternary_TV_VAL_TV, - top->out()->as(), - top->in1()->as(), - top->in2(), - top->in3()->as()); - } else if (is_in1_tv && !is_in2_tv && !is_in3_tv) { - handleOpRecord( - top, - serde::RecordType::Ternary_TV_VAL_VAL, - top->out()->as(), - top->in1()->as(), - top->in2(), - top->in3()); - } else if (!is_in1_tv && is_in2_tv && is_in3_tv) { - handleOpRecord( - top, - serde::RecordType::Ternary_VAL_TV_TV, - top->out()->as(), - top->in1(), - top->in2()->as(), - top->in3()->as()); - } else if (!is_in1_tv && is_in2_tv && !is_in3_tv) { - handleOpRecord( - top, - serde::RecordType::Ternary_VAL_TV_VAL, - top->out()->as(), - top->in1(), - top->in2()->as(), - top->in3()); - } else if (!is_in1_tv && !is_in2_tv && is_in3_tv) { - handleOpRecord( - top, - serde::RecordType::Ternary_VAL_VAL_TV, - top->out()->as(), - top->in1(), - top->in2(), - top->in3()->as()); - } - } else { - Scalar output = fd_->defineScalar(); - map_val_to_fd_index_.emplace(top->out(), output()); - handleOpRecord( - top, - serde::RecordType::Ternary_VAL, - top->out(), - top->in1(), - top->in2(), - top->in3()); - } - } - - // Map ReductionOp to python frontend - void handle(const ReductionOp* rop) final { - auto* out_tv = rop->out()->as(); - - // The min and max reduction operations expect the dtype argument to be - // PrimDataType::Null - PrimDataType dtype = (rop->getReductionOpType() == BinaryOpType::Min || - rop->getReductionOpType() == BinaryOpType::FMin || - rop->getReductionOpType() == BinaryOpType::Max || - rop->getReductionOpType() == BinaryOpType::FMax) - ? PrimDataType::Null - : std::get(rop->out()->dtype().type); - - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(rop->out(), output()); - fd_->defineRecord(new ReductionOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(rop->in()))}, - {fd_->recordingState(output())}, - "ops." + python::toString(rop), - getSerdeType(rop), - getFunction< - TensorView*, - TensorView*, - const std::vector&, - bool, - DataType>(rop), - getReductionAxes(out_tv), - /*keep_dim=*/false, - dtype)); - } - - // Map WelfordOp to python frontend - void handle(const WelfordOp* wop) final { - NVF_ERROR(wop->initAvg()->evaluate().as() == 0.0); - NVF_ERROR(wop->initVar()->evaluate().as() == 0.0); - NVF_ERROR(wop->initN()->evaluate().as() == 0); - - NVF_ERROR(wop->outAvg()->isA()); - auto* out_avg_tv = wop->outAvg()->as(); - Tensor out_avg = fd_->defineTensor(out_avg_tv->nDims()); - map_val_to_fd_index_.emplace(wop->outAvg(), out_avg()); - - NVF_ERROR(wop->outVar()->isA()); - auto* out_var_tv = wop->outVar()->as(); - Tensor out_var = fd_->defineTensor(out_var_tv->nDims()); - map_val_to_fd_index_.emplace(wop->outVar(), out_var()); - - NVF_ERROR(wop->outN()->isA()); - auto* out_N_tv = wop->outN()->as(); - Tensor out_N = fd_->defineTensor(out_N_tv->nDims()); - map_val_to_fd_index_.emplace(wop->outN(), out_N()); - - fd_->defineRecord(new WelfordOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(wop->inAvg()))}, - {fd_->recordingState(out_avg()), - fd_->recordingState(out_var()), - fd_->recordingState(out_N())}, - getReductionAxes(out_avg_tv))); - } - - // If input and output values share the same type, a LoadStoreOp will be - // created instead of a CastOp. - void handle(const LoadStoreOp* lsop) final { - // short-circuit: lsop is a permutation. - if (lsop->out()->isA() && - lsop->out()->as()->hasRoot()) { - return handlePermute(lsop); - } - - // Skip set unary operation - size_t input_fid = map_val_to_fd_index_.at(lsop->in()); - map_val_to_fd_index_.emplace(lsop->out(), input_fid); - } - - // Add DimsOpRecord to create permutation in FusionDefinition - void handlePermute(const LoadStoreOp* lsop) { - auto* out_tv = lsop->out()->as(); - - std::optional> new2old = ir_utils::computePermutation( - out_tv->getRootDomain(), out_tv->getLogicalDomain()); - NVF_ERROR(new2old.has_value(), "Expected permutation"); - - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - fd_->defineRecord(new DimsOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(lsop->in()))}, - {fd_->recordingState(output())}, - std::move(new2old.value()), - "ops.permute")); - } - - // Add Broadcast operation to FusionDefinition - void handle(const BroadcastOp* bcast_op) final { - Tensor output = - fd_->defineTensor(bcast_op->out()->as()->nDims()); - fd_->defineRecord(new BroadcastOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(bcast_op->in()))}, - {fd_->recordingState(output())}, - "ops.broadcast", - bcast_op->getBroadcastDimFlags())); - map_val_to_fd_index_.emplace(bcast_op->out(), output()); - } - - // Map SqueezeOp to python frontend - void handle(const SqueezeOp* sop) final { - std::vector squeeze_dims; - const std::vector& is_squeeze_dims = sop->getSqueezeDimFlags(); - for (int64_t dim : arange((int64_t)is_squeeze_dims.size())) { - if (is_squeeze_dims.at(dim)) { - squeeze_dims.push_back(dim); - } - } - - // Always squeeze_expanded dimensions - Tensor output = fd_->defineTensor(sop->out()->as()->nDims()); - map_val_to_fd_index_.emplace(sop->out(), output()); - fd_->defineRecord(new SqueezeOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(sop->in()))}, - {fd_->recordingState(output())}, - squeeze_dims, - /*squeeze_expanded=*/true)); - } - - // Map ReshapeOp to python frontend - void handle(const ReshapeOp* vop) final { - // Get extent's for output's logical domain - auto* out_tv = vop->out()->as(); - Vector new_shape = getShape(out_tv); - - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - fd_->defineRecord(new ReshapeOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(vop->in())), - fd_->recordingState(new_shape())}, - {fd_->recordingState(output())})); - } - - // Map ExpandOp to python frontend - void handle(const ExpandOp* eop) final { - auto* in_tv = eop->in()->as(); - auto* out_tv = eop->out()->as(); - NVF_ERROR(in_tv->nDims() == out_tv->nDims()); - Vector new_shape = getShape(out_tv); - - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - fd_->defineRecord(new ExpandOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(eop->in())), - fd_->recordingState(new_shape())}, - {fd_->recordingState(output())})); - } - - // Map SliceOp to python frontend - void handle(const SliceOp* sop) final { - std::vector slices = sop->getRanges(); - - std::vector start_indices; - start_indices.reserve(slices.size()); - - std::vector stop_indices; - stop_indices.reserve(slices.size()); - - std::vector strides; - strides.reserve(slices.size()); - - for (const nvfuser::Slice& s : slices) { - start_indices.push_back(s.start); - stop_indices.push_back(s.stop); - strides.push_back(s.step); - } - - Vector new_start = createVector(start_indices); - Vector new_stop = createVector(stop_indices); - Vector new_strides = createVector(strides); - - Tensor output = fd_->defineTensor(sop->out()->as()->nDims()); - map_val_to_fd_index_.emplace(sop->out(), output()); - fd_->defineRecord(new SliceOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(sop->in())), - fd_->recordingState(new_start()), - fd_->recordingState(new_stop()), - fd_->recordingState(new_strides())}, - {fd_->recordingState(output())}, - /*manual_normalization=*/true)); - } - - // Map PadOp to python frontend - void handle(const PadOp* pad_op) final { - Tensor output = fd_->defineTensor(pad_op->out()->as()->nDims()); - map_val_to_fd_index_.emplace(pad_op->out(), output()); - - // Step 1: Get pad widths in normalized order. - std::vector normalized_pad_widths = pad_op->getPadWidths(); - const int64_t total_size = (int64_t)normalized_pad_widths.size(); - - // Step 2: Get indices for normalized pad widths. - std::vector normalized_indices(total_size); - std::iota(normalized_indices.begin(), normalized_indices.end(), 0); - - // Step 3: Transform to indices for original pad widths - std::vector original_indices; - original_indices.reserve(normalized_indices.size()); - std::transform( - normalized_indices.begin(), - normalized_indices.end(), - std::back_inserter(original_indices), - [=](int64_t normalized_idx) { - int64_t offset = total_size - normalized_idx; - int64_t dim = ceilDiv(offset, 2) - 1; - - int64_t original_idx = dim * 2; - // right pad values require an additional offset - if (offset % 2 == 1) { - original_idx += 1; - } - return original_idx; - }); - - // Step 4: Get pad widths in original order. - std::vector original_order_pad_widths(total_size, nullptr); - for (int64_t normalized_idx : normalized_indices) { - original_order_pad_widths.at(original_indices.at(normalized_idx)) = - normalized_pad_widths.at(normalized_idx); - } - - // Check that no pad width values are nullptr. - NVF_ERROR(std::all_of( - original_order_pad_widths.begin(), - original_order_pad_widths.end(), - [](Val* v) { return v != nullptr; })); - - Vector pad_widths = createVector(original_order_pad_widths); - fd_->defineRecord(new PadOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(pad_op->in())), - fd_->recordingState(pad_widths()), - fd_->recordingState(map_val_to_fd_index_.at(pad_op->value()))}, - {fd_->recordingState(output())})); - } - - // Map CatOp to python frontend - void handle(const CatOp* cat_op) final { - Tensor output = - fd_->defineTensor(cat_op->output(0)->as()->nDims()); - map_val_to_fd_index_.emplace(cat_op->output(0), output()); - - std::vector tensor_states; - tensor_states.reserve(cat_op->inputs().size()); - std::transform( - cat_op->inputs().begin(), - cat_op->inputs().end(), - std::back_inserter(tensor_states), - [&](Val* v) { - return fd_->recordingState(map_val_to_fd_index_.at(v)); - }); - - fd_->defineRecord(new CatOpRecord( - tensor_states, - {fd_->recordingState(output())}, - cat_op->concatenatedDim(), - /*manual_padding=*/true)); - } - - // Map RNGOp to RandomDistOpRecord - void handle(const RNGOp* rop) final { - auto* out_tv = rop->output(0)->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - std::vector arg_states; - - // arg1 and arg2 are minval and maxval for uniform. - // arg1 and arg2 are mean and std for normal. - std::vector params = rop->getParameters(); - if (params.empty()) { - // Default arg1 and arg2 is (0, 1) for both uniform and normal. - Scalar zero_value = createScalar(fusion_->zeroVal()); - Scalar one_value = createScalar(fusion_->oneVal()); - arg_states.push_back(fd_->recordingState(zero_value())); - arg_states.push_back(fd_->recordingState(one_value())); - } else { - NVF_ERROR( - params.size() == 2, - "Expect only two parameters for uniform and normal random ops."); - std::transform( - params.begin(), - params.end(), - std::back_inserter(arg_states), - [&](Val* v) { - return fd_->recordingState(map_val_to_fd_index_.at(v)); - }); - } - - Vector out_shape = createVector(rop->getShape()); - arg_states.push_back(fd_->recordingState(out_shape())); - - // The philox seed and offset are optional. - if (rop->getRNGSeedVal() != nullptr) { - arg_states.push_back( - fd_->recordingState(map_val_to_fd_index_.at(rop->getRNGSeedVal()))); - } - if (rop->getRNGOffsetVal() != nullptr) { - arg_states.push_back( - fd_->recordingState(map_val_to_fd_index_.at(rop->getRNGOffsetVal()))); - } - - switch (rop->getRNGOpType()) { - case RNGOpType::Uniform: - case RNGOpType::UniformRange: - fd_->defineRecord( - new RandomDistOpRecord( - arg_states, - {fd_->recordingState(output())}, - std::get(out_tv->dtype().type))); - break; - case RNGOpType::NormalStandard: - case RNGOpType::NormalGeneral: - fd_->defineRecord( - new RandomDistOpRecord( - arg_states, - {fd_->recordingState(output())}, - std::get(out_tv->dtype().type))); - break; - default: - NVF_ERROR(false, "Unsupported RNGOpType."); - } - } - - // Map LinearOp to python frontend - void handle(const LinearOp* lop) final { - auto* out_tv = lop->out()->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - if (lop->bias() != nullptr) { - fd_->defineRecord( - new OpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(lop->inA())), - fd_->recordingState(map_val_to_fd_index_.at(lop->inB())), - fd_->recordingState(map_val_to_fd_index_.at(lop->bias()))}, - {fd_->recordingState(output())}, - ("ops.linear"), - serde::RecordType::Ternary_TV, - static_cast< - TensorView* (*)(TensorView*, TensorView*, TensorView*)>( - linear))); - } else { - fd_->defineRecord(new OpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(lop->inA())), - fd_->recordingState(map_val_to_fd_index_.at(lop->inB()))}, - {fd_->recordingState(output())}, - ("ops.linear"), - serde::RecordType::Binary_TV, - static_cast(linear))); - } - } - - // Map FullOp to python frontend - void handle(const FullOp* fop) final { - auto* out_tv = fop->output(0)->as(); - Vector tensor_shape = getShape(out_tv); - - Scalar fill_value = createScalar(fop->getFillValue()); - - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - fd_->defineRecord(new FullOpRecord( - {fd_->recordingState(tensor_shape()), - fd_->recordingState(fill_value())}, - {fd_->recordingState(output())}, - std::get(out_tv->dtype().type))); - } - - // Map IotaOp to python frontend - void handle(const IotaOp* iop) final { - auto* out_tv = iop->output(0)->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - Scalar length = createScalar(iop->length()); - Scalar start = createScalar(iop->start()); - Scalar step = createScalar(iop->step()); - - fd_->defineRecord(new IotaOpRecord( - {fd_->recordingState(length()), - fd_->recordingState(start()), - fd_->recordingState(step())}, - {fd_->recordingState(output())}, - std::get(iop->dtype().type))); - } - - // Map IndexSelectOp to IndexSelectOpRecord - void handle(const IndexSelectOp* isop) final { - auto* out_tv = isop->output(0)->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - fd_->defineRecord(new IndexSelectOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(isop->lookupTv())), - fd_->recordingState(map_val_to_fd_index_.at(isop->indexTv()))}, - {fd_->recordingState(output())}, - isop->dim())); - } - - // Map SelectOp to IndexSelectOpRecord - void handle(const SelectOp* sop) final { - auto* out_tv = sop->output(0)->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - fd_->defineRecord(new SelectOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(sop->lookupTv())), - fd_->recordingState(map_val_to_fd_index_.at(sop->input(1)))}, - {fd_->recordingState(output())}, - sop->dim())); - } - - // Map ScatterOp to python frontend - void handle(const ScatterOp* sop) final { - auto* out_tv = sop->output(0)->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - fd_->defineRecord(new ScatterOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(sop->in())), - fd_->recordingState(map_val_to_fd_index_.at(sop->index())), - fd_->recordingState(map_val_to_fd_index_.at(sop->src()))}, - {fd_->recordingState(output())}, - sop->dim())); - } - - // Map ArgsortOp to python frontend - void handle(const ArgsortOp* argsortop) final { - auto* out_tv = argsortop->output(0)->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - fd_->defineRecord(new ArgsortOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(argsortop->in()))}, - {fd_->recordingState(output())}, - argsortop->dim(), - argsortop->isDescending(), - argsortop->isStable())); - } - - // Map GroupedMmaOp to python frontend - void handle(const GroupedMmaOp* gmm_op) final { - TensorView* out_tv = gmm_op->out(); - Tensor output = fd_->defineTensor( - TensorDomain::noReductions(out_tv->getLogicalDomain()).size()); - map_val_to_fd_index_.emplace(out_tv, output()); - - int64_t out_block_scale_size = 0; - PrimDataType out_block_scale_dtype = DataType::BFloat16; - bool out_gamma = false; - - TensorView* out_block_scale_tv = gmm_op->outScale(); - if (out_block_scale_tv != nullptr) { - Tensor output_block_scale = fd_->defineTensor( - TensorDomain::noReductions(out_block_scale_tv->getLogicalDomain()) - .size()); - map_val_to_fd_index_.emplace(out_block_scale_tv, output_block_scale()); - auto block_size_extent = out_block_scale_tv->axis(-1)->extent(); - NVF_CHECK( - block_size_extent->isConstInt(), - "Block size extent needs to be a constant integer"); - out_block_scale_size = block_size_extent->evaluate().as(); - out_block_scale_dtype = - std::get(out_block_scale_tv->dtype().type); - } - TensorView* out_gamma_tv = gmm_op->outGamma(); - if (out_gamma_tv != nullptr) { - Tensor output_gamma = fd_->defineTensor( - TensorDomain::noReductions(out_gamma_tv->getLogicalDomain()).size()); - map_val_to_fd_index_.emplace(out_gamma_tv, output_gamma()); - out_gamma = true; - } - - if (gmm_op->inputs().size() == 3) { - fd_->defineRecord( - new OpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(gmm_op->matrix1())), - fd_->recordingState(map_val_to_fd_index_.at(gmm_op->matrix2())), - fd_->recordingState(map_val_to_fd_index_.at(gmm_op->offsets()))}, - {fd_->recordingState(output())}, - ("ops.grouped_mm"), - serde::RecordType::Ternary_TV, - static_cast< - TensorView* (*)(TensorView*, TensorView*, TensorView*)>( - [](TensorView* matrix1, - TensorView* matrix2, - TensorView* offsets) { - ScaledTensorView scaled_out = - grouped_mm(matrix1, matrix2, offsets); - return scaled_out.tv; - }))); - } else { - fd_->defineRecord(new ScaledGroupedMmaOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(gmm_op->matrix1())), - fd_->recordingState(map_val_to_fd_index_.at(gmm_op->matrix2())), - fd_->recordingState(map_val_to_fd_index_.at(gmm_op->offsets())), - gmm_op->hasScale() - ? fd_->recordingState(map_val_to_fd_index_.at(gmm_op->scale1())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - gmm_op->hasScale() - ? fd_->recordingState(map_val_to_fd_index_.at(gmm_op->scale2())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - gmm_op->hasAlpha() - ? fd_->recordingState(map_val_to_fd_index_.at(gmm_op->alpha())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - gmm_op->hasBias() - ? fd_->recordingState(map_val_to_fd_index_.at(gmm_op->bias())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - gmm_op->hasBeta() - ? fd_->recordingState(map_val_to_fd_index_.at(gmm_op->beta())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None)}, - {fd_->recordingState(output()), - out_block_scale_tv != nullptr - ? fd_->recordingState( - map_val_to_fd_index_.at(out_block_scale_tv)) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - out_gamma_tv != nullptr - ? fd_->recordingState(map_val_to_fd_index_.at(out_gamma_tv)) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None)}, - std::get(out_tv->dtype().type), - out_block_scale_size, - out_block_scale_dtype, - out_gamma)); - } - } - - // Map ScaledMmaOp to python frontend - void handle(const ScaledMmaOp* smm_op) final { - int64_t out_block_scale_size = 0; - PrimDataType out_block_scale_dtype = DataType::BFloat16; - bool out_gamma = false; - - TensorView* out_tv = smm_op->out(); - TensorView* out_block_scale_tv = smm_op->outScale(); - if (out_block_scale_tv != nullptr) { - Tensor output_block_scale = fd_->defineTensor( - TensorDomain::noReductions(out_block_scale_tv->getLogicalDomain()) - .size()); - map_val_to_fd_index_.emplace(out_block_scale_tv, output_block_scale()); - auto block_size_extent = out_block_scale_tv->axis(-1)->extent(); - NVF_CHECK( - block_size_extent->isConstInt(), - "Block size extent needs to be a constant integer"); - out_block_scale_size = block_size_extent->evaluate().as(); - out_block_scale_dtype = - std::get(out_block_scale_tv->dtype().type); - } - - TensorView* out_gamma_tv = smm_op->outGamma(); - if (out_gamma_tv != nullptr) { - Tensor output_gamma = fd_->defineTensor( - TensorDomain::noReductions(out_gamma_tv->getLogicalDomain()).size()); - map_val_to_fd_index_.emplace(out_gamma_tv, output_gamma()); - out_gamma = true; - } - - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - fd_->defineRecord(new ScaledMmaOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(smm_op->matrix1())), - fd_->recordingState(map_val_to_fd_index_.at(smm_op->matrix2())), - fd_->recordingState(map_val_to_fd_index_.at(smm_op->scale1())), - fd_->recordingState(map_val_to_fd_index_.at(smm_op->scale2())), - smm_op->hasAlpha() - ? fd_->recordingState(map_val_to_fd_index_.at(smm_op->alpha())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - smm_op->hasBias() - ? fd_->recordingState(map_val_to_fd_index_.at(smm_op->bias())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - smm_op->hasBeta() - ? fd_->recordingState(map_val_to_fd_index_.at(smm_op->beta())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None)}, - {fd_->recordingState(output()), - out_block_scale_tv != nullptr - ? fd_->recordingState(map_val_to_fd_index_.at(out_block_scale_tv)) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None), - out_gamma_tv != nullptr - ? fd_->recordingState(map_val_to_fd_index_.at(out_gamma_tv)) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None)}, - std::get(out_tv->dtype().type), - out_block_scale_size, - out_block_scale_dtype, - out_gamma)); - } - - // Map TopKOp to python frontend - void handle(const TopKOp* topkop) final { - // Create outputs for this RecordFunctor - std::vector fd_outputs; - fd_outputs.reserve(topkop->outputs().size()); - std::transform( - topkop->outputs().begin(), - topkop->outputs().end(), - std::back_inserter(fd_outputs), - [&](Val* v) { - NVF_ERROR(v->isA()); - Tensor output = fd_->defineTensor(v->as()->nDims()); - map_val_to_fd_index_.emplace(v, output()); - return fd_->recordingState(output()); - }); - - fd_->defineRecord(new TopKOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(topkop->in())), - fd_->recordingState(map_val_to_fd_index_.at(topkop->k()))}, - fd_outputs, - topkop->dim(), - topkop->isLargest(), - topkop->isSorted())); - } - - void handle(const ScanOp* scan_op) final { - auto out_tv = scan_op->out()->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - NVF_ERROR( - scan_op->opType() == BinaryOpType::Add, - "Only cumsum (BinaryOpType::Add) is supported for ScanOp."); - - fd_->defineRecord(new ScanOpRecord( - {fd_->recordingState( - map_val_to_fd_index_.at(scan_op->in()->as()))}, - {fd_->recordingState(output())}, - ("ops.cumsum"), - serde::RecordType::ScanOpCumsum, - static_cast(cumsum), - scan_op->dim(), - BinaryOpType::Add)); - } - - // Map GatherOp to python frontend - void handle(const GatherOp* gop) final { - auto* out_tv = gop->output(0)->as(); - Tensor output = fd_->defineTensor(out_tv->nDims()); - map_val_to_fd_index_.emplace(out_tv, output()); - - fd_->defineRecord(new GatherOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(gop->lookupTv())), - fd_->recordingState(map_val_to_fd_index_.at(gop->indexTv()))}, - {fd_->recordingState(output())}, - gop->dim())); - } - - // Map MatmulOp to TensorView-Only OpRecord - void handle(const MatmulOp* matmul_op) final { - Tensor output = - fd_->defineTensor(matmul_op->out()->as()->nDims()); - map_val_to_fd_index_.emplace(matmul_op->out(), output()); - - fd_->defineRecord(new OpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(matmul_op->inA())), - fd_->recordingState(map_val_to_fd_index_.at(matmul_op->inB()))}, - {fd_->recordingState(output())}, - ("ops.matmul"), - serde::RecordType::Binary_TV, - static_cast(matmul))); - } - - // Map SdpaFwdOp to SdpaFwdOpRecord - void handle(const SdpaFwdOp* sdpa_fwd_op) final { - // Create outputs for this RecordFunctor - std::vector fd_outputs; - fd_outputs.reserve(sdpa_fwd_op->outputs().size()); - std::transform( - sdpa_fwd_op->outputs().begin(), - sdpa_fwd_op->outputs().end(), - std::back_inserter(fd_outputs), - [&](Val* v) { - NVF_ERROR(v->isA()); - Tensor output = fd_->defineTensor(v->as()->nDims()); - map_val_to_fd_index_.emplace(v, output()); - return fd_->recordingState(output()); - }); - - State bias_state = (sdpa_fwd_op->bias() != nullptr) - ? fd_->recordingState(map_val_to_fd_index_.at(sdpa_fwd_op->bias())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - State mask_state = (sdpa_fwd_op->mask() != nullptr) - ? fd_->recordingState(map_val_to_fd_index_.at(sdpa_fwd_op->mask())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - State dropout_p_state = (sdpa_fwd_op->dropout_p() != nullptr) - ? fd_->recordingState(map_val_to_fd_index_.at(sdpa_fwd_op->dropout_p())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - State is_causal_state = (sdpa_fwd_op->is_causal() != nullptr) - ? fd_->recordingState(map_val_to_fd_index_.at(sdpa_fwd_op->is_causal())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - State scale_state = (sdpa_fwd_op->scale() != nullptr) - ? fd_->recordingState(map_val_to_fd_index_.at(sdpa_fwd_op->scale())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - fd_->defineRecord(new SdpaFwdOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(sdpa_fwd_op->query())), - fd_->recordingState(map_val_to_fd_index_.at(sdpa_fwd_op->key())), - fd_->recordingState(map_val_to_fd_index_.at(sdpa_fwd_op->value())), - bias_state, - mask_state, - dropout_p_state, - is_causal_state, - scale_state}, - fd_outputs)); - } - - // Map SdpaBwdOp to SdpaBwdOpRecord - void handle(const SdpaBwdOp* sdpa_bwd_op) final { - // Create outputs for this RecordFunctor - std::vector fd_outputs; - fd_outputs.reserve(sdpa_bwd_op->outputs().size()); - std::transform( - sdpa_bwd_op->outputs().begin(), - sdpa_bwd_op->outputs().end(), - std::back_inserter(fd_outputs), - [&](Val* v) { - NVF_ERROR(v->isA()); - Tensor output = fd_->defineTensor(v->as()->nDims()); - map_val_to_fd_index_.emplace(v, output()); - return fd_->recordingState(output()); - }); - - State dropout_p_state = (sdpa_bwd_op->dropout_p() != nullptr) - ? fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->dropout_p())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - State is_causal_state = (sdpa_bwd_op->is_causal() != nullptr) - ? fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->is_causal())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - State scale_state = (sdpa_bwd_op->scale() != nullptr) - ? fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->scale())) - : State(/*_index=*/0, /*_stype=*/serde::StateType::None); - - fd_->defineRecord(new SdpaBwdOpRecord( - {fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->grad_attn())), - fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->query())), - fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->key())), - fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->value())), - fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->attn_out())), - fd_->recordingState(map_val_to_fd_index_.at(sdpa_bwd_op->logsumexp())), - dropout_p_state, - is_causal_state, - fd_->recordingState( - map_val_to_fd_index_.at(sdpa_bwd_op->philox_seed())), - fd_->recordingState( - map_val_to_fd_index_.at(sdpa_bwd_op->philox_offset())), - scale_state}, - fd_outputs)); - } - - private: - //! The reference CPP fusion to be translated. - Fusion* fusion_ = nullptr; - //! The blank FusionDefinition that receives the RecordFunctors for - //! translated CPP values and expressions. - FusionDefinition* fd_ = nullptr; - //! Map nvfuser Val to FusionDefinition index. - std::unordered_map map_val_to_fd_index_; -}; - -} // namespace - -std::unordered_map translate( - Fusion* fusion, - FusionDefinition* fd) { - return FusionTranslator::translate(fusion, fd); -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/translation.h b/python/python_frontend/translation.h deleted file mode 100644 index b7378c43566..00000000000 --- a/python/python_frontend/translation.h +++ /dev/null @@ -1,20 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once -#include -#include -#include - -namespace nvfuser::python_frontend { - -// Translate a CPP Fusion into a Python FusionDefinition. -NVF_API std::unordered_map translate( - Fusion* fusion, - FusionDefinition* fd); - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/translation_utils.cpp b/python/python_frontend/translation_utils.cpp deleted file mode 100644 index 3d5f7b3a392..00000000000 --- a/python/python_frontend/translation_utils.cpp +++ /dev/null @@ -1,80 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -// -#include -#include - -namespace nvfuser::python_frontend { - -#define GET_FUNCTION_TERNARY_SPECIALIZATION_DEFINITION( \ - ResultType, InType1, InType2, InType3) \ - template <> \ - std::function \ - getFunction(const TernaryOp* top) { \ - auto wrap_function = [](ResultType (*fn)(InType1, InType2, InType3)) { \ - return fn; \ - }; \ - \ - switch (top->getTernaryOpType()) { \ - case TernaryOpType::Clamp: \ - return wrap_function(clamp); \ - break; \ - case TernaryOpType::Lerp: \ - return wrap_function(lerp); \ - break; \ - case TernaryOpType::Threshold: \ - return wrap_function(threshold); \ - break; \ - case TernaryOpType::Where: \ - return wrap_function(where); \ - break; \ - default: \ - NVF_CHECK( \ - false, \ - "Unexpected operator type: ", \ - top->getTernaryOpType(), \ - " in ", \ - top->toString()); \ - } \ - } - -// Fully specialized template functions to create std::function for TernaryOp. -GET_FUNCTION_TERNARY_SPECIALIZATION_DEFINITION( - TensorView*, - TensorView*, - Val*, - Val*) -GET_FUNCTION_TERNARY_SPECIALIZATION_DEFINITION(Val*, Val*, Val*, Val*) - -serde::RecordType getSerdeType(const ReductionOp* rop) { - switch (rop->getReductionOpType()) { - case BinaryOpType::Add: - return serde::RecordType::ReductionSum; - break; - case BinaryOpType::Mul: - return serde::RecordType::ReductionProd; - break; - case BinaryOpType::FMax: - case BinaryOpType::Max: - return serde::RecordType::ReductionMax; - break; - case BinaryOpType::FMin: - case BinaryOpType::Min: - return serde::RecordType::ReductionMin; - break; - default: - NVF_CHECK( - false, - "Unexpected reduction operator type: ", - rop->getReductionOpType(), - " in ", - rop->toString()); - } -} - -} // namespace nvfuser::python_frontend diff --git a/python/python_frontend/translation_utils.h b/python/python_frontend/translation_utils.h deleted file mode 100644 index 9dddb31d8a6..00000000000 --- a/python/python_frontend/translation_utils.h +++ /dev/null @@ -1,300 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once -#include -#include - -namespace nvfuser::python_frontend { - -// Get std::function for UnaryOp -template -std::function getFunction(const UnaryOp* uop) { - auto wrap_function = [](ResultType (*fn)(ArgTypes...)) { return fn; }; - - switch (uop->getUnaryOpType()) { - case UnaryOpType::Abs: - return wrap_function(abs); - case UnaryOpType::Acos: - return wrap_function(acos); - case UnaryOpType::Acosh: - return wrap_function(acosh); - case UnaryOpType::Asin: - return wrap_function(asin); - case UnaryOpType::Asinh: - return wrap_function(asinh); - case UnaryOpType::Atan: - return wrap_function(atan); - case UnaryOpType::Atanh: - return wrap_function(atanh); - case UnaryOpType::Ceil: - return wrap_function(ceil); - case UnaryOpType::Cos: - return wrap_function(cos); - case UnaryOpType::Cosh: - return wrap_function(cosh); - case UnaryOpType::Exp: - return wrap_function(exp); - case UnaryOpType::Exp2: - return wrap_function(exp2); - case UnaryOpType::Expm1: - return wrap_function(expm1); - case UnaryOpType::Erf: - return wrap_function(erf); - case UnaryOpType::Erfc: - return wrap_function(erfc); - case UnaryOpType::Erfinv: - return wrap_function(erfinv); - case UnaryOpType::Erfcinv: - return wrap_function(erfcinv); - case UnaryOpType::Floor: - return wrap_function(floor); - case UnaryOpType::Frac: - return wrap_function(frac); - case UnaryOpType::Lgamma: - return wrap_function(lgamma); - case UnaryOpType::Log: - return wrap_function(log); - case UnaryOpType::Log10: - return wrap_function(log10); - case UnaryOpType::Log1p: - return wrap_function(log1p); - case UnaryOpType::Log2: - return wrap_function(log2); - case UnaryOpType::Neg: - return wrap_function(neg); - case UnaryOpType::LogicalNot: - return wrap_function(logical_not); - case UnaryOpType::BitwiseNot: - return wrap_function(bitwise_not); - case UnaryOpType::Reciprocal: - return wrap_function(reciprocal); - case UnaryOpType::Relu: - return wrap_function(relu); - case UnaryOpType::Rsqrt: - return wrap_function(rsqrt); - case UnaryOpType::Round: - return wrap_function(round); - case UnaryOpType::Sigmoid: - return wrap_function(sigmoid); - case UnaryOpType::Signbit: - return wrap_function(signbit); - case UnaryOpType::Silu: - return wrap_function(silu); - case UnaryOpType::Sin: - return wrap_function(sin); - case UnaryOpType::Sinh: - return wrap_function(sinh); - case UnaryOpType::Sqrt: - return wrap_function(sqrt); - case UnaryOpType::Tan: - return wrap_function(tan); - case UnaryOpType::Tanh: - return wrap_function(tanh); - case UnaryOpType::Trunc: - return wrap_function(trunc); - case UnaryOpType::IsFinite: - return wrap_function(isfinite); - case UnaryOpType::IsInf: - return wrap_function(isinf); - case UnaryOpType::IsNan: - return wrap_function(isnan); - case UnaryOpType::IsNegInf: - return wrap_function(isneginf); - case UnaryOpType::IsPosInf: - return wrap_function(isposinf); - case UnaryOpType::IsReal: - return wrap_function(isreal); - case UnaryOpType::Real: - return wrap_function(real); - case UnaryOpType::Imag: - return wrap_function(imag); - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - uop->getUnaryOpType(), - " in ", - uop->toString()); - } -} - -// Get std::function for BinaryOp -template -std::function getFunction(const BinaryOp* bop) { - auto wrap_function = [](ResultType (*fn)(ArgTypes...)) { return fn; }; - - switch (bop->getBinaryOpType()) { - case BinaryOpType::Add: - return wrap_function(add); - break; - case BinaryOpType::Atan2: - return wrap_function(atan2); - break; - case BinaryOpType::Div: - return wrap_function(div); - break; - case BinaryOpType::Fmod: - return wrap_function(fmod); - break; - case BinaryOpType::Mul: - return wrap_function(mul); - break; - case BinaryOpType::Nextafter: - return wrap_function(nextafter); - break; - case BinaryOpType::Pow: - return wrap_function(pow); - break; - case BinaryOpType::Remainder: - return wrap_function(remainder); - break; - case BinaryOpType::Sub: - return wrap_function(sub); - break; - case BinaryOpType::Mod: - return wrap_function(mod); - break; - case BinaryOpType::Eq: - return wrap_function(eq); - break; - case BinaryOpType::NE: - return wrap_function(ne); - break; - case BinaryOpType::GT: - return wrap_function(gt); - break; - case BinaryOpType::GE: - return wrap_function(ge); - break; - case BinaryOpType::LT: - return wrap_function(lt); - break; - case BinaryOpType::LE: - return wrap_function(le); - break; - case BinaryOpType::BitwiseAnd: - return wrap_function(bitwise_and); - break; - case BinaryOpType::BitwiseOr: - return wrap_function(bitwise_or); - break; - case BinaryOpType::BitwiseXor: - return wrap_function(bitwise_xor); - break; - case BinaryOpType::LogicalAnd: - return wrap_function(logical_and); - break; - case BinaryOpType::LogicalOr: - return wrap_function(logical_or); - break; - case BinaryOpType::Lshift: - return wrap_function(bitwise_left_shift); - break; - case BinaryOpType::Rshift: - return wrap_function(bitwise_right_shift); - break; - case BinaryOpType::Gcd: - return wrap_function(gcd); - break; - case BinaryOpType::FMin: - case BinaryOpType::Min: - return wrap_function(minimum); - break; - case BinaryOpType::FMax: - case BinaryOpType::Max: - return wrap_function(maximum); - break; - case BinaryOpType::CeilDiv: - return wrap_function(ceilDiv); - break; - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - bop->getBinaryOpType(), - " in ", - bop->toString()); - } -} - -// Get std::function for TernaryOp -template -std::function getFunction(const TernaryOp* top) { - auto wrap_function = [](ResultType (*fn)(ArgTypes...)) { return fn; }; - - // clamp and threshold define a subset of TernaryOp configurations, so they - // are handled in a separate template specialization. - switch (top->getTernaryOpType()) { - case TernaryOpType::Lerp: - return wrap_function(lerp); - break; - case TernaryOpType::Where: - return wrap_function(where); - break; - case TernaryOpType::Threshold: - case TernaryOpType::Clamp: - NVF_CHECK( - false, - "Invalid function arguments for operator type", - top->getTernaryOpType(), - " in ", - top->toString()); - default: - NVF_CHECK( - false, - "Unexpected operator type: ", - top->getTernaryOpType(), - " in ", - top->toString()); - } -} - -// Fully specialized template functions to create std::function for TernaryOp. -template <> -std::function getFunction< - TensorView*, - TensorView*, - Val*, - Val*>(const TernaryOp* top); - -template <> -std::function getFunction( - const TernaryOp* top); - -// Get std::function for ReductionOp -template -std::function getFunction(const ReductionOp* rop) { - switch (rop->getReductionOpType()) { - case BinaryOpType::Add: - return sum; - break; - case BinaryOpType::Mul: - return prod; - break; - case BinaryOpType::FMax: - case BinaryOpType::Max: - return max; - break; - case BinaryOpType::FMin: - case BinaryOpType::Min: - return min; - break; - default: - NVF_CHECK( - false, - "Unexpected reduction operator type: ", - rop->getReductionOpType(), - " in ", - rop->toString()); - } -} - -// Get serde record type for ReductionOp -serde::RecordType getSerdeType(const ReductionOp* rop); - -} // namespace nvfuser::python_frontend diff --git a/python/utils.py b/python/utils.py index 272d347c23e..09067ec48a8 100644 --- a/python/utils.py +++ b/python/utils.py @@ -310,10 +310,7 @@ def copy_library(self, ext, library_name): self.copy_file(libnvfuser_path, install_dst) def build_extension(self, ext): - if ext.name == "nvfuser._C": - self.copy_library(ext, "libnvfuser") - self.copy_shared_library("libnvfuser_codegen.so") - elif ext.name == "nvfuser_direct._C_DIRECT": + if ext.name == "nvfuser_direct._C_DIRECT": self.copy_library(ext, "libnvfuser_direct") self.copy_shared_library("libnvfuser_codegen.so") else: @@ -550,8 +547,7 @@ def run(config, version_tag, relative_path): from setuptools import Extension, setup, find_packages # NOTE(crcrpar): Deliberately build basically two dynamic libraries here so that they can - # be treated as "nvfuser_package_data". This function call will put the two of "nvfuser" and - # "nvfuser_codegen" into "./nvfuser/lib", and the former will be "nvfuser._C". + # be treated as "nvfuser_package_data". if config.build_setup: cmake(config, relative_path) if not config.cmake_only: @@ -593,7 +589,6 @@ def run(config, version_tag, relative_path): description="A Fusion Code Generator for NVIDIA GPUs (commonly known as 'nvFuser')", packages=find_packages(), ext_modules=[ - Extension(name="nvfuser._C", sources=[]), Extension(name="nvfuser_direct._C_DIRECT", sources=[]), ], license_files=("LICENSE",), diff --git a/tests/python/direct/test_import.py b/tests/python/direct/test_import.py index d8507772df8..ee02ac97fc2 100644 --- a/tests/python/direct/test_import.py +++ b/tests/python/direct/test_import.py @@ -9,20 +9,3 @@ def test_import_correct(): import nvfuser_direct # noqa: F401 except Exception as e: raise RuntimeError("Failed to import nvfuser_direct.") - - -def test_import_conflict_direct_then_nvfuser(): - import warnings - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - import nvfuser_direct # noqa: F401 - import nvfuser # noqa: F401 - - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert ( - "Be careful! You've imported nvfuser when the nvfuser_direct module is already imported." - in str(w[-1].message) - ) diff --git a/tests/python/utils/__init__.py b/tests/python/utils/__init__.py deleted file mode 100644 index 862248c0d26..00000000000 --- a/tests/python/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from . import utils # noqa: F401,F403 -from .utils import * # noqa: F401,F403 diff --git a/tests/python/utils/utils.py b/tests/python/utils/utils.py deleted file mode 100644 index e8ce6eaebe5..00000000000 --- a/tests/python/utils/utils.py +++ /dev/null @@ -1,358 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# Owner(s): ["module: nvfuser"] - -import os -from copy import deepcopy -from typing import Callable, Optional -import tempfile -import torch -import pytest -from contextlib import contextmanager -from torch.testing import make_tensor -from torch.testing._internal.common_utils import TestCase -from looseversion import LooseVersion - -# flake8 complains about DataType being unused in this file but it is necessary -# to run captured fusion definition. -# flake8: noqa -from nvfuser import FusionCache, FusionDefinition, DataType, clone, Tensor - - -def is_pre_volta(): - prop = torch.cuda.get_device_properties(torch.cuda.current_device()) - return prop.major < 7 - - -def is_pre_ampere(): - prop = torch.cuda.get_device_properties(torch.cuda.current_device()) - return prop.major < 8 - - -def is_pre_hopper(): - prop = torch.cuda.get_device_properties(torch.cuda.current_device()) - return prop.major < 9 - - -def is_pre_blackwell(): - prop = torch.cuda.get_device_properties(torch.cuda.current_device()) - return prop.major < 10 - - -def verify_stride_order(output_strides, stride_order): - sorted_stride = list(output_strides) - rank = len(output_strides) - for idx, axis in enumerate(stride_order): - sorted_stride[rank - 1 - axis] = output_strides[idx] - assert sorted(sorted_stride, reverse=True) == sorted_stride - - -# torch.allclose does not work with fp8 datatype, so cast to fp64. -# However, casting complex values to real discards the imaginary -# part, so skip complex dtypes. -def compare_nvfuser_correctness(outputs, reference_outputs): - for idx, ref_out in enumerate(reference_outputs): - if not ref_out.dtype.is_complex: - ref_out = ref_out.to(torch.float64) - if not outputs[idx].dtype.is_complex: - outputs[idx] = outputs[idx].to(torch.float64) - match = torch.allclose(ref_out, outputs[idx], equal_nan=True) - if not match: - return False - return True - - -# Get string representation for FusionDefinition -# Run captured python definition -# Check that the result of captured python definition matches original results -def check_captured_python_definition(reference_outputs, fd, inputs, device=None): - import re - - try: - fd_str = fd.__repr__() - func_name = re.findall("(nvfuser_fusion_id\\d+)", fd_str.split("\n")[1])[0] - exec(fd_str) - - # Execute the python definition that was captured - with FusionDefinition() as fd_cap: - eval(func_name)(fd_cap) - - torch.manual_seed(0) - captured_outputs = fd_cap.execute(inputs, device=device) - return compare_nvfuser_correctness(captured_outputs, reference_outputs) - except Exception as err: - print("\nException For Printed FusionDefinition:") - print( - "(A failure here suggests a mismatch in functionality between the original definition and the printed definition.)" - ) - print(fd_str) - raise err - - -# Run original FusionDefinition -# Clone FusionDefinition -# Apply segmentation if it supported for this FusionDefinition -# Run cloned python definition -# Check that the result of cloned python definition matches original results -def check_cpp_translation( - reference_outputs, fd, inputs, supports_segmentation, device=None -): - try: - torch.manual_seed(0) - - # Clone - cloned_fd = FusionDefinition() - clone(fd, cloned_fd) - - # Segment - if supports_segmentation: - cloned_fd.segment(inputs) - - # Run - cloned_outputs = cloned_fd.execute(inputs, device=device) - return compare_nvfuser_correctness(cloned_outputs, reference_outputs) - except Exception as err: - print("\nException For CPP Translation:") - print( - "(A failure here suggests a mismatch in functionality between the original and cloned definitions.)" - ) - print("Does FusionDefinition supports segmentation?\t", supports_segmentation) - print(fd._repro_error_str("executing", inputs)) - raise err - - -# This DEBUG_SERDE environment flag is used to debug serialization failures. -# -# If DEBUG_SERDE=debug -# 1) It disables automatically saving FusionCache upon program exit. Therefore, -# it has to be a global flag not per-test. -# -# 2) It resets the FusionCache after each test, which is useful for isolating -# failures. Note, some failures only occur when running multiple tests -# together and accumulating fusions in the cache. -# -# 3) It keeps the temporary files that are created during serde_check. -# Normally, these files are deleted after each test. -# -# DEBUG_SERDE=disable -# 1) It disables the @nvfusertest_serde_check decorator. This disables checking -# that serde round-trips preserve the definition during testing. -env_var_debug_serde = os.getenv("DEBUG_SERDE", "").lower() -debug_serde: bool = env_var_debug_serde == "debug" -disable_serde: bool = env_var_debug_serde == "disable" -del env_var_debug_serde - - -# The pytest framework and test_python_frontend.py use different arguments for -# testing, so we need specific `serde_check` decorators for both frameworks. -# basic_serde_check is the common part between them. It serializes the cache, -# deletes it, and then deserialized to recreate the cache. -def basic_serde_check(): - # If DEBUG_SERDE is enabled, the temporary file is not deleted - # automatically - with tempfile.NamedTemporaryFile(delete=(not debug_serde)) as tmp: - try: - # Serialize FusionCache - fc = FusionCache.get() - fc.serialize(tmp.name) - - FusionCache.reset() - - # Get new FusionCache because the previous one was destroyed by - # the reset call. - fc = FusionCache.get() - assert fc.num_fusions() == 0 - fc.deserialize(tmp.name) - except Exception as e: - if debug_serde: - raise RuntimeError( - f"***** {tmp.name} contains the serialized binary for this failure." - ) - else: - raise RuntimeError( - "***** Use DEBUG_SERDE=debug to debug serialization failure." - ) - - -# Enable automatic serialization upon program exit and test deserializing the -# default workspace. NOTE: Serializing error test cases corrupts the serialized -# binary. Call FusionCache.reset() to clear the cache after running an error -# test in `test_python_frontend.py'. -def atexit_serde_check(): - if disable_serde: - # Ignore FusionCache and automatic serialization if serde check is - # disabled - return - - from nvfuser import FusionCache - - if not debug_serde: - from nvfuser import enable_automatic_serialization - - # Turn on default serialization upon program exit - enable_automatic_serialization() - - # Automatically load common workplace - fc = FusionCache.get() - # Clear FusionCache because the tests expect a new fusion to be generated. - FusionCache.reset() - - -def nvfusertest_serde_check(test_fn: Callable): - """ - A decorator to verify that serialization works with the given exec_nvfuser - function. Currently, it uses serialization to rebuild the FusionCache - structure. - """ - if disable_serde: - - def inner_fn(*args, **kwargs): - # Remove skip_serde_check if it was given - kwargs.pop("skip_serde_check", None) - return test_fn(*args, **kwargs) - - return inner_fn - - def inner_fn(*args, **kwargs): - self, fusion_func, inputs = args - - # NOTE: For debug purposes, clear FusionCache before running first test - # so the behavior is more deterministic (PR #1848). - is_new_fusion_expected = kwargs.get("new_fusion_expected", True) - if debug_serde and is_new_fusion_expected: - FusionCache.reset() - assert FusionCache.get().num_fusions() == 0 - - # skip_serde_check is only used by the decorator so remove it before - # running test_fn - skip_serde_check = kwargs.pop("skip_serde_check", False) - if skip_serde_check: - return test_fn(self, fusion_func, inputs, **kwargs) - - # Run test to populate FusionCache. Deep copy inputs for this run but - # not the final run. When a fusion output aliases an input, it will - # change the input value for subsequent function calls. Therefore, only - # the final run should take the original tensors and potentially update - # their values. - inputs_copy = deepcopy(inputs) - test_fn(self, fusion_func, inputs_copy, **kwargs) - - # Serialize and Deserialize FusionCache - basic_serde_check() - - # Run test with repopulated FusionCache - kwargs["new_fusion_expected"] = False - return test_fn(self, fusion_func, inputs, **kwargs) - - return inner_fn - - -UPDATED_SDPA = LooseVersion(torch.__version__) >= LooseVersion("2.7.0") - - -def define_sdpa_rng_state(fd: FusionDefinition) -> tuple[Tensor, Tensor]: - dtype = DataType.UInt64 if UPDATED_SDPA else DataType.Int - is_cpu = False if UPDATED_SDPA else True - philox_shape = [2] if UPDATED_SDPA else [] - philox_seed = fd.define_tensor( - shape=philox_shape, - dtype=dtype, - is_cpu=is_cpu, - ) - philox_offset = fd.define_tensor( - shape=[], - dtype=dtype, - is_cpu=is_cpu, - ) - return philox_seed, philox_offset - - -def create_sdpa_rng_tensors() -> tuple[torch.Tensor, torch.Tensor]: - dtype = torch.uint64 if UPDATED_SDPA else torch.int64 - device = "cuda" if UPDATED_SDPA else "cpu" - philox_shape = (2,) if UPDATED_SDPA else () - philox_seed = torch.testing.make_tensor(philox_shape, device=device, dtype=dtype) - philox_offset = torch.testing.make_tensor((), device=device, dtype=dtype) - return philox_seed, philox_offset - - -""" -Base class for any test class that needs to verify serialization -and run captured string representations of FusionDefinition. -""" - - -class NVFuserTest(TestCase): - @classmethod - def setup_class(cls): - """ - Setup is run once at the class level, before running any tests of the class. - `atexit_serde_check` enables automatic serialization at the end of the test suite. - """ - os.environ["NVIDIA_TF32_OVERRIDE"] = "0" - - atexit_serde_check() - - # Helper function to verify the nvfuser output and make sure the string - # definition based on the FusionDefinition is executable and matches the - # original definition - @nvfusertest_serde_check - def exec_nvfuser( - self, - fusion_func, - inputs, - *, - _enable_options=[], - _disable_options=[], - new_fusion_expected=True, - device=None, - is_clonable=True, - supports_segmentation=True, - ): - fc = FusionCache.get() - before_fusions = fc.num_fusions() - # Copy inputs because aliased outputs can modify inputs when running - # FusionDefinition - inputs_captured = deepcopy(inputs) - if is_clonable: - inputs_cloned = deepcopy(inputs) - - # Execute a fusion function and capture the string python definition - with FusionDefinition() as fd: - fusion_func(fd) - torch.manual_seed(0) - out = fd.execute( - inputs, - device=device, - _enable_options=_enable_options, - _disable_options=_disable_options, - ) - - self.assertTrue( - check_captured_python_definition(out, fd, inputs_captured, device) - ) - if not disable_serde: - self.assertEqual( - fc.num_fusions() - before_fusions, int(new_fusion_expected) - ) - - if is_clonable: - self.assertTrue( - check_cpp_translation(out, fd, inputs_cloned, supports_segmentation) - ) - return out, fd - - -@contextmanager -def set_env(**environ): - """ - Override environment variable - """ - old_environ = dict(os.environ) - os.environ.update(environ) - try: - yield - finally: - os.environ.clear() - os.environ.update(old_environ) diff --git a/tools/env-config/env_options.yaml b/tools/env-config/env_options.yaml index 749530f73bf..3c3958ef48e 100644 --- a/tools/env-config/env_options.yaml +++ b/tools/env-config/env_options.yaml @@ -371,18 +371,6 @@ options: category: dump env_var: NVFUSER_DUMP - - name: python_definition_segments - description: Python Frontend Fusion Definition of segments - type: bool - category: dump - env_var: NVFUSER_DUMP - - - name: python_frontend_debug - description: Python Frontend debug information - type: bool - category: dump - env_var: NVFUSER_DUMP - - name: transform_propagator description: Print propagation path and replay result type: bool From 5107f269ecdc730bf15083d22f4f7677ad3b1a0c Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 3 Feb 2026 12:54:38 -0800 Subject: [PATCH 2/3] bump version --- python/version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/version.txt b/python/version.txt index e7ccda1a357..d4ca806c1a4 100644 --- a/python/version.txt +++ b/python/version.txt @@ -1 +1 @@ -0.2.35 +0.2.36 From aa333dadbae0c1c0867a51a5b0ff9daae27a00f4 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Wed, 4 Feb 2026 14:59:13 -0800 Subject: [PATCH 3/3] Update tools/check_symbol_visibility.sh --- tools/check_symbol_visibility.sh | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tools/check_symbol_visibility.sh b/tools/check_symbol_visibility.sh index 9e369ae58a6..1f38318cb99 100755 --- a/tools/check_symbol_visibility.sh +++ b/tools/check_symbol_visibility.sh @@ -49,20 +49,13 @@ CUTLASS_LIB="$BUILD_DIR/libnvf_cutlass.so" check_file_exists "$NVFUSER_CODEGEN_LIB" || exit 1 # Find Python extension files -NVFUSER_EXT=$(find_python_extensions "$PYTHON_DIR/nvfuser" "_C*.so") NVFUSER_DIRECT_EXT=$(find_python_extensions "$PYTHON_DIR/nvfuser_direct" "_C_DIRECT*.so") -if [ -z "$NVFUSER_EXT" ]; then - echo "ERROR: nvfuser Python extension (_C*.so) not found in $PYTHON_DIR/nvfuser" - exit 1 -fi - if [ -z "$NVFUSER_DIRECT_EXT" ]; then echo "ERROR: nvfuser_direct Python extension (_C_DIRECT*.so) not found in $PYTHON_DIR/nvfuser_direct" exit 1 fi -echo "Found nvfuser extension: $NVFUSER_EXT" echo "Found nvfuser_direct extension: $NVFUSER_DIRECT_EXT" echo "" @@ -129,15 +122,13 @@ check_extension_symbols() { } # Check both extensions -NVFUSER_OK=0 NVFUSER_DIRECT_OK=0 -check_extension_symbols "$NVFUSER_EXT" "nvfuser" || NVFUSER_OK=1 check_extension_symbols "$NVFUSER_DIRECT_EXT" "nvfuser_direct" || NVFUSER_DIRECT_OK=1 # 4. Final results echo "=== FINAL RESULTS ===" -if [ $NVFUSER_OK -eq 0 ] && [ $NVFUSER_DIRECT_OK -eq 0 ]; then +if [ $NVFUSER_DIRECT_OK -eq 0 ]; then echo "✅ SUCCESS: All Python extensions have properly exported symbols" echo "" echo "Cleaning up temporary files..." @@ -146,9 +137,6 @@ if [ $NVFUSER_OK -eq 0 ] && [ $NVFUSER_DIRECT_OK -eq 0 ]; then else echo "❌ FAILURE: Missing symbols detected" echo "" - if [ $NVFUSER_OK -ne 0 ]; then - echo "- nvfuser extension has missing symbols (see $TEMP_DIR/nvfuser_missing_symbols.txt)" - fi if [ $NVFUSER_DIRECT_OK -ne 0 ]; then echo "- nvfuser_direct extension has missing symbols (see $TEMP_DIR/nvfuser_direct_missing_symbols.txt)" fi @@ -165,9 +153,6 @@ else echo "- exported_symbols.txt: Combined exported symbols from all libraries" echo "- nvfuser_undefined_symbols.txt: Undefined symbols from nvfuser extension" echo "- nvfuser_direct_undefined_symbols.txt: Undefined symbols from nvfuser_direct extension" - if [ $NVFUSER_OK -ne 0 ]; then - echo "- nvfuser_missing_symbols.txt: Missing symbols from nvfuser extension" - fi if [ $NVFUSER_DIRECT_OK -ne 0 ]; then echo "- nvfuser_direct_missing_symbols.txt: Missing symbols from nvfuser_direct extension" fi