diff --git a/python/cuda/bench/__init__.py b/python/cuda/bench/__init__.py index e6e7753c..4d2f4963 100644 --- a/python/cuda/bench/__init__.py +++ b/python/cuda/bench/__init__.py @@ -69,12 +69,21 @@ def _get_cuda_major_version(): State = _nvbench_module.State register = _nvbench_module.register run_all_benchmarks = _nvbench_module.run_all_benchmarks -test_cpp_exception = _nvbench_module.test_cpp_exception -test_py_exception = _nvbench_module.test_py_exception +_test_cpp_exception = _nvbench_module._test_cpp_exception +_test_py_exception = _nvbench_module._test_py_exception # Expose the module as _nvbench for backward compatibility (e.g., for tests) _nvbench = _nvbench_module +# Set module of exposed objects +Benchmark.__module__ = __name__ +CudaStream.__module__ = __name__ +Launch.__module__ = __name__ +NVBenchRuntimeError.__module__ = __name__ +State.__module__ = __name__ +register.__module__ = __name__ +run_all_benchmarks.__module__ = __name__ + # Clean up internal symbols del ( _nvbench_module, diff --git a/python/src/py_nvbench.cpp b/python/src/py_nvbench.cpp index 8ecac4a2..b1c00022 100644 --- a/python/src/py_nvbench.cpp +++ b/python/src/py_nvbench.cpp @@ -273,7 +273,7 @@ static void def_class_CudaStream(py::module_ m) // nvbench::cuda_stream::get_stream static constexpr const char *class_CudaStream_doc = R"XXX( -Represents CUDA stream + Represents CUDA stream Note ---- @@ -321,7 +321,7 @@ void def_class_Launch(py::module_ m) // nvbench::launch::get_stream -> nvbench::cuda_stream static constexpr const char *class_Launch_doc = R"XXXX( -Configuration object for function launch. + Configuration object for function launch. Note ---- @@ -363,13 +363,13 @@ static void def_class_Benchmark(py::module_ m) // nvbench::benchmark_base::set_min_samples static constexpr const char *class_Benchmark_doc = R"XXXX( -Represents NVBench benchmark. + Represents NVBench benchmark. Note ---- The class is not user-constructible. - Use `~register` function to create Benchmark and register + Use `register` function to create Benchmark and register it with NVBench. )XXXX"; auto py_benchmark_cls = py::class_(m, "Benchmark", class_Benchmark_doc); @@ -691,7 +691,7 @@ void def_class_State(py::module_ m) using state_ref_t = std::reference_wrapper; static constexpr const char *class_State_doc = R"XXXX( -Represent benchmark configuration state. + Represents benchmark configuration state. Note ---- @@ -736,7 +736,7 @@ Get device_id of the device from this configuration return std::ref(state.get_cuda_stream()); }; static constexpr const char *method_get_stream_doc = R"XXXX( -Get `~CudaStream` object from this configuration" +Get `CudaStream` object from this configuration )XXXX"; pystate_cls.def("get_stream", method_get_stream_impl, @@ -1014,10 +1014,10 @@ Use argument True to disable use of blocking kernel by NVBench" } }; static constexpr const char *method_exec_doc = R"XXXX( -Execute callable running the benchmark. + Execute callable running the benchmark. The callable may be executed multiple times. The callable - will be passed `~Launch` object argument. + will be passed `Launch` object argument. Parameters ---------- @@ -1194,8 +1194,8 @@ Register benchmark function of type Callable[[nvbench.State], None] py::arg("argv") = py::list()); // Testing utilities - m.def("test_cpp_exception", []() { throw nvbench_run_error("Test"); }); - m.def("test_py_exception", []() { + m.def("_test_cpp_exception", []() { throw nvbench_run_error("Test"); }); + m.def("_test_py_exception", []() { py::set_error(exc_storage.get_stored(), "Test"); throw py::error_already_set(); }); diff --git a/python/test/test_cuda_bench.py b/python/test/test_cuda_bench.py index 7d927e8f..b63d24d8 100644 --- a/python/test/test_cuda_bench.py +++ b/python/test/test_cuda_bench.py @@ -6,12 +6,12 @@ def test_cpp_exception(): with pytest.raises(RuntimeError, match="Test"): - bench._nvbench.test_cpp_exception() + bench._nvbench._test_cpp_exception() def test_py_exception(): with pytest.raises(bench.NVBenchRuntimeError, match="Test"): - bench._nvbench.test_py_exception() + bench._nvbench._test_py_exception() @pytest.mark.parametrize(