Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions python/cuda/bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,21 @@ def _get_cuda_major_version():
State = _nvbench_module.State
register = _nvbench_module.register
run_all_benchmarks = _nvbench_module.run_all_benchmarks
test_cpp_exception = _nvbench_module.test_cpp_exception
test_py_exception = _nvbench_module.test_py_exception
_test_cpp_exception = _nvbench_module._test_cpp_exception
_test_py_exception = _nvbench_module._test_py_exception

# Expose the module as _nvbench for backward compatibility (e.g., for tests)
_nvbench = _nvbench_module

# Set module of exposed objects
Benchmark.__module__ = __name__
CudaStream.__module__ = __name__
Launch.__module__ = __name__
NVBenchRuntimeError.__module__ = __name__
State.__module__ = __name__
register.__module__ = __name__
run_all_benchmarks.__module__ = __name__

# Clean up internal symbols
del (
_nvbench_module,
Expand Down
20 changes: 10 additions & 10 deletions python/src/py_nvbench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ static void def_class_CudaStream(py::module_ m)
// nvbench::cuda_stream::get_stream

static constexpr const char *class_CudaStream_doc = R"XXX(
Represents CUDA stream
Represents CUDA stream

Note
----
Expand Down Expand Up @@ -321,7 +321,7 @@ void def_class_Launch(py::module_ m)
// nvbench::launch::get_stream -> nvbench::cuda_stream

static constexpr const char *class_Launch_doc = R"XXXX(
Configuration object for function launch.
Configuration object for function launch.

Note
----
Expand Down Expand Up @@ -363,13 +363,13 @@ static void def_class_Benchmark(py::module_ m)
// nvbench::benchmark_base::set_min_samples

static constexpr const char *class_Benchmark_doc = R"XXXX(
Represents NVBench benchmark.
Represents NVBench benchmark.

Note
----
The class is not user-constructible.

Use `~register` function to create Benchmark and register
Use `register` function to create Benchmark and register
it with NVBench.
)XXXX";
auto py_benchmark_cls = py::class_<nvbench::benchmark_base>(m, "Benchmark", class_Benchmark_doc);
Expand Down Expand Up @@ -691,7 +691,7 @@ void def_class_State(py::module_ m)

using state_ref_t = std::reference_wrapper<nvbench::state>;
static constexpr const char *class_State_doc = R"XXXX(
Represent benchmark configuration state.
Represents benchmark configuration state.

Note
----
Expand Down Expand Up @@ -736,7 +736,7 @@ Get device_id of the device from this configuration
return std::ref(state.get_cuda_stream());
};
static constexpr const char *method_get_stream_doc = R"XXXX(
Get `~CudaStream` object from this configuration"
Get `CudaStream` object from this configuration
)XXXX";
pystate_cls.def("get_stream",
method_get_stream_impl,
Expand Down Expand Up @@ -1014,10 +1014,10 @@ Use argument True to disable use of blocking kernel by NVBench"
}
};
static constexpr const char *method_exec_doc = R"XXXX(
Execute callable running the benchmark.
Execute callable running the benchmark.

The callable may be executed multiple times. The callable
will be passed `~Launch` object argument.
will be passed `Launch` object argument.

Parameters
----------
Expand Down Expand Up @@ -1194,8 +1194,8 @@ Register benchmark function of type Callable[[nvbench.State], None]
py::arg("argv") = py::list());

// Testing utilities
m.def("test_cpp_exception", []() { throw nvbench_run_error("Test"); });
m.def("test_py_exception", []() {
m.def("_test_cpp_exception", []() { throw nvbench_run_error("Test"); });
m.def("_test_py_exception", []() {
py::set_error(exc_storage.get_stored(), "Test");
throw py::error_already_set();
});
Expand Down
4 changes: 2 additions & 2 deletions python/test/test_cuda_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

def test_cpp_exception():
with pytest.raises(RuntimeError, match="Test"):
bench._nvbench.test_cpp_exception()
bench._nvbench._test_cpp_exception()


def test_py_exception():
with pytest.raises(bench.NVBenchRuntimeError, match="Test"):
bench._nvbench.test_py_exception()
bench._nvbench._test_py_exception()


@pytest.mark.parametrize(
Expand Down
Loading