From 3e75f537da1ca1bfc886a839a3cbeb3476625d81 Mon Sep 17 00:00:00 2001 From: Martin Morrison-Grant Date: Wed, 4 Jun 2025 14:54:00 +0100 Subject: [PATCH 1/5] Refactor reference counting in UR across all adapters into a new common UR_ReferenceCounter class. --- .../source/adapters/cuda/adapter.cpp | 8 +-- .../source/adapters/cuda/adapter.hpp | 10 ++-- .../source/adapters/cuda/command_buffer.cpp | 8 +-- .../source/adapters/cuda/command_buffer.hpp | 10 ++-- .../source/adapters/cuda/context.cpp | 8 +-- .../source/adapters/cuda/context.hpp | 12 ++--- .../source/adapters/cuda/device.cpp | 2 +- .../source/adapters/cuda/device.hpp | 9 ++-- .../source/adapters/cuda/event.cpp | 10 ++-- .../source/adapters/cuda/event.hpp | 9 ++-- .../source/adapters/cuda/kernel.cpp | 12 +++-- .../source/adapters/cuda/kernel.hpp | 14 +++-- .../source/adapters/cuda/memory.cpp | 9 ++-- .../source/adapters/cuda/memory.hpp | 23 ++++---- .../source/adapters/cuda/physical_mem.cpp | 6 +-- .../source/adapters/cuda/physical_mem.hpp | 15 +++--- .../source/adapters/cuda/program.cpp | 11 ++-- .../source/adapters/cuda/program.hpp | 18 +++---- .../source/adapters/cuda/queue.cpp | 8 +-- .../source/adapters/cuda/sampler.cpp | 8 +-- .../source/adapters/cuda/sampler.hpp | 13 +++-- unified-runtime/source/adapters/cuda/usm.cpp | 6 +-- unified-runtime/source/adapters/cuda/usm.hpp | 11 ++-- .../source/adapters/hip/adapter.cpp | 8 +-- .../source/adapters/hip/adapter.hpp | 7 ++- .../source/adapters/hip/command_buffer.cpp | 8 +-- .../source/adapters/hip/command_buffer.hpp | 11 ++-- .../source/adapters/hip/context.cpp | 8 +-- .../source/adapters/hip/context.hpp | 12 ++--- .../source/adapters/hip/device.cpp | 2 +- .../source/adapters/hip/device.hpp | 9 ++-- unified-runtime/source/adapters/hip/event.cpp | 10 ++-- unified-runtime/source/adapters/hip/event.hpp | 9 ++-- .../source/adapters/hip/kernel.cpp | 12 +++-- .../source/adapters/hip/kernel.hpp | 12 ++--- .../source/adapters/hip/memory.hpp | 24 ++++----- .../source/adapters/hip/physical_mem.hpp | 9 ++-- .../source/adapters/hip/program.cpp | 11 ++-- .../source/adapters/hip/program.hpp | 14 +++-- unified-runtime/source/adapters/hip/queue.cpp | 9 ++-- .../source/adapters/hip/sampler.cpp | 8 +-- .../source/adapters/hip/sampler.hpp | 13 +++-- unified-runtime/source/adapters/hip/usm.cpp | 6 +-- unified-runtime/source/adapters/hip/usm.hpp | 12 ++--- .../source/adapters/level_zero/adapter.cpp | 10 ++-- .../source/adapters/level_zero/adapter.hpp | 13 +++-- .../adapters/level_zero/async_alloc.cpp | 2 +- .../adapters/level_zero/command_buffer.cpp | 10 ++-- .../adapters/level_zero/command_buffer.hpp | 7 ++- .../source/adapters/level_zero/common.hpp | 54 +++---------------- .../source/adapters/level_zero/context.cpp | 6 +-- .../source/adapters/level_zero/context.hpp | 5 ++ .../source/adapters/level_zero/device.cpp | 6 +-- .../source/adapters/level_zero/device.hpp | 6 +++ .../source/adapters/level_zero/event.cpp | 28 +++++----- .../source/adapters/level_zero/event.hpp | 42 +++++++++------ .../source/adapters/level_zero/kernel.cpp | 6 +-- .../source/adapters/level_zero/kernel.hpp | 9 +++- .../source/adapters/level_zero/memory.cpp | 10 ++-- .../source/adapters/level_zero/memory.hpp | 8 ++- .../adapters/level_zero/physical_mem.cpp | 6 +-- .../adapters/level_zero/physical_mem.hpp | 6 +++ .../source/adapters/level_zero/program.cpp | 8 +-- .../source/adapters/level_zero/program.hpp | 6 +++ .../source/adapters/level_zero/queue.cpp | 22 ++++---- .../source/adapters/level_zero/queue.hpp | 32 ++++++----- .../source/adapters/level_zero/sampler.cpp | 4 +- .../source/adapters/level_zero/sampler.hpp | 6 +++ .../source/adapters/level_zero/usm.cpp | 10 ++-- .../source/adapters/level_zero/usm.hpp | 8 ++- .../source/adapters/native_cpu/adapter.cpp | 13 +++-- .../source/adapters/native_cpu/common.hpp | 23 ++------ .../source/adapters/native_cpu/context.cpp | 8 +-- .../source/adapters/native_cpu/context.hpp | 7 ++- .../source/adapters/native_cpu/event.cpp | 8 +-- .../source/adapters/native_cpu/event.hpp | 13 +++-- .../source/adapters/native_cpu/kernel.cpp | 8 +-- .../source/adapters/native_cpu/kernel.hpp | 18 +++++-- .../source/adapters/native_cpu/memory.cpp | 6 ++- .../source/adapters/native_cpu/memory.hpp | 8 ++- .../source/adapters/native_cpu/program.cpp | 8 +-- .../source/adapters/native_cpu/program.hpp | 10 ++-- .../source/adapters/native_cpu/queue.cpp | 8 +-- .../source/adapters/native_cpu/queue.hpp | 11 +++- .../source/adapters/offload/adapter.cpp | 8 +-- .../source/adapters/offload/adapter.hpp | 8 ++- .../source/adapters/offload/context.cpp | 6 +-- .../source/adapters/offload/context.hpp | 8 ++- .../source/adapters/offload/kernel.cpp | 6 +-- .../source/adapters/offload/kernel.hpp | 8 ++- .../source/adapters/offload/program.cpp | 6 +-- .../source/adapters/offload/program.hpp | 8 ++- .../source/adapters/offload/queue.cpp | 6 +-- .../source/adapters/offload/queue.hpp | 6 +++ .../source/adapters/opencl/adapter.cpp | 8 +-- .../source/adapters/opencl/adapter.hpp | 8 ++- .../source/adapters/opencl/command_buffer.cpp | 6 +-- .../source/adapters/opencl/command_buffer.hpp | 12 ++--- .../source/adapters/opencl/context.cpp | 6 +-- .../source/adapters/opencl/context.hpp | 12 ++--- .../source/adapters/opencl/device.cpp | 6 +-- .../source/adapters/opencl/device.hpp | 10 ++-- .../source/adapters/opencl/event.cpp | 6 +-- .../source/adapters/opencl/event.hpp | 12 ++--- .../source/adapters/opencl/kernel.cpp | 6 +-- .../source/adapters/opencl/kernel.hpp | 12 ++--- .../source/adapters/opencl/memory.cpp | 6 +-- .../source/adapters/opencl/memory.hpp | 12 ++--- .../source/adapters/opencl/program.cpp | 6 +-- .../source/adapters/opencl/program.hpp | 12 ++--- .../source/adapters/opencl/queue.cpp | 6 +-- .../source/adapters/opencl/queue.hpp | 13 +++-- .../source/adapters/opencl/sampler.cpp | 6 +-- .../source/adapters/opencl/sampler.hpp | 10 ++-- .../source/common/cuda-hip/stream_queue.hpp | 12 ++--- .../source/common/ur_ref_counter.hpp | 28 ++++++++++ 116 files changed, 658 insertions(+), 560 deletions(-) create mode 100644 unified-runtime/source/common/ur_ref_counter.hpp diff --git a/unified-runtime/source/adapters/cuda/adapter.cpp b/unified-runtime/source/adapters/cuda/adapter.cpp index dca627c87fc19..78e4311432926 100644 --- a/unified-runtime/source/adapters/cuda/adapter.cpp +++ b/unified-runtime/source/adapters/cuda/adapter.cpp @@ -66,7 +66,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, std::call_once(InitFlag, [=]() { ur::cuda::adapter = new ur_adapter_handle_t_; }); - ur::cuda::adapter->RefCount++; + ur::cuda::adapter->getRefCounter().increment(); *phAdapters = ur::cuda::adapter; } @@ -78,13 +78,13 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - ur::cuda::adapter->RefCount++; + ur::cuda::adapter->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - if (--ur::cuda::adapter->RefCount == 0) { + if (ur::cuda::adapter->getRefCounter().decrement() == 0) { delete ur::cuda::adapter; } return UR_RESULT_SUCCESS; @@ -108,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_CUDA); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(ur::cuda::adapter->RefCount.load()); + return ReturnValue(ur::cuda::adapter->getRefCounter().getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/cuda/adapter.hpp b/unified-runtime/source/adapters/cuda/adapter.hpp index 6ec9007bceacf..ff07eea7f20b0 100644 --- a/unified-runtime/source/adapters/cuda/adapter.hpp +++ b/unified-runtime/source/adapters/cuda/adapter.hpp @@ -11,25 +11,29 @@ #ifndef UR_CUDA_ADAPTER_HPP_INCLUDED #define UR_CUDA_ADAPTER_HPP_INCLUDED +#include "common/ur_ref_counter.hpp" #include "logger/ur_logger.hpp" #include "platform.hpp" #include "tracing.hpp" #include -#include #include struct ur_adapter_handle_t_ : ur::cuda::handle_base { - std::atomic RefCount = 0; struct cuda_tracing_context_t_ *TracingCtx = nullptr; logger::Logger &logger; std::unique_ptr Platform; ur_adapter_handle_t_(); ~ur_adapter_handle_t_(); ur_adapter_handle_t_(const ur_adapter_handle_t_ &) = delete; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; -// Keep the global namespace'd +// Keep the global namespace namespace ur::cuda { extern ur_adapter_handle_t adapter; } // namespace ur::cuda diff --git a/unified-runtime/source/adapters/cuda/command_buffer.cpp b/unified-runtime/source/adapters/cuda/command_buffer.cpp index 2399b8f81857c..42f534502ecbd 100644 --- a/unified-runtime/source/adapters/cuda/command_buffer.cpp +++ b/unified-runtime/source/adapters/cuda/command_buffer.cpp @@ -60,7 +60,7 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_( bool IsInOrder) : handle_base(), Context(Context), Device(Device), IsUpdatable(IsUpdatable), IsInOrder(IsInOrder), CudaGraph{nullptr}, CudaGraphExec{nullptr}, - RefCount{1}, NextSyncPoint{0} { + NextSyncPoint{0} { urContextRetain(Context); } @@ -380,13 +380,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - hCommandBuffer->incrementReferenceCount(); + hCommandBuffer->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - if (hCommandBuffer->decrementReferenceCount() == 0) { + if (hCommandBuffer->getRefCounter().decrement() == 0) { // Ref count has reached zero, release of created commands for (auto &Command : hCommandBuffer->CommandHandles) { commandHandleDestroy(Command); @@ -1476,7 +1476,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp( switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(hCommandBuffer->getReferenceCount()); + return ReturnValue(hCommandBuffer->getRefCounter().getCount()); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/cuda/command_buffer.hpp b/unified-runtime/source/adapters/cuda/command_buffer.hpp index e11b9ab74969a..1a9f40669ca73 100644 --- a/unified-runtime/source/adapters/cuda/command_buffer.hpp +++ b/unified-runtime/source/adapters/cuda/command_buffer.hpp @@ -12,8 +12,10 @@ #include #include +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include "logger/ur_logger.hpp" + #include #include #include @@ -173,9 +175,7 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base { return SyncPoint; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } // UR context associated with this command-buffer ur_context_handle_t Context; @@ -191,7 +191,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base { CUgraphExec CudaGraphExec = nullptr; // Atomic variable counting the number of reference to this command_buffer // using std::atomic prevents data race when incrementing/decrementing. - std::atomic_uint32_t RefCount; // Ordered map of sync_points to ur_events, so that we can find the last // node added to an in-order command-buffer. @@ -203,4 +202,7 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base { // Handles to individual commands in the command-buffer std::vector> CommandHandles; + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/cuda/context.cpp b/unified-runtime/source/adapters/cuda/context.cpp index 074bbeb440b2e..47a93ced4ad61 100644 --- a/unified-runtime/source/adapters/cuda/context.cpp +++ b/unified-runtime/source/adapters/cuda/context.cpp @@ -66,7 +66,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo( return ReturnValue(hContext->getDevices().data(), hContext->getDevices().size()); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(hContext->getReferenceCount()); + return ReturnValue(hContext->getRefCounter().getCount()); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // 2D USM memcpy is supported. return ReturnValue(true); @@ -83,7 +83,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo( UR_APIEXPORT ur_result_t UR_APICALL urContextRelease(ur_context_handle_t hContext) { - if (hContext->decrementReferenceCount() > 0) { + if (hContext->getRefCounter().getCount() > 0) { return UR_RESULT_SUCCESS; } hContext->invokeExtendedDeleters(); @@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) { UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - assert(hContext->getReferenceCount() > 0); + assert(hContext->getRefCounter().getCount() > 0); - hContext->incrementReferenceCount(); + hContext->getRefCounter().increment(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/cuda/context.hpp b/unified-runtime/source/adapters/cuda/context.hpp index 1a24d163c9a50..036c1c7e914ca 100644 --- a/unified-runtime/source/adapters/cuda/context.hpp +++ b/unified-runtime/source/adapters/cuda/context.hpp @@ -13,13 +13,13 @@ #include #include -#include #include #include #include #include "adapter.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "device.hpp" #include "umf_helpers.hpp" @@ -88,7 +88,6 @@ struct ur_context_handle_t_ : ur::cuda::handle_base { }; std::vector Devices; - std::atomic_uint32_t RefCount; // UMF CUDA memory provider and pool for the host memory // (UMF_MEMORY_TYPE_HOST) @@ -96,7 +95,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base { umf_memory_pool_handle_t MemoryPoolHost = nullptr; ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) - : handle_base(), Devices{Devs, Devs + NumDevices}, RefCount{1} { + : handle_base(), Devices{Devs, Devs + NumDevices} { // Create UMF CUDA memory provider for the host memory // (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because // it is guaranteed to exist). @@ -140,11 +139,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base { return std::distance(Devices.begin(), It); } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } void addPool(ur_usm_pool_handle_t Pool); @@ -156,6 +151,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base { std::mutex Mutex; std::vector ExtendedDeleters; std::set PoolHandles; + UR_ReferenceCounter RefCounter; }; namespace { diff --git a/unified-runtime/source/adapters/cuda/device.cpp b/unified-runtime/source/adapters/cuda/device.cpp index 2cb43ebc88356..aeb382dc0a8ff 100644 --- a/unified-runtime/source/adapters/cuda/device.cpp +++ b/unified-runtime/source/adapters/cuda/device.cpp @@ -593,7 +593,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue("CUDA"); } case UR_DEVICE_INFO_REFERENCE_COUNT: { - return ReturnValue(hDevice->getReferenceCount()); + return ReturnValue(hDevice->getRefCounter().getCount()); } case UR_DEVICE_INFO_VERSION: { std::stringstream SS; diff --git a/unified-runtime/source/adapters/cuda/device.hpp b/unified-runtime/source/adapters/cuda/device.hpp index 3a28d54b17d21..2d9876aa002be 100644 --- a/unified-runtime/source/adapters/cuda/device.hpp +++ b/unified-runtime/source/adapters/cuda/device.hpp @@ -15,6 +15,7 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" struct ur_device_handle_t_ : ur::cuda::handle_base { private: @@ -23,7 +24,6 @@ struct ur_device_handle_t_ : ur::cuda::handle_base { native_type CuDevice; CUcontext CuContext; CUevent EvBase; // CUDA event used as base counter - std::atomic_uint32_t RefCount; ur_platform_handle_t Platform; uint32_t DeviceIndex; @@ -42,7 +42,7 @@ struct ur_device_handle_t_ : ur::cuda::handle_base { ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase, ur_platform_handle_t platform, uint32_t DevIndex) : handle_base(), CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase), - RefCount{1}, Platform(platform), DeviceIndex{DevIndex} { + Platform(platform), DeviceIndex{DevIndex} { UR_CHECK_ERROR(cuDeviceGetAttribute( &MaxRegsPerBlock, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, cuDevice)); @@ -136,7 +136,7 @@ struct ur_device_handle_t_ : ur::cuda::handle_base { CUcontext getNativeContext() const noexcept { return CuContext; }; - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_platform_handle_t getPlatform() const noexcept { return Platform; }; @@ -178,6 +178,9 @@ struct ur_device_handle_t_ : ur::cuda::handle_base { // (UMF_MEMORY_TYPE_SHARED) umf_memory_provider_handle_t MemoryProviderShared; umf_memory_pool_handle_t MemoryPoolShared; + +private: + UR_ReferenceCounter RefCounter; }; int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute); diff --git a/unified-runtime/source/adapters/cuda/event.cpp b/unified-runtime/source/adapters/cuda/event.cpp index f9343a6b6f751..4cbf1328bc454 100644 --- a/unified-runtime/source/adapters/cuda/event.cpp +++ b/unified-runtime/source/adapters/cuda/event.cpp @@ -179,7 +179,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, case UR_EVENT_INFO_COMMAND_TYPE: return ReturnValue(hEvent->getCommandType()); case UR_EVENT_INFO_REFERENCE_COUNT: - return ReturnValue(hEvent->getReferenceCount()); + return ReturnValue(hEvent->getRefCounter().getCount()); case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: return ReturnValue(hEvent->getExecutionStatus()); case UR_EVENT_INFO_CONTEXT: @@ -248,9 +248,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - const auto RefCount = hEvent->incrementReferenceCount(); - - if (RefCount == 0) { + if (hEvent->getRefCounter().increment() == 0) { return UR_RESULT_ERROR_OUT_OF_RESOURCES; } @@ -260,12 +258,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - if (hEvent->getReferenceCount() == 0) { + if (hEvent->getRefCounter().getCount() == 0) { return UR_RESULT_ERROR_INVALID_EVENT; } // decrement ref count. If it is 0, delete the event. - if (hEvent->decrementReferenceCount() == 0) { + if (hEvent->getRefCounter().decrement() == 0) { std::unique_ptr event_ptr{hEvent}; ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT; try { diff --git a/unified-runtime/source/adapters/cuda/event.hpp b/unified-runtime/source/adapters/cuda/event.hpp index 92f74349f9b3e..23a6d4c17e1b1 100644 --- a/unified-runtime/source/adapters/cuda/event.hpp +++ b/unified-runtime/source/adapters/cuda/event.hpp @@ -13,6 +13,7 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "queue.hpp" /// UR Event mapping to CUevent @@ -59,16 +60,12 @@ struct ur_event_handle_t_ : ur::cuda::handle_base { ur_command_t getCommandType() const noexcept { return CommandType; } ur_context_handle_t getContext() const noexcept { return Context; }; uint32_t getEventID() const noexcept { return EventID; } - - // Reference counting. - uint32_t getReferenceCount() const noexcept { return RefCount; } - uint32_t incrementReferenceCount() { return ++RefCount; } - uint32_t decrementReferenceCount() { return --RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } private: ur_command_t CommandType; // The type of command associated with event. - std::atomic_uint32_t RefCount{1}; // Event reference count. + UR_ReferenceCounter RefCounter; bool HasOwnership{true}; // Signifies if event owns the native type. bool HasProfiling{false}; // Signifies if event has profiling information. diff --git a/unified-runtime/source/adapters/cuda/kernel.cpp b/unified-runtime/source/adapters/cuda/kernel.cpp index f296c74611462..d10a80aa79a8c 100644 --- a/unified-runtime/source/adapters/cuda/kernel.cpp +++ b/unified-runtime/source/adapters/cuda/kernel.cpp @@ -127,9 +127,10 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - UR_ASSERT(hKernel->getReferenceCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL); + UR_ASSERT(hKernel->getRefCounter().getCount() > 0u, + UR_RESULT_ERROR_INVALID_KERNEL); - hKernel->incrementReferenceCount(); + hKernel->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -137,10 +138,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(ur_kernel_handle_t hKernel) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - UR_ASSERT(hKernel->getReferenceCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL); + UR_ASSERT(hKernel->getRefCounter().getCount() != 0, + UR_RESULT_ERROR_INVALID_KERNEL); // decrement ref count. If it is 0, delete the program. - if (hKernel->decrementReferenceCount() == 0) { + if (hKernel->getRefCounter().decrement() == 0) { // no internal cuda resources to clean up. Just delete it. delete hKernel; return UR_RESULT_SUCCESS; @@ -248,7 +250,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, case UR_KERNEL_INFO_NUM_ARGS: return ReturnValue(hKernel->getNumArgs()); case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(hKernel->getReferenceCount()); + return ReturnValue(hKernel->getRefCounter().getCount()); case UR_KERNEL_INFO_CONTEXT: return ReturnValue(hKernel->getContext()); case UR_KERNEL_INFO_PROGRAM: diff --git a/unified-runtime/source/adapters/cuda/kernel.hpp b/unified-runtime/source/adapters/cuda/kernel.hpp index 6898527e8df30..34a6694b09d74 100644 --- a/unified-runtime/source/adapters/cuda/kernel.hpp +++ b/unified-runtime/source/adapters/cuda/kernel.hpp @@ -13,10 +13,10 @@ #include #include -#include #include #include +#include "common/ur_ref_counter.hpp" #include "program.hpp" /// Implementation of a UR Kernel for CUDA @@ -42,7 +42,6 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base { std::string Name; ur_context_handle_t Context; ur_program_handle_t Program; - std::atomic_uint32_t RefCount; static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u; size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions]; @@ -255,7 +254,7 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base { ur_context_handle_t Context) : handle_base(), Function{Func}, FunctionWithOffsetParam{FuncWithOffsetParam}, Name{Name}, - Context{Context}, Program{Program}, RefCount{1} { + Context{Context}, Program{Program} { urProgramRetain(Program); urContextRetain(Context); @@ -304,11 +303,7 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base { urContextRelease(Context); } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } native_type get() const noexcept { return Function; }; @@ -354,4 +349,7 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base { uint32_t getLocalSize() const noexcept { return Args.getLocalSize(); } size_t getRegsPerThread() const noexcept { return RegsPerThread; }; + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/cuda/memory.cpp b/unified-runtime/source/adapters/cuda/memory.cpp index d673ad06c09b9..4c4c0f49afc6b 100644 --- a/unified-runtime/source/adapters/cuda/memory.cpp +++ b/unified-runtime/source/adapters/cuda/memory.cpp @@ -78,8 +78,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( } UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) { - UR_ASSERT(hMem->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT); - hMem->incrementReferenceCount(); + UR_ASSERT(hMem->getRefCounter().getCount() > 0, + UR_RESULT_ERROR_INVALID_MEM_OBJECT); + hMem->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -89,7 +90,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) { UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { try { // Do nothing if there are other references - if (hMem->decrementReferenceCount() > 0) { + if (hMem->getRefCounter().decrement() > 0) { return UR_RESULT_SUCCESS; } @@ -162,7 +163,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory, return ReturnValue(hMemory->getContext()); } case UR_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hMemory->getReferenceCount()); + return ReturnValue(hMemory->getRefCounter().getCount()); } default: diff --git a/unified-runtime/source/adapters/cuda/memory.hpp b/unified-runtime/source/adapters/cuda/memory.hpp index 92aeb5878b952..e492894da23e5 100644 --- a/unified-runtime/source/adapters/cuda/memory.hpp +++ b/unified-runtime/source/adapters/cuda/memory.hpp @@ -16,6 +16,7 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include "queue.hpp" @@ -314,9 +315,6 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base { // Context where the memory object is accessible ur_context_handle_t Context; - /// Reference counting of the handler - std::atomic_uint32_t RefCount; - // Original mem flags passed ur_mem_flags_t MemFlags; @@ -345,7 +343,7 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base { /// Constructs the UR mem handler for a non-typed allocation ("buffer") ur_mem_handle_t_(ur_context_handle_t Ctxt, ur_mem_flags_t MemFlags, BufferMem::AllocMode Mode, void *HostPtr, size_t Size) - : handle_base(), Context{Ctxt}, RefCount{1}, MemFlags{MemFlags}, + : handle_base(), Context{Ctxt}, MemFlags{MemFlags}, HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false), Mem{std::in_place_type, Ctxt, this, Mode, HostPtr, Size} { urContextRetain(Context); @@ -353,9 +351,9 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base { // Subbuffer constructor ur_mem_handle_t_(ur_mem_handle_t Parent, size_t SubBufferOffset) - : handle_base(), Context{Parent->Context}, RefCount{1}, - MemFlags{Parent->MemFlags}, HaveMigratedToDeviceSinceLastWrite( - Parent->Context->Devices.size(), false), + : handle_base(), Context{Parent->Context}, MemFlags{Parent->MemFlags}, + HaveMigratedToDeviceSinceLastWrite(Parent->Context->Devices.size(), + false), Mem{BufferMem{std::get(Parent->Mem)}} { auto &SubBuffer = std::get(Mem); SubBuffer.Parent = Parent; @@ -376,7 +374,7 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base { ur_mem_handle_t_(ur_context_handle_t Ctxt, ur_mem_flags_t MemFlags, ur_image_format_t ImageFormat, ur_image_desc_t ImageDesc, void *HostPtr) - : handle_base(), Context{Ctxt}, RefCount{1}, MemFlags{MemFlags}, + : handle_base(), Context{Ctxt}, MemFlags{MemFlags}, HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false), Mem{std::in_place_type, Ctxt, @@ -424,11 +422,7 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base { ur_context_handle_t getContext() const noexcept { return Context; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } void setLastQueueWritingToMemObj(ur_queue_handle_t WritingQueue) { urQueueRetain(WritingQueue); @@ -443,4 +437,7 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base { Device == WritingQueue->getDevice(); } } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/cuda/physical_mem.cpp b/unified-runtime/source/adapters/cuda/physical_mem.cpp index 71bf596acb09b..48372d798d3a5 100644 --- a/unified-runtime/source/adapters/cuda/physical_mem.cpp +++ b/unified-runtime/source/adapters/cuda/physical_mem.cpp @@ -46,13 +46,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemCreate( UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemRetain(ur_physical_mem_handle_t hPhysicalMem) { - hPhysicalMem->incrementReferenceCount(); + hPhysicalMem->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemRelease(ur_physical_mem_handle_t hPhysicalMem) { - if (hPhysicalMem->decrementReferenceCount() > 0) + if (hPhysicalMem->getRefCounter().decrement() > 0) return UR_RESULT_SUCCESS; try { @@ -88,7 +88,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemGetInfo( return ReturnValue(hPhysicalMem->getProperties()); } case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hPhysicalMem->getReferenceCount()); + return ReturnValue(hPhysicalMem->getRefCounter().getCount()); } default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; diff --git a/unified-runtime/source/adapters/cuda/physical_mem.hpp b/unified-runtime/source/adapters/cuda/physical_mem.hpp index d9abe587b1b5f..a81a0a1a68dce 100644 --- a/unified-runtime/source/adapters/cuda/physical_mem.hpp +++ b/unified-runtime/source/adapters/cuda/physical_mem.hpp @@ -14,6 +14,7 @@ #include #include "adapter.hpp" +#include "common/ur_ref_counter.hpp" #include "device.hpp" #include "platform.hpp" @@ -23,7 +24,6 @@ struct ur_physical_mem_handle_t_ : ur::cuda::handle_base { using native_type = CUmemGenericAllocationHandle; - std::atomic_uint32_t RefCount; native_type PhysicalMem; ur_context_handle_t_ *Context; ur_device_handle_t Device; @@ -33,8 +33,8 @@ struct ur_physical_mem_handle_t_ : ur::cuda::handle_base { ur_physical_mem_handle_t_(native_type PhysMem, ur_context_handle_t_ *Ctx, ur_device_handle_t Device, size_t Size, ur_physical_mem_properties_t Properties) - : handle_base(), RefCount(1), PhysicalMem(PhysMem), Context(Ctx), - Device(Device), Size(Size), Properties(Properties) { + : handle_base(), PhysicalMem(PhysMem), Context(Ctx), Device(Device), + Size(Size), Properties(Properties) { urContextRetain(Context); } @@ -46,15 +46,14 @@ struct ur_physical_mem_handle_t_ : ur::cuda::handle_base { ur_device_handle_t_ *getDevice() const noexcept { return Device; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } size_t getSize() const noexcept { return Size; } ur_physical_mem_properties_t getProperties() const noexcept { return Properties; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/cuda/program.cpp b/unified-runtime/source/adapters/cuda/program.cpp index 0600d371d5261..fb897578d99a4 100644 --- a/unified-runtime/source/adapters/cuda/program.cpp +++ b/unified-runtime/source/adapters/cuda/program.cpp @@ -350,7 +350,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(hProgram->getReferenceCount()); + return ReturnValue(hProgram->getRefCounter().getCount()); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(hProgram->Context); case UR_PROGRAM_INFO_NUM_DEVICES: @@ -383,8 +383,9 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - UR_ASSERT(hProgram->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_PROGRAM); - hProgram->incrementReferenceCount(); + UR_ASSERT(hProgram->getRefCounter().getCount() > 0, + UR_RESULT_ERROR_INVALID_PROGRAM); + hProgram->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -395,11 +396,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramRelease(ur_program_handle_t hProgram) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - UR_ASSERT(hProgram->getReferenceCount() != 0, + UR_ASSERT(hProgram->getRefCounter().getCount() != 0, UR_RESULT_ERROR_INVALID_PROGRAM); // decrement ref count. If it is 0, delete the program. - if (hProgram->decrementReferenceCount() == 0) { + if (hProgram->getRefCounter().decrement() == 0) { std::unique_ptr ProgramPtr{hProgram}; try { ScopedContext Active(hProgram->getDevice()); diff --git a/unified-runtime/source/adapters/cuda/program.hpp b/unified-runtime/source/adapters/cuda/program.hpp index 7371283c274d1..c19e323057f08 100644 --- a/unified-runtime/source/adapters/cuda/program.hpp +++ b/unified-runtime/source/adapters/cuda/program.hpp @@ -12,9 +12,9 @@ #include #include -#include #include +#include "common/ur_ref_counter.hpp" #include "context.hpp" struct ur_program_handle_t_ : ur::cuda::handle_base { @@ -22,7 +22,6 @@ struct ur_program_handle_t_ : ur::cuda::handle_base { native_type Module; const char *Binary; size_t BinarySizeInBytes; - std::atomic_uint32_t RefCount; ur_context_handle_t Context; ur_device_handle_t Device; @@ -49,9 +48,9 @@ struct ur_program_handle_t_ : ur::cuda::handle_base { ur_program_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device) : handle_base(), Module{nullptr}, Binary{}, BinarySizeInBytes{0}, - RefCount{1}, Context{Context}, Device{Device}, - KernelReqdWorkGroupSizeMD{}, KernelMaxWorkGroupSizeMD{}, - KernelMaxLinearWorkGroupSizeMD{}, KernelReqdSubGroupSizeMD{} { + Context{Context}, Device{Device}, KernelReqdWorkGroupSizeMD{}, + KernelMaxWorkGroupSizeMD{}, KernelMaxLinearWorkGroupSizeMD{}, + KernelReqdSubGroupSizeMD{} { urContextRetain(Context); // When the log is queried we use strnlen(InfoLog), so it needs to be @@ -71,13 +70,12 @@ struct ur_program_handle_t_ : ur::cuda::handle_base { native_type get() const noexcept { return Module; }; - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_result_t getGlobalVariablePointer(const char *name, CUdeviceptr *DeviceGlobal, size_t *DeviceGlobalSize); + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/cuda/queue.cpp b/unified-runtime/source/adapters/cuda/queue.cpp index 7c0f7b09f3a42..2b8cdbe5dbeea 100644 --- a/unified-runtime/source/adapters/cuda/queue.cpp +++ b/unified-runtime/source/adapters/cuda/queue.cpp @@ -107,14 +107,14 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice, } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - assert(hQueue->getReferenceCount() > 0); + assert(hQueue->getRefCounter().getCount() > 0); - hQueue->incrementReferenceCount(); + hQueue->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { - if (hQueue->decrementReferenceCount() > 0) { + if (hQueue->getRefCounter().decrement() > 0) { return UR_RESULT_SUCCESS; } @@ -229,7 +229,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hQueue->Device); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(hQueue->getReferenceCount()); + return ReturnValue(hQueue->getRefCounter().getCount()); case UR_QUEUE_INFO_FLAGS: return ReturnValue(hQueue->URFlags); case UR_QUEUE_INFO_EMPTY: { diff --git a/unified-runtime/source/adapters/cuda/sampler.cpp b/unified-runtime/source/adapters/cuda/sampler.cpp index f17c94cfa1b07..ae5af284dd1fc 100644 --- a/unified-runtime/source/adapters/cuda/sampler.cpp +++ b/unified-runtime/source/adapters/cuda/sampler.cpp @@ -73,7 +73,7 @@ urSamplerGetInfo(ur_sampler_handle_t hSampler, ur_sampler_info_t propName, switch (propName) { case UR_SAMPLER_INFO_REFERENCE_COUNT: - return ReturnValue(hSampler->getReferenceCount()); + return ReturnValue(hSampler->getRefCounter().getCount()); case UR_SAMPLER_INFO_CONTEXT: return ReturnValue(hSampler->Context); case UR_SAMPLER_INFO_NORMALIZED_COORDS: { @@ -95,7 +95,7 @@ urSamplerGetInfo(ur_sampler_handle_t hSampler, ur_sampler_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urSamplerRetain(ur_sampler_handle_t hSampler) { - hSampler->incrementReferenceCount(); + hSampler->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -103,12 +103,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urSamplerRelease(ur_sampler_handle_t hSampler) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - if (hSampler->getReferenceCount() == 0) { + if (hSampler->getRefCounter().getCount() == 0) { return UR_RESULT_ERROR_INVALID_SAMPLER; } // decrement ref count. If it is 0, delete the sampler. - if (hSampler->decrementReferenceCount() == 0) { + if (hSampler->getRefCounter().decrement() == 0) { delete hSampler; } diff --git a/unified-runtime/source/adapters/cuda/sampler.hpp b/unified-runtime/source/adapters/cuda/sampler.hpp index e429439848e06..95dc0fde5d137 100644 --- a/unified-runtime/source/adapters/cuda/sampler.hpp +++ b/unified-runtime/source/adapters/cuda/sampler.hpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include /// Implementation of samplers for CUDA @@ -25,7 +26,6 @@ /// | 1 | filter mode /// | 0 | normalize coords struct ur_sampler_handle_t_ : ur::cuda::handle_base { - std::atomic_uint32_t RefCount; uint32_t Props; float MinMipmapLevelClamp; float MaxMipmapLevelClamp; @@ -33,14 +33,10 @@ struct ur_sampler_handle_t_ : ur::cuda::handle_base { ur_context_handle_t Context; ur_sampler_handle_t_(ur_context_handle_t Context) - : handle_base(), RefCount(1), Props(0), MinMipmapLevelClamp(0.0f), + : handle_base(), Props(0), MinMipmapLevelClamp(0.0f), MaxMipmapLevelClamp(0.0f), MaxAnisotropy(0.0f), Context(Context) {} - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_bool_t isNormalizedCoords() const noexcept { return static_cast(Props & 0b1); @@ -67,4 +63,7 @@ struct ur_sampler_handle_t_ : ur::cuda::handle_base { return static_cast((Props >> 12) & 0b1); } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/cuda/usm.cpp b/unified-runtime/source/adapters/cuda/usm.cpp index 723abf4be16c7..ab653e71032fe 100644 --- a/unified-runtime/source/adapters/cuda/usm.cpp +++ b/unified-runtime/source/adapters/cuda/usm.cpp @@ -290,14 +290,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate( UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRetain( /// [in] pointer to USM memory pool ur_usm_pool_handle_t Pool) { - Pool->incrementReferenceCount(); + Pool->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRelease( /// [in] pointer to USM memory pool ur_usm_pool_handle_t Pool) { - if (Pool->decrementReferenceCount() > 0) { + if (Pool->getRefCounter().decrement() > 0) { return UR_RESULT_SUCCESS; } Pool->Context->removePool(Pool); @@ -320,7 +320,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolGetInfo( switch (propName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(hPool->getReferenceCount()); + return ReturnValue(hPool->getRefCounter().getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(hPool->Context); diff --git a/unified-runtime/source/adapters/cuda/usm.hpp b/unified-runtime/source/adapters/cuda/usm.hpp index 27e1beb6b606c..1b4025bdd677a 100644 --- a/unified-runtime/source/adapters/cuda/usm.hpp +++ b/unified-runtime/source/adapters/cuda/usm.hpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include #include @@ -18,7 +19,6 @@ usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig(); // A ur_usm_pool_handle_t can represent different types of memory pools. It may // sit on top of a UMF pool or a CUmemoryPool, but not both. struct ur_usm_pool_handle_t_ : ur::cuda::handle_base { - std::atomic_uint32_t RefCount = 1; ur_context_handle_t Context = nullptr; ur_device_handle_t Device = nullptr; @@ -44,17 +44,16 @@ struct ur_usm_pool_handle_t_ : ur::cuda::handle_base { ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device, CUmemoryPool CUmemPool); - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } bool hasUMFPool(umf_memory_pool_t *umf_pool); // To be used if ur_usm_pool_handle_t represents a CUmemoryPool. bool usesCudaPool() const { return CUmemPool != CUmemoryPool{0}; }; CUmemoryPool getCudaPool() { return CUmemPool; }; + +private: + UR_ReferenceCounter RefCounter; }; ur_result_t USMDeviceAllocImpl(void **ResultPtr, ur_context_handle_t Context, diff --git a/unified-runtime/source/adapters/hip/adapter.cpp b/unified-runtime/source/adapters/hip/adapter.cpp index 225a743bc4455..dfb690f60de78 100644 --- a/unified-runtime/source/adapters/hip/adapter.cpp +++ b/unified-runtime/source/adapters/hip/adapter.cpp @@ -58,7 +58,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( std::call_once(InitFlag, [=]() { ur::hip::adapter = new ur_adapter_handle_t_; }); - ur::hip::adapter->RefCount++; + ur::hip::adapter->getRefCounter().increment(); *phAdapters = ur::hip::adapter; } if (pNumAdapters) { @@ -69,7 +69,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - if (--ur::hip::adapter->RefCount == 0) { + if (--ur::hip::adapter->getRefCounter().decrement() == 0) { delete ur::hip::adapter; } @@ -77,7 +77,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - ur::hip::adapter->RefCount++; + ur::hip::adapter->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -99,7 +99,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_HIP); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(ur::hip::adapter->RefCount.load()); + return ReturnValue(ur::hip::adapter->getRefCounter().getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/hip/adapter.hpp b/unified-runtime/source/adapters/hip/adapter.hpp index ce054cfc27883..7c2256545a689 100644 --- a/unified-runtime/source/adapters/hip/adapter.hpp +++ b/unified-runtime/source/adapters/hip/adapter.hpp @@ -11,6 +11,7 @@ #ifndef UR_HIP_ADAPTER_HPP_INCLUDED #define UR_HIP_ADAPTER_HPP_INCLUDED +#include "common/ur_ref_counter.hpp" #include "logger/ur_logger.hpp" #include "platform.hpp" @@ -18,10 +19,14 @@ #include struct ur_adapter_handle_t_ : ur::hip::handle_base { - std::atomic RefCount = 0; logger::Logger &logger; std::unique_ptr Platform; ur_adapter_handle_t_(); + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; namespace ur::hip { diff --git a/unified-runtime/source/adapters/hip/command_buffer.cpp b/unified-runtime/source/adapters/hip/command_buffer.cpp index af058cdde4cf0..28488cb6b4007 100644 --- a/unified-runtime/source/adapters/hip/command_buffer.cpp +++ b/unified-runtime/source/adapters/hip/command_buffer.cpp @@ -26,7 +26,7 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_( bool IsInOrder) : handle_base(), Context(hContext), Device(hDevice), IsUpdatable(IsUpdatable), IsInOrder(IsInOrder), HIPGraph{nullptr}, - HIPGraphExec{nullptr}, RefCount{1}, NextSyncPoint{0} { + HIPGraphExec{nullptr}, NextSyncPoint{0} { urContextRetain(hContext); } @@ -266,13 +266,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - hCommandBuffer->incrementReferenceCount(); + hCommandBuffer->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - if (hCommandBuffer->decrementReferenceCount() == 0) { + if (hCommandBuffer->getRefCounter().decrement() == 0) { delete hCommandBuffer; } return UR_RESULT_SUCCESS; @@ -1045,7 +1045,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp( switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(hCommandBuffer->getReferenceCount()); + return ReturnValue(hCommandBuffer->getRefCounter().getCount()); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/hip/command_buffer.hpp b/unified-runtime/source/adapters/hip/command_buffer.hpp index 3d0047adee013..18ce37885bd32 100644 --- a/unified-runtime/source/adapters/hip/command_buffer.hpp +++ b/unified-runtime/source/adapters/hip/command_buffer.hpp @@ -109,9 +109,8 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base { registerSyncPoint(SyncPoint, std::move(HIPNode)); return SyncPoint; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - uint32_t getReferenceCount() const noexcept { return RefCount; } + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } // UR context associated with this command-buffer ur_context_handle_t Context; @@ -125,9 +124,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base { hipGraph_t HIPGraph; // HIP Graph Exec handle hipGraphExec_t HIPGraphExec = nullptr; - // Atomic variable counting the number of reference to this command_buffer - // using std::atomic prevents data race when incrementing/decrementing. - std::atomic_uint32_t RefCount; // Ordered map of sync_points to ur_events std::map SyncPoints; @@ -138,4 +134,7 @@ struct ur_exp_command_buffer_handle_t_ : ur::hip::handle_base { // Handles to individual commands in the command-buffer std::vector> CommandHandles; + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/hip/context.cpp b/unified-runtime/source/adapters/hip/context.cpp index ae3c039a0abfe..2663840c5314f 100644 --- a/unified-runtime/source/adapters/hip/context.cpp +++ b/unified-runtime/source/adapters/hip/context.cpp @@ -68,7 +68,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, return ReturnValue(hContext->getDevices().data(), hContext->getDevices().size()); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(hContext->getReferenceCount()); + return ReturnValue(hContext->getRefCounter().getCount()); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // 2D USM memcpy is supported. return ReturnValue(true); @@ -85,7 +85,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urContextRelease(ur_context_handle_t hContext) { - if (hContext->decrementReferenceCount() == 0) { + if (hContext->getRefCounter().decrement() == 0) { hContext->invokeExtendedDeleters(); delete hContext; } @@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) { UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - assert(hContext->getReferenceCount() > 0); + assert(hContext->getRefCounter().getCount() > 0); - hContext->incrementReferenceCount(); + hContext->getRefCounter().increment(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/context.hpp b/unified-runtime/source/adapters/hip/context.hpp index 3c011cec43a1b..8ce847e8d1a76 100644 --- a/unified-runtime/source/adapters/hip/context.hpp +++ b/unified-runtime/source/adapters/hip/context.hpp @@ -13,6 +13,7 @@ #include "adapter.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "device.hpp" #include "platform.hpp" @@ -88,10 +89,8 @@ struct ur_context_handle_t_ : ur::hip::handle_base { std::vector Devices; - std::atomic_uint32_t RefCount; - ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) - : handle_base(), Devices{Devs, Devs + NumDevices}, RefCount{1} { + : handle_base(), Devices{Devs, Devs + NumDevices} { UR_CHECK_ERROR(urAdapterRetain(ur::hip::adapter)); }; @@ -125,11 +124,7 @@ struct ur_context_handle_t_ : ur::hip::handle_base { return std::distance(Devices.begin(), It); } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } void addPool(ur_usm_pool_handle_t Pool); @@ -141,4 +136,5 @@ struct ur_context_handle_t_ : ur::hip::handle_base { std::mutex Mutex; std::vector ExtendedDeleters; std::set PoolHandles; + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/hip/device.cpp b/unified-runtime/source/adapters/hip/device.cpp index f991dbf7db416..06b224dddfe00 100644 --- a/unified-runtime/source/adapters/hip/device.cpp +++ b/unified-runtime/source/adapters/hip/device.cpp @@ -477,7 +477,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue("HIP"); } case UR_DEVICE_INFO_REFERENCE_COUNT: { - return ReturnValue(hDevice->getReferenceCount()); + return ReturnValue(hDevice->getRefCounter().getCount()); } case UR_DEVICE_INFO_VERSION: { std::stringstream S; diff --git a/unified-runtime/source/adapters/hip/device.hpp b/unified-runtime/source/adapters/hip/device.hpp index f03fcdd8463b0..93954c3414fd1 100644 --- a/unified-runtime/source/adapters/hip/device.hpp +++ b/unified-runtime/source/adapters/hip/device.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include @@ -22,7 +23,7 @@ struct ur_device_handle_t_ : ur::hip::handle_base { using native_type = hipDevice_t; native_type HIPDevice; - std::atomic_uint32_t RefCount; + UR_ReferenceCounter RefCounter; ur_platform_handle_t Platform; hipEvent_t EvBase; // HIP event used as base counter uint32_t DeviceIndex; @@ -38,8 +39,8 @@ struct ur_device_handle_t_ : ur::hip::handle_base { public: ur_device_handle_t_(native_type HipDevice, hipEvent_t EvBase, ur_platform_handle_t Platform, uint32_t DeviceIndex) - : handle_base(), HIPDevice(HipDevice), RefCount{1}, Platform(Platform), - EvBase(EvBase), DeviceIndex(DeviceIndex) { + : handle_base(), HIPDevice(HipDevice), Platform(Platform), EvBase(EvBase), + DeviceIndex(DeviceIndex) { UR_CHECK_ERROR(hipDeviceGetAttribute( &MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice)); @@ -99,7 +100,7 @@ struct ur_device_handle_t_ : ur::hip::handle_base { native_type get() const noexcept { return HIPDevice; }; - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_platform_handle_t getPlatform() const noexcept { return Platform; }; diff --git a/unified-runtime/source/adapters/hip/event.cpp b/unified-runtime/source/adapters/hip/event.cpp index f7da65d6fc993..3c6920b74ecc4 100644 --- a/unified-runtime/source/adapters/hip/event.cpp +++ b/unified-runtime/source/adapters/hip/event.cpp @@ -189,7 +189,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, case UR_EVENT_INFO_COMMAND_TYPE: return ReturnValue(hEvent->getCommandType()); case UR_EVENT_INFO_REFERENCE_COUNT: - return ReturnValue(hEvent->getReferenceCount()); + return ReturnValue(hEvent->getRefCounter().getCount()); case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: { try { return ReturnValue(hEvent->getExecutionStatus()); @@ -245,9 +245,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventSetCallback(ur_event_handle_t, } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - const auto RefCount = hEvent->incrementReferenceCount(); - - if (RefCount == 0) { + if (hEvent->getRefCounter().increment() == 0) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } @@ -257,12 +255,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - if (hEvent->getReferenceCount() == 0) { + if (hEvent->getRefCounter().getCount() == 0) { return UR_RESULT_ERROR_INVALID_EVENT; } // decrement ref count. If it is 0, delete the event. - if (hEvent->decrementReferenceCount() == 0) { + if (hEvent->getRefCounter().decrement() == 0) { std::unique_ptr event_ptr{hEvent}; ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT; try { diff --git a/unified-runtime/source/adapters/hip/event.hpp b/unified-runtime/source/adapters/hip/event.hpp index 63fe5ca273449..6b0c1fb40f740 100644 --- a/unified-runtime/source/adapters/hip/event.hpp +++ b/unified-runtime/source/adapters/hip/event.hpp @@ -11,6 +11,7 @@ #include "common.hpp" #include "queue.hpp" +#include "common/ur_ref_counter.hpp" /// UR Event mapping to hipEvent_t struct ur_event_handle_t_ : ur::hip::handle_base { @@ -56,16 +57,12 @@ struct ur_event_handle_t_ : ur::hip::handle_base { ur_command_t getCommandType() const noexcept { return CommandType; } ur_context_handle_t getContext() const noexcept { return Context; }; uint32_t getEventId() const noexcept { return EventId; } - - // Reference counting. - uint32_t getReferenceCount() const noexcept { return RefCount; } - uint32_t incrementReferenceCount() { return ++RefCount; } - uint32_t decrementReferenceCount() { return --RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } private: ur_command_t CommandType; // The type of command associated with event. - std::atomic_uint32_t RefCount{1}; // Event reference count. + UR_ReferenceCounter RefCounter; bool HasOwnership{true}; // Signifies if event owns the native type. bool HasProfiling{false}; // Signifies if event has profiling information. diff --git a/unified-runtime/source/adapters/hip/kernel.cpp b/unified-runtime/source/adapters/hip/kernel.cpp index 39cddecd1efd5..243b7b7b25a4d 100644 --- a/unified-runtime/source/adapters/hip/kernel.cpp +++ b/unified-runtime/source/adapters/hip/kernel.cpp @@ -127,9 +127,10 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - UR_ASSERT(hKernel->getReferenceCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL); + UR_ASSERT(hKernel->getRefCounter().getCount() > 0u, + UR_RESULT_ERROR_INVALID_KERNEL); - hKernel->incrementReferenceCount(); + hKernel->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -137,10 +138,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(ur_kernel_handle_t hKernel) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - UR_ASSERT(hKernel->getReferenceCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL); + UR_ASSERT(hKernel->getRefCounter().getCount() != 0, + UR_RESULT_ERROR_INVALID_KERNEL); // decrement ref count. If it is 0, delete the program. - if (hKernel->decrementReferenceCount() == 0) { + if (hKernel->getRefCounter().decrement() == 0) { // no internal cuda resources to clean up. Just delete it. delete hKernel; return UR_RESULT_SUCCESS; @@ -201,7 +203,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, case UR_KERNEL_INFO_NUM_ARGS: return ReturnValue(hKernel->getNumArgs()); case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(hKernel->getReferenceCount()); + return ReturnValue(hKernel->getRefCounter().getCount()); case UR_KERNEL_INFO_CONTEXT: return ReturnValue(hKernel->getContext()); case UR_KERNEL_INFO_PROGRAM: diff --git a/unified-runtime/source/adapters/hip/kernel.hpp b/unified-runtime/source/adapters/hip/kernel.hpp index 569a2dc8c0bf4..fa8f1ef6f2c2b 100644 --- a/unified-runtime/source/adapters/hip/kernel.hpp +++ b/unified-runtime/source/adapters/hip/kernel.hpp @@ -12,10 +12,10 @@ #include #include -#include #include #include +#include "common/ur_ref_counter.hpp" #include "program.hpp" /// Implementation of a UR Kernel for HIP @@ -41,7 +41,6 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base { std::string Name; ur_context_handle_t Context; ur_program_handle_t Program; - std::atomic_uint32_t RefCount; static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u; size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions]; @@ -267,11 +266,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base { ur_program_handle_t getProgram() const noexcept { return Program; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } native_type get() const noexcept { return Function; }; @@ -310,4 +305,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base { } uint32_t getLocalSize() const noexcept { return Args.getLocalSize(); } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/hip/memory.hpp b/unified-runtime/source/adapters/hip/memory.hpp index b2367edf8f4b7..587109353ec27 100644 --- a/unified-runtime/source/adapters/hip/memory.hpp +++ b/unified-runtime/source/adapters/hip/memory.hpp @@ -10,6 +10,8 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" + #include "context.hpp" #include "event.hpp" #include @@ -316,9 +318,6 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { // Context where the memory object is accessible ur_context Context; - /// Reference counting of the handler - std::atomic_uint32_t RefCount; - // Original mem flags passed ur_mem_flags_t MemFlags; @@ -347,7 +346,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { /// Constructs the UR mem handler for a non-typed allocation ("buffer") ur_mem_handle_t_(ur_context_handle_t Ctxt, ur_mem_flags_t MemFlags, BufferMem::AllocMode Mode, void *HostPtr, size_t Size) - : Context{Ctxt}, RefCount{1}, MemFlags{MemFlags}, + : Context{Ctxt}, MemFlags{MemFlags}, HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false), Mem{std::in_place_type, Ctxt, this, Mode, HostPtr, Size} { urContextRetain(Context); @@ -355,9 +354,9 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { // Subbuffer constructor ur_mem_handle_t_(ur_mem Parent, size_t SubBufferOffset) - : handle_base(), Context{Parent->Context}, RefCount{1}, - MemFlags{Parent->MemFlags}, HaveMigratedToDeviceSinceLastWrite( - Parent->Context->Devices.size(), false), + : handle_base(), Context{Parent->Context}, MemFlags{Parent->MemFlags}, + HaveMigratedToDeviceSinceLastWrite(Parent->Context->Devices.size(), + false), Mem{BufferMem{std::get(Parent->Mem)}} { auto &SubBuffer = std::get(Mem); SubBuffer.Parent = Parent; @@ -378,7 +377,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { ur_mem_handle_t_(ur_context Ctxt, ur_mem_flags_t MemFlags, ur_image_format_t ImageFormat, ur_image_desc_t ImageDesc, void *HostPtr) - : Context{Ctxt}, RefCount{1}, MemFlags{MemFlags}, + : Context{Ctxt}, MemFlags{MemFlags}, HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false), Mem{std::in_place_type, Ctxt, @@ -419,11 +418,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { ur_context getContext() const noexcept { return Context; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } void setLastQueueWritingToMemObj(ur_queue_handle_t WritingQueue) { if (LastQueueWritingToMemObj != nullptr) { @@ -436,4 +431,7 @@ struct ur_mem_handle_t_ : ur::hip::handle_base { Device == WritingQueue->getDevice(); } } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/hip/physical_mem.hpp b/unified-runtime/source/adapters/hip/physical_mem.hpp index 47342ae206510..0ecefd55f58f6 100644 --- a/unified-runtime/source/adapters/hip/physical_mem.hpp +++ b/unified-runtime/source/adapters/hip/physical_mem.hpp @@ -10,6 +10,8 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" + #include "device.hpp" #include "platform.hpp" @@ -22,9 +24,8 @@ struct ur_physical_mem_handle_t_ : ur::hip::handle_base { ur_physical_mem_handle_t_() : handle_base(), RefCount(1) {} - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } - uint32_t getReferenceCount() const noexcept { return RefCount; } +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/hip/program.cpp b/unified-runtime/source/adapters/hip/program.cpp index 94e2f46440e96..3073343320cad 100644 --- a/unified-runtime/source/adapters/hip/program.cpp +++ b/unified-runtime/source/adapters/hip/program.cpp @@ -385,7 +385,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(hProgram->getReferenceCount()); + return ReturnValue(hProgram-- > getRefCounter().getCount()); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(hProgram->Context); case UR_PROGRAM_INFO_NUM_DEVICES: @@ -418,8 +418,9 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - UR_ASSERT(hProgram->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_PROGRAM); - hProgram->incrementReferenceCount(); + UR_ASSERT(hProgram->getRefCounter().getCount() > 0, + UR_RESULT_ERROR_INVALID_PROGRAM); + hProgram->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -430,11 +431,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramRelease(ur_program_handle_t hProgram) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - UR_ASSERT(hProgram->getReferenceCount() != 0, + UR_ASSERT(hProgram->getRefCounter().getCount() != 0, UR_RESULT_ERROR_INVALID_PROGRAM); // decrement ref count. If it is 0, delete the program. - if (hProgram->decrementReferenceCount() == 0) { + if (hProgram->getRefCounter().decrement() == 0) { std::unique_ptr ProgramPtr{hProgram}; try { ScopedDevice Active(hProgram->getDevice()); diff --git a/unified-runtime/source/adapters/hip/program.hpp b/unified-runtime/source/adapters/hip/program.hpp index c94818c6c43ab..57c4895289e62 100644 --- a/unified-runtime/source/adapters/hip/program.hpp +++ b/unified-runtime/source/adapters/hip/program.hpp @@ -11,9 +11,9 @@ #include -#include #include +#include "common/ur_ref_counter.hpp" #include "context.hpp" /// Implementation of UR Program on HIP Module object @@ -22,7 +22,6 @@ struct ur_program_handle_t_ : ur::hip::handle_base { native_type Module; const char *Binary; size_t BinarySizeInBytes; - std::atomic_uint32_t RefCount; ur_context_handle_t Context; ur_device_handle_t Device; std::string ExecutableCache; @@ -49,7 +48,7 @@ struct ur_program_handle_t_ : ur::hip::handle_base { ur_program_handle_t_(ur_context_handle_t Ctxt, ur_device_handle_t Device) : handle_base(), Module{nullptr}, Binary{}, BinarySizeInBytes{0}, - RefCount{1}, Context{Ctxt}, Device{Device}, KernelReqdWorkGroupSizeMD{}, + Context{Ctxt}, Device{Device}, KernelReqdWorkGroupSizeMD{}, KernelReqdSubGroupSizeMD{} { urContextRetain(Context); @@ -71,13 +70,12 @@ struct ur_program_handle_t_ : ur::hip::handle_base { native_type get() const noexcept { return Module; }; - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_result_t getGlobalVariablePointer(const char *name, hipDeviceptr_t *DeviceGlobal, size_t *DeviceGlobalSize); + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/hip/queue.cpp b/unified-runtime/source/adapters/hip/queue.cpp index cca434b8f6ab0..634c3094e1334 100644 --- a/unified-runtime/source/adapters/hip/queue.cpp +++ b/unified-runtime/source/adapters/hip/queue.cpp @@ -107,7 +107,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hQueue->Device); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(hQueue->getReferenceCount()); + return ReturnValue(hQueue->getRefCounter().getCount()); case UR_QUEUE_INFO_FLAGS: return ReturnValue(hQueue->URFlags); case UR_QUEUE_INFO_EMPTY: { @@ -135,14 +135,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - UR_ASSERT(hQueue->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_QUEUE); + UR_ASSERT(hQueue->getRefCounter().getCount() > 0, + UR_RESULT_ERROR_INVALID_QUEUE); - hQueue->incrementReferenceCount(); + hQueue->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { - if (hQueue->decrementReferenceCount() > 0) { + if (hQueue->getRefCounter().decrement() > 0) { return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/sampler.cpp b/unified-runtime/source/adapters/hip/sampler.cpp index addcdb031402e..35c0fc7d6f131 100644 --- a/unified-runtime/source/adapters/hip/sampler.cpp +++ b/unified-runtime/source/adapters/hip/sampler.cpp @@ -44,7 +44,7 @@ ur_result_t urSamplerGetInfo(ur_sampler_handle_t hSampler, switch (propName) { case UR_SAMPLER_INFO_REFERENCE_COUNT: - return ReturnValue(hSampler->getReferenceCount()); + return ReturnValue(hSampler->getRefCounter().getCount()); case UR_SAMPLER_INFO_CONTEXT: return ReturnValue(hSampler->Context); case UR_SAMPLER_INFO_NORMALIZED_COORDS: { @@ -67,19 +67,19 @@ ur_result_t urSamplerGetInfo(ur_sampler_handle_t hSampler, } ur_result_t urSamplerRetain(ur_sampler_handle_t hSampler) { - hSampler->incrementReferenceCount(); + hSampler->getRefCounter().increment(); return UR_RESULT_SUCCESS; } ur_result_t urSamplerRelease(ur_sampler_handle_t hSampler) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - if (hSampler->getReferenceCount() == 0) { + if (hSampler->getRefCounter().getCount() == 0) { return UR_RESULT_ERROR_INVALID_SAMPLER; } // decrement ref count. If it is 0, delete the sampler. - if (hSampler->decrementReferenceCount() == 0) { + if (hSampler->getRefCounter().decrement() == 0) { delete hSampler; } diff --git a/unified-runtime/source/adapters/hip/sampler.hpp b/unified-runtime/source/adapters/hip/sampler.hpp index 1a1defea851ed..cf00ee8ddceb9 100644 --- a/unified-runtime/source/adapters/hip/sampler.hpp +++ b/unified-runtime/source/adapters/hip/sampler.hpp @@ -10,6 +10,7 @@ #include +#include "common/ur_ref_counter.hpp" #include "context.hpp" /// Implementation of samplers for HIP @@ -26,7 +27,6 @@ /// | 1 | filter mode /// | 0 | normalize coords struct ur_sampler_handle_t_ : ur::hip::handle_base { - std::atomic_uint32_t RefCount; uint32_t Props; float MinMipmapLevelClamp; float MaxMipmapLevelClamp; @@ -34,13 +34,9 @@ struct ur_sampler_handle_t_ : ur::hip::handle_base { ur_context_handle_t Context; ur_sampler_handle_t_(ur_context_handle_t Context) - : handle_base(), RefCount(1), Props(0), Context(Context) {} + : handle_base(), Props(0), Context(Context) {} - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_bool_t isNormalizedCoords() const noexcept { return static_cast(Props & 0b1); @@ -69,4 +65,7 @@ struct ur_sampler_handle_t_ : ur::hip::handle_base { return static_cast((Props >> 12) & 0b1); } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/hip/usm.cpp b/unified-runtime/source/adapters/hip/usm.cpp index 1945f6c24e055..2d9ba783e7a03 100644 --- a/unified-runtime/source/adapters/hip/usm.cpp +++ b/unified-runtime/source/adapters/hip/usm.cpp @@ -439,14 +439,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate( UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRetain( /// [in] pointer to USM memory pool ur_usm_pool_handle_t Pool) { - Pool->incrementReferenceCount(); + Pool->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRelease( /// [in] pointer to USM memory pool ur_usm_pool_handle_t Pool) { - if (Pool->decrementReferenceCount() > 0) { + if (Pool->getRefCounter().decrement() > 0) { return UR_RESULT_SUCCESS; } Pool->Context->removePool(Pool); @@ -469,7 +469,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolGetInfo( switch (propName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(hPool->getReferenceCount()); + return ReturnValue(hPool->getRefCounter().getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(hPool->Context); diff --git a/unified-runtime/source/adapters/hip/usm.hpp b/unified-runtime/source/adapters/hip/usm.hpp index 8a3a36ffcaaba..a9a7ba914b8b5 100644 --- a/unified-runtime/source/adapters/hip/usm.hpp +++ b/unified-runtime/source/adapters/hip/usm.hpp @@ -9,6 +9,7 @@ //===-----------------------------------------------------------------===// #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include #include @@ -16,8 +17,6 @@ usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig(); struct ur_usm_pool_handle_t_ : ur::hip::handle_base { - std::atomic_uint32_t RefCount = 1; - ur_context_handle_t Context = nullptr; usm::DisjointPoolAllConfigs DisjointPoolConfigs = @@ -30,13 +29,12 @@ struct ur_usm_pool_handle_t_ : ur::hip::handle_base { ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_usm_pool_desc_t *PoolDesc); - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } bool hasUMFPool(umf_memory_pool_t *umf_pool); + +private: + UR_ReferenceCounter RefCounter; }; // Implements memory allocation via driver API for USM allocator interface diff --git a/unified-runtime/source/adapters/level_zero/adapter.cpp b/unified-runtime/source/adapters/level_zero/adapter.cpp index c3a8a8baf5b6a..b6dfdeb872e40 100644 --- a/unified-runtime/source/adapters/level_zero/adapter.cpp +++ b/unified-runtime/source/adapters/level_zero/adapter.cpp @@ -668,7 +668,7 @@ ur_result_t urAdapterGet( if (NumEntries > 0 && Adapters) { if (GlobalAdapter) { std::lock_guard Lock{GlobalAdapter->Mutex}; - if (GlobalAdapter->RefCount++ == 0) { + if (GlobalAdapter->getRefCounter().increment() == 0) { adapterStateInit(); } } else { @@ -677,7 +677,7 @@ ur_result_t urAdapterGet( // cleanup. GlobalAdapter = new ur_adapter_handle_t_(); std::lock_guard Lock{GlobalAdapter->Mutex}; - if (GlobalAdapter->RefCount++ == 0) { + if (GlobalAdapter->getRefCounter().increment() == 0) { adapterStateInit(); } std::atexit(globalAdapterOnDemandCleanup); @@ -696,7 +696,7 @@ ur_result_t urAdapterRelease(ur_adapter_handle_t) { // Check first if the Adapter pointer is valid if (GlobalAdapter) { std::lock_guard Lock{GlobalAdapter->Mutex}; - if (--GlobalAdapter->RefCount == 0) { + if (GlobalAdapter->getRefCounter().decrement() == 0) { auto result = adapterStateTeardown(); #ifdef UR_STATIC_LEVEL_ZERO // Given static linking of the L0 Loader, we must delay the loader's @@ -713,7 +713,7 @@ ur_result_t urAdapterRelease(ur_adapter_handle_t) { ur_result_t urAdapterRetain(ur_adapter_handle_t) { if (GlobalAdapter) { std::lock_guard Lock{GlobalAdapter->Mutex}; - GlobalAdapter->RefCount++; + GlobalAdapter->getRefCounter().increment(); } return UR_RESULT_SUCCESS; @@ -743,7 +743,7 @@ ur_result_t urAdapterGetInfo(ur_adapter_handle_t, ur_adapter_info_t PropName, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_LEVEL_ZERO); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(GlobalAdapter->RefCount.load()); + return ReturnValue(GlobalAdapter->getRefCounter().getCount()); case UR_ADAPTER_INFO_VERSION: { #ifdef UR_ADAPTER_LEVEL_ZERO_V2 uint32_t adapterVersion = 2; diff --git a/unified-runtime/source/adapters/level_zero/adapter.hpp b/unified-runtime/source/adapters/level_zero/adapter.hpp index cf96e672c56e4..eac6de82333ee 100644 --- a/unified-runtime/source/adapters/level_zero/adapter.hpp +++ b/unified-runtime/source/adapters/level_zero/adapter.hpp @@ -9,13 +9,14 @@ //===----------------------------------------------------------------------===// #pragma once +#include +#include + +#include "common/ur_ref_counter.hpp" #include "logger/ur_logger.hpp" #include "ur_interface_loader.hpp" -#include #include #include -#include -#include #include #include #include @@ -27,7 +28,6 @@ class ur_legacy_sink; struct ur_adapter_handle_t_ : ur::handle_base { ur_adapter_handle_t_(); - std::atomic RefCount = 0; std::mutex Mutex; zes_pfnDriverGetDeviceByUuidExp_t getDeviceByUUIdFunctionPtr = nullptr; @@ -47,6 +47,11 @@ struct ur_adapter_handle_t_ : ur::handle_base { ZeCache> PlatformCache; logger::Logger &logger; HMODULE processHandle = nullptr; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; extern ur_adapter_handle_t_ *GlobalAdapter; diff --git a/unified-runtime/source/adapters/level_zero/async_alloc.cpp b/unified-runtime/source/adapters/level_zero/async_alloc.cpp index 204b43c3bcc79..774d959a48e5d 100644 --- a/unified-runtime/source/adapters/level_zero/async_alloc.cpp +++ b/unified-runtime/source/adapters/level_zero/async_alloc.cpp @@ -247,7 +247,7 @@ ur_result_t urEnqueueUSMFreeExp( } size_t size = umfPoolMallocUsableSize(hPool, Mem); - (*Event)->RefCount.increment(); + (*Event)->getRefCounter().increment(); usmPool->AsyncPool.insert(Mem, size, *Event, Queue); // Signal that USM free event was finished diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/command_buffer.cpp index 020afb90564ff..5e871b71806e2 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.cpp @@ -823,13 +823,13 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device, ur_result_t urCommandBufferRetainExp(ur_exp_command_buffer_handle_t CommandBuffer) { - CommandBuffer->RefCount.increment(); + CommandBuffer->getRefCounter().increment(); return UR_RESULT_SUCCESS; } ur_result_t urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t CommandBuffer) { - if (!CommandBuffer->RefCount.decrementAndTest()) + if (!CommandBuffer->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; CommandBuffer->cleanupCommandBufferResources(); @@ -1643,7 +1643,7 @@ ur_result_t enqueueImmediateAppendPath( if (CommandBuffer->CurrentSubmissionEvent) { UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent)); } - (*Event)->RefCount.increment(); + (*Event)->getRefCounter().increment(); CommandBuffer->CurrentSubmissionEvent = *Event; UR_CALL(Queue->executeCommandList(CommandListHelper, false, false)); @@ -1726,7 +1726,7 @@ ur_result_t enqueueWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer, if (CommandBuffer->CurrentSubmissionEvent) { UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent)); } - (*Event)->RefCount.increment(); + (*Event)->getRefCounter().increment(); CommandBuffer->CurrentSubmissionEvent = *Event; UR_CALL(Queue->executeCommandList(SignalCommandList, false /*IsBlocking*/, @@ -1850,7 +1850,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer, switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()}); + return ReturnValue(uint32_t{hCommandBuffer->getRefCounter().getCount()}); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.hpp b/unified-runtime/source/adapters/level_zero/command_buffer.hpp index f7b62a9c8dd1e..f6ad7ae3b9a59 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.hpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.hpp @@ -16,7 +16,7 @@ #include #include "common.hpp" - +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include "kernel.hpp" #include "queue.hpp" @@ -149,4 +149,9 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object { // Track handle objects to free when command-buffer is destroyed. std::vector> CommandHandles; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/level_zero/common.hpp b/unified-runtime/source/adapters/level_zero/common.hpp index 33a1072e217a9..b918b995be227 100644 --- a/unified-runtime/source/adapters/level_zero/common.hpp +++ b/unified-runtime/source/adapters/level_zero/common.hpp @@ -34,6 +34,7 @@ #include #include +#include "common/ur_ref_counter.hpp" #include "logger/ur_logger.hpp" #include "ur_interface_loader.hpp" @@ -213,55 +214,9 @@ void zeParseError(ze_result_t ZeError, const char *&ErrorString); #define ZE_CALL_NOCHECK_NAME(ZeName, ZeArgs, callName) \ ZeCall().doCall(ZeName ZeArgs, callName, #ZeArgs, false) -// This wrapper around std::atomic is created to limit operations with reference -// counter and to make allowed operations more transparent in terms of -// thread-safety in the plugin. increment() and load() operations do not need a -// mutex guard around them since the underlying data is already atomic. -// decrementAndTest() method is used to guard a code which needs to be -// executed when object's ref count becomes zero after release. This method also -// doesn't need a mutex guard because decrement operation is atomic and only one -// thread can reach ref count equal to zero, i.e. only a single thread can pass -// through this check. -struct ReferenceCounter { - ReferenceCounter() : RefCount{1} {} - - // Reset the counter to the initial value. - void reset() { RefCount = 1; } - - // Used when retaining an object. - void increment() { RefCount++; } - - // Supposed to be used in ur*GetInfo* methods where ref count value is - // requested. - uint32_t load() { return RefCount.load(); } - - // This method allows to guard a code which needs to be executed when object's - // ref count becomes zero after release. It is important to notice that only a - // single thread can pass through this check. This is true because of several - // reasons: - // 1. Decrement operation is executed atomically. - // 2. It is not allowed to retain an object after its refcount reaches zero. - // 3. It is not allowed to release an object more times than the value of - // the ref count. - // 2. and 3. basically means that we can't use an object at all as soon as its - // refcount reaches zero. Using this check guarantees that code for deleting - // an object and releasing its resources is executed once by a single thread - // and we don't need to use any mutexes to guard access to this object in the - // scope after this check. Of course if we access another objects in this code - // (not the one which is being deleted) then access to these objects must be - // guarded, for example with a mutex. - bool decrementAndTest() { return --RefCount == 0; } - -private: - std::atomic RefCount; -}; - // Base class to store common data struct ur_object : ur::handle_base { - ur_object() : handle_base(), RefCount{} {} - - // Must be atomic to prevent data race when incrementing/decrementing. - ReferenceCounter RefCount; + ur_object() : handle_base() {} // This mutex protects accesses to all the non-const member variables. // Exclusive access is required to modify any of these members. @@ -296,6 +251,11 @@ struct MemAllocRecord : ur_object { // TODO: this should go away when memory isolation issue is fixed in the Level // Zero runtime. ur_context_handle_t Context; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; extern usm::DisjointPoolAllConfigs DisjointPoolConfigInstance; diff --git a/unified-runtime/source/adapters/level_zero/context.cpp b/unified-runtime/source/adapters/level_zero/context.cpp index 3209b8b789155..796715aca3e44 100644 --- a/unified-runtime/source/adapters/level_zero/context.cpp +++ b/unified-runtime/source/adapters/level_zero/context.cpp @@ -61,7 +61,7 @@ ur_result_t urContextRetain( /// [in] handle of the context to get a reference of. ur_context_handle_t Context) { - Context->RefCount.increment(); + Context->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -113,7 +113,7 @@ ur_result_t urContextGetInfo( case UR_CONTEXT_INFO_NUM_DEVICES: return ReturnValue(uint32_t(Context->Devices.size())); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Context->RefCount.load()}); + return ReturnValue(uint32_t{Context->getRefCounter().getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // 2D USM memcpy is supported. return ReturnValue(uint8_t{UseMemcpy2DOperations}); @@ -251,7 +251,7 @@ ur_device_handle_t ur_context_handle_t_::getRootDevice() const { // from the list of tracked contexts. ur_result_t ContextReleaseHelper(ur_context_handle_t Context) { - if (!Context->RefCount.decrementAndTest()) + if (!Context->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; if (IndirectAccessTrackingEnabled) { diff --git a/unified-runtime/source/adapters/level_zero/context.hpp b/unified-runtime/source/adapters/level_zero/context.hpp index 86e0ea27b5c3e..ab1e52599a602 100644 --- a/unified-runtime/source/adapters/level_zero/context.hpp +++ b/unified-runtime/source/adapters/level_zero/context.hpp @@ -23,6 +23,7 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "queue.hpp" #include "usm.hpp" @@ -358,6 +359,8 @@ struct ur_context_handle_t_ : ur_object { // Get handle to the L0 context ze_context_handle_t getZeHandle() const; + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: enum EventFlags { EVENT_FLAG_HOST_VISIBLE = UR_BIT(0), @@ -404,6 +407,8 @@ struct ur_context_handle_t_ : ur_object { return &EventCaches[index]; } + + UR_ReferenceCounter RefCounter; }; // Helper function to release the context, a caller must lock the platform-level diff --git a/unified-runtime/source/adapters/level_zero/device.cpp b/unified-runtime/source/adapters/level_zero/device.cpp index e54caf59f3b07..bf10e06c6f794 100644 --- a/unified-runtime/source/adapters/level_zero/device.cpp +++ b/unified-runtime/source/adapters/level_zero/device.cpp @@ -449,7 +449,7 @@ ur_result_t urDeviceGetInfo( return ReturnValue((uint32_t)Device->SubDevices.size()); } case UR_DEVICE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Device->RefCount.load()}); + return ReturnValue(uint32_t{Device->getRefCounter().getCount()}); case UR_DEVICE_INFO_SUPPORTED_PARTITIONS: { // SYCL spec says: if this SYCL device cannot be partitioned into at least // two sub devices then the returned vector must be empty. @@ -1588,7 +1588,7 @@ ur_result_t urDeviceGetGlobalTimestamps( ur_result_t urDeviceRetain(ur_device_handle_t Device) { // The root-device ref-count remains unchanged (always 1). if (Device->isSubDevice()) { - Device->RefCount.increment(); + Device->getRefCounter().increment(); } return UR_RESULT_SUCCESS; } @@ -1596,7 +1596,7 @@ ur_result_t urDeviceRetain(ur_device_handle_t Device) { ur_result_t urDeviceRelease(ur_device_handle_t Device) { // Root devices are destroyed during the piTearDown process. if (Device->isSubDevice()) { - if (Device->RefCount.decrementAndTest()) { + if (Device->getRefCounter().decrement() == 0) { delete Device; } } diff --git a/unified-runtime/source/adapters/level_zero/device.hpp b/unified-runtime/source/adapters/level_zero/device.hpp index 1ca19ed80cd4f..d567b82f73d18 100644 --- a/unified-runtime/source/adapters/level_zero/device.hpp +++ b/unified-runtime/source/adapters/level_zero/device.hpp @@ -20,6 +20,7 @@ #include "adapters/level_zero/platform.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include #include #include @@ -212,6 +213,8 @@ struct ur_device_handle_t_ : ur_object { return ValidBits == 64 ? ~0ULL : (1ULL << ValidBits) - 1ULL; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + // Cache of the immutable device properties. ZeCache> ZeDeviceProperties; ZeCache> ZeDeviceComputeProperties; @@ -238,6 +241,9 @@ struct ur_device_handle_t_ : ur_object { // unique ephemeral identifer of the device in the adapter std::optional Id; + +private: + UR_ReferenceCounter RefCounter; }; inline std::vector diff --git a/unified-runtime/source/adapters/level_zero/event.cpp b/unified-runtime/source/adapters/level_zero/event.cpp index f06cae5ec0cb3..1eeae7c19bf8e 100644 --- a/unified-runtime/source/adapters/level_zero/event.cpp +++ b/unified-runtime/source/adapters/level_zero/event.cpp @@ -505,7 +505,7 @@ ur_result_t urEventGetInfo( return ReturnValue(Result); } case UR_EVENT_INFO_REFERENCE_COUNT: { - return ReturnValue(Event->RefCount.load()); + return ReturnValue(Event->getRefCounter().getCount()); } default: UR_LOG(ERR, "Unsupported ParamName in urEventGetInfo: ParamName={}(0x{})", @@ -873,8 +873,8 @@ urEventWait(uint32_t NumEvents, ur_result_t /// [in] handle of the event object urEventRetain(/** [in] handle of the event object */ ur_event_handle_t Event) { - Event->RefCountExternal++; - Event->RefCount.increment(); + Event->getRefCounterExternal().increment(); + Event->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -882,7 +882,7 @@ urEventRetain(/** [in] handle of the event object */ ur_event_handle_t Event) { ur_result_t urEventRelease(/** [in] handle of the event object */ ur_event_handle_t Event) { - Event->RefCountExternal--; + Event->getRefCounterExternal().decrement(); bool isEventsWaitCompleted = (Event->CommandType == UR_COMMAND_EVENTS_WAIT || Event->CommandType == UR_COMMAND_EVENTS_WAIT_WITH_BARRIER) && @@ -941,7 +941,7 @@ ur_result_t urExtEventCreate( false /*CounterBasedEventEnabled*/, false /*ForceDisableProfiling*/, false)); - (*Event)->RefCountExternal++; + (*Event)->getRefCounterExternal().increment(); if (!(*Event)->CounterBasedEventsEnabled) ZE2UR_CALL(zeEventHostSignal, ((*Event)->ZeEvent)); return UR_RESULT_SUCCESS; @@ -963,7 +963,7 @@ ur_result_t urEventCreateWithNativeHandle( false /*CounterBasedEventEnabled*/, false /*ForceDisableProfiling*/, false)); - (*Event)->RefCountExternal++; + (*Event)->getRefCounterExternal().increment(); if (!(*Event)->CounterBasedEventsEnabled) ZE2UR_CALL(zeEventHostSignal, ((*Event)->ZeEvent)); return UR_RESULT_SUCCESS; @@ -975,7 +975,7 @@ ur_result_t urEventCreateWithNativeHandle( UREvent = new ur_event_handle_t_(ZeEvent, nullptr /* ZeEventPool */, Context, UR_EXT_COMMAND_TYPE_USER, Properties->isNativeHandleOwned); - UREvent->RefCountExternal++; + UREvent->getRefCounterExternal().increment(); } catch (const std::bad_alloc &) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; @@ -1088,7 +1088,7 @@ ur_event_handle_t_::~ur_event_handle_t_() { ur_result_t urEventReleaseInternal(ur_event_handle_t Event, bool *isEventDeleted) { - if (!Event->RefCount.decrementAndTest()) + if (!Event->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; if (Event->OriginAllocEvent) { @@ -1428,8 +1428,8 @@ ur_result_t ur_event_handle_t_::reset() { CommandData = nullptr; CommandType = UR_EXT_COMMAND_TYPE_USER; WaitList = {}; - RefCountExternal = 0; - RefCount.reset(); + RefCounterExternal.reset(); + RefCounter.reset(); CommandList = std::nullopt; completionBatch = std::nullopt; OriginAllocEvent = nullptr; @@ -1524,7 +1524,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( std::shared_lock Lock(CurQueue->LastCommandEvent->Mutex); this->ZeEventList[0] = CurQueue->LastCommandEvent->ZeEvent; this->UrEventList[0] = CurQueue->LastCommandEvent; - this->UrEventList[0]->RefCount.increment(); + this->UrEventList[0]->getRefCounter().increment(); TmpListLength = 1; } else if (EventListLength > 0) { this->ZeEventList = new ze_event_handle_t[EventListLength]; @@ -1660,7 +1660,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( IsInternal, IsMultiDevice)); MultiDeviceZeEvent = MultiDeviceEvent->ZeEvent; const auto &ZeCommandList = CommandList->first; - EventList[I]->RefCount.increment(); + EventList[I]->getRefCounter().increment(); // Append a Barrier to wait on the original event while signalling the // new multi device event. @@ -1676,11 +1676,11 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( this->ZeEventList[TmpListLength] = MultiDeviceZeEvent; this->UrEventList[TmpListLength] = MultiDeviceEvent; - this->UrEventList[TmpListLength]->RefCount.increment(); + this->UrEventList[TmpListLength]->getRefCounter().increment(); } else { this->ZeEventList[TmpListLength] = EventList[I]->ZeEvent; this->UrEventList[TmpListLength] = EventList[I]; - this->UrEventList[TmpListLength]->RefCount.increment(); + this->UrEventList[TmpListLength]->getRefCounter().increment(); } if (QueueLock.has_value()) { diff --git a/unified-runtime/source/adapters/level_zero/event.hpp b/unified-runtime/source/adapters/level_zero/event.hpp index 13b36bcdfbe94..df56289c8bdef 100644 --- a/unified-runtime/source/adapters/level_zero/event.hpp +++ b/unified-runtime/source/adapters/level_zero/event.hpp @@ -25,6 +25,8 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" + #include "queue.hpp" #include "ur_api.h" @@ -220,23 +222,7 @@ struct ur_event_handle_t_ : ur_object { uint64_t RecordEventStartTimestamp = 0; uint64_t RecordEventEndTimestamp = 0; - // Besides each PI object keeping a total reference count in - // ur_object::RefCount we keep special track of the event *external* - // references. This way we are able to tell when the event is not referenced - // externally anymore, i.e. it can't be passed as a dependency event to - // piEnqueue* functions and explicitly waited meaning that we can do some - // optimizations: - // 1. For in-order queues we can reset and reuse event even if it was not yet - // completed by submitting a reset command to the queue (since there are no - // external references, we know that nobody can wait this event somewhere in - // parallel thread or pass it as a dependency which may lead to hang) - // 2. We can avoid creating host proxy event. - // This counter doesn't track the lifetime of an event object. Even if it - // reaches zero an event object may not be destroyed and can be used - // internally in the plugin. - std::atomic RefCountExternal{0}; - - bool hasExternalRefs() { return RefCountExternal != 0; } + bool hasExternalRefs() { return RefCounterExternal.getCount() != 0; } // Reset ur_event_handle_t object. ur_result_t reset(); @@ -262,6 +248,28 @@ struct ur_event_handle_t_ : ur_object { // Used only for asynchronous allocations. This is the event originally used // on async free to indicate when the allocation can be used again. ur_event_handle_t OriginAllocEvent = nullptr; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + UR_ReferenceCounter &getRefCounterExternal() noexcept { return RefCounterExternal; } + +private: + UR_ReferenceCounter RefCounter; + + // Besides each PI object keeping a total reference count in + // ur_object::RefCount we keep special track of the event *external* + // references. This way we are able to tell when the event is not referenced + // externally anymore, i.e. it can't be passed as a dependency event to + // piEnqueue* functions and explicitly waited meaning that we can do some + // optimizations: + // 1. For in-order queues we can reset and reuse event even if it was not yet + // completed by submitting a reset command to the queue (since there are no + // external references, we know that nobody can wait this event somewhere in + // parallel thread or pass it as a dependency which may lead to hang) + // 2. We can avoid creating host proxy event. + // This counter doesn't track the lifetime of an event object. Even if it + // reaches zero an event object may not be destroyed and can be used + // internally in the plugin. + UR_ReferenceCounter RefCounterExternal; }; // Helper function to implement zeHostSynchronize. diff --git a/unified-runtime/source/adapters/level_zero/kernel.cpp b/unified-runtime/source/adapters/level_zero/kernel.cpp index b6f56cf26f0e8..ac652679cfdce 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/kernel.cpp @@ -787,7 +787,7 @@ ur_result_t urKernelGetInfo( } } case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Kernel->RefCount.load()}); + return ReturnValue(uint32_t{Kernel->getRefCounter().getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: try { uint32_t Size; @@ -938,7 +938,7 @@ ur_result_t urKernelGetSubGroupInfo( ur_result_t urKernelRetain( /// [in] handle for the Kernel to retain ur_kernel_handle_t Kernel) { - Kernel->RefCount.increment(); + Kernel->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -946,7 +946,7 @@ ur_result_t urKernelRetain( ur_result_t urKernelRelease( /// [in] handle for the Kernel to release ur_kernel_handle_t Kernel) { - if (!Kernel->RefCount.decrementAndTest()) + if (!Kernel->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; auto KernelProgram = Kernel->Program; diff --git a/unified-runtime/source/adapters/level_zero/kernel.hpp b/unified-runtime/source/adapters/level_zero/kernel.hpp index 7f80348cda31f..05458272e59d5 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/kernel.hpp @@ -9,9 +9,11 @@ //===----------------------------------------------------------------------===// #pragma once +#include + #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "memory.hpp" -#include struct ur_kernel_handle_t_ : ur_object { ur_kernel_handle_t_(bool OwnZeHandle, ur_program_handle_t Program) @@ -106,6 +108,11 @@ struct ur_kernel_handle_t_ : ur_object { // Cache of the kernel properties. ZeCache> ZeKernelProperties; ZeCache ZeKernelName; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; ur_result_t getZeKernel(ze_device_handle_t hDevice, ur_kernel_handle_t hKernel, diff --git a/unified-runtime/source/adapters/level_zero/memory.cpp b/unified-runtime/source/adapters/level_zero/memory.cpp index 0f6bb37dde904..f9b1020d1b35e 100644 --- a/unified-runtime/source/adapters/level_zero/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/memory.cpp @@ -1052,7 +1052,7 @@ ur_result_t urEnqueueMemBufferMap( // Add the event to the command list. CommandList->second.append(reinterpret_cast(*Event)); - (*Event)->RefCount.increment(); + (*Event)->getRefCounter().increment(); const auto &ZeCommandList = CommandList->first; const auto &WaitList = (*Event)->WaitList; @@ -1183,7 +1183,7 @@ ur_result_t urEnqueueMemUnmap( nullptr /*ForcedCmdQueue*/)); CommandList->second.append(reinterpret_cast(*Event)); - (*Event)->RefCount.increment(); + (*Event)->getRefCounter().increment(); const auto &ZeCommandList = CommandList->first; @@ -1635,14 +1635,14 @@ ur_result_t urMemBufferCreate( ur_result_t urMemRetain( /// [in] handle of the memory object to get access ur_mem_handle_t Mem) { - Mem->RefCount.increment(); + Mem->getRefCounter().increment(); return UR_RESULT_SUCCESS; } ur_result_t urMemRelease( /// [in] handle of the memory object to release ur_mem_handle_t Mem) { - if (!Mem->RefCount.decrementAndTest()) + if (!Mem->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; if (Mem->isImage()) { @@ -1848,7 +1848,7 @@ ur_result_t urMemGetInfo( return ReturnValue(size_t{Buffer->Size}); } case UR_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(Buffer->RefCount.load()); + return ReturnValue(Buffer->getRefCounter().getCount()); } default: { return UR_RESULT_ERROR_INVALID_ENUMERATION; diff --git a/unified-runtime/source/adapters/level_zero/memory.hpp b/unified-runtime/source/adapters/level_zero/memory.hpp index 715b5b51870c1..dbc343547919c 100644 --- a/unified-runtime/source/adapters/level_zero/memory.hpp +++ b/unified-runtime/source/adapters/level_zero/memory.hpp @@ -19,6 +19,7 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include "event.hpp" #include "program.hpp" @@ -90,6 +91,8 @@ struct ur_mem_handle_t_ : ur_object { // Method to get type of the derived object (image or buffer) bool isImage() const { return mem_type == mem_type_t::image; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + protected: ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context) : UrContext{Context}, UrDevice{nullptr}, mem_type(type) {} @@ -101,6 +104,9 @@ struct ur_mem_handle_t_ : ur_object { // Since the destructor isn't virtual, callers must destruct it via ur_buffer // or ur_image ~ur_mem_handle_t_() {}; + +private: + UR_ReferenceCounter RefCounter; }; struct ur_buffer final : ur_mem_handle_t_ { @@ -116,7 +122,7 @@ struct ur_buffer final : ur_mem_handle_t_ { : ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext), Size(Size), SubBuffer{{Parent, Origin}} { // Retain the Parent Buffer due to the Creation of the SubBuffer. - Parent->RefCount.increment(); + Parent->getRefCounter().increment(); } // Interop-buffer constructor diff --git a/unified-runtime/source/adapters/level_zero/physical_mem.cpp b/unified-runtime/source/adapters/level_zero/physical_mem.cpp index 5d4d0acce0eb3..eb7714eb18bcc 100644 --- a/unified-runtime/source/adapters/level_zero/physical_mem.cpp +++ b/unified-runtime/source/adapters/level_zero/physical_mem.cpp @@ -42,12 +42,12 @@ ur_result_t urPhysicalMemCreate( } ur_result_t urPhysicalMemRetain(ur_physical_mem_handle_t hPhysicalMem) { - hPhysicalMem->RefCount.increment(); + hPhysicalMem->getRefCounter().increment(); return UR_RESULT_SUCCESS; } ur_result_t urPhysicalMemRelease(ur_physical_mem_handle_t hPhysicalMem) { - if (!hPhysicalMem->RefCount.decrementAndTest()) + if (!hPhysicalMem->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; if (checkL0LoaderTeardown()) { @@ -68,7 +68,7 @@ ur_result_t urPhysicalMemGetInfo(ur_physical_mem_handle_t hPhysicalMem, switch (propName) { case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hPhysicalMem->RefCount.load()); + return ReturnValue(hPhysicalMem->getRefCounter().getCount()); } default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; diff --git a/unified-runtime/source/adapters/level_zero/physical_mem.hpp b/unified-runtime/source/adapters/level_zero/physical_mem.hpp index 6ce630bcc5e1f..801996b2935e7 100644 --- a/unified-runtime/source/adapters/level_zero/physical_mem.hpp +++ b/unified-runtime/source/adapters/level_zero/physical_mem.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" struct ur_physical_mem_handle_t_ : ur_object { ur_physical_mem_handle_t_(ze_physical_mem_handle_t ZePhysicalMem, @@ -21,4 +22,9 @@ struct ur_physical_mem_handle_t_ : ur_object { // Keeps the PI context of this memory handle. ur_context_handle_t Context; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index 497e3057b7b9b..1f3520b0f6f45 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -558,14 +558,14 @@ ur_result_t urProgramLinkExp( ur_result_t urProgramRetain( /// [in] handle for the Program to retain ur_program_handle_t Program) { - Program->RefCount.increment(); + Program->getRefCounter().increment(); return UR_RESULT_SUCCESS; } ur_result_t urProgramRelease( /// [in] handle for the Program to release ur_program_handle_t Program) { - if (!Program->RefCount.decrementAndTest()) + if (!Program->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; delete Program; @@ -708,7 +708,7 @@ ur_result_t urProgramGetInfo( switch (PropName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Program->RefCount.load()}); + return ReturnValue(uint32_t{Program->getRefCounter().getCount()}); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(Program->Context); case UR_PROGRAM_INFO_NUM_DEVICES: @@ -1115,7 +1115,7 @@ void ur_program_handle_t_::ur_release_program_resources(bool deletion) { // must be destroyed before the Module can be destroyed. So, be sure // to destroy build log before destroying the module. if (!deletion) { - if (!RefCount.decrementAndTest()) { + if (!RefCounter.decrement() == 0) { return; } } diff --git a/unified-runtime/source/adapters/level_zero/program.hpp b/unified-runtime/source/adapters/level_zero/program.hpp index 789daf052ba0c..250c4456dc868 100644 --- a/unified-runtime/source/adapters/level_zero/program.hpp +++ b/unified-runtime/source/adapters/level_zero/program.hpp @@ -10,6 +10,8 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" + #include "device.hpp" struct ur_program_handle_t_ : ur_object { @@ -226,6 +228,8 @@ struct ur_program_handle_t_ : ur_object { // UR_PROGRAM_INFO_BINARY_SIZES. const std::vector AssociatedDevices; + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: struct DeviceData { // Log from the result of building the program for the device using @@ -264,4 +268,6 @@ struct ur_program_handle_t_ : ur_object { // handle from the program. // TODO: Currently interoparability UR API does not support multiple devices. ze_module_handle_t InteropZeModule = nullptr; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/level_zero/queue.cpp b/unified-runtime/source/adapters/level_zero/queue.cpp index 2cd607ccf93d7..ab7b9cf2b9461 100644 --- a/unified-runtime/source/adapters/level_zero/queue.cpp +++ b/unified-runtime/source/adapters/level_zero/queue.cpp @@ -369,7 +369,7 @@ ur_result_t urQueueGetInfo( case UR_QUEUE_INFO_DEVICE: return ReturnValue(Queue->Device); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Queue->RefCount.load()}); + return ReturnValue(uint32_t{Queue->getRefCounter().getCount()}); case UR_QUEUE_INFO_FLAGS: return ReturnValue(Queue->Properties); case UR_QUEUE_INFO_SIZE: @@ -591,9 +591,9 @@ ur_result_t urQueueRetain( ur_queue_handle_t Queue) { { std::scoped_lock Lock(Queue->Mutex); - Queue->RefCountExternal++; + Queue->getRefCounterExternal().increment(); } - Queue->RefCount.increment(); + Queue->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -607,12 +607,12 @@ ur_result_t urQueueRelease( { std::scoped_lock Lock(Queue->Mutex); - if ((--Queue->RefCountExternal) != 0) { + if ((Queue->getRefCounterExternal().decrement()) != 0) { // When an External Reference exists one still needs to decrement the // internal reference count. When the External Reference count == 0, then // cleanup of the queue begins and the final decrement of the internal // reference count is completed. - static_cast(Queue->RefCount.decrementAndTest()); + static_cast(Queue->getRefCounter().decrement() == 0); return UR_RESULT_SUCCESS; } @@ -1389,7 +1389,7 @@ ur_queue_handle_t_::executeCommandList(ur_command_list_ptr_t CommandList, if (!Event->HostVisibleEvent) { Event->HostVisibleEvent = reinterpret_cast(HostVisibleEvent); - HostVisibleEvent->RefCount.increment(); + HostVisibleEvent->getRefCounter().increment(); } } @@ -1550,7 +1550,7 @@ ur_result_t ur_queue_handle_t_::addEventToQueueCache(ur_event_handle_t Event) { } void ur_queue_handle_t_::active_barriers::add(ur_event_handle_t &Event) { - Event->RefCount.increment(); + Event->getRefCounter().increment(); Events.push_back(Event); } @@ -1588,7 +1588,7 @@ void ur_queue_handle_t_::clearEndTimeRecordings() { } ur_result_t urQueueReleaseInternal(ur_queue_handle_t Queue) { - if (!Queue->RefCount.decrementAndTest()) + if (!Queue->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; for (auto &Cache : Queue->EventCaches) { @@ -1921,7 +1921,7 @@ ur_result_t createEventAndAssociateQueue(ur_queue_handle_t Queue, // Append this Event to the CommandList, if any if (CommandList != Queue->CommandListMap.end()) { CommandList->second.append(*Event); - (*Event)->RefCount.increment(); + (*Event)->getRefCounter().increment(); } // We need to increment the reference counter here to avoid ur_queue_handle_t @@ -1929,7 +1929,7 @@ ur_result_t createEventAndAssociateQueue(ur_queue_handle_t Queue, // urEventRelease requires access to the associated ur_queue_handle_t. // In urEventRelease, the reference counter of the Queue is decremented // to release it. - Queue->RefCount.increment(); + Queue->getRefCounter().increment(); // SYCL RT does not track completion of the events, so it could // release a PI event as soon as that's not being waited in the app. @@ -1961,7 +1961,7 @@ void ur_queue_handle_t_::CaptureIndirectAccesses() { // SubmissionsCount turns to 0. We don't want to know how many times // allocation was retained by each submission. if (Pair.second) - Elem.second.RefCount.increment(); + Elem.second.getRefCounter().increment(); } } Kernel->SubmissionsCount++; diff --git a/unified-runtime/source/adapters/level_zero/queue.hpp b/unified-runtime/source/adapters/level_zero/queue.hpp index 405929c8f0f0e..664d17497a01d 100644 --- a/unified-runtime/source/adapters/level_zero/queue.hpp +++ b/unified-runtime/source/adapters/level_zero/queue.hpp @@ -25,6 +25,8 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" + #include "device.hpp" extern "C" { @@ -419,18 +421,6 @@ struct ur_queue_handle_t_ : ur_object { // list is needed for a command. active_barriers ActiveBarriers; - // Besides each PI object keeping a total reference count in - // ur_object::RefCount we keep special track of the queue *external* - // references. This way we are able to tell when the queue is being finished - // externally, and can wait for internal references to complete, and do proper - // cleanup of the queue. - // This counter doesn't track the lifetime of a queue object, it only tracks - // the number of external references. I.e. even if it reaches zero a queue - // object may not be destroyed and can be used internally in the plugin. - // That's why we intentionally don't use atomic type for this counter to - // enforce guarding with a mutex all the work involving this counter. - uint32_t RefCountExternal{1}; - // Indicates that the queue is healthy and all operations on it are OK. bool Healthy{true}; @@ -692,6 +682,24 @@ struct ur_queue_handle_t_ : ur_object { // Pointer to the unified handle. ur_queue_handle_t_ *UnifiedHandle; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + UR_ReferenceCounter &getRefCounterExternal() noexcept { return RefCounterExternal; } + +private: + UR_ReferenceCounter RefCounter; + + // Besides each PI object keeping a total reference count in + // ur_object::RefCount we keep special track of the queue *external* + // references. This way we are able to tell when the queue is being finished + // externally, and can wait for internal references to complete, and do proper + // cleanup of the queue. + // This counter doesn't track the lifetime of a queue object, it only tracks + // the number of external references. I.e. even if it reaches zero a queue + // object may not be destroyed and can be used internally in the plugin. + // That's why we intentionally don't use atomic type for this counter to + // enforce guarding with a mutex all the work involving this counter. + UR_ReferenceCounter RefCounterExternal; }; // This helper function creates a ur_event_handle_t and associate a diff --git a/unified-runtime/source/adapters/level_zero/sampler.cpp b/unified-runtime/source/adapters/level_zero/sampler.cpp index 4f6f5760faada..8272dc0ff05fb 100644 --- a/unified-runtime/source/adapters/level_zero/sampler.cpp +++ b/unified-runtime/source/adapters/level_zero/sampler.cpp @@ -124,14 +124,14 @@ ur_result_t urSamplerCreate( ur_result_t urSamplerRetain( /// [in] handle of the sampler object to get access ur_sampler_handle_t Sampler) { - Sampler->RefCount.increment(); + Sampler->getRefCounter().increment(); return UR_RESULT_SUCCESS; } ur_result_t urSamplerRelease( /// [in] handle of the sampler object to release ur_sampler_handle_t Sampler) { - if (!Sampler->RefCount.decrementAndTest()) + if (!Sampler->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; if (checkL0LoaderTeardown()) { diff --git a/unified-runtime/source/adapters/level_zero/sampler.hpp b/unified-runtime/source/adapters/level_zero/sampler.hpp index 9a834a05215d9..15ce988ac280e 100644 --- a/unified-runtime/source/adapters/level_zero/sampler.hpp +++ b/unified-runtime/source/adapters/level_zero/sampler.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" struct ur_sampler_handle_t_ : ur_object { ur_sampler_handle_t_(ze_sampler_handle_t Sampler) : ZeSampler{Sampler} {} @@ -18,6 +19,11 @@ struct ur_sampler_handle_t_ : ur_object { ze_sampler_handle_t ZeSampler; ZeStruct ZeSamplerDesc; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; // Construct ZE sampler desc from UR sampler desc. diff --git a/unified-runtime/source/adapters/level_zero/usm.cpp b/unified-runtime/source/adapters/level_zero/usm.cpp index c6abed7ccabb6..83a4bdeac0070 100644 --- a/unified-runtime/source/adapters/level_zero/usm.cpp +++ b/unified-runtime/source/adapters/level_zero/usm.cpp @@ -523,14 +523,14 @@ ur_result_t urUSMPoolCreate( ur_result_t /// [in] pointer to USM memory pool urUSMPoolRetain(ur_usm_pool_handle_t Pool) { - Pool->RefCount.increment(); + Pool->getRefCounter().increment(); return UR_RESULT_SUCCESS; } ur_result_t /// [in] pointer to USM memory pool urUSMPoolRelease(ur_usm_pool_handle_t Pool) { - if (Pool->RefCount.decrementAndTest()) { + if (Pool->getRefCounter().decrement() == 0) { std::shared_lock ContextLock(Pool->Context->Mutex); Pool->Context->UsmPoolHandles.remove(Pool); delete Pool; @@ -553,7 +553,7 @@ ur_result_t urUSMPoolGetInfo( switch (PropName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(Pool->RefCount.load()); + return ReturnValue(Pool->getRefCounter().getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(Pool->Context); @@ -1239,7 +1239,7 @@ ur_result_t ZeMemFreeHelper(ur_context_handle_t Context, void *Ptr) { if (It == std::end(Context->MemAllocs)) { die("All memory allocations must be tracked!"); } - if (!It->second.RefCount.decrementAndTest()) { + if (!It->second.getRefCounter().decrement() == 0) { // Memory can't be deallocated yet. return UR_RESULT_SUCCESS; } @@ -1286,7 +1286,7 @@ ur_result_t USMFreeHelper(ur_context_handle_t Context, void *Ptr, if (It == std::end(Context->MemAllocs)) { die("All memory allocations must be tracked!"); } - if (!It->second.RefCount.decrementAndTest()) { + if (!It->second.getRefCounter().decrement() == 0) { // Memory can't be deallocated yet. return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/level_zero/usm.hpp b/unified-runtime/source/adapters/level_zero/usm.hpp index b29ea29a7914d..abd2e6e509f3d 100644 --- a/unified-runtime/source/adapters/level_zero/usm.hpp +++ b/unified-runtime/source/adapters/level_zero/usm.hpp @@ -9,13 +9,14 @@ //===----------------------------------------------------------------------===// #pragma once -#include "common.hpp" +#include +#include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "enqueued_pool.hpp" #include "event.hpp" #include "ur_api.h" #include "ur_pool_manager.hpp" -#include #include usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig(); @@ -53,9 +54,12 @@ struct ur_usm_pool_handle_t_ : ur_object { ur_context_handle_t Context; + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: UsmPool *getPool(const usm::pool_descriptor &Desc); usm::pool_manager PoolManager; + UR_ReferenceCounter RefCounter; }; // Exception type to pass allocation errors diff --git a/unified-runtime/source/adapters/native_cpu/adapter.cpp b/unified-runtime/source/adapters/native_cpu/adapter.cpp index 3fd6d4256825b..470f227302086 100644 --- a/unified-runtime/source/adapters/native_cpu/adapter.cpp +++ b/unified-runtime/source/adapters/native_cpu/adapter.cpp @@ -13,14 +13,17 @@ #include "ur_api.h" struct ur_adapter_handle_t_ : ur::native_cpu::handle_base { - std::atomic RefCount = 0; logger::Logger &logger = logger::get_logger("native_cpu"); + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; } Adapter; UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (phAdapters) { - Adapter.RefCount++; + Adapter.getRefCounter().increment(); *phAdapters = &Adapter; } if (pNumAdapters) { @@ -30,12 +33,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - Adapter.RefCount--; + Adapter.getRefCounter().decrement(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - Adapter.RefCount++; + Adapter.getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -57,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_NATIVE_CPU); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(Adapter.RefCount.load()); + return ReturnValue(Adapter.getRefCounter().getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/native_cpu/common.hpp b/unified-runtime/source/adapters/native_cpu/common.hpp index 3eab5a54a16e9..21539d2ad24ac 100644 --- a/unified-runtime/source/adapters/native_cpu/common.hpp +++ b/unified-runtime/source/adapters/native_cpu/common.hpp @@ -10,9 +10,11 @@ #pragma once +#include + +#include "common/ur_ref_counter.hpp" #include "logger/ur_logger.hpp" #include "ur/ur.hpp" -#include constexpr size_t MaxMessageSize = 256; @@ -44,25 +46,6 @@ struct ddi_getter { using handle_base = ur::handle_base; } // namespace ur::native_cpu -// Todo: replace this with a common helper once it is available -struct RefCounted : ur::native_cpu::handle_base { - std::atomic_uint32_t _refCount; - uint32_t incrementReferenceCount() { return ++_refCount; } - uint32_t decrementReferenceCount() { return --_refCount; } - RefCounted() : handle_base(), _refCount{1} {} - uint32_t getReferenceCount() const { return _refCount; } -}; - -// Base class to store common data -struct ur_object : RefCounted { - ur_shared_mutex Mutex; -}; - -template inline void decrementOrDelete(T *refC) { - if (refC->decrementReferenceCount() == 0) - delete refC; -} - inline uint64_t get_timestamp() { return std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch()) diff --git a/unified-runtime/source/adapters/native_cpu/context.cpp b/unified-runtime/source/adapters/native_cpu/context.cpp index 5b7e8fc839884..5b3d4c71bcc94 100644 --- a/unified-runtime/source/adapters/native_cpu/context.cpp +++ b/unified-runtime/source/adapters/native_cpu/context.cpp @@ -30,13 +30,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate( UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - hContext->incrementReferenceCount(); + hContext->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urContextRelease(ur_context_handle_t hContext) { - decrementOrDelete(hContext); + if (hContext->getRefCounter().decrement() == 0) { + delete hContext; + } return UR_RESULT_SUCCESS; } @@ -51,7 +53,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, case UR_CONTEXT_INFO_DEVICES: return returnValue(hContext->_device); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return returnValue(uint32_t{hContext->getReferenceCount()}); + return returnValue(uint32_t{hContext->getRefCounter().getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: return returnValue(true); case UR_CONTEXT_INFO_USM_FILL2D_SUPPORT: diff --git a/unified-runtime/source/adapters/native_cpu/context.hpp b/unified-runtime/source/adapters/native_cpu/context.hpp index b9d2d22dd1565..15e1d4d476efb 100644 --- a/unified-runtime/source/adapters/native_cpu/context.hpp +++ b/unified-runtime/source/adapters/native_cpu/context.hpp @@ -15,6 +15,7 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "device.hpp" #include "ur/ur.hpp" @@ -83,7 +84,7 @@ static usm_alloc_info get_alloc_info(void *ptr) { } // namespace native_cpu -struct ur_context_handle_t_ : RefCounted { +struct ur_context_handle_t_ { ur_context_handle_t_(ur_device_handle_t_ *phDevices) : _device{phDevices} {} ur_device_handle_t _device; @@ -135,7 +136,11 @@ struct ur_context_handle_t_ : RefCounted { return ptr; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: std::mutex alloc_mutex; std::set allocations; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/native_cpu/event.cpp b/unified-runtime/source/adapters/native_cpu/event.cpp index 91b8fb302eb18..f9a88d66a8782 100644 --- a/unified-runtime/source/adapters/native_cpu/event.cpp +++ b/unified-runtime/source/adapters/native_cpu/event.cpp @@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, case UR_EVENT_INFO_COMMAND_TYPE: return ReturnValue(hEvent->getCommandType()); case UR_EVENT_INFO_REFERENCE_COUNT: - return ReturnValue(hEvent->getReferenceCount()); + return ReturnValue(hEvent->getRefCounter().getCount()); case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: return ReturnValue(hEvent->getExecutionStatus()); case UR_EVENT_INFO_CONTEXT: @@ -69,12 +69,14 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - hEvent->incrementReferenceCount(); + hEvent->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { - decrementOrDelete(hEvent); + if (hEvent->getRefCounter().decrement() == 0) { + delete hEvent; + } return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/event.hpp b/unified-runtime/source/adapters/native_cpu/event.hpp index 479c671b38cd1..5d5f6861540fd 100644 --- a/unified-runtime/source/adapters/native_cpu/event.hpp +++ b/unified-runtime/source/adapters/native_cpu/event.hpp @@ -8,14 +8,17 @@ // //===----------------------------------------------------------------------===// #pragma once -#include "common.hpp" -#include "ur_api.h" + #include #include #include #include -struct ur_event_handle_t_ : RefCounted { +#include "common.hpp" +#include "common/ur_ref_counter.hpp" +#include "ur_api.h" + +struct ur_event_handle_t_ { ur_event_handle_t_(ur_queue_handle_t queue, ur_command_t command_type); @@ -55,6 +58,8 @@ struct ur_event_handle_t_ : RefCounted { uint64_t get_end_timestamp() const { return timestamp_end; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: ur_queue_handle_t queue; ur_context_handle_t context; @@ -65,4 +70,6 @@ struct ur_event_handle_t_ : RefCounted { std::packaged_task callback; uint64_t timestamp_start = 0; uint64_t timestamp_end = 0; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/native_cpu/kernel.cpp b/unified-runtime/source/adapters/native_cpu/kernel.cpp index 500b2c6bcd8a5..80c60573a1aaf 100644 --- a/unified-runtime/source/adapters/native_cpu/kernel.cpp +++ b/unified-runtime/source/adapters/native_cpu/kernel.cpp @@ -95,7 +95,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, case UR_KERNEL_INFO_FUNCTION_NAME: return ReturnValue(hKernel->_name); case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hKernel->getReferenceCount()}); + return ReturnValue(uint32_t{hKernel->getRefCounter().getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: return ReturnValue(""); case UR_KERNEL_INFO_SPILL_MEM_SIZE: @@ -194,13 +194,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetSubGroupInfo( } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - hKernel->incrementReferenceCount(); + hKernel->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(ur_kernel_handle_t hKernel) { - decrementOrDelete(hKernel); + if (hKernel->getRefCounter().decrement() == 0) { + delete hKernel; + } return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/kernel.hpp b/unified-runtime/source/adapters/native_cpu/kernel.hpp index 8c77e855d9796..eafe2ae861c31 100644 --- a/unified-runtime/source/adapters/native_cpu/kernel.hpp +++ b/unified-runtime/source/adapters/native_cpu/kernel.hpp @@ -8,13 +8,15 @@ #pragma once +#include +#include + #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "memory.hpp" #include "nativecpu_state.hpp" #include "program.hpp" -#include #include -#include using nativecpu_kernel_t = void(void *const *, native_cpu::state *); using nativecpu_ptr_t = nativecpu_kernel_t *; @@ -27,7 +29,7 @@ struct local_arg_info_t { : argIndex(argIndex), argSize(argSize) {} }; -struct ur_kernel_handle_t_ : RefCounted { +struct ur_kernel_handle_t_ { ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name, nativecpu_task_t subhandler) @@ -193,14 +195,18 @@ struct ur_kernel_handle_t_ : RefCounted { void addPtrArg(void *Ptr, size_t Index) { Args.addPtrArg(Index, Ptr); } void addArgReference(ur_mem_handle_t Arg) { - Arg->incrementReferenceCount(); + Arg->getRefCounter().increment(); ReferencedArgs.push_back(Arg); } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: void removeArgReferences() { for (auto arg : ReferencedArgs) - decrementOrDelete(arg); + if (arg->getRefCounter().decrement() == 0) { + delete arg; + } } void takeArgReferences(const ur_kernel_handle_t_ &other) { for (auto arg : other.ReferencedArgs) @@ -214,4 +220,6 @@ struct ur_kernel_handle_t_ : RefCounted { std::optional MaxWGSize = std::nullopt; std::optional MaxLinearWGSize = std::nullopt; std::vector ReferencedArgs; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/native_cpu/memory.cpp b/unified-runtime/source/adapters/native_cpu/memory.cpp index 67eb95f3d9523..f9669b611d21d 100644 --- a/unified-runtime/source/adapters/native_cpu/memory.cpp +++ b/unified-runtime/source/adapters/native_cpu/memory.cpp @@ -64,7 +64,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t /*hMem*/) { UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { UR_ASSERT(hMem, UR_RESULT_ERROR_INVALID_NULL_HANDLE); - decrementOrDelete(hMem); + if (hMem->getRefCounter().decrement() == 0) { + delete hMem; + } return UR_RESULT_SUCCESS; } @@ -78,7 +80,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition( !(static_cast(hBuffer))->isSubBuffer(), UR_RESULT_ERROR_INVALID_MEM_OBJECT); - std::shared_lock Guard(hBuffer->Mutex); + std::shared_lock Guard(hBuffer->getMutex()); if (flags != UR_MEM_FLAG_READ_WRITE) { die("urMemBufferPartition: NativeCPU implements only read-write buffer," diff --git a/unified-runtime/source/adapters/native_cpu/memory.hpp b/unified-runtime/source/adapters/native_cpu/memory.hpp index d5415e82e94bd..4df558a8e7eb9 100644 --- a/unified-runtime/source/adapters/native_cpu/memory.hpp +++ b/unified-runtime/source/adapters/native_cpu/memory.hpp @@ -10,14 +10,13 @@ #pragma once -#include #include #include #include "common.hpp" #include "context.hpp" -struct ur_mem_handle_t_ : ur_object { +struct ur_mem_handle_t_ { ur_mem_handle_t_(size_t Size, bool _IsImage) : _mem{static_cast(malloc(Size))}, _ownsMem{true}, IsImage{_IsImage} {} @@ -44,8 +43,13 @@ struct ur_mem_handle_t_ : ur_object { char *_mem; bool _ownsMem; + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + ur_shared_mutex &getMutex() noexcept { return Mutex; } + private: const bool IsImage; + UR_ReferenceCounter RefCounter; + ur_shared_mutex Mutex; }; struct ur_buffer final : ur_mem_handle_t_ { diff --git a/unified-runtime/source/adapters/native_cpu/program.cpp b/unified-runtime/source/adapters/native_cpu/program.cpp index fee72f8a6bc3c..d0a8b2483813e 100644 --- a/unified-runtime/source/adapters/native_cpu/program.cpp +++ b/unified-runtime/source/adapters/native_cpu/program.cpp @@ -171,13 +171,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - hProgram->incrementReferenceCount(); + hProgram->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urProgramRelease(ur_program_handle_t hProgram) { - decrementOrDelete(hProgram); + if (hProgram->getRefCounter().decrement() == 0) { + delete hProgram; + } return UR_RESULT_SUCCESS; } @@ -205,7 +207,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return returnValue(hProgram->getReferenceCount()); + return returnValue(hProgram->getRefCounter().getCount()); case UR_PROGRAM_INFO_CONTEXT: return returnValue(nullptr); case UR_PROGRAM_INFO_NUM_DEVICES: diff --git a/unified-runtime/source/adapters/native_cpu/program.hpp b/unified-runtime/source/adapters/native_cpu/program.hpp index d58412751e8f2..5f138c57477a4 100644 --- a/unified-runtime/source/adapters/native_cpu/program.hpp +++ b/unified-runtime/source/adapters/native_cpu/program.hpp @@ -12,6 +12,7 @@ #include +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include @@ -21,12 +22,10 @@ namespace native_cpu { using WGSize_t = std::array; } -struct ur_program_handle_t_ : RefCounted { +struct ur_program_handle_t_ { ur_program_handle_t_(ur_context_handle_t ctx, const unsigned char *pBinary) : _ctx{ctx}, _ptr{pBinary} {} - uint32_t getReferenceCount() const noexcept { return _refCount; } - ur_context_handle_t _ctx; const unsigned char *_ptr; struct _compare { @@ -41,6 +40,11 @@ struct ur_program_handle_t_ : RefCounted { std::unordered_map KernelMaxWorkGroupSizeMD; std::unordered_map KernelMaxLinearWorkGroupSizeMD; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; // The nativecpu_entry struct is also defined as LLVM-IR in the diff --git a/unified-runtime/source/adapters/native_cpu/queue.cpp b/unified-runtime/source/adapters/native_cpu/queue.cpp index 5de7037519490..4a7b32afed772 100644 --- a/unified-runtime/source/adapters/native_cpu/queue.cpp +++ b/unified-runtime/source/adapters/native_cpu/queue.cpp @@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hQueue->getDevice()); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(hQueue->getReferenceCount()); + return ReturnValue(hQueue->getRefCounter().getCount()); case UR_QUEUE_INFO_EMPTY: return ReturnValue(hQueue->isEmpty()); default: @@ -48,13 +48,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate( } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - hQueue->incrementReferenceCount(); + hQueue->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { - decrementOrDelete(hQueue); + if (hQueue->getRefCounter().decrement() == 0) { + delete hQueue; + } return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/queue.hpp b/unified-runtime/source/adapters/native_cpu/queue.hpp index 1369b49a10984..81f741f40641e 100644 --- a/unified-runtime/source/adapters/native_cpu/queue.hpp +++ b/unified-runtime/source/adapters/native_cpu/queue.hpp @@ -8,12 +8,15 @@ // //===----------------------------------------------------------------------===// #pragma once + +#include + #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "event.hpp" #include "ur_api.h" -#include -struct ur_queue_handle_t_ : RefCounted { +struct ur_queue_handle_t_ { ur_queue_handle_t_(ur_device_handle_t device, ur_context_handle_t context, const ur_queue_properties_t *pProps) : device(device), context(context), @@ -52,10 +55,14 @@ struct ur_queue_handle_t_ : RefCounted { return events.size() == 0; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: ur_device_handle_t device; ur_context_handle_t context; std::set events; const bool inOrder; const bool profilingEnabled; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/offload/adapter.cpp b/unified-runtime/source/adapters/offload/adapter.cpp index 6eb9cd7239eaf..c084691a99fa2 100644 --- a/unified-runtime/source/adapters/offload/adapter.cpp +++ b/unified-runtime/source/adapters/offload/adapter.cpp @@ -65,7 +65,7 @@ ur_result_t ur_adapter_handle_t_::init() { UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (phAdapters) { - if (++Adapter.RefCount == 1) { + if (Adapter.getRefCounter().increment() == 1) { Adapter.init(); } *phAdapters = &Adapter; @@ -77,7 +77,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - if (--Adapter.RefCount == 0) { + if (Adapter.getRefCounter().decrement() == 0) { // This can crash when tracing is enabled. // olShutDown(); }; @@ -85,7 +85,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - Adapter.RefCount++; + Adapter.getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_OFFLOAD); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(Adapter.RefCount.load()); + return ReturnValue(Adapter.getRefCounter().getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(1); default: diff --git a/unified-runtime/source/adapters/offload/adapter.hpp b/unified-runtime/source/adapters/offload/adapter.hpp index b85995b0f6a08..f01934c4e1246 100644 --- a/unified-runtime/source/adapters/offload/adapter.hpp +++ b/unified-runtime/source/adapters/offload/adapter.hpp @@ -10,23 +10,27 @@ #pragma once -#include #include #include #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "logger/ur_logger.hpp" #include "platform.hpp" struct ur_adapter_handle_t_ : ur::offload::handle_base { - std::atomic_uint32_t RefCount = 0; logger::Logger &Logger = logger::get_logger("offload"); ol_device_handle_t HostDevice = nullptr; std::vector Platforms; ur_result_t init(); + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; extern ur_adapter_handle_t_ Adapter; diff --git a/unified-runtime/source/adapters/offload/context.cpp b/unified-runtime/source/adapters/offload/context.cpp index 2dcbcd4da82f5..ca0735be70549 100644 --- a/unified-runtime/source/adapters/offload/context.cpp +++ b/unified-runtime/source/adapters/offload/context.cpp @@ -34,7 +34,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, case UR_CONTEXT_INFO_DEVICES: return ReturnValue(&hContext->Device, 1); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(hContext->RefCount.load()); + return ReturnValue(hContext->getRefCounter().getCount()); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: case UR_CONTEXT_INFO_USM_FILL2D_SUPPORT: return ReturnValue(false); @@ -47,13 +47,13 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - hContext->RefCount++; + hContext->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urContextRelease(ur_context_handle_t hContext) { - if (--hContext->RefCount == 0) { + if (hContext->getRefCounter().decrement() == 0) { delete hContext; } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/offload/context.hpp b/unified-runtime/source/adapters/offload/context.hpp index 64727ce3338bb..6291662859baf 100644 --- a/unified-runtime/source/adapters/offload/context.hpp +++ b/unified-runtime/source/adapters/offload/context.hpp @@ -11,11 +11,12 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include #include #include -struct ur_context_handle_t_ : RefCounted { +struct ur_context_handle_t_ { ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} { urDeviceRetain(Device); } @@ -23,4 +24,9 @@ struct ur_context_handle_t_ : RefCounted { ur_device_handle_t Device; std::unordered_map AllocTypeMap; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/offload/kernel.cpp b/unified-runtime/source/adapters/offload/kernel.cpp index 12bfe0478130a..75267476f4227 100644 --- a/unified-runtime/source/adapters/offload/kernel.cpp +++ b/unified-runtime/source/adapters/offload/kernel.cpp @@ -42,7 +42,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, switch (propName) { case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(hKernel->RefCount.load()); + return ReturnValue(hKernel->getRefCounter().getCount()); default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; } @@ -51,13 +51,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - hKernel->RefCount++; + hKernel->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(ur_kernel_handle_t hKernel) { - if (--hKernel->RefCount == 0) { + if (hKernel->getRefCounter().decrement() == 0) { delete hKernel; } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/offload/kernel.hpp b/unified-runtime/source/adapters/offload/kernel.hpp index dea7e25d9da9e..b3b19e59cdd95 100644 --- a/unified-runtime/source/adapters/offload/kernel.hpp +++ b/unified-runtime/source/adapters/offload/kernel.hpp @@ -18,8 +18,9 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" -struct ur_kernel_handle_t_ : RefCounted { +struct ur_kernel_handle_t_ { // Simplified version of the CUDA adapter's argument implementation struct OffloadKernelArguments { @@ -59,4 +60,9 @@ struct ur_kernel_handle_t_ : RefCounted { ol_kernel_handle_t OffloadKernel; OffloadKernelArguments Args{}; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/offload/program.cpp b/unified-runtime/source/adapters/offload/program.cpp index 5e33f49286b80..c97e29e990772 100644 --- a/unified-runtime/source/adapters/offload/program.cpp +++ b/unified-runtime/source/adapters/offload/program.cpp @@ -221,7 +221,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(hProgram->RefCount.load()); + return ReturnValue(hProgram->getRefCounter().getCount()); default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; } @@ -231,13 +231,13 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - hProgram->RefCount++; + hProgram->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urProgramRelease(ur_program_handle_t hProgram) { - if (--hProgram->RefCount == 0) { + if (hProgram->getRefCounter().decrement() == 0) { auto Res = olDestroyProgram(hProgram->OffloadProgram); if (Res) { return offloadResultToUR(Res); diff --git a/unified-runtime/source/adapters/offload/program.hpp b/unified-runtime/source/adapters/offload/program.hpp index 1d0263aad2998..f325b2fc9d637 100644 --- a/unified-runtime/source/adapters/offload/program.hpp +++ b/unified-runtime/source/adapters/offload/program.hpp @@ -14,7 +14,13 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" -struct ur_program_handle_t_ : RefCounted { +struct ur_program_handle_t_ { ol_program_handle_t OffloadProgram; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/offload/queue.cpp b/unified-runtime/source/adapters/offload/queue.cpp index 0e12c6206dd84..72c506c0227e2 100644 --- a/unified-runtime/source/adapters/offload/queue.cpp +++ b/unified-runtime/source/adapters/offload/queue.cpp @@ -46,7 +46,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, switch (propName) { case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(hQueue->RefCount.load()); + return ReturnValue(hQueue->getRefCounter().getCount()); default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; } @@ -55,12 +55,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - hQueue->RefCount++; + hQueue->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { - if (--hQueue->RefCount == 0) { + if (hQueue->getRefCounter().decrement() == 0) { auto Res = olDestroyQueue(hQueue->OffloadQueue); if (Res) { return offloadResultToUR(Res); diff --git a/unified-runtime/source/adapters/offload/queue.hpp b/unified-runtime/source/adapters/offload/queue.hpp index 6afe4bf15098e..cb9dba20daba2 100644 --- a/unified-runtime/source/adapters/offload/queue.hpp +++ b/unified-runtime/source/adapters/offload/queue.hpp @@ -14,8 +14,14 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" struct ur_queue_handle_t_ : RefCounted { ol_queue_handle_t OffloadQueue; ol_device_handle_t OffloadDevice; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/adapter.cpp b/unified-runtime/source/adapters/opencl/adapter.cpp index 797e1c8582ef8..8d334ce2d1eab 100644 --- a/unified-runtime/source/adapters/opencl/adapter.cpp +++ b/unified-runtime/source/adapters/opencl/adapter.cpp @@ -78,7 +78,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, } auto &adapter = *phAdapters; - adapter->RefCount++; + adapter->getRefCounter().increment(); } if (pNumAdapters) { @@ -90,13 +90,13 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t hAdapter) { - ++hAdapter->RefCount; + hAdapter->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t hAdapter) { - if (--hAdapter->RefCount == 0) { + if (hAdapter->getRefCounter().decrement() == 0) { delete hAdapter; } return UR_RESULT_SUCCESS; @@ -119,7 +119,7 @@ urAdapterGetInfo(ur_adapter_handle_t hAdapter, ur_adapter_info_t propName, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_OPENCL); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(hAdapter->RefCount.load()); + return ReturnValue(hAdapter->getRefCounter().getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/opencl/adapter.hpp b/unified-runtime/source/adapters/opencl/adapter.hpp index 1a83963343c9c..108e842cc1cb6 100644 --- a/unified-runtime/source/adapters/opencl/adapter.hpp +++ b/unified-runtime/source/adapters/opencl/adapter.hpp @@ -15,7 +15,7 @@ #include "CL/cl.h" #include "common.hpp" -#include "logger/ur_logger.hpp" +#include "common/ur_ref_counter.hpp" struct ur_adapter_handle_t_ : ur::opencl::handle_base { ur_adapter_handle_t_(); @@ -24,7 +24,6 @@ struct ur_adapter_handle_t_ : ur::opencl::handle_base { ur_adapter_handle_t_(ur_adapter_handle_t_ &) = delete; ur_adapter_handle_t_ &operator=(const ur_adapter_handle_t_ &) = delete; - std::atomic RefCount = 0; logger::Logger &log = logger::get_logger("opencl"); cl_ext::ExtFuncPtrCacheT fnCache{}; @@ -37,6 +36,11 @@ struct ur_adapter_handle_t_ : ur::opencl::handle_base { #define CL_CORE_FUNCTION(FUNC) decltype(::FUNC) *FUNC = nullptr; #include "core_functions.def" #undef CL_CORE_FUNCTION + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; namespace ur { diff --git a/unified-runtime/source/adapters/opencl/command_buffer.cpp b/unified-runtime/source/adapters/opencl/command_buffer.cpp index e048b2d22175c..a5c91780dcb38 100644 --- a/unified-runtime/source/adapters/opencl/command_buffer.cpp +++ b/unified-runtime/source/adapters/opencl/command_buffer.cpp @@ -108,13 +108,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - hCommandBuffer->incrementReferenceCount(); + hCommandBuffer->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - if (hCommandBuffer->decrementReferenceCount() == 0) { + if (hCommandBuffer->getRefCounter().decrement() == 0) { delete hCommandBuffer; } @@ -783,7 +783,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp( switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(hCommandBuffer->getReferenceCount()); + return ReturnValue(hCommandBuffer->getRefCounter().getCount()); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/opencl/command_buffer.hpp b/unified-runtime/source/adapters/opencl/command_buffer.hpp index e7b1df48aaddd..c369b00ed2607 100644 --- a/unified-runtime/source/adapters/opencl/command_buffer.hpp +++ b/unified-runtime/source/adapters/opencl/command_buffer.hpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include #include @@ -54,8 +55,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::opencl::handle_base { /// List of commands in the command-buffer. std::vector> CommandHandles; - /// Object reference count - std::atomic_uint32_t RefCount; /// Track last submission of the command-buffer cl_event LastSubmission; @@ -67,11 +66,12 @@ struct ur_exp_command_buffer_handle_t_ : ur::opencl::handle_base { : handle_base(), hInternalQueue(hQueue), hContext(hContext), hDevice(hDevice), CLCommandBuffer(CLCommandBuffer), IsUpdatable(IsUpdatable), IsInOrder(IsInOrder), IsFinalized(false), - RefCount(0), LastSubmission(nullptr) {} + LastSubmission(nullptr) {} ~ur_exp_command_buffer_handle_t_(); - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/context.cpp b/unified-runtime/source/adapters/opencl/context.cpp index 9f91fabba87c4..232353643d83b 100644 --- a/unified-runtime/source/adapters/opencl/context.cpp +++ b/unified-runtime/source/adapters/opencl/context.cpp @@ -108,7 +108,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, return ReturnValue(&hContext->Devices[0], hContext->DeviceCount); } case UR_CONTEXT_INFO_REFERENCE_COUNT: { - return ReturnValue(hContext->getReferenceCount()); + return ReturnValue(hContext->getRefCounter().getCount()); } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; @@ -120,7 +120,7 @@ urContextRelease(ur_context_handle_t hContext) { static std::mutex contextReleaseMutex; std::lock_guard lock(contextReleaseMutex); - if (hContext->decrementReferenceCount() == 0) { + if (hContext->getRefCounter().decrement() == 0) { delete hContext; } @@ -129,7 +129,7 @@ urContextRelease(ur_context_handle_t hContext) { UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - hContext->incrementReferenceCount(); + hContext->getRefCounter().increment(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/opencl/context.hpp b/unified-runtime/source/adapters/opencl/context.hpp index db35222488e67..150e342a1444a 100644 --- a/unified-runtime/source/adapters/opencl/context.hpp +++ b/unified-runtime/source/adapters/opencl/context.hpp @@ -11,6 +11,7 @@ #include "adapter.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "device.hpp" #include @@ -20,7 +21,6 @@ struct ur_context_handle_t_ : ur::opencl::handle_base { native_type CLContext; std::vector Devices; uint32_t DeviceCount; - std::atomic RefCount = 0; bool IsNativeHandleOwned = true; ur_context_handle_t_(native_type Ctx, uint32_t DevCount, @@ -33,14 +33,9 @@ struct ur_context_handle_t_ : ur::opencl::handle_base { // The context retains a reference to the adapter so it can clear the // function ptr cache on destruction urAdapterRetain(ur::cl::getAdapter()); - RefCount = 1; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } static ur_result_t makeWithNative(native_type Ctx, uint32_t DevCount, const ur_device_handle_t *phDevices, @@ -60,4 +55,7 @@ struct ur_context_handle_t_ : ur::opencl::handle_base { clReleaseContext(CLContext); } } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/device.cpp b/unified-runtime/source/adapters/opencl/device.cpp index 1dbde401219ac..d68f76ab316ad 100644 --- a/unified-runtime/source/adapters/opencl/device.cpp +++ b/unified-runtime/source/adapters/opencl/device.cpp @@ -1019,7 +1019,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return UR_RESULT_SUCCESS; } case UR_DEVICE_INFO_REFERENCE_COUNT: { - return ReturnValue(hDevice->getReferenceCount()); + return ReturnValue(hDevice->getRefCounter().getCount()); } case UR_DEVICE_INFO_PARTITION_MAX_SUB_DEVICES: { CL_RETURN_ON_FAILURE(clGetDeviceInfo(hDevice->CLDevice, @@ -1567,7 +1567,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDevicePartition( // Root devices ref count are unchanged through out the program lifetime. UR_APIEXPORT ur_result_t UR_APICALL urDeviceRetain(ur_device_handle_t hDevice) { if (hDevice->ParentDevice) { - hDevice->incrementReferenceCount(); + hDevice->getRefCounter().increment(); } return UR_RESULT_SUCCESS; @@ -1577,7 +1577,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceRetain(ur_device_handle_t hDevice) { UR_APIEXPORT ur_result_t UR_APICALL urDeviceRelease(ur_device_handle_t hDevice) { if (hDevice->ParentDevice) { - if (hDevice->decrementReferenceCount() == 0) { + if (hDevice->getRefCounter().decrement() == 0) { delete hDevice; } } diff --git a/unified-runtime/source/adapters/opencl/device.hpp b/unified-runtime/source/adapters/opencl/device.hpp index 1c100535c6643..56df610b4d19b 100644 --- a/unified-runtime/source/adapters/opencl/device.hpp +++ b/unified-runtime/source/adapters/opencl/device.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "device.hpp" #include "platform.hpp" @@ -51,11 +52,7 @@ struct ur_device_handle_t_ : ur::opencl::handle_base { } } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_result_t getDeviceVersion(oclv::OpenCLVersion &Version) { size_t DevVerSize = 0; @@ -114,4 +111,7 @@ struct ur_device_handle_t_ : ur::opencl::handle_base { return UR_RESULT_SUCCESS; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/event.cpp b/unified-runtime/source/adapters/opencl/event.cpp index bb13f297b60bf..237623dac516d 100644 --- a/unified-runtime/source/adapters/opencl/event.cpp +++ b/unified-runtime/source/adapters/opencl/event.cpp @@ -136,14 +136,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetNativeHandle( } UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { - if (hEvent->decrementReferenceCount() == 0) { + if (hEvent->getRefCounter().decrement() == 0) { delete hEvent; } return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - hEvent->incrementReferenceCount(); + hEvent->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -188,7 +188,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, return ReturnValue(hEvent->Queue); } case UR_EVENT_INFO_REFERENCE_COUNT: { - return ReturnValue(hEvent->getReferenceCount()); + return ReturnValue(hEvent->getRefCounter().getCount()); } default: { size_t CheckPropSize = 0; diff --git a/unified-runtime/source/adapters/opencl/event.hpp b/unified-runtime/source/adapters/opencl/event.hpp index dacddf3d8e1f6..def71ae561cd1 100644 --- a/unified-runtime/source/adapters/opencl/event.hpp +++ b/unified-runtime/source/adapters/opencl/event.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "queue.hpp" #include @@ -19,13 +20,11 @@ struct ur_event_handle_t_ : ur::opencl::handle_base { native_type CLEvent; ur_context_handle_t Context; ur_queue_handle_t Queue; - std::atomic RefCount = 0; bool IsNativeHandleOwned = true; ur_event_handle_t_(native_type Event, ur_context_handle_t Ctx, ur_queue_handle_t Queue) : handle_base(), CLEvent(Event), Context(Ctx), Queue(Queue) { - RefCount = 1; urContextRetain(Context); if (Queue) { urQueueRetain(Queue); @@ -42,11 +41,7 @@ struct ur_event_handle_t_ : ur::opencl::handle_base { } } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_result_t ensureQueue() { if (!Queue) { @@ -60,6 +55,9 @@ struct ur_event_handle_t_ : ur::opencl::handle_base { return UR_RESULT_SUCCESS; } + +private: + UR_ReferenceCounter RefCounter; }; inline cl_event *ifUrEvent(ur_event_handle_t *ReturnedEvent, cl_event &Event) { diff --git a/unified-runtime/source/adapters/opencl/kernel.cpp b/unified-runtime/source/adapters/opencl/kernel.cpp index f0c22a99749a9..c9062a7bb3a84 100644 --- a/unified-runtime/source/adapters/opencl/kernel.cpp +++ b/unified-runtime/source/adapters/opencl/kernel.cpp @@ -152,7 +152,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, return ReturnValue(hKernel->Context); } case UR_KERNEL_INFO_REFERENCE_COUNT: { - return ReturnValue(hKernel->getReferenceCount()); + return ReturnValue(hKernel->getRefCounter().getCount()); } default: { size_t CheckPropSize = 0; @@ -343,13 +343,13 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - hKernel->incrementReferenceCount(); + hKernel->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(ur_kernel_handle_t hKernel) { - if (hKernel->decrementReferenceCount() == 0) { + if (hKernel->getRefCounter().decrement() == 0) { delete hKernel; } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/opencl/kernel.hpp b/unified-runtime/source/adapters/opencl/kernel.hpp index ef73e47c1e319..2f14668405e91 100644 --- a/unified-runtime/source/adapters/opencl/kernel.hpp +++ b/unified-runtime/source/adapters/opencl/kernel.hpp @@ -11,6 +11,7 @@ #include "adapter.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include "program.hpp" @@ -21,14 +22,12 @@ struct ur_kernel_handle_t_ : ur::opencl::handle_base { native_type CLKernel; ur_program_handle_t Program; ur_context_handle_t Context; - std::atomic RefCount = 0; bool IsNativeHandleOwned = true; clSetKernelArgMemPointerINTEL_fn clSetKernelArgMemPointerINTEL = nullptr; ur_kernel_handle_t_(native_type Kernel, ur_program_handle_t Program, ur_context_handle_t Context) : handle_base(), CLKernel(Kernel), Program(Program), Context(Context) { - RefCount = 1; urProgramRetain(Program); urContextRetain(Context); @@ -46,14 +45,13 @@ struct ur_kernel_handle_t_ : ur::opencl::handle_base { } } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } static ur_result_t makeWithNative(native_type NativeKernel, ur_program_handle_t Program, ur_context_handle_t Context, ur_kernel_handle_t &Kernel); + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/memory.cpp b/unified-runtime/source/adapters/opencl/memory.cpp index 62698e6105520..1212b8e729a93 100644 --- a/unified-runtime/source/adapters/opencl/memory.cpp +++ b/unified-runtime/source/adapters/opencl/memory.cpp @@ -521,7 +521,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory, return ReturnValue(hMemory->Context); } case UR_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hMemory->getReferenceCount()); + return ReturnValue(hMemory->getRefCounter().getCount()); } default: { size_t CheckPropSize = 0; @@ -569,12 +569,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t hMemory, } UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) { - hMem->incrementReferenceCount(); + hMem->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { - if (hMem->decrementReferenceCount() == 0) { + if (hMem->getRefCounter().decrement() == 0) { delete hMem; } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/opencl/memory.hpp b/unified-runtime/source/adapters/opencl/memory.hpp index a0f8410e3df03..592f3941c8ac6 100644 --- a/unified-runtime/source/adapters/opencl/memory.hpp +++ b/unified-runtime/source/adapters/opencl/memory.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include @@ -18,12 +19,10 @@ struct ur_mem_handle_t_ : ur::opencl::handle_base { using native_type = cl_mem; native_type CLMemory; ur_context_handle_t Context; - std::atomic RefCount = 0; bool IsNativeHandleOwned = true; ur_mem_handle_t_(native_type Mem, ur_context_handle_t Ctx) : handle_base(), CLMemory(Mem), Context(Ctx) { - RefCount = 1; urContextRetain(Context); } @@ -34,13 +33,12 @@ struct ur_mem_handle_t_ : ur::opencl::handle_base { } } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } static ur_result_t makeWithNative(native_type NativeMem, ur_context_handle_t Ctx, ur_mem_handle_t &Mem); + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/program.cpp b/unified-runtime/source/adapters/opencl/program.cpp index 1c3a5e45b3bd5..715015a96940c 100644 --- a/unified-runtime/source/adapters/opencl/program.cpp +++ b/unified-runtime/source/adapters/opencl/program.cpp @@ -223,7 +223,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, return ReturnValue(hProgram->Devices.data(), hProgram->NumDevices); } case UR_PROGRAM_INFO_REFERENCE_COUNT: { - return ReturnValue(hProgram->getReferenceCount()); + return ReturnValue(hProgram->getRefCounter().getCount()); } default: { size_t CheckPropSize = 0; @@ -383,13 +383,13 @@ urProgramGetBuildInfo(ur_program_handle_t hProgram, ur_device_handle_t hDevice, UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - hProgram->incrementReferenceCount(); + hProgram->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urProgramRelease(ur_program_handle_t hProgram) { - if (hProgram->decrementReferenceCount() == 0) { + if (hProgram->getRefCounter().decrement() == 0) { delete hProgram; } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/opencl/program.hpp b/unified-runtime/source/adapters/opencl/program.hpp index 69b3430d2bc3a..8eea84b7cafbb 100644 --- a/unified-runtime/source/adapters/opencl/program.hpp +++ b/unified-runtime/source/adapters/opencl/program.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include @@ -18,7 +19,6 @@ struct ur_program_handle_t_ : ur::opencl::handle_base { using native_type = cl_program; native_type CLProgram; ur_context_handle_t Context; - std::atomic RefCount = 0; bool IsNativeHandleOwned = true; uint32_t NumDevices = 0; std::vector Devices; @@ -26,7 +26,6 @@ struct ur_program_handle_t_ : ur::opencl::handle_base { ur_program_handle_t_(native_type Prog, ur_context_handle_t Ctx, uint32_t NumDevices, ur_device_handle_t *Devs) : handle_base(), CLProgram(Prog), Context(Ctx), NumDevices(NumDevices) { - RefCount = 1; urContextRetain(Context); for (uint32_t i = 0; i < NumDevices; i++) { Devices.push_back(Devs[i]); @@ -40,13 +39,12 @@ struct ur_program_handle_t_ : ur::opencl::handle_base { } } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } static ur_result_t makeWithNative(native_type NativeProg, ur_context_handle_t Context, ur_program_handle_t &Program); + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/queue.cpp b/unified-runtime/source/adapters/opencl/queue.cpp index 040186e701647..4c63fca65c27b 100644 --- a/unified-runtime/source/adapters/opencl/queue.cpp +++ b/unified-runtime/source/adapters/opencl/queue.cpp @@ -234,7 +234,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, return ReturnValue(mapCLQueuePropsToUR(QueueProperties)); } case UR_QUEUE_INFO_REFERENCE_COUNT: { - return ReturnValue(hQueue->getReferenceCount()); + return ReturnValue(hQueue->getRefCounter().getCount()); } default: { size_t CheckPropSize = 0; @@ -289,12 +289,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueFlush(ur_queue_handle_t hQueue) { } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - hQueue->incrementReferenceCount(); + hQueue->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { - if (hQueue->decrementReferenceCount() == 0) { + if (hQueue->getRefCounter().decrement() == 0) { delete hQueue; } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/opencl/queue.hpp b/unified-runtime/source/adapters/opencl/queue.hpp index 5fd429d7b0518..0a7a087891695 100644 --- a/unified-runtime/source/adapters/opencl/queue.hpp +++ b/unified-runtime/source/adapters/opencl/queue.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include "device.hpp" @@ -22,7 +23,6 @@ struct ur_queue_handle_t_ : ur::opencl::handle_base { ur_device_handle_t Device; // Used to keep a handle to the default queue alive if it is different std::optional DeviceDefault = std::nullopt; - std::atomic RefCount = 0; bool IsNativeHandleOwned = true; // Used to implement UR_QUEUE_INFO_EMPTY query bool IsInOrder; @@ -32,7 +32,7 @@ struct ur_queue_handle_t_ : ur::opencl::handle_base { ur_device_handle_t Dev, bool InOrder) : handle_base(), CLQueue(Queue), Context(Ctx), Device(Dev), IsInOrder(InOrder) { - RefCount = 1; + urDeviceRetain(Device); urContextRetain(Context); } @@ -53,11 +53,7 @@ struct ur_queue_handle_t_ : ur::opencl::handle_base { } } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } // Stores last event for in-order queues. Has no effect if queue is Out Of // Order. The last event is used to implement UR_QUEUE_INFO_EMPTY query. @@ -74,4 +70,7 @@ struct ur_queue_handle_t_ : ur::opencl::handle_base { } return UR_RESULT_SUCCESS; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/sampler.cpp b/unified-runtime/source/adapters/opencl/sampler.cpp index 69f3f5167a986..b011c108ada45 100644 --- a/unified-runtime/source/adapters/opencl/sampler.cpp +++ b/unified-runtime/source/adapters/opencl/sampler.cpp @@ -175,7 +175,7 @@ urSamplerGetInfo(ur_sampler_handle_t hSampler, ur_sampler_info_t propName, return ReturnValue(hSampler->Context); } case UR_SAMPLER_INFO_REFERENCE_COUNT: { - return ReturnValue(hSampler->getReferenceCount()); + return ReturnValue(hSampler->getRefCounter().getCount()); } // ur_bool_t have a size of uint8_t, but cl_bool size have the size of // uint32_t so this adjust UR_SAMPLER_INFO_NORMALIZED_COORDS info to map @@ -221,13 +221,13 @@ urSamplerGetInfo(ur_sampler_handle_t hSampler, ur_sampler_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urSamplerRetain(ur_sampler_handle_t hSampler) { - hSampler->incrementReferenceCount(); + hSampler->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urSamplerRelease(ur_sampler_handle_t hSampler) { - if (hSampler->decrementReferenceCount() == 0) { + if (hSampler->getRefCounter().decrement() == 0) { delete hSampler; } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/opencl/sampler.hpp b/unified-runtime/source/adapters/opencl/sampler.hpp index b661fe195bcb2..d368a32814e40 100644 --- a/unified-runtime/source/adapters/opencl/sampler.hpp +++ b/unified-runtime/source/adapters/opencl/sampler.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include @@ -17,12 +18,10 @@ struct ur_sampler_handle_t_ : ur::opencl::handle_base { using native_type = cl_sampler; native_type CLSampler; ur_context_handle_t Context; - std::atomic RefCount = 0; bool IsNativeHandleOwned = false; ur_sampler_handle_t_(native_type Sampler, ur_context_handle_t Ctx) : handle_base(), CLSampler(Sampler), Context(Ctx) { - RefCount = 1; urContextRetain(Context); } @@ -33,9 +32,8 @@ struct ur_sampler_handle_t_ : ur::opencl::handle_base { } } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/common/cuda-hip/stream_queue.hpp b/unified-runtime/source/common/cuda-hip/stream_queue.hpp index 0ead67e1d8729..67e6469ebc153 100644 --- a/unified-runtime/source/common/cuda-hip/stream_queue.hpp +++ b/unified-runtime/source/common/cuda-hip/stream_queue.hpp @@ -15,6 +15,8 @@ #include #include +#include "common/ur_ref_count.hpp" + using ur_stream_guard = std::unique_lock; /// Generic implementation of an out-of-order UR queue based on in-order @@ -44,7 +46,6 @@ struct stream_queue_t { std::vector TransferAppliedBarrier; ur_context_handle_t_ *Context; ur_device_handle_t_ *Device; - std::atomic_uint32_t RefCount{1}; std::atomic_uint32_t EventCount{0}; std::atomic_uint32_t ComputeStreamIndex{0}; std::atomic_uint32_t TransferStreamIndex{0}; @@ -344,11 +345,7 @@ struct stream_queue_t { ur_context_handle_t_ *getContext() const { return Context; }; - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } uint32_t getNextEventId() noexcept { return ++EventCount; } @@ -383,4 +380,7 @@ struct stream_queue_t { native_type getStream() { return q->getThreadLocalStream(); } ~interop_guard() { q->getThreadLocalStream() = native_type{0}; } }; + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/common/ur_ref_counter.hpp b/unified-runtime/source/common/ur_ref_counter.hpp new file mode 100644 index 0000000000000..9ea947ef9f65e --- /dev/null +++ b/unified-runtime/source/common/ur_ref_counter.hpp @@ -0,0 +1,28 @@ +/* + * + * Copyright (C) 2025 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM + * Exceptions. See LICENSE.TXT + * + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + */ +#ifndef UR_REF_COUNTER_HPP +#define UR_REF_COUNTER_HPP 1 + +#include +#include + +class UR_ReferenceCounter { +public: + uint32_t getCount() const noexcept { return Count.load(); } + uint32_t increment() { return ++Count; } + uint32_t decrement() { return --Count; } + void reset() { Count = 1; } + +private: + std::atomic_uint32_t Count{1}; +}; + +#endif // UR_REF_COUNTER_HPP From f47f9b52a1c0c8b9e19854173d89f022a5caadfa Mon Sep 17 00:00:00 2001 From: Martin Morrison-Grant Date: Thu, 5 Jun 2025 10:45:10 +0100 Subject: [PATCH 2/5] Remove mutex from OpenCL urContextRelease - not required. --- unified-runtime/source/adapters/hip/event.hpp | 2 +- unified-runtime/source/adapters/level_zero/event.hpp | 4 +++- unified-runtime/source/adapters/level_zero/queue.hpp | 4 +++- unified-runtime/source/common/ur_ref_counter.hpp | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/unified-runtime/source/adapters/hip/event.hpp b/unified-runtime/source/adapters/hip/event.hpp index 6b0c1fb40f740..f0f4c262c3707 100644 --- a/unified-runtime/source/adapters/hip/event.hpp +++ b/unified-runtime/source/adapters/hip/event.hpp @@ -10,8 +10,8 @@ #pragma once #include "common.hpp" -#include "queue.hpp" #include "common/ur_ref_counter.hpp" +#include "queue.hpp" /// UR Event mapping to hipEvent_t struct ur_event_handle_t_ : ur::hip::handle_base { diff --git a/unified-runtime/source/adapters/level_zero/event.hpp b/unified-runtime/source/adapters/level_zero/event.hpp index df56289c8bdef..e0dc1b0d86aa2 100644 --- a/unified-runtime/source/adapters/level_zero/event.hpp +++ b/unified-runtime/source/adapters/level_zero/event.hpp @@ -250,7 +250,9 @@ struct ur_event_handle_t_ : ur_object { ur_event_handle_t OriginAllocEvent = nullptr; UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } - UR_ReferenceCounter &getRefCounterExternal() noexcept { return RefCounterExternal; } + UR_ReferenceCounter &getRefCounterExternal() noexcept { + return RefCounterExternal; + } private: UR_ReferenceCounter RefCounter; diff --git a/unified-runtime/source/adapters/level_zero/queue.hpp b/unified-runtime/source/adapters/level_zero/queue.hpp index 664d17497a01d..467a4aab7baeb 100644 --- a/unified-runtime/source/adapters/level_zero/queue.hpp +++ b/unified-runtime/source/adapters/level_zero/queue.hpp @@ -684,7 +684,9 @@ struct ur_queue_handle_t_ : ur_object { ur_queue_handle_t_ *UnifiedHandle; UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } - UR_ReferenceCounter &getRefCounterExternal() noexcept { return RefCounterExternal; } + UR_ReferenceCounter &getRefCounterExternal() noexcept { + return RefCounterExternal; + } private: UR_ReferenceCounter RefCounter; diff --git a/unified-runtime/source/common/ur_ref_counter.hpp b/unified-runtime/source/common/ur_ref_counter.hpp index 9ea947ef9f65e..fbb47642fcc5f 100644 --- a/unified-runtime/source/common/ur_ref_counter.hpp +++ b/unified-runtime/source/common/ur_ref_counter.hpp @@ -11,8 +11,8 @@ #ifndef UR_REF_COUNTER_HPP #define UR_REF_COUNTER_HPP 1 -#include #include +#include class UR_ReferenceCounter { public: From 9ce413ceab1d7ced61c580a9bde0a2285f7b4114 Mon Sep 17 00:00:00 2001 From: Martin Morrison-Grant Date: Thu, 5 Jun 2025 12:18:21 +0100 Subject: [PATCH 3/5] fix include and remove cl context mutex --- unified-runtime/source/adapters/opencl/context.cpp | 3 --- unified-runtime/source/common/cuda-hip/stream_queue.hpp | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/unified-runtime/source/adapters/opencl/context.cpp b/unified-runtime/source/adapters/opencl/context.cpp index 232353643d83b..a9ade3c142743 100644 --- a/unified-runtime/source/adapters/opencl/context.cpp +++ b/unified-runtime/source/adapters/opencl/context.cpp @@ -117,9 +117,6 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urContextRelease(ur_context_handle_t hContext) { - static std::mutex contextReleaseMutex; - - std::lock_guard lock(contextReleaseMutex); if (hContext->getRefCounter().decrement() == 0) { delete hContext; } diff --git a/unified-runtime/source/common/cuda-hip/stream_queue.hpp b/unified-runtime/source/common/cuda-hip/stream_queue.hpp index 67e6469ebc153..6d431ffe55c34 100644 --- a/unified-runtime/source/common/cuda-hip/stream_queue.hpp +++ b/unified-runtime/source/common/cuda-hip/stream_queue.hpp @@ -15,7 +15,7 @@ #include #include -#include "common/ur_ref_count.hpp" +#include "common/ur_ref_counter.hpp" using ur_stream_guard = std::unique_lock; From b3fbf7262b51a4235a0282ece8309ad4b1a0b1ad Mon Sep 17 00:00:00 2001 From: Martin Morrison-Grant Date: Thu, 5 Jun 2025 12:42:20 +0100 Subject: [PATCH 4/5] Use new ref count class in sanitizer layers. Fix some compile errors. --- .../source/adapters/hip/kernel.hpp | 2 +- .../source/adapters/hip/physical_mem.hpp | 4 +--- .../source/adapters/level_zero/event.hpp | 6 +++--- .../source/adapters/level_zero/queue.hpp | 4 ++-- .../source/adapters/offload/common.hpp | 4 ---- .../source/adapters/offload/event.cpp | 6 +++--- .../source/adapters/offload/event.hpp | 8 +++++++- .../source/adapters/opencl/device.hpp | 2 -- .../layers/sanitizer/asan/asan_buffer.hpp | 6 +++++- .../loader/layers/sanitizer/asan/asan_ddi.cpp | 16 +++++++-------- .../sanitizer/asan/asan_interceptor.hpp | 20 +++++++++++++++---- .../layers/sanitizer/msan/msan_buffer.hpp | 8 ++++++-- .../loader/layers/sanitizer/msan/msan_ddi.cpp | 17 ++++++++-------- .../sanitizer/msan/msan_interceptor.hpp | 19 +++++++++++++++--- .../layers/sanitizer/tsan/tsan_buffer.hpp | 8 ++++++-- .../loader/layers/sanitizer/tsan/tsan_ddi.cpp | 12 +++++------ .../sanitizer/tsan/tsan_interceptor.hpp | 15 +++++++++++--- 17 files changed, 101 insertions(+), 56 deletions(-) diff --git a/unified-runtime/source/adapters/hip/kernel.hpp b/unified-runtime/source/adapters/hip/kernel.hpp index fa8f1ef6f2c2b..826fabad82ef7 100644 --- a/unified-runtime/source/adapters/hip/kernel.hpp +++ b/unified-runtime/source/adapters/hip/kernel.hpp @@ -237,7 +237,7 @@ struct ur_kernel_handle_t_ : ur::hip::handle_base { ur_context_handle_t Ctxt) : handle_base(), Function{Func}, FunctionWithOffsetParam{FuncWithOffsetParam}, Name{Name}, Context{Ctxt}, - Program{Program}, RefCount{1} { + Program{Program} { urProgramRetain(Program); urContextRetain(Context); diff --git a/unified-runtime/source/adapters/hip/physical_mem.hpp b/unified-runtime/source/adapters/hip/physical_mem.hpp index 0ecefd55f58f6..8e39849405e0e 100644 --- a/unified-runtime/source/adapters/hip/physical_mem.hpp +++ b/unified-runtime/source/adapters/hip/physical_mem.hpp @@ -20,9 +20,7 @@ /// TODO: Implement. /// struct ur_physical_mem_handle_t_ : ur::hip::handle_base { - std::atomic_uint32_t RefCount; - - ur_physical_mem_handle_t_() : handle_base(), RefCount(1) {} + ur_physical_mem_handle_t_() : handle_base() {} UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } diff --git a/unified-runtime/source/adapters/level_zero/event.hpp b/unified-runtime/source/adapters/level_zero/event.hpp index e0dc1b0d86aa2..2ba4913f1ae7d 100644 --- a/unified-runtime/source/adapters/level_zero/event.hpp +++ b/unified-runtime/source/adapters/level_zero/event.hpp @@ -257,11 +257,11 @@ struct ur_event_handle_t_ : ur_object { private: UR_ReferenceCounter RefCounter; - // Besides each PI object keeping a total reference count in - // ur_object::RefCount we keep special track of the event *external* + // Besides each UR object keeping a total reference count in + // RefCounter we keep special track of the event *external* // references. This way we are able to tell when the event is not referenced // externally anymore, i.e. it can't be passed as a dependency event to - // piEnqueue* functions and explicitly waited meaning that we can do some + // urEnqueue* functions and explicitly waited meaning that we can do some // optimizations: // 1. For in-order queues we can reset and reuse event even if it was not yet // completed by submitting a reset command to the queue (since there are no diff --git a/unified-runtime/source/adapters/level_zero/queue.hpp b/unified-runtime/source/adapters/level_zero/queue.hpp index 467a4aab7baeb..fe62321e9d328 100644 --- a/unified-runtime/source/adapters/level_zero/queue.hpp +++ b/unified-runtime/source/adapters/level_zero/queue.hpp @@ -691,8 +691,8 @@ struct ur_queue_handle_t_ : ur_object { private: UR_ReferenceCounter RefCounter; - // Besides each PI object keeping a total reference count in - // ur_object::RefCount we keep special track of the queue *external* + // Besides each UR object keeping a total reference count in + // RefCounter we keep special track of the queue *external* // references. This way we are able to tell when the queue is being finished // externally, and can wait for internal references to complete, and do proper // cleanup of the queue. diff --git a/unified-runtime/source/adapters/offload/common.hpp b/unified-runtime/source/adapters/offload/common.hpp index e24bab3fef5ef..fdcb1f12cd216 100644 --- a/unified-runtime/source/adapters/offload/common.hpp +++ b/unified-runtime/source/adapters/offload/common.hpp @@ -19,7 +19,3 @@ struct ddi_getter { }; using handle_base = ur::handle_base; } // namespace ur::offload - -struct RefCounted : ur::offload::handle_base { - std::atomic_uint32_t RefCount = 1; -}; diff --git a/unified-runtime/source/adapters/offload/event.cpp b/unified-runtime/source/adapters/offload/event.cpp index 5668ea8b50e63..6f9f28612f2df 100644 --- a/unified-runtime/source/adapters/offload/event.cpp +++ b/unified-runtime/source/adapters/offload/event.cpp @@ -23,7 +23,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hKernel, switch (propName) { case UR_EVENT_INFO_REFERENCE_COUNT: - return ReturnValue(hKernel->RefCount.load()); + return ReturnValue(hKernel->getRefCounter().getCount()); default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; } @@ -51,13 +51,13 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - hEvent->RefCount++; + hEvent->getRefCounter().increment(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { - if (--hEvent->RefCount == 0) { + if (hEvent->getRefCounter().decrement() == 0) { // There's a small bug in olDestroyEvent that will crash. Leak the event // in the meantime. // auto Res = olDestroyEvent(hEvent->OffloadEvent); diff --git a/unified-runtime/source/adapters/offload/event.hpp b/unified-runtime/source/adapters/offload/event.hpp index 16e0dc649d2ef..c2d98a4b4d48f 100644 --- a/unified-runtime/source/adapters/offload/event.hpp +++ b/unified-runtime/source/adapters/offload/event.hpp @@ -14,7 +14,13 @@ #include #include "common.hpp" +#include "common/ur_ref_counter.hpp" -struct ur_event_handle_t_ : RefCounted { +struct ur_event_handle_t_ { ol_event_handle_t OffloadEvent; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/opencl/device.hpp b/unified-runtime/source/adapters/opencl/device.hpp index 56df610b4d19b..562ef820c1085 100644 --- a/unified-runtime/source/adapters/opencl/device.hpp +++ b/unified-runtime/source/adapters/opencl/device.hpp @@ -20,13 +20,11 @@ struct ur_device_handle_t_ : ur::opencl::handle_base { ur_platform_handle_t Platform; cl_device_type Type = 0; ur_device_handle_t ParentDevice = nullptr; - std::atomic RefCount = 0; bool IsNativeHandleOwned = true; ur_device_handle_t_(native_type Dev, ur_platform_handle_t Plat, ur_device_handle_t Parent) : handle_base(), CLDevice(Dev), Platform(Plat), ParentDevice(Parent) { - RefCount = 1; if (Parent) { Type = Parent->Type; [[maybe_unused]] auto Res = clRetainDevice(CLDevice); diff --git a/unified-runtime/source/loader/layers/sanitizer/asan/asan_buffer.hpp b/unified-runtime/source/loader/layers/sanitizer/asan/asan_buffer.hpp index 113454638e9dd..168bac8826459 100644 --- a/unified-runtime/source/loader/layers/sanitizer/asan/asan_buffer.hpp +++ b/unified-runtime/source/loader/layers/sanitizer/asan/asan_buffer.hpp @@ -17,6 +17,7 @@ #include #include +#include "common/ur_ref_counter.hpp" #include "ur/ur.hpp" namespace ur_sanitizer_layer { @@ -68,9 +69,12 @@ struct MemBuffer { std::optional SubBuffer; - std::atomic RefCount = 1; + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } ur_shared_mutex Mutex; + +private: + UR_ReferenceCounter RefCounter; }; ur_result_t EnqueueMemCopyRectHelper( diff --git a/unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp b/unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp index 0b4a64c38a549..0cea171d947c9 100644 --- a/unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp +++ b/unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp @@ -301,7 +301,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( auto ProgramInfo = getAsanInterceptor()->getProgramInfo(hProgram); if (ProgramInfo != nullptr) { - ProgramInfo->RefCount++; + ProgramInfo->getRefCounter().increment(); } return UR_RESULT_SUCCESS; @@ -454,7 +454,7 @@ ur_result_t UR_APICALL urProgramRelease( UR_CALL(pfnProgramRelease(hProgram)); auto ProgramInfo = getAsanInterceptor()->getProgramInfo(hProgram); - if (ProgramInfo != nullptr && --ProgramInfo->RefCount == 0) { + if (ProgramInfo != nullptr && ProgramInfo->getRefCounter().decrement() == 0) { UR_CALL(getAsanInterceptor()->unregisterProgram(hProgram)); UR_CALL(getAsanInterceptor()->eraseProgram(hProgram)); } @@ -608,7 +608,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( auto ContextInfo = getAsanInterceptor()->getContextInfo(hContext); UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - ContextInfo->RefCount++; + ContextInfo->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -630,7 +630,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( auto ContextInfo = getAsanInterceptor()->getContextInfo(hContext); UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - if (--ContextInfo->RefCount == 0) { + if (ContextInfo->getRefCounter().decrement() == 0) { UR_CALL(getAsanInterceptor()->eraseContext(hContext)); } @@ -750,7 +750,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( UR_LOG_L(getContext()->logger, DEBUG, "==== urMemRetain"); if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hMem)) { - MemBuffer->RefCount++; + MemBuffer->getRefCounter().increment(); } else { UR_CALL(pfnRetain(hMem)); } @@ -772,7 +772,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( UR_LOG_L(getContext()->logger, DEBUG, "==== urMemRelease"); if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hMem)) { - if (--MemBuffer->RefCount != 0) { + if (MemBuffer->getRefCounter().decrement() != 0) { return UR_RESULT_SUCCESS; } UR_CALL(MemBuffer->free()); @@ -1425,7 +1425,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( UR_CALL(pfnRetain(hKernel)); auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel); - KernelInfo.RefCount++; + KernelInfo.getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -1444,7 +1444,7 @@ __urdlllocal ur_result_t urKernelRelease( UR_LOG_L(getContext()->logger, DEBUG, "==== urKernelRelease"); auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel); - if (--KernelInfo.RefCount == 0) { + if (KernelInfo.getRefCounter().decrement() == 0) { UR_CALL(getAsanInterceptor()->eraseKernelInfo(hKernel)); } UR_CALL(pfnRelease(hKernel)); diff --git a/unified-runtime/source/loader/layers/sanitizer/asan/asan_interceptor.hpp b/unified-runtime/source/loader/layers/sanitizer/asan/asan_interceptor.hpp index c3f854d317e56..b26e5076a2923 100644 --- a/unified-runtime/source/loader/layers/sanitizer/asan/asan_interceptor.hpp +++ b/unified-runtime/source/loader/layers/sanitizer/asan/asan_interceptor.hpp @@ -21,6 +21,7 @@ #include "sanitizer_common/sanitizer_common.hpp" #include "sanitizer_common/sanitizer_options.hpp" #include "ur_sanitizer_layer.hpp" +#include "common/ur_ref_counter.hpp" #include #include @@ -82,7 +83,6 @@ struct QueueInfo { struct KernelInfo { ur_kernel_handle_t Handle; - std::atomic RefCount = 1; // sanitized kernel bool IsInstrumented = false; @@ -107,11 +107,15 @@ struct KernelInfo { getContext()->urDdiTable.Kernel.pfnRelease(Handle); assert(Result == UR_RESULT_SUCCESS); } + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; struct ProgramInfo { ur_program_handle_t Handle; - std::atomic RefCount = 1; // Program is built only once, so we don't need to lock it std::unordered_set> AllocInfoForGlobals; @@ -130,6 +134,11 @@ struct ProgramInfo { } bool isKernelInstrumented(ur_kernel_handle_t Kernel) const; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; struct ContextInfo { @@ -138,8 +147,6 @@ struct ContextInfo { ur_usm_pool_handle_t USMPool{}; std::once_flag PoolInit; - std::atomic RefCount = 1; - std::vector DeviceList; std::unordered_map AllocInfosMap; @@ -163,6 +170,11 @@ struct ContextInfo { } ur_usm_pool_handle_t getUSMPool(); + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; struct AsanRuntimeDataWrapper { diff --git a/unified-runtime/source/loader/layers/sanitizer/msan/msan_buffer.hpp b/unified-runtime/source/loader/layers/sanitizer/msan/msan_buffer.hpp index 9b23ee6e07200..ed58ec75ddf02 100644 --- a/unified-runtime/source/loader/layers/sanitizer/msan/msan_buffer.hpp +++ b/unified-runtime/source/loader/layers/sanitizer/msan/msan_buffer.hpp @@ -17,6 +17,7 @@ #include #include +#include "common/ur_ref_counter.hpp" #include "ur/ur.hpp" namespace ur_sanitizer_layer { @@ -68,9 +69,12 @@ struct MemBuffer { std::optional SubBuffer; - std::atomic RefCount = 1; - ur_shared_mutex Mutex; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; ur_result_t EnqueueMemCopyRectHelper( diff --git a/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp b/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp index 13aa868cbf0f0..b3515047f8bd1 100644 --- a/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp +++ b/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp @@ -12,6 +12,7 @@ */ #include "msan_ddi.hpp" +#include "common/ur_ref_counter.hpp" #include "msan_interceptor.hpp" #include "sanitizer_common/sanitizer_utils.hpp" #include "ur_sanitizer_layer.hpp" @@ -248,7 +249,7 @@ urProgramRetain(ur_program_handle_t auto ProgramInfo = getMsanInterceptor()->getProgramInfo(hProgram); UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - ProgramInfo->RefCount++; + ProgramInfo->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -381,7 +382,7 @@ ur_result_t urProgramRelease( auto ProgramInfo = getMsanInterceptor()->getProgramInfo(hProgram); UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - if (--ProgramInfo->RefCount == 0) { + if (ProgramInfo->getRefCounter().decrement() == 0) { UR_CALL(getMsanInterceptor()->unregisterProgram(hProgram)); UR_CALL(getMsanInterceptor()->eraseProgram(hProgram)); } @@ -512,7 +513,7 @@ ur_result_t urContextRetain( auto ContextInfo = getMsanInterceptor()->getContextInfo(hContext); UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - ContextInfo->RefCount++; + ContextInfo->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -530,7 +531,7 @@ ur_result_t urContextRelease( auto ContextInfo = getMsanInterceptor()->getContextInfo(hContext); UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - if (--ContextInfo->RefCount == 0) { + if (ContextInfo->getRefCounter().decrement() == 0) { UR_CALL(getMsanInterceptor()->eraseContext(hContext)); } @@ -642,7 +643,7 @@ ur_result_t urMemRetain( UR_LOG_L(getContext()->logger, DEBUG, "==== urMemRetain"); if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hMem)) { - MemBuffer->RefCount++; + MemBuffer->getRefCounter().increment(); } else { UR_CALL(pfnRetain(hMem)); } @@ -660,7 +661,7 @@ ur_result_t urMemRelease( UR_LOG_L(getContext()->logger, DEBUG, "==== urMemRelease"); if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hMem)) { - if (--MemBuffer->RefCount != 0) { + if (MemBuffer->getRefCounter().decrement() != 0) { return UR_RESULT_SUCCESS; } UR_CALL(MemBuffer->free()); @@ -1331,7 +1332,7 @@ ur_result_t urKernelRetain( UR_CALL(pfnRetain(hKernel)); auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel); - KernelInfo.RefCount++; + KernelInfo.getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -1346,7 +1347,7 @@ ur_result_t urKernelRelease( UR_LOG_L(getContext()->logger, DEBUG, "==== urKernelRelease"); auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel); - if (--KernelInfo.RefCount == 0) { + if (KernelInfo.getRefCounter().decrement() == 0) { UR_CALL(getMsanInterceptor()->eraseKernelInfo(hKernel)); } UR_CALL(pfnRelease(hKernel)); diff --git a/unified-runtime/source/loader/layers/sanitizer/msan/msan_interceptor.hpp b/unified-runtime/source/loader/layers/sanitizer/msan/msan_interceptor.hpp index b7c5d274ae76a..35ef9a137b352 100644 --- a/unified-runtime/source/loader/layers/sanitizer/msan/msan_interceptor.hpp +++ b/unified-runtime/source/loader/layers/sanitizer/msan/msan_interceptor.hpp @@ -20,6 +20,7 @@ #include "sanitizer_common/sanitizer_common.hpp" #include "sanitizer_common/sanitizer_options.hpp" #include "ur_sanitizer_layer.hpp" +#include "common/ur_ref_counter.hpp" #include #include @@ -75,7 +76,6 @@ struct QueueInfo { struct KernelInfo { ur_kernel_handle_t Handle; - std::atomic RefCount = 1; // sanitized kernel bool IsInstrumented = false; @@ -102,11 +102,15 @@ struct KernelInfo { getContext()->urDdiTable.Kernel.pfnRelease(Handle); assert(Result == UR_RESULT_SUCCESS); } + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; struct ProgramInfo { ur_program_handle_t Handle; - std::atomic RefCount = 1; struct KernelMetada { bool CheckLocals; @@ -130,12 +134,16 @@ struct ProgramInfo { bool isKernelInstrumented(ur_kernel_handle_t Kernel) const; const KernelMetada &getKernelMetadata(ur_kernel_handle_t Kernel) const; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; struct ContextInfo { ur_context_handle_t Handle; size_t CleanShadowSize = 1024; - std::atomic RefCount = 1; std::vector DeviceList; @@ -146,6 +154,11 @@ struct ContextInfo { } ~ContextInfo(); + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; struct MsanRuntimeDataWrapper { diff --git a/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_buffer.hpp b/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_buffer.hpp index 9a9f642f7ae2f..2b184033fba7a 100644 --- a/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_buffer.hpp +++ b/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_buffer.hpp @@ -17,6 +17,7 @@ #include #include +#include "common/ur_ref_counter.hpp" #include "ur/ur.hpp" namespace ur_sanitizer_layer { @@ -68,9 +69,12 @@ struct MemBuffer { std::optional SubBuffer; - std::atomic RefCount = 1; - ur_shared_mutex Mutex; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; ur_result_t EnqueueMemCopyRectHelper( diff --git a/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp b/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp index f3802f652d614..251e621122e07 100644 --- a/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp +++ b/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp @@ -103,7 +103,7 @@ ur_result_t urContextRetain( UR_LOG_L(getContext()->logger, ERR, "Invalid context"); return UR_RESULT_ERROR_INVALID_CONTEXT; } - ContextInfo->RefCount++; + ContextInfo->getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -123,7 +123,7 @@ ur_result_t urContextRelease( return UR_RESULT_ERROR_INVALID_CONTEXT; } - if (--ContextInfo->RefCount == 0) { + if (ContextInfo->getRefCounter().decrement() == 0) { UR_CALL(getTsanInterceptor()->eraseContext(hContext)); } @@ -290,7 +290,7 @@ ur_result_t urMemRetain( UR_LOG_L(getContext()->logger, DEBUG, "==== urMemRetain"); if (auto MemBuffer = getTsanInterceptor()->getMemBuffer(hMem)) { - MemBuffer->RefCount++; + MemBuffer->getRefCounter().increment(); } else { UR_CALL(getContext()->urDdiTable.Mem.pfnRetain(hMem)); } @@ -306,7 +306,7 @@ ur_result_t urMemRelease( UR_LOG_L(getContext()->logger, DEBUG, "==== urMemRelease"); if (auto MemBuffer = getTsanInterceptor()->getMemBuffer(hMem)) { - if (--MemBuffer->RefCount != 0) { + if (MemBuffer->getRefCounter().decrement() != 0) { return UR_RESULT_SUCCESS; } UR_CALL(MemBuffer->free()); @@ -965,7 +965,7 @@ ur_result_t urKernelRetain( UR_CALL(getContext()->urDdiTable.Kernel.pfnRetain(hKernel)); auto &KernelInfo = getTsanInterceptor()->getKernelInfo(hKernel); - KernelInfo.RefCount++; + KernelInfo.getRefCounter().increment(); return UR_RESULT_SUCCESS; } @@ -980,7 +980,7 @@ ur_result_t urKernelRelease( UR_LOG_L(getContext()->logger, DEBUG, "==== urKernelRelease"); auto &KernelInfo = getTsanInterceptor()->getKernelInfo(hKernel); - if (--KernelInfo.RefCount == 0) { + if (KernelInfo.getRefCounter().decrement() == 0) { UR_CALL(getTsanInterceptor()->eraseKernel(hKernel)); } UR_CALL(pfnRelease(hKernel)); diff --git a/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_interceptor.hpp b/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_interceptor.hpp index c98df6fb59550..8ba65e659d515 100644 --- a/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_interceptor.hpp +++ b/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_interceptor.hpp @@ -20,6 +20,8 @@ #include "tsan_shadow.hpp" #include "ur_sanitizer_layer.hpp" +#include "common/ur_ref_counter.hpp" + namespace ur_sanitizer_layer { namespace tsan { @@ -44,8 +46,6 @@ struct DeviceInfo { struct ContextInfo { ur_context_handle_t Handle; - std::atomic RefCount = 1; - std::vector DeviceList; ur_shared_mutex AllocInfosMapMutex; @@ -69,6 +69,11 @@ struct ContextInfo { ContextInfo &operator=(const ContextInfo &) = delete; void insertAllocInfo(ur_device_handle_t Device, TsanAllocInfo AI); + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; struct DeviceGlobalInfo { @@ -78,7 +83,6 @@ struct DeviceGlobalInfo { struct KernelInfo { ur_kernel_handle_t Handle = nullptr; - std::atomic RefCount = 1; // lock this mutex if following fields are accessed ur_shared_mutex Mutex; @@ -101,6 +105,11 @@ struct KernelInfo { KernelInfo(const KernelInfo &) = delete; KernelInfo &operator=(const KernelInfo &) = delete; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; struct TsanRuntimeDataWrapper { From dc52dce9f817f538c409c429b4e34927f060cf7a Mon Sep 17 00:00:00 2001 From: Martin Morrison-Grant Date: Thu, 5 Jun 2025 14:34:41 +0100 Subject: [PATCH 5/5] Add to L0v2 and fix compile errors. --- unified-runtime/source/adapters/hip/adapter.cpp | 2 +- unified-runtime/source/adapters/hip/memory.cpp | 9 +++++---- unified-runtime/source/adapters/hip/program.cpp | 2 +- .../adapters/level_zero/v2/command_buffer.cpp | 6 +++--- .../adapters/level_zero/v2/command_buffer.hpp | 5 +++++ .../source/adapters/level_zero/v2/context.cpp | 6 +++--- .../source/adapters/level_zero/v2/context.hpp | 6 ++++++ .../source/adapters/level_zero/v2/event.cpp | 6 +++--- .../source/adapters/level_zero/v2/event.hpp | 5 +++++ .../source/adapters/level_zero/v2/event_pool.cpp | 4 ++-- .../source/adapters/level_zero/v2/kernel.cpp | 6 +++--- .../source/adapters/level_zero/v2/kernel.hpp | 5 +++++ .../source/adapters/level_zero/v2/memory.cpp | 6 +++--- .../source/adapters/level_zero/v2/memory.hpp | 13 +++++-------- .../source/adapters/level_zero/v2/queue_handle.hpp | 4 ++-- .../level_zero/v2/queue_immediate_in_order.cpp | 6 +++--- .../level_zero/v2/queue_immediate_in_order.hpp | 6 ++++++ .../source/adapters/level_zero/v2/usm.cpp | 6 +++--- .../source/adapters/level_zero/v2/usm.hpp | 5 +++++ .../layers/sanitizer/asan/asan_interceptor.hpp | 2 +- .../layers/sanitizer/msan/msan_interceptor.hpp | 2 +- 21 files changed, 71 insertions(+), 41 deletions(-) diff --git a/unified-runtime/source/adapters/hip/adapter.cpp b/unified-runtime/source/adapters/hip/adapter.cpp index dfb690f60de78..2350e45bc5ece 100644 --- a/unified-runtime/source/adapters/hip/adapter.cpp +++ b/unified-runtime/source/adapters/hip/adapter.cpp @@ -69,7 +69,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - if (--ur::hip::adapter->getRefCounter().decrement() == 0) { + if (ur::hip::adapter->getRefCounter().decrement() == 0) { delete ur::hip::adapter; } diff --git a/unified-runtime/source/adapters/hip/memory.cpp b/unified-runtime/source/adapters/hip/memory.cpp index d1e8789a41927..f13c191b78b08 100644 --- a/unified-runtime/source/adapters/hip/memory.cpp +++ b/unified-runtime/source/adapters/hip/memory.cpp @@ -63,7 +63,7 @@ checkSupportedImageChannelType(ur_image_channel_type_t ImageChannelType) { UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { try { // Do nothing if there are other references - if (hMem->decrementReferenceCount() > 0) { + if (hMem->getRefCounter().decrement() > 0) { return UR_RESULT_SUCCESS; } @@ -259,7 +259,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory, return ReturnValue(hMemory->getContext()); } case UR_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hMemory->getReferenceCount()); + return ReturnValue(hMemory->getRefCounter().getCount()); } default: @@ -439,8 +439,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t hMemory, } UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) { - UR_ASSERT(hMem->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT); - hMem->incrementReferenceCount(); + UR_ASSERT(hMem->getRefCounter().getCount() > 0, + UR_RESULT_ERROR_INVALID_MEM_OBJECT); + hMem->getRefCounter().increment(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/program.cpp b/unified-runtime/source/adapters/hip/program.cpp index 3073343320cad..17596e268dd62 100644 --- a/unified-runtime/source/adapters/hip/program.cpp +++ b/unified-runtime/source/adapters/hip/program.cpp @@ -385,7 +385,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(hProgram-- > getRefCounter().getCount()); + return ReturnValue(hProgram->getRefCounter().getCount()); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(hProgram->Context); case UR_PROGRAM_INFO_NUM_DEVICES: diff --git a/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp index 4281f5e280326..77d55b3ae1f8d 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp @@ -237,7 +237,7 @@ urCommandBufferCreateExp(ur_context_handle_t context, ur_device_handle_t device, ur_result_t urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { - hCommandBuffer->RefCount.increment(); + hCommandBuffer->getRefCounter().increment(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); @@ -245,7 +245,7 @@ urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { ur_result_t urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { - if (!hCommandBuffer->RefCount.decrementAndTest()) + if (!hCommandBuffer->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; delete hCommandBuffer; @@ -683,7 +683,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer, switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()}); + return ReturnValue(uint32_t{hCommandBuffer->getRefCounter().getCount()}); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp b/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp index 155c8c3b4a3a6..9612123d466ce 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp @@ -12,6 +12,7 @@ #include "../helpers/mutable_helpers.hpp" #include "command_list_manager.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include "kernel.hpp" #include "lockable.hpp" @@ -59,6 +60,8 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object { const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, uint32_t numSyncPointsInWaitList); + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: // Stores all sync points that are created by the command buffer. std::vector syncPoints; @@ -77,4 +80,6 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object { bool isFinalized = false; ur_event_handle_t currentExecution = nullptr; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/level_zero/v2/context.cpp b/unified-runtime/source/adapters/level_zero/v2/context.cpp index 85774c7897198..1ba311f9c3519 100644 --- a/unified-runtime/source/adapters/level_zero/v2/context.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/context.cpp @@ -80,12 +80,12 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext, defaultUSMPool(this, nullptr), asyncPool(this, nullptr) {} ur_result_t ur_context_handle_t_::retain() { - RefCount.increment(); + RefCounter.increment(); return UR_RESULT_SUCCESS; } ur_result_t ur_context_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCounter.decrement() == 0) return UR_RESULT_SUCCESS; delete this; @@ -191,7 +191,7 @@ ur_result_t urContextGetInfo(ur_context_handle_t hContext, case UR_CONTEXT_INFO_NUM_DEVICES: return ReturnValue(uint32_t(hContext->getDevices().size())); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hContext->RefCount.load()}); + return ReturnValue(uint32_t{hContext->getRefCounter().getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // TODO: this is currently not implemented return ReturnValue(uint8_t{false}); diff --git a/unified-runtime/source/adapters/level_zero/v2/context.hpp b/unified-runtime/source/adapters/level_zero/v2/context.hpp index 25e2ba8e0aa5c..c5fcf360028f8 100644 --- a/unified-runtime/source/adapters/level_zero/v2/context.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/context.hpp @@ -14,6 +14,8 @@ #include "command_list_cache.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" + #include "event_pool_cache.hpp" #include "usm.hpp" @@ -53,6 +55,8 @@ struct ur_context_handle_t_ : ur_object { // For that the Device or its root devices need to be in the context. bool isValidDevice(ur_device_handle_t Device) const; + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: const v2::raii::ze_context_handle_t hContext; const std::vector hDevices; @@ -69,4 +73,6 @@ struct ur_context_handle_t_ : ur_object { ur_usm_pool_handle_t_ defaultUSMPool; ur_usm_pool_handle_t_ asyncPool; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/level_zero/v2/event.cpp b/unified-runtime/source/adapters/level_zero/v2/event.cpp index 0b054589c6b81..48634c8b8508f 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/event.cpp @@ -158,12 +158,12 @@ ze_event_handle_t ur_event_handle_t_::getZeEvent() const { } ur_result_t ur_event_handle_t_::retain() { - RefCount.increment(); + RefCounter.increment(); return UR_RESULT_SUCCESS; } ur_result_t ur_event_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCounter.decrement() == 0) return UR_RESULT_SUCCESS; if (event_pool) { @@ -256,7 +256,7 @@ ur_result_t urEventGetInfo(ur_event_handle_t hEvent, ur_event_info_t propName, } } case UR_EVENT_INFO_REFERENCE_COUNT: { - return returnValue(hEvent->RefCount.load()); + return returnValue(hEvent->getRefCounter().increment()); } case UR_EVENT_INFO_COMMAND_QUEUE: { auto urQueueHandle = reinterpret_cast(hEvent->getQueue()) - diff --git a/unified-runtime/source/adapters/level_zero/v2/event.hpp b/unified-runtime/source/adapters/level_zero/v2/event.hpp index 6ed0ebccbc561..13487f7a5d1de 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/event.hpp @@ -17,6 +17,7 @@ #include "adapters/level_zero/v2/queue_api.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "event_provider.hpp" namespace v2 { @@ -111,10 +112,14 @@ struct ur_event_handle_t_ : ur_object { uint64_t getEventStartTimestmap() const; uint64_t getEventEndTimestamp(); + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: ur_event_handle_t_(ur_context_handle_t hContext, event_variant hZeEvent, v2::event_flags_t flags, v2::event_pool *pool); + UR_ReferenceCounter RefCounter; + protected: ur_context_handle_t hContext; diff --git a/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp b/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp index d9639a1a6dcb4..077c84ccbce3a 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp @@ -51,8 +51,8 @@ void event_pool::free(ur_event_handle_t event) { freelist.push_back(event); // The event is still in the pool, so we need to increment the refcount - assert(event->RefCount.load() == 0); - event->RefCount.increment(); + assert(event->getRefCounter().getCount() == 0); + event->getRefCounter().increment(); } event_provider *event_pool::getProvider() const { return provider.get(); } diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp index a2189b57536e8..75591e4880919 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp @@ -97,7 +97,7 @@ ur_kernel_handle_t_::ur_kernel_handle_t_( } ur_result_t ur_kernel_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCounter.decrement() == 0) return UR_RESULT_SUCCESS; // manually release kernels to allow errors to be propagated @@ -370,7 +370,7 @@ urKernelCreateWithNativeHandle(ur_native_handle_t hNativeKernel, ur_result_t urKernelRetain( /// [in] handle for the Kernel to retain ur_kernel_handle_t hKernel) try { - hKernel->RefCount.increment(); + hKernel->getRefCounter().increment(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); @@ -634,7 +634,7 @@ ur_result_t urKernelGetInfo(ur_kernel_handle_t hKernel, spills.size()); } case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hKernel->RefCount.load()}); + return ReturnValue(uint32_t{hKernel->getRefCounter().getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: { auto attributes = hKernel->getSourceAttributes(); return ReturnValue(static_cast(attributes.data())); diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp index 0cabb888ac3be..fdf1b4590a995 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp @@ -13,6 +13,7 @@ #include "../program.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "memory.hpp" struct ur_single_device_kernel_t { @@ -91,6 +92,8 @@ struct ur_kernel_handle_t_ : ur_object { ze_command_list_handle_t cmdList, wait_list_view &waitListView); + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: // Keep the program of the kernel. const ur_program_handle_t hProgram; @@ -116,4 +119,6 @@ struct ur_kernel_handle_t_ : ur_object { // pointer to any non-null kernel in deviceKernels ur_single_device_kernel_t *nonEmptyKernel; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.cpp b/unified-runtime/source/adapters/level_zero/v2/memory.cpp index b1f3829dd6967..e1c4c06a6ada1 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.cpp @@ -671,7 +671,7 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMem, ur_mem_info_t propName, return returnValue(size_t{hMem->getBuffer()->getSize()}); } case UR_MEM_INFO_REFERENCE_COUNT: { - return returnValue(hMem->getObject()->RefCount.load()); + return returnValue(hMem->getRefCounter().getCount()); } default: { return UR_RESULT_ERROR_INVALID_ENUMERATION; @@ -684,14 +684,14 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMem, ur_mem_info_t propName, } ur_result_t urMemRetain(ur_mem_handle_t hMem) try { - hMem->getObject()->RefCount.increment(); + hMem->getRefCounter().increment(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urMemRelease(ur_mem_handle_t hMem) try { - if (!hMem->getObject()->RefCount.decrementAndTest()) + if (!hMem->getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; delete hMem; diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.hpp b/unified-runtime/source/adapters/level_zero/v2/memory.hpp index 0aadd6e85e2a8..7d0d76cfd91a7 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.hpp @@ -19,6 +19,7 @@ #include "../image_common.hpp" #include "command_list_manager.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "lockable.hpp" using usm_unique_ptr_t = std::unique_ptr>; @@ -280,16 +281,10 @@ struct ur_mem_handle_t_ : ur::handle_base { mem); } - ur_object *getObject() { - return std::visit( - [](auto &&arg) -> ur_object * { - return static_cast(&arg); - }, - mem); - } - bool isImage() const { return std::holds_alternative(mem); } + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: template ur_mem_handle_t_(std::in_place_type_t, Args &&...args) @@ -300,4 +295,6 @@ struct ur_mem_handle_t_ : ur::handle_base { ur_discrete_buffer_handle_t, ur_shared_buffer_handle_t, ur_mem_sub_buffer_t, ur_mem_image_t> mem; + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp index 75bf4a16faf61..a73f6706da7d9 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp @@ -48,7 +48,7 @@ struct ur_queue_handle_t_ : ur::handle_base { ur_result_t queueRetain() { return std::visit( [](auto &q) { - q.RefCount.increment(); + q.getRefCounter().increment(); return UR_RESULT_SUCCESS; }, queue_data); @@ -57,7 +57,7 @@ struct ur_queue_handle_t_ : ur::handle_base { ur_result_t queueRelease() { return std::visit( [queueHandle = this](auto &q) { - if (!q.RefCount.decrementAndTest()) + if (!q.getRefCounter().decrement() == 0) return UR_RESULT_SUCCESS; delete queueHandle; return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp index 2b0be62dda74f..65094132086d7 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp @@ -113,7 +113,7 @@ ur_queue_immediate_in_order_t::queueGetInfo(ur_queue_info_t propName, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hDevice); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{RefCount.load()}); + return ReturnValue(uint32_t{RefCounter.getCount()}); case UR_QUEUE_INFO_FLAGS: return ReturnValue(flags); case UR_QUEUE_INFO_SIZE: @@ -173,7 +173,7 @@ ur_result_t ur_queue_immediate_in_order_t::queueFinish() { void ur_queue_immediate_in_order_t::recordSubmittedKernel( ur_kernel_handle_t hKernel) { submittedKernels.push_back(hKernel); - hKernel->RefCount.increment(); + hKernel->getRefCounter().increment(); } ur_result_t ur_queue_immediate_in_order_t::queueFlush() { @@ -852,7 +852,7 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMFreeExp( if (internalEvent == nullptr) { // When the output event is used instead of an internal event, we need to // increment the refcount. - (*phEvent)->RefCount.increment(); + (*phEvent)->getRefCounter().increment(); } if (numWaitEvents > 0) { diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp index 2e3ae8c59caa1..413ebbef52e18 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp @@ -11,6 +11,7 @@ #include "../common.hpp" #include "../device.hpp" +#include "common/ur_ref_counter.hpp" #include "context.hpp" #include "event.hpp" @@ -284,6 +285,11 @@ struct ur_queue_immediate_in_order_t : ur_object, public ur_queue_t_ { const ur_exp_enqueue_native_command_properties_t *, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) override; + + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + +private: + UR_ReferenceCounter RefCounter; }; } // namespace v2 diff --git a/unified-runtime/source/adapters/level_zero/v2/usm.cpp b/unified-runtime/source/adapters/level_zero/v2/usm.cpp index ec246a94cc06b..b5e44ea042260 100644 --- a/unified-runtime/source/adapters/level_zero/v2/usm.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/usm.cpp @@ -313,7 +313,7 @@ ur_result_t urUSMPoolCreate( ur_result_t /// [in] pointer to USM memory pool urUSMPoolRetain(ur_usm_pool_handle_t hPool) try { - hPool->RefCount.increment(); + hPool->getRefCounter().increment(); return UR_RESULT_SUCCESS; } catch (umf_result_t e) { return umf::umf2urResult(e); @@ -324,7 +324,7 @@ urUSMPoolRetain(ur_usm_pool_handle_t hPool) try { ur_result_t /// [in] pointer to USM memory pool urUSMPoolRelease(ur_usm_pool_handle_t hPool) try { - if (hPool->RefCount.decrementAndTest()) { + if (hPool->getRefCounter().decrement() == 0) { delete hPool; } return UR_RESULT_SUCCESS; @@ -349,7 +349,7 @@ ur_result_t urUSMPoolGetInfo( switch (propName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(hPool->RefCount.load()); + return ReturnValue(hPool->getRefCounter().getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(hPool->getContextHandle()); diff --git a/unified-runtime/source/adapters/level_zero/v2/usm.hpp b/unified-runtime/source/adapters/level_zero/v2/usm.hpp index 515712ff22b6f..964494f4c1a91 100644 --- a/unified-runtime/source/adapters/level_zero/v2/usm.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/usm.hpp @@ -14,6 +14,7 @@ #include "../enqueued_pool.hpp" #include "common.hpp" +#include "common/ur_ref_counter.hpp" #include "event.hpp" #include "ur_pool_manager.hpp" @@ -47,9 +48,13 @@ struct ur_usm_pool_handle_t_ : ur_object { void cleanupPools(); void cleanupPoolsForQueue(void *hQueue); + UR_ReferenceCounter &getRefCounter() noexcept { return RefCounter; } + private: ur_context_handle_t hContext; usm::pool_manager poolManager; UsmPool *getPool(const usm::pool_descriptor &desc); + + UR_ReferenceCounter RefCounter; }; diff --git a/unified-runtime/source/loader/layers/sanitizer/asan/asan_interceptor.hpp b/unified-runtime/source/loader/layers/sanitizer/asan/asan_interceptor.hpp index b26e5076a2923..56ff5bd8be775 100644 --- a/unified-runtime/source/loader/layers/sanitizer/asan/asan_interceptor.hpp +++ b/unified-runtime/source/loader/layers/sanitizer/asan/asan_interceptor.hpp @@ -18,10 +18,10 @@ #include "asan_libdevice.hpp" #include "asan_shadow.hpp" #include "asan_statistics.hpp" +#include "common/ur_ref_counter.hpp" #include "sanitizer_common/sanitizer_common.hpp" #include "sanitizer_common/sanitizer_options.hpp" #include "ur_sanitizer_layer.hpp" -#include "common/ur_ref_counter.hpp" #include #include diff --git a/unified-runtime/source/loader/layers/sanitizer/msan/msan_interceptor.hpp b/unified-runtime/source/loader/layers/sanitizer/msan/msan_interceptor.hpp index 35ef9a137b352..9883bfc4d69ce 100644 --- a/unified-runtime/source/loader/layers/sanitizer/msan/msan_interceptor.hpp +++ b/unified-runtime/source/loader/layers/sanitizer/msan/msan_interceptor.hpp @@ -13,6 +13,7 @@ #pragma once +#include "common/ur_ref_counter.hpp" #include "msan_allocator.hpp" #include "msan_buffer.hpp" #include "msan_libdevice.hpp" @@ -20,7 +21,6 @@ #include "sanitizer_common/sanitizer_common.hpp" #include "sanitizer_common/sanitizer_options.hpp" #include "ur_sanitizer_layer.hpp" -#include "common/ur_ref_counter.hpp" #include #include