diff --git a/offload/API/APIDefs.td b/offload/API/APIDefs.td new file mode 100644 index 0000000000000..d22da0282f8c3 --- /dev/null +++ b/offload/API/APIDefs.td @@ -0,0 +1,183 @@ +//===-- APIDefs.td - Base definitions for Offload tablegen -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the class definitions used to implement the Offload API, +// as well as helper functions used to help populate relevant records. +// See offload/API/README.md for more detailed documentation. +// +//===----------------------------------------------------------------------===// + + +// Parameter flags +defvar PARAM_IN = 0x1; +defvar PARAM_OUT = 0x2; +defvar PARAM_OPTIONAL = 0x4; + +// Prefix for API naming. This could be hard-coded in the future when a value +// is agreed upon. +defvar PREFIX = "OL"; +defvar prefix = !tolower(PREFIX); + +// Does the type end with '_handle_t'? +class IsHandleType { + // size("_handle_t") == 9 + bit ret = !if(!lt(!size(Type), 9), 0, + !ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1)); +} + +// Does the type end with '*'? +class IsPointerType { + bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1); +} + +class Param Flags = 0> { + string type = Type; + string name = Name; + string desc = Desc; + bits<3> flags = Flags; + bit IsHandle = IsHandleType.ret; + bit IsPointer = IsPointerType.ret; +} + +class Return Conditions = []> { + string value = Value; + list conditions = Conditions; +} + +class ShouldCheckHandle { + bit ret = !and(P.IsHandle, !eq(!and(PARAM_OPTIONAL, P.flags), 0)); +} + +class ShouldCheckPointer { + bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0)); +} + +// For a list of returns that contains a specific return code, find and append +// new conditions to that return +class AppendConditionsToReturn Returns, string ReturnValue, + list Conditions> { + list ret = + !foreach(Ret, Returns, + !if(!eq(Ret.value, ReturnValue), + Return, Ret)); +} + +// Add null handle checks to a function's return value descriptions +class AddHandleChecksToReturns Params, list Returns> { + list handle_params = + !foreach(P, Params, !if(ShouldCheckHandle

.ret, P.name, "")); + list handle_params_filt = + !filter(param, handle_params, !ne(param, "")); + list handle_param_conds = + !foreach(handle, handle_params_filt, "`NULL == "#handle#"`"); + + // Does the list of returns already contain ERROR_INVALID_NULL_HANDLE? + bit returns_has_inv_handle = !foldl( + 0, Returns, HasErr, Ret, + !or(HasErr, !eq(Ret.value, PREFIX#"_RESULT_ERROR_INVALID_NULL_HANDLE"))); + + list returns_out = !if(returns_has_inv_handle, + AppendConditionsToReturn.ret, + !listconcat(Returns, [Return]) + ); +} + +// Add null pointer checks to a function's return value descriptions +class AddPointerChecksToReturns Params, list Returns> { + list ptr_params = + !foreach(P, Params, !if(ShouldCheckPointer

.ret, P.name, "")); + list ptr_params_filt = !filter(param, ptr_params, !ne(param, "")); + list ptr_param_conds = + !foreach(ptr, ptr_params_filt, "`NULL == "#ptr#"`"); + + // Does the list of returns already contain ERROR_INVALID_NULL_POINTER? + bit returns_has_inv_ptr = !foldl( + 0, Returns, HasErr, Ret, + !or(HasErr, !eq(Ret.value, PREFIX#"_RESULT_ERROR_INVALID_NULL_POINTER"))); + list returns_out = !if(returns_has_inv_ptr, + AppendConditionsToReturn.ret, + !listconcat(Returns, [Return]) + ); +} + +defvar DefaultReturns = [Return, + Return, + Return, + Return]; + +class APIObject { + string name; + string desc; +} + +class Function : APIObject { + string api_class = Class; + list params; + list returns; + list details = []; + list analogues = []; + + list returns_with_def = !listconcat(DefaultReturns, returns); + list all_returns = AddPointerChecksToReturns.returns_out>.returns_out; +} + +class Etor { + string name = Name; + string desc = Desc; +} + +class Enum : APIObject { + // This refers to whether the enumerator descriptions specify a return + // type for functions where this enum may be used as an input type. + // The format is "[$x_some_return_t] Description text" + // (TODO: This is lifted from UR, is it relevant?) + bit is_typed = 0; + + list etors = []; +} + +class StructMember { + string type = Type; + string name = Name; + string desc = Desc; +} + +defvar DefaultPropStructMembers = + [StructMember, + StructMember<"void*", "pNext", "pointer to extension-specific structure">]; + +class StructHasInheritedMembers { + bit ret = !or(!eq(BaseClass, prefix#"_base_properties_t"), + !eq(BaseClass, prefix#"_base_desc_t")); +} + +class Struct : APIObject { + string base_class = ""; + list members; + list all_members = + !if(StructHasInheritedMembers.ret, + DefaultPropStructMembers, [])#members; +} + +class Typedef : APIObject { string value; } + +class FptrTypedef : APIObject { + list params; + list returns; +} + +class Macro : APIObject { + string value; + + string condition; + string alt_value; +} + +class Handle : APIObject; diff --git a/offload/API/Example.td b/offload/API/Example.td new file mode 100644 index 0000000000000..b78ce63506605 --- /dev/null +++ b/offload/API/Example.td @@ -0,0 +1,490 @@ +//===-- Example.td - Example definitions for Offload -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file serves as an example for the Offload tablegen framework. +// It is NOT an actual representation of the API. It is based off a random +// selection of features from Unified Runtime. +// +//===----------------------------------------------------------------------===// + + +def : Macro { + let name = "OL_MAKE_VERSION( _major, _minor )"; + let desc = "Generates generic API versions"; + let value = "(( _major << 16 )|( _minor & 0x0000ffff))"; +} + +def : Macro { + let name = "OL_MAJOR_VERSION( _ver )"; + let desc = "Extracts API major version"; + let value = "( _ver >> 16 )"; +} + +def : Macro { + let name = "OL_MINOR_VERSION( _ver )"; + let desc = "Extracts API minor version"; + let value = "( _ver & 0x0000ffff )"; +} + +def : Macro { + let name = "OL_APICALL"; + let desc = "Calling convention for all API functions"; + let condition = "defined(_WIN32)"; + let value = "__cdecl"; + let alt_value = ""; +} + +def : Macro { + let name = "OL_APIEXPORT"; + let desc = "Microsoft-specific dllexport storage-class attribute"; + let condition = "defined(_WIN32)"; + let value = "__declspec(dllexport)"; + let alt_value = ""; +} + +def : Macro { + let name = "OL_DLLEXPORT"; + let desc = "Microsoft-specific dllexport storage-class attribute"; + let condition = "defined(_WIN32)"; + let value = "__declspec(dllexport)"; +} + +def : Macro { + let name = "OL_DLLEXPORT"; + let desc = "GCC-specific dllexport storage-class attribute"; + let condition = "__GNUC__ >= 4"; + let value = "__attribute__ ((visibility (\"default\")))"; + let alt_value = ""; +} + +def : Typedef { + let name = "ol_bool_t"; + let value = "uint8_t"; + let desc = "compiler-independent type"; +} + +def : Handle { + let name = "ol_loader_config_handle_t"; + let desc = "Handle of a loader config object"; +} + +def : Handle { + let name = "ol_adapter_handle_t"; + let desc = "Handle of an adapter instance"; +} + +def : Handle { + let name = "ol_platform_handle_t"; + let desc = "Handle of a platform instance"; +} + +def : Handle { + let name = "ol_device_handle_t"; + let desc = "Handle of platform's device object"; +} + +def : Handle { + let name = "ol_context_handle_t"; + let desc = "Handle of context object"; +} + +def : Handle { + let name = "ol_event_handle_t"; + let desc = "Handle of event object"; +} + +def : Handle { + let name = "ol_program_handle_t"; + let desc = "Handle of Program object"; +} + +def : Handle { + let name = "ol_kernel_handle_t"; + let desc = "Handle of program's Kernel object"; +} + +def : Handle { + let name = "ol_queue_handle_t"; + let desc = "Handle of a queue object"; +} + +def : Handle { + let name = "ol_native_handle_t"; + let desc = "Handle of a native object"; +} + +def : Handle { + let name = "ol_sampler_handle_t"; + let desc = "Handle of a Sampler object"; +} + +def : Handle { + let name = "ol_mem_handle_t"; + let desc = "Handle of memory object which can either be buffer or image"; +} + +def : Handle { + let name = "ol_physical_mem_handle_t"; + let desc = "Handle of physical memory object"; +} + +def : Macro { + let name = "OL_BIT( _i )"; + let desc = "Generic macro for enumerator bit masks"; + let value = "( 1 << _i )"; +} + +def : Enum { + let name = "ol_result_t"; + let desc = "Defines Return/Error codes"; + let etors =[ + Etor<"SUCCESS", "Success">, + Etor<"ERROR_INVALID_OPERATION", "Invalid operation">, + Etor<"ERROR_INVALID_QUEUE_PROPERTIES", "Invalid queue properties">, + Etor<"ERROR_INVALID_QUEUE", "Invalid queue">, + Etor<"ERROR_INVALID_VALUE", "Invalid Value">, + Etor<"ERROR_INVALID_CONTEXT", "Invalid context">, + Etor<"ERROR_INVALID_PLATFORM", "Invalid platform">, + Etor<"ERROR_INVALID_BINARY", "Invalid binary">, + Etor<"ERROR_INVALID_PROGRAM", "Invalid program">, + Etor<"ERROR_INVALID_SAMPLER", "Invalid sampler">, + Etor<"ERROR_INVALID_BUFFER_SIZE", "Invalid buffer size">, + Etor<"ERROR_INVALID_MEM_OBJECT", "Invalid memory object">, + Etor<"ERROR_INVALID_EVENT", "Invalid event">, + Etor<"ERROR_INVALID_EVENT_WAIT_LIST", "Returned when the event wait list or the events in the wait list are invalid.">, + Etor<"ERROR_MISALIGNED_SUB_BUFFER_OFFSET", "Misaligned sub buffer offset">, + Etor<"ERROR_INVALID_WORK_GROUP_SIZE", "Invalid work group size">, + Etor<"ERROR_COMPILER_NOT_AVAILABLE", "Compiler not available">, + Etor<"ERROR_PROFILING_INFO_NOT_AVAILABLE", "Profiling info not available">, + Etor<"ERROR_DEVICE_NOT_FOUND", "Device not found">, + Etor<"ERROR_INVALID_DEVICE", "Invalid device">, + Etor<"ERROR_DEVICE_LOST", "Device hung, reset, was removed, or adapter update occurred">, + Etor<"ERROR_DEVICE_REQUIRES_RESET", "Device requires a reset">, + Etor<"ERROR_DEVICE_IN_LOW_POWER_STATE", "Device currently in low power state">, + Etor<"ERROR_DEVICE_PARTITION_FAILED", "Device partitioning failed">, + Etor<"ERROR_INVALID_DEVICE_PARTITION_COUNT", "Invalid counts provided with OL_DEVICE_PARTITION_BY_COUNTS">, + Etor<"ERROR_INVALID_WORK_ITEM_SIZE", "Invalid work item size">, + Etor<"ERROR_INVALID_WORK_DIMENSION", "Invalid work dimension">, + Etor<"ERROR_INVALID_KERNEL_ARGS", "Invalid kernel args">, + Etor<"ERROR_INVALID_KERNEL", "Invalid kernel">, + Etor<"ERROR_INVALID_KERNEL_NAME", "[Validation] kernel name is not found in the program">, + Etor<"ERROR_INVALID_KERNEL_ARGUMENT_INDEX", "[Validation] kernel argument index is not valid for kernel">, + Etor<"ERROR_INVALID_KERNEL_ARGUMENT_SIZE", "[Validation] kernel argument size does not match kernel">, + Etor<"ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE", "[Validation] value of kernel attribute is not valid for the kernel or device">, + Etor<"ERROR_INVALID_IMAGE_SIZE", "Invalid image size">, + Etor<"ERROR_INVALID_IMAGE_FORMAT_DESCRIPTOR", "Invalid image format descriptor">, + Etor<"ERROR_IMAGE_FORMAT_NOT_SUPPORTED", "Image format not supported">, + Etor<"ERROR_MEM_OBJECT_ALLOCATION_FAILURE", "Memory object allocation failure">, + Etor<"ERROR_INVALID_PROGRAM_EXECUTABLE", "Program object parameter is invalid.">, + Etor<"ERROR_UNINITIALIZED", "[Validation] adapter is not initialized or specific entry-point is not implemented">, + Etor<"ERROR_OUT_OF_HOST_MEMORY", "Insufficient host memory to satisfy call">, + Etor<"ERROR_OUT_OF_DEVICE_MEMORY", "Insufficient device memory to satisfy call">, + Etor<"ERROR_OUT_OF_RESOURCES", "Out of resources">, + Etor<"ERROR_PROGRAM_BUILD_FAILURE", "Error occurred when building program, see build log for details">, + Etor<"ERROR_PROGRAM_LINK_FAILURE", "Error occurred when linking programs, see build log for details">, + Etor<"ERROR_UNSUPPORTED_VERSION", "[Validation] generic error code for unsupported versions">, + Etor<"ERROR_UNSUPPORTED_FEATURE", "[Validation] generic error code for unsupported features">, + Etor<"ERROR_INVALID_ARGUMENT", "[Validation] generic error code for invalid arguments">, + Etor<"ERROR_INVALID_NULL_HANDLE", "[Validation] handle argument is not valid">, + Etor<"ERROR_HANDLE_OBJECT_IN_USE", "[Validation] object pointed to by handle still in-use by device">, + Etor<"ERROR_INVALID_NULL_POINTER", "[Validation] pointer argument may not be nullptr">, + Etor<"ERROR_INVALID_SIZE", "[Validation] invalid size or dimensions (e.g., must not be zero, or is out of bounds)">, + Etor<"ERROR_UNSUPPORTED_SIZE", "[Validation] size argument is not supported by the device (e.g., too large)">, + Etor<"ERROR_UNSUPPORTED_ALIGNMENT", "[Validation] alignment argument is not supported by the device (e.g., too small)">, + Etor<"ERROR_INVALID_SYNCHRONIZATION_OBJECT", "[Validation] synchronization object in invalid state">, + Etor<"ERROR_INVALID_ENUMERATION", "[Validation] enumerator argument is not valid">, + Etor<"ERROR_UNSUPPORTED_ENUMERATION", "[Validation] enumerator argument is not supported by the device">, + Etor<"ERROR_UNSUPPORTED_IMAGE_FORMAT", "[Validation] image format is not supported by the device">, + Etor<"ERROR_INVALID_NATIVE_BINARY", "[Validation] native binary is not supported by the device">, + Etor<"ERROR_INVALID_GLOBAL_NAME", "[Validation] global variable is not found in the program">, + Etor<"ERROR_INVALID_FUNCTION_NAME", "[Validation] function name is not found in the program">, + Etor<"ERROR_INVALID_GROUP_SIZE_DIMENSION", "[Validation] group size dimension is not valid for the kernel or device">, + Etor<"ERROR_INVALID_GLOBAL_WIDTH_DIMENSION", "[Validation] global width dimension is not valid for the kernel or device">, + Etor<"ERROR_PROGRAM_UNLINKED", "[Validation] compiled program or program with imports needs to be linked before kernels can be created from it.">, + Etor<"ERROR_OVERLAPPING_REGIONS", "[Validation] copy operations do not support overlapping regions of memory">, + Etor<"ERROR_INVALID_HOST_PTR", "Invalid host pointer">, + Etor<"ERROR_INVALID_USM_SIZE", "Invalid USM size">, + Etor<"ERROR_OBJECT_ALLOCATION_FAILURE", "Objection allocation failure">, + Etor<"ERROR_ADAPTER_SPECIFIC", "An adapter specific warning/error has been reported and can be retrieved via the urPlatformGetLastError entry point.">, + Etor<"ERROR_LAYER_NOT_PRESENT", "A requested layer was not found by the loader.">, + Etor<"ERROR_IN_EVENT_LIST_EXEC_STATUS", "An event in the provided wait list has OL_EVENT_STATUS_ERROR.">, + Etor<"ERROR_UNKNOWN", "Unknown or internal error"> + ]; +} + +def : Struct { + let name = "ol_base_properties_t"; + let desc = "Base for all properties types"; + let members = [ + StructMember<"ol_structure_type_t", "stype", "[in] type of this structure">, + StructMember<"void*", "pNext", "[in,out][optional] pointer to extension-specific structure"> + ]; +} + +def : Struct { + let name = "ol_base_desc_t"; + let desc = "Base for all descriptor types"; + let members = [ + StructMember<"ol_structure_type_t", "stype", "[in] type of this structure">, + StructMember<"const void*", "pNext", "[in][optional] pointer to extension-specific structure"> + ]; +} + +def : Struct { + let name = "ol_rect_offset_t"; + let desc = "3D offset argument passed to buffer rect operations"; + let members = [ + StructMember<"uint64_t", "x", "[in] x offset (bytes)">, + StructMember<"uint64_t", "y", "[in] y offset (scalar)">, + StructMember<"uint64_t", "z", "[in] z offset (scalar)"> + ]; +} + +def : Struct { + let name = "ol_rect_region_t"; + let desc = "3D region argument passed to buffer rect operations"; + let members = [ + StructMember<"uint64_t", "width", "[in] width (bytes)">, + StructMember<"uint64_t", "height", "[in] height (scalar)">, + StructMember<"uint64_t", "depth", "[in] scalar (scalar)"> + ]; +} + +def : Enum { + let name = "ol_queue_info_t"; + let desc = "Query queue info"; + let is_typed = 1; + let etors =[ + Etor<"CONTEXT", "[ol_context_handle_t] context associated with this queue.">, + Etor<"DEVICE", "[ol_device_handle_t] device associated with this queue.">, + Etor<"DEVICE_DEFAULT", "[ol_queue_handle_t] the current default queue of the underlying device.">, + Etor<"FLAGS", "[ol_queue_flags_t] the properties associated with ol_queue_properties_t::flags.">, + Etor<"REFERENCE_COUNT", [{[uint32_t] Reference count of the queue object. +The reference count returned should be considered immediately stale. +It is unsuitable for general use in applications. This feature is provided for identifying memory leaks.}]>, + Etor<"SIZE", "[uint32_t] The size of the queue">, + Etor<"EMPTY", "[ol_bool_t] return true if the queue was empty at the time of the query"> + ]; +} + +def : Enum { + let name = "ol_queue_flags_t"; + let desc = "Queue property flags"; + let etors =[ + Etor<"OUT_OF_ORDER_EXEC_MODE_ENABLE", "Enable/disable out of order execution">, + Etor<"PROFILING_ENABLE", "Enable/disable profiling">, + Etor<"ON_DEVICE", "Is a device queue">, + Etor<"ON_DEVICE_DEFAULT", "Is the default queue for a device">, + Etor<"DISCARD_EVENTS", "Events will be discarded">, + Etor<"PRIORITY_LOW", "Low priority queue">, + Etor<"PRIORITY_HIGH", "High priority queue">, + Etor<"SUBMISSION_BATCHED", "Hint: enqueue and submit in a batch later. No change in queue semantics. Implementation chooses submission mode.">, + Etor<"SUBMISSION_IMMEDIATE", "Hint: enqueue and submit immediately. No change in queue semantics. Implementation chooses submission mode.">, + Etor<"USE_DEFAULT_STREAM", "Use the default stream. Only meaningful for CUDA. Other platforms may ignore this flag.">, + Etor<"SYNC_WITH_DEFAULT_STREAM", "Synchronize with the default stream. Only meaningful for CUDA. Other platforms may ignore this flag."> + ]; +} + +def : Function<"olQueue"> { + let name = "GetInfo"; + let desc = "Query information about a command queue"; + let params = [ + Param<"ol_queue_handle_t", "hQueue", "handle of the queue object", PARAM_IN>, + Param<"ol_queue_info_t", "propName", "name of the queue property to query", PARAM_IN>, + Param<"size_t", "propSize", "size in bytes of the queue property value provided", PARAM_IN>, + Param<"void*", "pPropValue", "[typename(propName, propSize)] value of the queue property", !or(PARAM_OUT, PARAM_OPTIONAL)>, + Param<"size_t*", "pPropSizeRet", "size in bytes returned in queue property value", !or(PARAM_OUT, PARAM_OPTIONAL)> + ]; + let returns = [ + Return<"OL_RESULT_ERROR_UNSUPPORTED_ENUMERATION", [ + "If `propName` is not supported by the adapter." + ]>, + Return<"OL_RESULT_ERROR_INVALID_SIZE", [ + "`propSize == 0 && pPropValue != NULL`", + "If `propSize` is less than the real number of bytes needed to return the info." + ]>, + Return<"OL_RESULT_ERROR_INVALID_NULL_POINTER", [ + "`propSize != 0 && pPropValue == NULL`", + "`pPropValue == NULL && pPropSizeRet == NULL`" + ]>, + Return<"OL_RESULT_ERROR_INVALID_QUEUE">, + Return<"OL_RESULT_ERROR_OUT_OF_HOST_MEMORY">, + Return<"OL_RESULT_ERROR_OUT_OF_RESOURCES"> + ]; +} + +def : Struct { + let name = "ol_queue_properties_t"; + let desc = "Queue creation properties"; + let base_class = "ol_base_properties_t"; + let members = [ + StructMember<"ol_queue_flags_t", "flags", "[in] Bitfield of queue creation flags"> + ]; +} + +def : Struct { + let name = "ol_queue_index_properties_t"; + let desc = "Queue index creation properties"; + let base_class = "ol_base_properties_t"; + let members = [ + StructMember<"uint32_t", "computeIndex", "[in] Specifies the compute index as described in the sycl_ext_intel_queue_index extension."> + ]; +} + +def : Function<"olQueue"> { + let name = "Create"; + let desc = "Create a command queue for a device in a context"; + let details = [ + "See also ol_queue_index_properties_t." + ]; + let params = [ + Param<"ol_context_handle_t", "hContext", "handle of the context object", PARAM_IN>, + Param<"ol_device_handle_t", "hDevice", "handle of the device object", PARAM_IN>, + Param<"const ol_queue_properties_t*", "pProperties", "pointer to queue creation properties.", !or(PARAM_IN, PARAM_OPTIONAL)>, + Param<"ol_queue_handle_t*", "phQueue", "pointer to handle of queue object created", PARAM_OUT> + ]; + let returns = [ + Return<"OL_RESULT_ERROR_INVALID_CONTEXT">, + Return<"OL_RESULT_ERROR_INVALID_DEVICE">, + Return<"OL_RESULT_ERROR_INVALID_QUEUE_PROPERTIES", [ + "`pProperties != NULL && pProperties->flags & OL_QUEUE_FLAG_PRIORITY_HIGH && pProperties->flags & OL_QUEUE_FLAG_PRIORITY_LOW`", + "`pProperties != NULL && pProperties->flags & OL_QUEUE_FLAG_SUBMISSION_BATCHED && pProperties->flags & OL_QUEUE_FLAG_SUBMISSION_IMMEDIATE`" + ]>, + Return<"OL_RESULT_ERROR_OUT_OF_HOST_MEMORY">, + Return<"OL_RESULT_ERROR_OUT_OF_RESOURCES"> + ]; +} + +def : Function<"olQueue"> { + let name = "Retain"; + let desc = "Get a reference to the command queue handle. Increment the command queue's reference count"; + let details = [ + "Useful in library function to retain access to the command queue after the caller released the queue." + ]; + let params = [ + Param<"ol_queue_handle_t", "hQueue", "handle of the queue object to get access", PARAM_IN> + ]; + let returns = [ + Return<"OL_RESULT_ERROR_INVALID_QUEUE">, + Return<"OL_RESULT_ERROR_OUT_OF_HOST_MEMORY">, + Return<"OL_RESULT_ERROR_OUT_OF_RESOURCES"> + ]; +} + +def : Function<"olQueue"> { + let name = "Release"; + let desc = "Decrement the command queue's reference count and delete the command queue if the reference count becomes zero."; + let details = [ + "After the command queue reference count becomes zero and all queued commands in the queue have finished, the queue is deleted.", + "It also performs an implicit flush to issue all previously queued commands in the queue." + ]; + let params = [ + Param<"ol_queue_handle_t", "hQueue", "handle of the queue object to release", PARAM_IN> + ]; + let returns = [ + Return<"OL_RESULT_ERROR_INVALID_QUEUE">, + Return<"OL_RESULT_ERROR_OUT_OF_HOST_MEMORY">, + Return<"OL_RESULT_ERROR_OUT_OF_RESOURCES"> + ]; +} + +def : Struct { + let name = "ol_queue_native_desc_t"; + let desc = "Descriptor for olQueueGetNativeHandle and olQueueCreateWithNativeHandle."; + let base_class = "ol_base_desc_t"; + let members = [ + StructMember<"void*", "pNativeData", "[in][optional] Adapter-specific metadata needed to create the handle."> + ]; +} + +def : Function<"olQueue"> { + let name = "GetNativeHandle"; + let desc = "Return queue native queue handle."; + let details = [ + "Retrieved native handle can be used for direct interaction with the native platform driver.", + "Use interoperability queue extensions to convert native handle to native type.", + "The application may call this function from simultaneous threads for the same context.", + "The implementation of this function should be thread-safe." + ]; + let params = [ + Param<"ol_queue_handle_t", "hQueue", "handle of the queue.", PARAM_IN>, + Param<"ol_queue_native_desc_t*", "pDesc", "pointer to native descriptor", !or(PARAM_IN, PARAM_OPTIONAL)>, + Param<"ol_native_handle_t*", "phNativeQueue", "a pointer to the native handle of the queue.", PARAM_OUT> + ]; + let returns = [ + Return<"OL_RESULT_ERROR_UNSUPPORTED_FEATURE", [ + "If the adapter has no underlying equivalent handle." + ]> + ]; +} + +def : Struct { + let name = "ol_queue_native_properties_t"; + let desc = "Properties for for olQueueCreateWithNativeHandle."; + let base_class = "ol_base_properties_t"; + let members = [ + StructMember<"bool", "isNativeHandleOwned", [{[in] Indicates UR owns the native handle or if it came from an interoperability +operation in the application that asked to not transfer the ownership to +the unified-runtime.}]> + ]; +} + +def : Function<"olQueue"> { + let name = "CreateWithNativeHandle"; + let desc = "Create runtime queue object from native queue handle."; + let details = [ + "Creates runtime queue handle from native driver queue handle.", + "The application may call this function from simultaneous threads for the same context.", + "The implementation of this function should be thread-safe." + ]; + let params = [ + Param<"ol_native_handle_t", "hNativeQueue", "[nocheck] the native handle of the queue.", PARAM_IN>, + Param<"ol_context_handle_t", "hContext", "handle of the context object", PARAM_IN>, + Param<"ol_device_handle_t", "hDevice", "handle of the device object", PARAM_IN>, + Param<"const ol_queue_native_properties_t*", "pProperties", "pointer to native queue properties struct", !or(PARAM_IN, PARAM_OPTIONAL)>, + Param<"ol_queue_handle_t*", "phQueue", "pointer to the handle of the queue object created.", PARAM_OUT> + ]; + let returns = [ + Return<"OL_RESULT_ERROR_UNSUPPORTED_FEATURE", [ + "If the adapter has no underlying equivalent handle." + ]> + ]; +} + +def : Function<"olQueue"> { + let name = "Finish"; + let desc = "Blocks until all previously issued commands to the command queue are finished."; + let details = [ + "Blocks until all previously issued commands to the command queue are issued and completed.", + "olQueueFinish does not return until all enqueued commands have been processed and finished.", + "olQueueFinish acts as a synchronization point." + ]; + let params = [ + Param<"ol_queue_handle_t", "hQueue", "handle of the queue to be finished.", PARAM_IN> + ]; + let returns = [ + Return<"OL_RESULT_ERROR_INVALID_QUEUE">, + Return<"OL_RESULT_ERROR_OUT_OF_HOST_MEMORY"> + ]; +} + +def : Function<"olQueue"> { + let name = "Flush"; + let desc = "Issues all previously enqueued commands in a command queue to the device."; + let details = [ + "Guarantees that all enqueued commands will be issued to the appropriate device.", + "There is no guarantee that they will be completed after olQueueFlush returns." + ]; + let params = [ + Param<"ol_queue_handle_t", "hQueue", "handle of the queue to be flushed.", PARAM_IN> + ]; + let returns = [ + Return<"OL_RESULT_ERROR_INVALID_QUEUE">, + Return<"OL_RESULT_ERROR_OUT_OF_HOST_MEMORY"> + ]; +} diff --git a/offload/API/OffloadAPI.td b/offload/API/OffloadAPI.td new file mode 100644 index 0000000000000..d76dbb966c1f7 --- /dev/null +++ b/offload/API/OffloadAPI.td @@ -0,0 +1,13 @@ +//===-- OffloadAPI.td - Root tablegen file for Offload -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// Always include this file first +include "APIDefs.td" + +// Add API definition files here +include "Example.td" diff --git a/offload/API/README.md b/offload/API/README.md new file mode 100644 index 0000000000000..4bc69b4740e78 --- /dev/null +++ b/offload/API/README.md @@ -0,0 +1,103 @@ +# Offload API definitions + +**Note**: This is a work-in-progress. The intention is for this to serve as a +starting off point for design discussion. It is loosely based on equivalent +tooling in Unified Runtime. + +The Tablegen files in this directory are used to define the Offload API. They +are used with the `offload-tblgen` tool to generate API headers and (stub) +validation code. There are plans to add support for tracing, printing (e.g. +adding `operator<<(std::ostream)` defs to API structs, enums, etc), and test +generation. + +The root file is `OffloadAPI.td` - additional `.td` files can be included in +this file to add them to the API. + +## API Objects +The API consists of a number of objects, which always have a *name* field and +*description* field, and are one of the following types: + +### Function +Represents an API entry point function. Has a list of returns and parameters. +Also has fields for details (representing a bullet-point list of +information about the function that would otherwise be too detailed for the +description), and analogues (equivalent functions in other APIs). + +#### Parameter +Represents a parameter to a function, has *type*, *name*, and *desc* fields. +Also has a *flags* field containing flags representing whether the parameter is +in, out, or optional. + +The *type* field is used to infer if the parameter is a pointer or handle type. +A *handle* type is a pointer to an opaque struct, used to abstract over +plugin-specific implementation details. + +#### Return +A return represents a possible return code from the function, and optionally a +list of conditions in which this value may be returned. The conditions list is +not expected to be exhaustive. A condition is considered free-form text, but +if it is wrapped in \`backticks\` then it is treated as literal code +representing an error condition (e.g. `someParam < 1`). These conditions are +used to automatically create validation checks by the `offload-tblgen` +validation generator. + +Returns are automatically generated for functions with pointer or handle +parameters, so API authors do not need to exhaustively add null checks for +these types of parameters. All functions also get a number of default return +values automatically. + + +### Struct +Represents a struct. Contains a list of members, which each have a *type*, +*name*, and *desc*. + +Also optionally takes a *base_class* field. If this is either of the special +`ol_base_properties_t` or `ol_base_desc_t` structs, then the struct will inherit +members from those structs. The generated struct does **not** use actual C++ +inheritance, but instead explicitly has those members copied in, which preserves +compatibility with C. + +### Enum +Represents a C-style enum. Contains a list of `etor` values. + +All enums automatically get a `_FORCE_UINT32 = 0x7fffffff` value, +which forces the underlying type to be uint32. + +### Handle +Represents a pointer to an opaque struct, as described in the Parameter section. +It does not take any extra fields. + +### Typedef +Represents a typedef, contains only a *value* field. + +### Macro +Represents a C preprocessor `#define`. Contains a *value* field. Optionally +takes a *condition* field, which allows the macro to be conditionally defined, +and an *alt_value* field, which represents the value if the condition is false. + +Macro arguments are presented in the *name* field (e.g. name = `mymacro(arg)`). + +While there may seem little point generating a macro from tablegen, doing this +allows the entire source of the header file to be generated from the tablegen +files, rather than requiring a mix of C source and tablegen. + +## Generation + +### API header +``` +./offload-tblgen -I /offload/API /offload/API/OffloadAPI.td --gen-api +``` +The comments in the generated header are in Doxygen format, although +generating documentation from them hasn't been tested yet. + +### Validation functions +``` +./offload-tblgen -I /offload/API /offload/API/OffloadAPI.td --gen-validation +``` +The functions are partially stubbed and are designed to be used in conjunction +with code that can track live handle references, etc. See the equivalent code +in Unified Runtime for an idea of how this might work. + +### Future Tablegen backends +`RecordTypes.hpp` contains wrappers for all of the API object types, which will +allow more backends to be easily added in future. diff --git a/offload/tools/CMakeLists.txt b/offload/tools/CMakeLists.txt index a850647fbd58e..ba3b9dcb2b89d 100644 --- a/offload/tools/CMakeLists.txt +++ b/offload/tools/CMakeLists.txt @@ -26,3 +26,4 @@ endmacro() add_subdirectory(deviceinfo) add_subdirectory(kernelreplay) +add_subdirectory(offload-tblgen) diff --git a/offload/tools/offload-tblgen/APIGen.cpp b/offload/tools/offload-tblgen/APIGen.cpp new file mode 100644 index 0000000000000..02741c633fd0e --- /dev/null +++ b/offload/tools/offload-tblgen/APIGen.cpp @@ -0,0 +1,169 @@ +//===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload header ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen backend that produces the contents of the Offload API +// header. The generated comments are Doxygen compatible. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#include "GenCommon.hpp" +#include "RecordTypes.hpp" + +using namespace llvm; +using namespace offload::tblgen; + +// Produce a possibly multi-line comment from the input string +static std::string MakeComment(StringRef in) { + std::string out = ""; + size_t LineStart = 0; + size_t LineBreak = 0; + while (LineBreak < in.size()) { + LineBreak = in.find_first_of("\n", LineStart); + if (LineBreak - LineStart <= 1) { + break; + } + out += std::string("\t///< ") + + in.substr(LineStart, LineBreak - LineStart).str() + "\n"; + LineStart = LineBreak + 1; + } + + return out; +} + +static void ProcessHandle(const HandleRec &H, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", H.getDesc()); + OS << formatv("typedef struct {0}_ *{0};\n", H.getName()); +} + +static void ProcessTypedef(const TypedefRec &T, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", T.getDesc()); + OS << formatv("typedef {0} {1};\n", T.getValue(), T.getName()); +} + +static void ProcessMacro(const MacroRec &M, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("#ifndef {0}\n", M.getName()); + if (auto Condition = M.getCondition()) { + OS << formatv("#if {0}\n", *Condition); + } + OS << "/// @brief " << M.getDesc() << "\n"; + OS << formatv("#define {0} {1}\n", M.getName(), M.getValue()); + if (auto AltValue = M.getAltValue()) { + OS << "#else\n"; + OS << formatv("#define {0} {1}\n", M.getName(), *AltValue); + } + if (auto Condition = M.getCondition()) { + OS << formatv("#endif // {0}\n", *Condition); + } + OS << formatv("#endif // {0}\n", M.getName()); +} + +static void ProcessFunction(const FunctionRec &F, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", F.getDesc()); + OS << CommentsBreak; + + OS << "/// @details\n"; + for (auto &Detail : F.getDetails()) { + OS << formatv("/// - {0}\n", Detail); + } + OS << CommentsBreak; + + // Emit analogue remarks + auto Analogues = F.getAnalogues(); + if (!Analogues.empty()) { + OS << "/// @remarks\n/// _Analogues_\n"; + for (auto &Analogue : Analogues) { + OS << formatv("/// - **{0}**\n", Analogue); + } + OS << CommentsBreak; + } + + OS << "/// @returns\n"; + auto Returns = F.getReturns(); + for (auto &Ret : Returns) { + OS << formatv("/// - ::{0}\n", Ret.getValue()); + auto RetConditions = Ret.getConditions(); + for (auto &RetCondition : RetConditions) { + OS << formatv("/// + {0}\n", RetCondition); + } + } + + OS << formatv("{0}_APIEXPORT {1}_result_t {0}_APICALL ", PrefixUpper, + PrefixLower); + OS << F.getFullName(); + OS << "(\n"; + auto Params = F.getParams(); + for (auto &Param : Params) { + OS << " " << Param.getType() << " " << Param.getName(); + if (Param != Params.back()) { + OS << ", "; + } else { + OS << " "; + } + OS << MakeParamComment(Param) << "\n"; + } + OS << ");\n\n"; +} + +static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", Enum.getDesc()); + OS << formatv("typedef enum {0} {{\n", Enum.getName()); + + uint32_t EtorVal = 0; + for (const auto &EnumVal : Enum.getValues()) { + auto Desc = MakeComment(EnumVal.getDesc()); + OS << formatv(" {0}_{1} = {2}, {3}", Enum.getEnumValNamePrefix(), + EnumVal.getName(), EtorVal++, Desc); + } + + // Add force uint32 val + OS << formatv( + " /// @cond\n {0}_FORCE_UINT32 = 0x7fffffff\n /// @endcond\n\n", + Enum.getEnumValNamePrefix()); + + OS << formatv("} {0};\n", Enum.getName()); +} + +static void ProcessStruct(const StructRec &Struct, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", Struct.getDesc()); + OS << formatv("typedef struct {0} {{\n", Struct.getName()); + + for (const auto &Member : Struct.getMembers()) { + OS << formatv(" {0} {1}; {2}", Member.getType(), Member.getName(), + MakeComment(Member.getDesc())); + } + + OS << formatv("} {0};\n\n", Struct.getName()); +} + +void EmitOffloadAPI(RecordKeeper &Records, raw_ostream &OS) { + for (auto *R : Records.getAllDerivedDefinitions("APIObject")) { + if (R->isSubClassOf("Macro")) { + ProcessMacro(MacroRec{R}, OS); + } else if (R->isSubClassOf("Typedef")) { + ProcessTypedef(TypedefRec{R}, OS); + } else if (R->isSubClassOf("Handle")) { + ProcessHandle(HandleRec{R}, OS); + } else if (R->isSubClassOf("Function")) { + ProcessFunction(FunctionRec{R}, OS); + } else if (R->isSubClassOf("Enum")) { + ProcessEnum(EnumRec{R}, OS); + } else if (R->isSubClassOf("Struct")) { + ProcessStruct(StructRec{R}, OS); + } + } +} diff --git a/offload/tools/offload-tblgen/CMakeLists.txt b/offload/tools/offload-tblgen/CMakeLists.txt new file mode 100644 index 0000000000000..753857f5dfac7 --- /dev/null +++ b/offload/tools/offload-tblgen/CMakeLists.txt @@ -0,0 +1,20 @@ +##===----------------------------------------------------------------------===## +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +##===----------------------------------------------------------------------===## +include(TableGen) + +libomptarget_say("Building the offload-tblgen tool") + +add_tablegen(offload-tblgen OPENMP + EXPORT OPENMP + APIGen.cpp + GenCommon.hpp + Generators.hpp + offload-tblgen.cpp + RecordTypes.hpp + ValidationGen.cpp + ) diff --git a/offload/tools/offload-tblgen/GenCommon.hpp b/offload/tools/offload-tblgen/GenCommon.hpp new file mode 100644 index 0000000000000..bb45fa8e6f94e --- /dev/null +++ b/offload/tools/offload-tblgen/GenCommon.hpp @@ -0,0 +1,28 @@ +//===- offload-tblgen/GenCommon.cpp - Common defs for Offload generators --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "RecordTypes.hpp" +#include "llvm/Support/FormatVariadic.h" + +constexpr auto CommentsHeader = R"( +/////////////////////////////////////////////////////////////////////////////// +)"; + +constexpr auto CommentsBreak = "///\n"; + +constexpr auto PrefixLower = "ol"; +constexpr auto PrefixUpper = "OL"; + +static std::string +MakeParamComment(const llvm::offload::tblgen::ParamRec &Param) { + return llvm::formatv("///< {0}{1}{2} {3}", (Param.isIn() ? "[in]" : ""), + (Param.isOut() ? "[out]" : ""), + (Param.isOpt() ? "[optional]" : ""), Param.getDesc()); +} diff --git a/offload/tools/offload-tblgen/Generators.hpp b/offload/tools/offload-tblgen/Generators.hpp new file mode 100644 index 0000000000000..c28ef8f0cc031 --- /dev/null +++ b/offload/tools/offload-tblgen/Generators.hpp @@ -0,0 +1,14 @@ +//===- offload-tblgen/Generators.hpp - Offload generator declarations -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "llvm/TableGen/Record.h" + +void EmitOffloadAPI(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); +void EmitOffloadValidation(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); diff --git a/offload/tools/offload-tblgen/RecordTypes.hpp b/offload/tools/offload-tblgen/RecordTypes.hpp new file mode 100644 index 0000000000000..cf1abf739fab2 --- /dev/null +++ b/offload/tools/offload-tblgen/RecordTypes.hpp @@ -0,0 +1,194 @@ +//===- offload-tblgen/RecordTypes.cpp - Offload record type wrappers -----===-// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "llvm/TableGen/Record.h" + +namespace llvm { +namespace offload { +namespace tblgen { + +class HandleRec { +public: + explicit HandleRec(Record *rec) : rec(rec) {} + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + +private: + Record *rec; +}; + +class MacroRec { +public: + explicit MacroRec(Record *rec) : rec(rec) {} + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + + std::optional getCondition() const { + return rec->getValueAsOptionalString("condition"); + } + StringRef getValue() const { return rec->getValueAsString("value"); } + std::optional getAltValue() const { + return rec->getValueAsOptionalString("alt_value"); + } + +private: + Record *rec; +}; + +class TypedefRec { +public: + explicit TypedefRec(Record *rec) : rec(rec) {} + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + StringRef getValue() const { return rec->getValueAsString("value"); } + +private: + Record *rec; +}; + +class EnumValueRec { +public: + explicit EnumValueRec(Record *rec) : rec(rec) {} + std::string getName() const { return rec->getValueAsString("name").upper(); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + +private: + Record *rec; +}; + +class EnumRec { +public: + explicit EnumRec(Record *rec) : rec(rec) { + for (auto *Val : rec->getValueAsListOfDefs("etors")) { + vals.emplace_back(EnumValueRec{Val}); + } + } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + const std::vector &getValues() const { return vals; } + + std::string getEnumValNamePrefix() const { + return StringRef(getName().str().substr(0, getName().str().length() - 2)) + .upper(); + } + +private: + Record *rec; + std::vector vals; +}; + +class StructMemberRec { +public: + explicit StructMemberRec(Record *rec) : rec(rec) {} + StringRef getType() const { return rec->getValueAsString("type"); } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + +private: + Record *rec; +}; + +class StructRec { +public: + explicit StructRec(Record *rec) : rec(rec) { + for (auto *Member : rec->getValueAsListOfDefs("all_members")) { + members.emplace_back(StructMemberRec(Member)); + } + } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + std::optional getBaseClass() const { + return rec->getValueAsOptionalString("base_class"); + } + const std::vector &getMembers() const { return members; } + +private: + Record *rec; + std::vector members; +}; + +class ParamRec { +public: + explicit ParamRec(Record *rec) : rec(rec) { + flags = rec->getValueAsBitsInit("flags"); + } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getType() const { return rec->getValueAsString("type"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + bool isIn() const { return dyn_cast(flags->getBit(0))->getValue(); } + bool isOut() const { return dyn_cast(flags->getBit(1))->getValue(); } + bool isOpt() const { return dyn_cast(flags->getBit(2))->getValue(); } + + Record *getRec() const { return rec; } + + // Needed to check whether we're at the back of a vector of params + bool operator!=(const ParamRec &p) const { return rec != p.getRec(); } + +private: + Record *rec; + BitsInit *flags; +}; + +class ReturnRec { +public: + ReturnRec(Record *rec) : rec(rec) {} + StringRef getValue() const { return rec->getValueAsString("value"); } + std::vector getConditions() const { + return rec->getValueAsListOfStrings("conditions"); + } + +private: + Record *rec; +}; + +class FunctionRec { +public: + FunctionRec(Record *rec) : rec(rec) { + for (auto &Ret : rec->getValueAsListOfDefs("all_returns")) + rets.emplace_back(Ret); + for (auto &Param : rec->getValueAsListOfDefs("params")) + params.emplace_back(Param); + } + + std::string getFullName() const { + return rec->getValueAsString("api_class").str() + + rec->getValueAsString("name").str(); + } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getClass() const { return rec->getValueAsString("api_class"); } + const std::vector &getReturns() const { return rets; } + const std::vector &getParams() const { return params; } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + std::vector getDetails() const { + return rec->getValueAsListOfStrings("details"); + } + std::vector getAnalogues() const { + return rec->getValueAsListOfStrings("analogues"); + } + + bool modifiesRefCount() const { + auto Name = rec->getValueAsString("name"); + auto Class = rec->getValueAsString("api_class"); + return (Name == "Create") || (Name == "Retain") || (Name == "Release") || + (Name == "Get" && Class == "Adapter"); + } + +private: + std::vector rets; + std::vector params; + + Record *rec; +}; + +} // namespace tblgen +} // namespace offload +} // namespace llvm diff --git a/offload/tools/offload-tblgen/ValidationGen.cpp b/offload/tools/offload-tblgen/ValidationGen.cpp new file mode 100644 index 0000000000000..00b9410efc96f --- /dev/null +++ b/offload/tools/offload-tblgen/ValidationGen.cpp @@ -0,0 +1,134 @@ +//===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload validation ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen backend that produces validation functions for the Offload +// API entry point functions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Record.h" + +#include "GenCommon.hpp" +#include "RecordTypes.hpp" + +using namespace llvm; +using namespace offload::tblgen; + +static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief Intercept function for {0}\n", F.getFullName()); + // Emit preamble + OS << formatv("{0}_result_t {1}_APICALL val_{2}(\n", PrefixLower, PrefixUpper, + F.getFullName()); + // Emit arguments + std::string ParamNameList = ""; + for (auto &Param : F.getParams()) { + OS << " " << Param.getType() << " " << Param.getName(); + if (Param != F.getParams().back()) { + OS << ", "; + } else { + OS << " "; + } + OS << MakeParamComment(Param) << "\n"; + ParamNameList += Param.getName().str() + ", "; + } + OS << ") {\n"; + + OS << " if (true /*enableParameterValidation*/) {\n"; + + // Emit validation checks + for (const auto &Return : F.getReturns()) { + for (auto &Condition : Return.getConditions()) { + if (Condition.starts_with("`") && Condition.ends_with("`")) { + auto ConditionString = Condition.substr(1, Condition.size() - 2); + OS << formatv(" if ({0}) {{\n", ConditionString); + OS << formatv(" return {0};\n", Return.getValue()); + OS << " }\n\n"; + } + } + } + OS << " }\n\n"; + + auto LifetimeTodoComment = + R"( // TODO: Implement. `refCountContext` is some global object that tracks known + // live handle objects, and logs related errors. + // In UR this is implemented as an unordered_map of handles to structs + // containing the reference count, amongst other details. In this case, a + // handle is invalid if it does not exist in the map. + +)"; + bool EmittedTodo = false; + + // Emit handle lifetime checks + for (auto &Param : F.getParams()) { + if (Param.getType().ends_with("handle_t")) { + // Only add this comment once per function to keep the code size down + if (!EmittedTodo) { + OS << LifetimeTodoComment; + EmittedTodo = true; + } + OS << formatv(" if (true /* enableLifeTimeValidation && " + "!refCountContext.isReferenceValid({0}) */) {{\n", + Param.getName()); + OS << formatv(" // refCountContext.logInvalidReference({0});\n", + Param.getName()); + OS << " }\n\n"; + } + } + + // Perform actual function call + ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); + OS << formatv(" {0}_result_t result = {1}({2});\n\n", PrefixLower, + F.getFullName(), ParamNameList); + + // Handle reference counting for cases where the function modifies the ref + // count of a handle + // * `Create` - initialize a reference count + // * `Retain` - increment a reference count + // * `Release` - decerement a reference count + if (F.modifiesRefCount()) { + OS << formatv(" if ( /*context.enableLeakChecking &&*/ result == " + "{0}_RESULT_SUCCESS) {\n", + PrefixUpper); + + // The refcount context optionally takes a bool specifying whether the + // handle being tracked is an adapter handle, as they are counted + // differently. + // TODO: This behavior is lifted from UR. Offload will likely be different. + auto AdapterHandleArg = (F.getClass() == "Adapter") ? "true" : "false"; + + if (F.getName() == "Create") { + // We only expect one handle output for these types of functions, but loop + // over all params just in case + for (auto &Param : F.getParams()) { + if (Param.isOut()) { + OS << formatv(" // refCountContext.createRefCount(*{0});\n", + Param.getName()); + } + } + // Retain and release functions only have 1 parameter + } else if (F.getName() == "Retain") { + OS << formatv(" // refCountContext.incrementRefCount({0}, {1});\n", + F.getParams().at(0).getName(), AdapterHandleArg); + } else { + OS << formatv(" // refCountContext.decrementRefCount({0}, {1});\n", + F.getParams().at(0).getName(), AdapterHandleArg); + } + OS << " }\n"; + } + + OS << " return result;\n"; + OS << "}\n"; +} + +void EmitOffloadValidation(RecordKeeper &Records, raw_ostream &OS) { + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + EmitValidationFunc(FunctionRec{R}, OS); + } +} diff --git a/offload/tools/offload-tblgen/offload-tblgen.cpp b/offload/tools/offload-tblgen/offload-tblgen.cpp new file mode 100644 index 0000000000000..e3f2590760df7 --- /dev/null +++ b/offload/tools/offload-tblgen/offload-tblgen.cpp @@ -0,0 +1,74 @@ +//===- offload-tblgen/offload-tblgen.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen tool that produces source files for the Offload project. +// See offload/API/README.md for more information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" + +#include "Generators.hpp" + +namespace llvm { +namespace offload { +namespace tblgen { + +enum ActionType { PrintRecords, DumpJSON, GenAPI, GenValidation }; + +namespace { +cl::opt Action( + cl::desc("Action to perform:"), + cl::values( + clEnumValN(PrintRecords, "print-records", + "Print all records to stdout (default)"), + clEnumValN(DumpJSON, "dump-json", + "Dump all records as machine-readable JSON"), + clEnumValN(GenAPI, "gen-api", "Generate Offload API header contents"), + clEnumValN(GenValidation, "gen-validation", + "Generate Offload entry point validation functions"))); +} + +static bool OffloadTableGenMain(raw_ostream &OS, RecordKeeper &Records) { + switch (Action) { + case PrintRecords: + OS << Records; + break; + case DumpJSON: + EmitJSON(Records, OS); + break; + case GenAPI: + EmitOffloadAPI(Records, OS); + break; + case GenValidation: + EmitOffloadValidation(Records, OS); + break; + default: + break; + } + + return false; +} + +int OffloadTblgenMain(int argc, char **argv) { + InitLLVM y(argc, argv); + cl::ParseCommandLineOptions(argc, argv); + return TableGenMain(argv[0], &OffloadTableGenMain); + ; +} +} // namespace tblgen +} // namespace offload +} // namespace llvm + +using namespace llvm; +using namespace offload::tblgen; + +int main(int argc, char **argv) { return OffloadTblgenMain(argc, argv); }