diff --git a/src/usparse/usparse_utils.hpp b/src/usparse/usparse_utils.hpp index 9432339..41b3308 100644 --- a/src/usparse/usparse_utils.hpp +++ b/src/usparse/usparse_utils.hpp @@ -21,11 +21,12 @@ } struct sync_handle { - static constexpr int SYNC_POLICY = 2; + static constexpr int SYNC_POLICY = 3; static constexpr int SYNC_POLICY_DEVICE_SYNC = 0; static constexpr int SYNC_POLICY_EVENT_SYNC = 1; static constexpr int SYNC_POLICY_MEMCOPY_SYNC = 2; + static constexpr int SYNC_POLICY_MEMSET_SYNC = 3; inline sync_handle() { if (SYNC_POLICY == SYNC_POLICY_DEVICE_SYNC) { @@ -34,6 +35,8 @@ struct sync_handle { hipEventCreate(&sync_kernel_event); } else if (SYNC_POLICY == SYNC_POLICY_MEMCOPY_SYNC) { mem_malloc(); + } else if (SYNC_POLICY == SYNC_POLICY_MEMSET_SYNC) { + HIP_CHECK(hipMalloc((int **)&dev_data, 1 * sizeof(int))); } } @@ -45,6 +48,8 @@ struct sync_handle { hipEventSynchronize(sync_kernel_event); } else if (SYNC_POLICY == SYNC_POLICY_MEMCOPY_SYNC) { mem_copy(); + } else if (SYNC_POLICY == SYNC_POLICY_MEMSET_SYNC) { + hipMemset(dev_data, 0, sizeof(int)); // TODO: HIP_CHECK } }