diff --git a/clang/include/clang/Sema/SemaSYCL.h b/clang/include/clang/Sema/SemaSYCL.h index 6a16be6b94818..6bec0ad1f87bd 100644 --- a/clang/include/clang/Sema/SemaSYCL.h +++ b/clang/include/clang/Sema/SemaSYCL.h @@ -318,6 +318,7 @@ class SemaSYCL : public SemaBase { void ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, MangleContext &MC); void SetSYCLKernelNames(); void MarkDevices(); + void ProcessFreeFunction(FunctionDecl *FD); /// Get the number of fields or captures within the parsed type. ExprResult ActOnSYCLBuiltinNumFieldsExpr(ParsedType PT); diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 8b2f27bd4eca3..e3b827b78286c 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -16651,6 +16651,13 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body, if (FD && !FD->isDeleted()) checkTypeSupport(FD->getType(), FD->getLocation(), FD); + // Handle free functions. + if (LangOpts.SYCLIsDevice && FD->hasAttr() && Body && + (FD->getTemplatedKind() == FunctionDecl::TK_NonTemplate || + FD->getTemplatedKind() == + FunctionDecl::TK_FunctionTemplateSpecialization)) + SYCL().ProcessFreeFunction(FD); + return dcl; } diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index d7142dd2abba0..1b08bcdc4265a 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -949,6 +949,33 @@ class KernelBodyTransform : public TreeTransform { Sema &SemaRef; }; +/// Creates a kernel parameter descriptor +/// \param Src field declaration to construct name from +/// \param Ty the desired parameter type +/// \return the constructed descriptor +static ParamDesc makeParamDesc(const FieldDecl *Src, QualType Ty) { + ASTContext &Ctx = Src->getASTContext(); + std::string Name = (Twine("_arg_") + Src->getName()).str(); + return std::make_tuple(Ty, &Ctx.Idents.get(Name), + Ctx.getTrivialTypeSourceInfo(Ty)); +} +static ParamDesc makeParamDesc(const ParmVarDecl *Src, QualType Ty) { + ASTContext &Ctx = Src->getASTContext(); + std::string Name = (Twine("__arg_") + Src->getName()).str(); + return std::make_tuple(Ty, &Ctx.Idents.get(Name), + Ctx.getTrivialTypeSourceInfo(Ty)); +} + +static ParamDesc makeParamDesc(ASTContext &Ctx, StringRef Name, QualType Ty) { + return std::make_tuple(Ty, &Ctx.Idents.get(Name), + Ctx.getTrivialTypeSourceInfo(Ty)); +} + +static void unsupportedFreeFunctionParamType() { + llvm::report_fatal_error("Only scalars and pointers are permitted as " + "free function parameters"); +} + class MarkWIScopeFnVisitor : public RecursiveASTVisitor { public: MarkWIScopeFnVisitor(ASTContext &Ctx) : Ctx(Ctx) {} @@ -1038,22 +1065,6 @@ static QualType GetSYCLKernelObjectType(const FunctionDecl *KernelCaller) { return KernelParamTy.getUnqualifiedType(); } -/// Creates a kernel parameter descriptor -/// \param Src field declaration to construct name from -/// \param Ty the desired parameter type -/// \return the constructed descriptor -static ParamDesc makeParamDesc(const FieldDecl *Src, QualType Ty) { - ASTContext &Ctx = Src->getASTContext(); - std::string Name = (Twine("_arg_") + Src->getName()).str(); - return std::make_tuple(Ty, &Ctx.Idents.get(Name), - Ctx.getTrivialTypeSourceInfo(Ty)); -} - -static ParamDesc makeParamDesc(ASTContext &Ctx, StringRef Name, QualType Ty) { - return std::make_tuple(Ty, &Ctx.Idents.get(Name), - Ctx.getTrivialTypeSourceInfo(Ty)); -} - /// \return the target of given SYCL accessor type static target getAccessTarget(QualType FieldTy, const ClassTemplateSpecializationDecl *AccTy) { @@ -1064,6 +1075,63 @@ static target getAccessTarget(QualType FieldTy, AccTy->getTemplateArgs()[3].getAsIntegral().getExtValue()); } +// FIXME: Free functions must have void return type, be declared at file scope, +// outside any namespaces, and with the SYCL_DEVICE attribute. If the +// SYCL_DEVICE attribute is not specified this function is not entered since the +// possibility of the function being a free function is ruled out already. +static bool isFreeFunction(SemaSYCL &SemaSYCLRef, const FunctionDecl *FD) { + for (auto *IRAttr : FD->specific_attrs()) { + SmallVector, 4> NameValuePairs = + IRAttr->getAttributeNameValuePairs(SemaSYCLRef.getASTContext()); + for (const auto &NameValuePair : NameValuePairs) { + if (NameValuePair.first == "sycl-nd-range-kernel" || + NameValuePair.first == "sycl-single-task-kernel") { + if (!FD->getReturnType()->isVoidType()) { + llvm::report_fatal_error( + "Only functions at file scope with void return " + "type are permitted as free functions"); + return false; + } + return true; + } + } + } + return false; +} + +// Creates a name for the free function kernel function. +// Consider a free function named "MyFunction". The normal device function will +// be given its mangled name, say "_Z10MyFunctionIiEvPT_S0_". The corresponding +// kernel function for this free function will be named +// "_Z24__sycl_kernel_MyFunctionIiEvPT_S0_". This is the mangled name of a +// fictitious function that has the same template and function parameters as the +// original free function but with identifier prefixed with __sycl_kernel_. +// We generate this name by starting with the mangled name of the free function +// and adjusting it textually to simulate the __sycl_kernel_ prefix. +// Because free functions are allowed only at file scope and cannot be within +// namespaces the mangled name has the format _Z... where +// length is the identifier's length. The text manipulation inserts the prefix +// __sycl_kernel_ and adjusts the length, leaving the rest of the name as-is. +static std::pair constructFreeFunctionKernelName( + SemaSYCL &SemaSYCLRef, const FunctionDecl *FreeFunc, MangleContext &MC) { + SmallString<256> Result; + llvm::raw_svector_ostream Out(Result); + std::string StableName; + + MC.mangleName(FreeFunc, Out); + std::string MangledName(Out.str()); + size_t StartNums = MangledName.find_first_of("0123456789"); + size_t EndNums = MangledName.find_first_not_of("0123456789", StartNums); + size_t NameLength = + std::stoi(MangledName.substr(StartNums, EndNums - StartNums)); + size_t NewNameLength = 14 /*length of __sycl_kernel_*/ + NameLength; + std::string NewName = MangledName.substr(0, StartNums) + + std::to_string(NewNameLength) + "__sycl_kernel_" + + MangledName.substr(EndNums); + StableName = NewName; + return {NewName, StableName}; +} + // The first template argument to the kernel caller function is used to identify // the kernel itself. static QualType calculateKernelNameType(ASTContext &Ctx, @@ -1201,6 +1269,23 @@ class KernelObjVisitor { // return Handlers.f(FD, FDTy); \ // })...) + // This enables handler execution only when previous Handlers succeed. + template + bool handleParam(ParmVarDecl *PD, QualType PDTy, Tn &&...tn) { + bool result = true; + (void)std::initializer_list{(result = result && tn(PD, PDTy), 0)...}; + return result; + } + + // This definition using std::bind is necessary because of a gcc 7.x bug. +#define KP_FOR_EACH(FUNC, Item, Qt) \ + handleParam( \ + Item, Qt, \ + std::bind(static_cast::*)( \ + bind_param_t, QualType)>( \ + &std::decay_t::FUNC), \ + std::ref(Handlers), _1, _2)...) + // Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter // the Wrapper structure that we're currently visiting. Owner is the parent // type (which doesn't exist in cases where it is a FieldDecl in the @@ -1343,6 +1428,27 @@ class KernelObjVisitor { KF_FOR_EACH(handleOtherType, Field, FieldTy); } + template + void visitParam(ParmVarDecl *Param, QualType ParamTy, + HandlerTys &...Handlers) { + if (isSyclSpecialType(ParamTy, SemaSYCLRef)) + KP_FOR_EACH(handleOtherType, Param, ParamTy); + else if (ParamTy->isStructureOrClassType()) + KP_FOR_EACH(handleOtherType, Param, ParamTy); + else if (ParamTy->isUnionType()) + KP_FOR_EACH(handleOtherType, Param, ParamTy); + else if (ParamTy->isReferenceType()) + KP_FOR_EACH(handleOtherType, Param, ParamTy); + else if (ParamTy->isPointerType()) + KP_FOR_EACH(handlePointerType, Param, ParamTy); + else if (ParamTy->isArrayType()) + KP_FOR_EACH(handleOtherType, Param, ParamTy); + else if (ParamTy->isScalarType()) + KP_FOR_EACH(handleScalarType, Param, ParamTy); + else + KP_FOR_EACH(handleOtherType, Param, ParamTy); + } + public: KernelObjVisitor(SemaSYCL &S) : SemaSYCLRef(S) {} @@ -1364,7 +1470,17 @@ class KernelObjVisitor { void visitArray(const CXXRecordDecl *Owner, FieldDecl *Field, QualType ArrayTy, HandlerTys &...Handlers); + // A visitor function that dispatches to functions as defined in + // SyclKernelFieldHandler by iterating over a free function parameter list. + template + void VisitFunctionParameters(FunctionDecl *FreeFunc, + HandlerTys &...Handlers) { + for (ParmVarDecl *Param : FreeFunc->parameters()) + visitParam(Param, Param->getType(), Handlers...); + } + #undef KF_FOR_EACH +#undef KP_FOR_EACH }; // A base type that the SYCL OpenCL Kernel construction task uses to implement @@ -1388,15 +1504,23 @@ class SyclKernelFieldHandlerBase { return true; } virtual bool handleSyclSpecialType(FieldDecl *, QualType) { return true; } + virtual bool handleSyclSpecialType(ParmVarDecl *, QualType) { return true; } virtual bool handleStructType(FieldDecl *, QualType) { return true; } + virtual bool handleStructType(ParmVarDecl *, QualType) { return true; } virtual bool handleUnionType(FieldDecl *, QualType) { return true; } + virtual bool handleUnionType(ParmVarDecl *, QualType) { return true; } virtual bool handleReferenceType(FieldDecl *, QualType) { return true; } + virtual bool handleReferenceType(ParmVarDecl *, QualType) { return true; } virtual bool handlePointerType(FieldDecl *, QualType) { return true; } + virtual bool handlePointerType(ParmVarDecl *, QualType) { return true; } virtual bool handleArrayType(FieldDecl *, QualType) { return true; } + virtual bool handleArrayType(ParmVarDecl *, QualType) { return true; } virtual bool handleScalarType(FieldDecl *, QualType) { return true; } + virtual bool handleScalarType(ParmVarDecl *, QualType) { return true; } // Most handlers shouldn't be handling this, just the field checker. virtual bool handleOtherType(FieldDecl *, QualType) { return true; } + virtual bool handleOtherType(ParmVarDecl *, QualType) { return true; } // Handle a simple struct that doesn't need to be decomposed, only called on // handlers with VisitInsideSimpleContainers as false. Replaces @@ -1405,6 +1529,12 @@ class SyclKernelFieldHandlerBase { QualType) { return true; } + + virtual bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *, + QualType) { + return true; + } + virtual bool handleNonDecompStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, QualType) { return true; @@ -1425,6 +1555,12 @@ class SyclKernelFieldHandlerBase { virtual bool leaveStruct(const CXXRecordDecl *, FieldDecl *, QualType) { return true; } + virtual bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) { + return true; + } + virtual bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) { + return true; + } virtual bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, QualType) { return true; @@ -1435,6 +1571,8 @@ class SyclKernelFieldHandlerBase { } virtual bool enterUnion(const CXXRecordDecl *, FieldDecl *) { return true; } virtual bool leaveUnion(const CXXRecordDecl *, FieldDecl *) { return true; } + virtual bool enterUnion(const CXXRecordDecl *, ParmVarDecl *) { return true; } + virtual bool leaveUnion(const CXXRecordDecl *, ParmVarDecl *) { return true; } // The following are used for stepping through array elements. virtual bool enterArray(FieldDecl *, QualType ArrayTy, QualType ElementTy) { @@ -1443,6 +1581,12 @@ class SyclKernelFieldHandlerBase { virtual bool leaveArray(FieldDecl *, QualType ArrayTy, QualType ElementTy) { return true; } + virtual bool enterArray(ParmVarDecl *, QualType ArrayTy, QualType ElementTy) { + return true; + } + virtual bool leaveArray(ParmVarDecl *, QualType ArrayTy, QualType ElementTy) { + return true; + } virtual bool nextElement(QualType, uint64_t) { return true; } @@ -1741,6 +1885,12 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { return isValid(); } + bool handleReferenceType(ParmVarDecl *PD, QualType ParamTy) final { + Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy; + IsInvalid = true; + return isValid(); + } + bool handleStructType(FieldDecl *FD, QualType FieldTy) final { CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); assert(RD && "Not a RecordDecl inside the handler for struct type"); @@ -1754,6 +1904,12 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { return isValid(); } + bool handleStructType(ParmVarDecl *PD, QualType ParamTy) final { + Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy; + IsInvalid = true; + return isValid(); + } + bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS, QualType FieldTy) final { IsInvalid |= checkSyclSpecialType(FieldTy, BS.getBeginLoc()); @@ -1765,11 +1921,23 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { return isValid(); } + bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { + Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy; + IsInvalid = true; + return isValid(); + } + bool handleArrayType(FieldDecl *FD, QualType FieldTy) final { IsInvalid |= checkNotCopyableToKernel(FD, FieldTy); return isValid(); } + bool handleArrayType(ParmVarDecl *PD, QualType ParamTy) final { + Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy; + IsInvalid = true; + return isValid(); + } + bool handlePointerType(FieldDecl *FD, QualType FieldTy) final { while (FieldTy->isAnyPointerType()) { FieldTy = QualType{FieldTy->getPointeeOrArrayElementType(), 0}; @@ -1782,12 +1950,30 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { return isValid(); } + bool handlePointerType(ParmVarDecl *PD, QualType ParamTy) final { + while (ParamTy->isAnyPointerType()) { + ParamTy = QualType{ParamTy->getPointeeOrArrayElementType(), 0}; + if (ParamTy->isVariableArrayType()) { + Diag.Report(PD->getLocation(), diag::err_vla_unsupported) << 0; + IsInvalid = true; + break; + } + } + return isValid(); + } + bool handleOtherType(FieldDecl *FD, QualType FieldTy) final { Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) << FieldTy; IsInvalid = true; return isValid(); } + bool handleOtherType(ParmVarDecl *PD, QualType ParamTy) final { + Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy; + IsInvalid = true; + return isValid(); + } + bool enterStruct(const CXXRecordDecl *, FieldDecl *, QualType) final { ++StructFieldDepth; return true; @@ -1798,6 +1984,18 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { return true; } + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &BS, QualType FieldTy) final { ++StructBaseDepth; @@ -1839,15 +2037,33 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler { return true; } + bool enterUnion(const CXXRecordDecl *, ParmVarDecl *) override { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool leaveUnion(const CXXRecordDecl *RD, FieldDecl *FD) override { --UnionCount; return true; } + bool leaveUnion(const CXXRecordDecl *, ParmVarDecl *) override { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleSyclSpecialType(FieldDecl *FD, QualType FieldTy) final { return checkType(FD->getLocation(), FieldTy); } + bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS, QualType FieldTy) final { return checkType(BS.getBeginLoc(), FieldTy); @@ -1883,17 +2099,35 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { return true; } + bool handleSyclSpecialType(ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handlePointerType(FieldDecl *, QualType) final { PointerStack.back() = true; return true; } + bool handlePointerType(ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool enterStruct(const CXXRecordDecl *, FieldDecl *, QualType) final { CollectionStack.push_back(false); PointerStack.push_back(false); return true; } + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool leaveStruct(const CXXRecordDecl *, FieldDecl *, QualType Ty) final { // If a record needs to be decomposed, it is marked with // SYCLRequiresDecompositionAttr. Else if a record contains @@ -1916,6 +2150,13 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { return true; } + bool leaveStruct(const CXXRecordDecl *RD, ParmVarDecl *PD, + QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, QualType) final { CollectionStack.push_back(false); @@ -1952,6 +2193,12 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { return true; } + bool enterArray(ParmVarDecl *, QualType ArrayTy, QualType ElementTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool leaveArray(FieldDecl *FD, QualType ArrayTy, QualType ElementTy) final { // If an array needs to be decomposed, it is marked with // SYCLRequiresDecompositionAttr. Else if the array is an array of pointers @@ -1974,6 +2221,12 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { } return true; } + + bool leaveArray(ParmVarDecl *PD, QualType ArrayTy, QualType ElementTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } }; static QualType ModifyAddressSpace(SemaSYCL &SemaSYCLRef, QualType Ty) { @@ -2085,6 +2338,13 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { return true; } + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, + QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool leaveStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { CXXRecordDecl *ModifiedRD = getGeneratedNewRecord(Ty->getAsCXXRecordDecl()); @@ -2099,6 +2359,13 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { return true; } + bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *PD, + QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, QualType Ty) final { createNewType(Ty->getAsCXXRecordDecl()); @@ -2139,6 +2406,12 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { return true; } + bool leaveArray(ParmVarDecl *PD, QualType ArrayTy, QualType ET) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handlePointerType(FieldDecl *FD, QualType FieldTy) final { QualType ModifiedPointerType = ModifyAddressSpace(SemaSYCLRef, FieldTy); if (!isArrayElement(FD, FieldTy)) @@ -2150,21 +2423,46 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { return true; } + bool handlePointerType(ParmVarDecl *PD, QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleScalarType(FieldDecl *FD, QualType FieldTy) final { addField(FD, FieldTy); return true; } + bool handleScalarType(ParmVarDecl *PD, QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { return handleScalarType(FD, FieldTy); } + bool handleUnionType(ParmVarDecl *PD, QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleNonDecompStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { addField(FD, Ty); return true; } + bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD, + QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleNonDecompStruct(const CXXRecordDecl *Parent, const CXXBaseSpecifier &BS, QualType Ty) final { createBaseSpecifier(Parent, Ty->getAsCXXRecordDecl(), BS); @@ -2193,7 +2491,8 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler { // A type to Create and own the FunctionDecl for the kernel. class SyclKernelDeclCreator : public SyclKernelFieldHandler { - FunctionDecl *KernelDecl; + bool IsFreeFunction = false; + FunctionDecl *KernelDecl = nullptr; llvm::SmallVector Params; Sema::ContextRAII FuncContext; // Holds the last handled field's first parameter. This doesn't store an @@ -2207,6 +2506,11 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { addParam(newParamDesc, FieldTy); } + void addParam(const ParmVarDecl *PD, QualType ParamTy) { + ParamDesc newParamDesc = makeParamDesc(PD, ParamTy); + addParam(newParamDesc, ParamTy); + } + void addParam(const CXXBaseSpecifier &BS, QualType FieldTy) { // TODO: There is no name for the base available, but duplicate names are // seemingly already possible, so we'll give them all the same name for now. @@ -2407,16 +2711,17 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { public: static constexpr const bool VisitInsideSimpleContainers = false; SyclKernelDeclCreator(SemaSYCL &S, SourceLocation Loc, bool IsInline, - bool IsSIMDKernel, FunctionDecl *SYCLKernel) - : SyclKernelFieldHandler(S), + bool IsSIMDKernel, bool IsFreeFunction, + FunctionDecl *SYCLKernel) + : SyclKernelFieldHandler(S), IsFreeFunction(IsFreeFunction), KernelDecl( createKernelDecl(S.getASTContext(), Loc, IsInline, IsSIMDKernel)), FuncContext(SemaSYCLRef.SemaRef, KernelDecl) { S.addSyclOpenCLKernel(SYCLKernel, KernelDecl); - - if (const auto *AddIRAttrFunc = - SYCLKernel->getAttr()) - KernelDecl->addAttr(AddIRAttrFunc->clone(SemaSYCLRef.getASTContext())); + for (const auto *IRAttr : + SYCLKernel->specific_attrs()) { + KernelDecl->addAttr(IRAttr->clone(SemaSYCLRef.getASTContext())); + } } ~SyclKernelDeclCreator() { @@ -2447,11 +2752,23 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return true; } + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool leaveStruct(const CXXRecordDecl *, FieldDecl *, QualType) final { --StructDepth; return true; } + bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &BS, QualType FieldTy) final { ++StructDepth; @@ -2498,6 +2815,12 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return handleSpecialType(FD, FieldTy); } + bool handleSyclSpecialType(ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + RecordDecl *wrapField(FieldDecl *Field, QualType FieldTy) { RecordDecl *WrapperClass = SemaSYCLRef.getASTContext().buildImplicitRecord("__wrapper_class"); @@ -2531,6 +2854,12 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return true; } + bool handlePointerType(ParmVarDecl *PD, QualType ParamTy) final { + QualType ModTy = ModifyAddressSpace(SemaSYCLRef, ParamTy); + addParam(PD, ModTy); + return true; + } + bool handleSimpleArrayType(FieldDecl *FD, QualType FieldTy) final { QualType ArrayTy = FieldTy; @@ -2549,6 +2878,11 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return true; } + bool handleScalarType(ParmVarDecl *PD, QualType ParamTy) final { + addParam(PD, ParamTy); + return true; + } + bool handleNonDecompStruct(const CXXRecordDecl *RD, FieldDecl *FD, QualType Ty) final { // This is a field which should not be decomposed. @@ -2563,6 +2897,13 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return true; } + bool handleNonDecompStruct(const CXXRecordDecl *RD, ParmVarDecl *PD, + QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleNonDecompStruct(const CXXRecordDecl *Base, const CXXBaseSpecifier &BS, QualType Ty) final { // This is a base class which should not be decomposed. @@ -2581,6 +2922,12 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return handleScalarType(FD, FieldTy); } + bool handleUnionType(ParmVarDecl *PD, QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + // Generate kernel argument to initialize specialization constants. void handleSyclKernelHandlerType() { ASTContext &Context = SemaSYCLRef.getASTContext(); @@ -2690,6 +3037,8 @@ class ESIMDKernelDiagnostics : public SyclKernelFieldHandler { QualType FieldTy) final { return handleSpecialType(FieldTy); } + + using SyclKernelFieldHandler::handleSyclSpecialType; }; class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler { @@ -2730,6 +3079,12 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler { return handleSpecialType(FieldTy); } + bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS, QualType FieldTy) final { return handleSpecialType(FieldTy); @@ -2740,11 +3095,21 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler { return true; } + bool handlePointerType(ParmVarDecl *PD, QualType ParamTy) final { + addParam(ParamTy); + return true; + } + bool handleScalarType(FieldDecl *FD, QualType FieldTy) final { addParam(FieldTy); return true; } + bool handleScalarType(ParmVarDecl *PD, QualType ParamTy) final { + addParam(ParamTy); + return true; + } + bool handleSimpleArrayType(FieldDecl *FD, QualType FieldTy) final { addParam(FieldTy); return true; @@ -2756,6 +3121,13 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler { return true; } + bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *, + QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleNonDecompStruct(const CXXRecordDecl *Base, const CXXBaseSpecifier &BS, QualType Ty) final { addParam(Ty); @@ -2765,6 +3137,12 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler { bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { return handleScalarType(FD, FieldTy); } + + bool handleUnionType(ParmVarDecl *PD, QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } }; std::string getKernelArgDesc(StringRef KernelArgDescription) { @@ -2842,6 +3220,7 @@ class SyclOptReportCreator : public SyclKernelFieldHandler { SourceLocation Loc) : SyclKernelFieldHandler(S), DC(DC), KernelInvocationLoc(Loc) {} + using SyclKernelFieldHandler::handleSyclSpecialType; bool handleSyclSpecialType(FieldDecl *FD, QualType FieldTy) final { for (const auto *Param : DC.getParamVarDeclsForCurrentField()) addParam(FD, Param->getType(), FieldTy.getAsString()); @@ -2864,6 +3243,7 @@ class SyclOptReportCreator : public SyclKernelFieldHandler { return true; } + using SyclKernelFieldHandler::handlePointerType; bool handlePointerType(FieldDecl *FD, QualType FieldTy) final { std::string KernelArgDescription = ""; bool IsCompilerGeneratedType = false; @@ -2882,11 +3262,13 @@ class SyclOptReportCreator : public SyclKernelFieldHandler { return true; } + using SyclKernelFieldHandler::handleScalarType; bool handleScalarType(FieldDecl *FD, QualType FieldTy) final { addParam(FD, FieldTy); return true; } + using SyclKernelFieldHandler::handleSimpleArrayType; bool handleSimpleArrayType(FieldDecl *FD, QualType FieldTy) final { // Simple arrays are always wrapped. for (const auto *Param : DC.getParamVarDeclsForCurrentField()) @@ -2894,6 +3276,7 @@ class SyclOptReportCreator : public SyclKernelFieldHandler { return true; } + using SyclKernelFieldHandler::handleNonDecompStruct; bool handleNonDecompStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { CXXRecordDecl *RD = Ty->getAsCXXRecordDecl(); @@ -2917,6 +3300,7 @@ class SyclOptReportCreator : public SyclKernelFieldHandler { return true; } + using SyclKernelFieldHandler::handleUnionType; bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { return handleScalarType(FD, FieldTy); } @@ -3674,6 +4058,248 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { removeFieldMemberExpr(FD, ArrayType); return true; } + using SyclKernelFieldHandler::enterArray; + using SyclKernelFieldHandler::enterStruct; + using SyclKernelFieldHandler::handleNonDecompStruct; + using SyclKernelFieldHandler::handlePointerType; + using SyclKernelFieldHandler::handleScalarType; + using SyclKernelFieldHandler::handleSyclSpecialType; + using SyclKernelFieldHandler::handleUnionType; + using SyclKernelFieldHandler::leaveArray; + using SyclKernelFieldHandler::leaveStruct; +}; + +class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { + SyclKernelDeclCreator &DeclCreator; + llvm::SmallVector BodyStmts; + FunctionDecl *FreeFunc = nullptr; + SourceLocation FreeFunctionSrcLoc; // Free function source location. + llvm::SmallVector ArgExprs; + + // Creates a DeclRefExpr to the ParmVar that represents the current free + // function parameter. + Expr *createParamReferenceExpr() { + ParmVarDecl *FreeFunctionParameter = + DeclCreator.getParamVarDeclsForCurrentField()[0]; + + QualType FreeFunctionParamType = FreeFunctionParameter->getOriginalType(); + Expr *DRE = SemaSYCLRef.SemaRef.BuildDeclRefExpr( + FreeFunctionParameter, FreeFunctionParamType, VK_LValue, + FreeFunctionSrcLoc); + DRE = SemaSYCLRef.SemaRef.DefaultLvalueConversion(DRE).get(); + return DRE; + } + + // Creates a DeclRefExpr to the ParmVar that represents the current pointer + // parameter. + Expr *createPointerParamReferenceExpr(QualType PointerTy, bool Wrapped) { + ParmVarDecl *FreeFunctionParameter = + DeclCreator.getParamVarDeclsForCurrentField()[0]; + + QualType FreeFunctionParamType = FreeFunctionParameter->getOriginalType(); + Expr *DRE = SemaSYCLRef.SemaRef.BuildDeclRefExpr( + FreeFunctionParameter, FreeFunctionParamType, VK_LValue, + FreeFunctionSrcLoc); + DRE = SemaSYCLRef.SemaRef.DefaultLvalueConversion(DRE).get(); + + if (PointerTy->getPointeeType().getAddressSpace() != + FreeFunctionParamType->getPointeeType().getAddressSpace()) + DRE = ImplicitCastExpr::Create(SemaSYCLRef.getASTContext(), PointerTy, + CK_AddressSpaceConversion, DRE, nullptr, + VK_PRValue, FPOptionsOverride()); + return DRE; + } + + // For a free function such as: + // void f(int i, int* p, struct Simple S) { ... } + // + // Keep the function as-is for the version callable from device code. + // void f(int i, int *p, struct Simple S) { ... } + // + // For the host-callable kernel function generate this: + // void __sycl_kernel_f(int __arg_i, int* __arg_p, struct Simple __arg_S) + // { + // f(__arg_i, __arg_p, __arg_S); + // } + CompoundStmt *createFreeFunctionKernelBody() { + SemaSYCLRef.SemaRef.PushFunctionScope(); + Expr *Fn = SemaSYCLRef.SemaRef.BuildDeclRefExpr( + FreeFunc, FreeFunc->getType(), VK_LValue, FreeFunctionSrcLoc); + ASTContext &Context = SemaSYCLRef.getASTContext(); + QualType ResultTy = FreeFunc->getReturnType(); + ExprValueKind VK = Expr::getValueKindForType(ResultTy); + ResultTy = ResultTy.getNonLValueExprType(Context); + Fn = ImplicitCastExpr::Create(Context, + Context.getPointerType(FreeFunc->getType()), + CK_FunctionToPointerDecay, Fn, nullptr, + VK_PRValue, FPOptionsOverride()); + auto CallExpr = CallExpr::Create(Context, Fn, ArgExprs, ResultTy, VK, + FreeFunctionSrcLoc, FPOptionsOverride()); + BodyStmts.push_back(CallExpr); + return CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {}, + {}); + } + +public: + static constexpr const bool VisitInsideSimpleContainers = false; + + FreeFunctionKernelBodyCreator(SemaSYCL &S, SyclKernelDeclCreator &DC, + FunctionDecl *FF) + : SyclKernelFieldHandler(S), DeclCreator(DC), FreeFunc(FF), + FreeFunctionSrcLoc(FF->getLocation()) {} + + ~FreeFunctionKernelBodyCreator() { + CompoundStmt *KernelBody = createFreeFunctionKernelBody(); + DeclCreator.setBody(KernelBody); + } + + bool handleSyclSpecialType(FieldDecl *FD, QualType Ty) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handleSyclSpecialType(ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS, + QualType Ty) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handlePointerType(FieldDecl *FD, QualType FieldTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handlePointerType(ParmVarDecl *PD, QualType ParamTy) final { + Expr *PointerRef = createPointerParamReferenceExpr(ParamTy, false); + ArgExprs.push_back(PointerRef); + return true; + } + + bool handleSimpleArrayType(FieldDecl *FD, QualType FieldTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handleNonDecompStruct(const CXXRecordDecl *, FieldDecl *FD, + QualType Ty) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *, + QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handleNonDecompStruct(const CXXRecordDecl *RD, + const CXXBaseSpecifier &BS, QualType Ty) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handleScalarType(FieldDecl *FD, QualType FieldTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handleScalarType(ParmVarDecl *, QualType) final { + Expr *ParamRef = createParamReferenceExpr(); + ArgExprs.push_back(ParamRef); + return true; + } + + bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool handleUnionType(ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool enterStruct(const CXXRecordDecl *RD, FieldDecl *FD, QualType Ty) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool leaveStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS, + QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool leaveStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS, + QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool enterArray(FieldDecl *FD, QualType ArrayType, + QualType ElementType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool enterArray(ParmVarDecl *PD, QualType ArrayType, + QualType ElementType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool leaveArray(FieldDecl *FD, QualType ArrayType, + QualType ElementType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + + bool leaveArray(ParmVarDecl *PD, QualType ArrayType, + QualType ElementType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } }; // Kernels are only the unnamed-lambda feature if the feature is enabled, AND @@ -3682,6 +4308,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { static bool IsSYCLUnnamedKernel(SemaSYCL &SemaSYCLRef, const FunctionDecl *FD) { if (!SemaSYCLRef.getLangOpts().SYCLUnnamedLambda) return false; + QualType FunctorTy = GetSYCLKernelObjectType(FD); QualType TmplArgTy = calculateKernelNameType(SemaSYCLRef.getASTContext(), FD); return SemaSYCLRef.getASTContext().hasSameType(FunctorTy, TmplArgTy); @@ -3699,6 +4326,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { ? 0 : SemaSYCLRef.getASTContext().getFieldOffset(FD) / 8; } + // For free functions each parameter is stand-alone, so offsets within a + // lambda/function object are not relevant. Therefore offsetOf will always be + // 0. + int64_t offsetOf(const ParmVarDecl *, QualType) const { return 0; } int64_t offsetOf(const CXXRecordDecl *RD, const CXXRecordDecl *Base) const { const ASTRecordLayout &Layout = @@ -3710,10 +4341,22 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { SYCLIntegrationHeader::kernel_param_kind_t Kind) { addParam(ArgTy, Kind, offsetOf(FD, ArgTy)); } - void addParam(QualType ArgTy, SYCLIntegrationHeader::kernel_param_kind_t Kind, + + // For free functions we increment the current offset as each parameter is + // added. + void addParam(const ParmVarDecl *PD, QualType ParamTy, + SYCLIntegrationHeader::kernel_param_kind_t Kind) { + addParam(ParamTy, Kind, offsetOf(PD, ParamTy)); + CurOffset += + SemaSYCLRef.getASTContext().getTypeSizeInChars(ParamTy).getQuantity(); + } + + void addParam(QualType ParamTy, + SYCLIntegrationHeader::kernel_param_kind_t Kind, uint64_t OffsetAdj) { uint64_t Size; - Size = SemaSYCLRef.getASTContext().getTypeSizeInChars(ArgTy).getQuantity(); + Size = + SemaSYCLRef.getASTContext().getTypeSizeInChars(ParamTy).getQuantity(); Header.addParamDesc(Kind, static_cast(Size), static_cast(CurOffset + OffsetAdj)); } @@ -3734,6 +4377,14 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { IsSYCLUnnamedKernel(S, KernelFunc), ObjSize); } + SyclKernelIntHeaderCreator(SemaSYCL &S, SYCLIntegrationHeader &H, + QualType NameType, FunctionDecl *FreeFunc) + : SyclKernelFieldHandler(S), Header(H) { + Header.startKernel(FreeFunc, NameType, FreeFunc->getLocation(), + false /*IsESIMD*/, true /*IsSYCLUnnamedKernel*/, + 0 /*ObjSize*/); + } + bool handleSyclSpecialType(const CXXRecordDecl *RD, const CXXBaseSpecifier &BC, QualType FieldTy) final { @@ -3787,6 +4438,12 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return true; } + bool handleSyclSpecialType(ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handlePointerType(FieldDecl *FD, QualType FieldTy) final { addParam(FD, FieldTy, ((StructDepth) ? SYCLIntegrationHeader::kind_std_layout @@ -3794,11 +4451,21 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return true; } + bool handlePointerType(ParmVarDecl *PD, QualType ParamTy) final { + addParam(PD, ParamTy, SYCLIntegrationHeader::kind_pointer); + return true; + } + bool handleScalarType(FieldDecl *FD, QualType FieldTy) final { addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); return true; } + bool handleScalarType(ParmVarDecl *PD, QualType ParamTy) final { + addParam(PD, ParamTy, SYCLIntegrationHeader::kind_std_layout); + return true; + } + bool handleSimpleArrayType(FieldDecl *FD, QualType FieldTy) final { // Arrays are always wrapped inside of structs, so just treat it as a simple // struct. @@ -3812,6 +4479,13 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return true; } + bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *PD, + QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool handleNonDecompStruct(const CXXRecordDecl *Base, const CXXBaseSpecifier &, QualType Ty) final { addParam(Ty, SYCLIntegrationHeader::kind_std_layout, @@ -3823,6 +4497,12 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return handleScalarType(FD, FieldTy); } + bool handleUnionType(ParmVarDecl *PD, QualType ParamTy) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + void handleSyclKernelHandlerType(QualType Ty) { // The compiler generated kernel argument used to initialize SYCL 2020 // specialization constants, `specialization_constants_buffer`, should @@ -3840,12 +4520,24 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return true; } + bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool leaveStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { --StructDepth; CurOffset -= offsetOf(FD, Ty); return true; } + bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS, QualType) final { CurOffset += offsetOf(RD, BS.getType()->getAsCXXRecordDecl()); @@ -3863,6 +4555,12 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return true; } + bool enterArray(ParmVarDecl *PD, QualType ArrayTy, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + bool nextElement(QualType ET, uint64_t Index) final { int64_t Size = SemaSYCLRef.getASTContext().getTypeSizeInChars(ET).getQuantity(); @@ -3876,6 +4574,12 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return true; } + bool leaveArray(ParmVarDecl *PD, QualType ArrayTy, QualType) final { + // TODO + unsupportedFreeFunctionParamType(); + return true; + } + using SyclKernelFieldHandler::enterStruct; using SyclKernelFieldHandler::leaveStruct; }; @@ -4202,10 +4906,17 @@ void SemaSYCL::SetSYCLKernelNames() { for (const std::pair &Pair : SyclKernelsToOpenCLKernels) { std::string CalculatedName, StableName; - std::tie(CalculatedName, StableName) = - constructKernelName(*this, Pair.first, *MangleCtx); - StringRef KernelName( - IsSYCLUnnamedKernel(*this, Pair.first) ? StableName : CalculatedName); + StringRef KernelName; + if (isFreeFunction(*this, Pair.first)) { + std::tie(CalculatedName, StableName) = + constructFreeFunctionKernelName(*this, Pair.first, *MangleCtx); + KernelName = CalculatedName; + } else { + std::tie(CalculatedName, StableName) = + constructKernelName(*this, Pair.first, *MangleCtx); + KernelName = + IsSYCLUnnamedKernel(*this, Pair.first) ? StableName : CalculatedName; + } getSyclIntegrationHeader().updateKernelNames(Pair.first, KernelName, StableName); @@ -4282,7 +4993,7 @@ void SemaSYCL::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, SyclKernelDeclCreator kernel_decl(*this, KernelObj->getLocation(), KernelCallerFunc->isInlined(), IsSIMDKernel, - KernelCallerFunc); + false /*IsFreeFunction*/, KernelCallerFunc); SyclKernelBodyCreator kernel_body(*this, kernel_decl, KernelObj, KernelCallerFunc, IsSIMDKernel, CallOperator); @@ -4323,6 +5034,26 @@ void SemaSYCL::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, } } +void ConstructFreeFunctionKernel(SemaSYCL &SemaSYCLRef, FunctionDecl *FD) { + SyclKernelArgsSizeChecker argsSizeChecker(SemaSYCLRef, FD->getLocation(), + false /*IsSIMDKernel*/); + SyclKernelDeclCreator kernel_decl(SemaSYCLRef, FD->getLocation(), + FD->isInlined(), false /*IsSIMDKernel */, + true /*IsFreeFunction*/, FD); + + FreeFunctionKernelBodyCreator kernel_body(SemaSYCLRef, kernel_decl, FD); + + SyclKernelIntHeaderCreator int_header( + SemaSYCLRef, SemaSYCLRef.getSyclIntegrationHeader(), FD->getType(), FD); + + SyclKernelIntFooterCreator int_footer(SemaSYCLRef, + SemaSYCLRef.getSyclIntegrationFooter()); + KernelObjVisitor Visitor{SemaSYCLRef}; + + Visitor.VisitFunctionParameters(FD, argsSizeChecker, kernel_decl, kernel_body, + int_header, int_footer); +}; + // Figure out the sub-group for the this function. First we check the // attributes, then the global settings. static std::pair @@ -4605,6 +5336,28 @@ void SemaSYCL::MarkDevices() { } } +void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) { + if (isFreeFunction(*this, FD)) { + SyclKernelFieldChecker FieldChecker(*this); + SyclKernelUnionChecker UnionChecker(*this); + + KernelObjVisitor Visitor{*this}; + + DiagnosingSYCLKernel = true; + + // Check parameters of free function. + Visitor.VisitFunctionParameters(FD, FieldChecker, UnionChecker); + + DiagnosingSYCLKernel = false; + + // Ignore the free function if any of the checkers fail validation. + if (!FieldChecker.isValid() || !UnionChecker.isValid()) + return; + + ConstructFreeFunctionKernel(*this, FD); + } +} + // ----------------------------------------------------------------------------- // SYCL device specific diagnostics implementation // ----------------------------------------------------------------------------- diff --git a/clang/test/CodeGenSYCL/free_function_int_header.cpp b/clang/test/CodeGenSYCL/free_function_int_header.cpp new file mode 100755 index 0000000000000..9cc154086f6ef --- /dev/null +++ b/clang/test/CodeGenSYCL/free_function_int_header.cpp @@ -0,0 +1,85 @@ +// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s +// RUN: FileCheck -input-file=%t.h %s +// +// This test checks integration header contents for free functions with scalar +// and pointer parameters. + +#include "mock_properties.hpp" +#include "sycl.hpp" + +// First overload of function ff_2. +__attribute__((sycl_device)) +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", + 2)]] void +ff_2(int *ptr, int start, int end) { + for (int i = start; i <= end; i++) + ptr[i] = start + 66; +} + +// Second overload of function ff_2. +__attribute__((sycl_device)) +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", + 2)]] void + ff_2(int* ptr, int start, int end, int value) { + for (int i = start; i <= end; i++) + ptr[i] = start + value; +} + +// Templated definition of function ff_3. +template +__attribute__((sycl_device)) +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void +ff_3(T *ptr, T start, T end) { + for (int i = start; i <= end; i++) + ptr[i] = start; +} + +// Explicit instantiation of ff_3 with int type. +template void ff_3(int *ptr, int start, int end); + +// Explicit instantiation of ff_3 with float type. +template void ff_3(float* ptr, float start, float end); + +// Specialization of ff_3 with double type. +template <> void ff_3(double *ptr, double start, double end) { + for (int i = start; i <= end; i++) + ptr[i] = end; +} + +// CHECK: const char* const kernel_names[] = { +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piii +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piiii +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_3IiEvPT_S0_S0_ +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_3IfEvPT_S0_S0_ +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_3IdEvPT_S0_S0_ +// CHECK-NEXT: }; + +// CHECK: const kernel_param_desc_t kernel_signatures[] = { +// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piii +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, + +// CHECK: {{.*}}__sycl_kernel_ff_2Piiii +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 16 }, + +// CHECK: {{.*}}__sycl_kernel_ff_3IiEvPT_S0_S0_ +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, + +// CHECK: {{.*}}__sycl_kernel_ff_3IfEvPT_S0_S0_ +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 }, + +// CHECK: {{.*}}__sycl_kernel_ff_3IdEvPT_S0_S0_ +// CHECK-NEXT: { kernel_param_kind_t::kind_pointer, 8, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 8, 8 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 8, 16 }, + +// CHECK: { kernel_param_kind_t::kind_invalid, -987654321, -987654321 }, +// CHECK-NEXT: }; \ No newline at end of file diff --git a/clang/test/SemaSYCL/free_function_kernel_params.cpp b/clang/test/SemaSYCL/free_function_kernel_params.cpp new file mode 100755 index 0000000000000..6f73d6f172aa6 --- /dev/null +++ b/clang/test/SemaSYCL/free_function_kernel_params.cpp @@ -0,0 +1,57 @@ +// RUN: %clang_cc1 -internal-isystem %S/Inputs -fsycl-is-device -ast-dump \ +// RUN: %s -o - | FileCheck %s +// This test checks parameter rewriting for free functions with parameters +// of type scalar and pointer. + +#include "sycl.hpp" + +__attribute__((sycl_device)) +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] +void ff_2(int *ptr, int start, int end) { + for (int i = start; i <= end; i++) + ptr[i] = start; +} +// CHECK: FunctionDecl {{.*}}__sycl_kernel_{{.*}} 'void (__global int *, int, int)' +// CHECK-NEXT: ParmVarDecl {{.*}} __arg_ptr '__global int *' +// CHECK-NEXT: ParmVarDecl {{.*}} __arg_start 'int' +// CHECK-NEXT: ParmVarDecl {{.*}} __arg_end 'int' +// CHECK-NEXT: CompoundStmt +// CHECK-NEXT: CallExpr {{.*}} 'void' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(int *, int, int)' +// CHECK-NEXT: DeclRefExpr {{.*}} 'void (int *, int, int)' lvalue Function {{.*}} 'ff_2' 'void (int *, int, int)' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int *' +// CHECK-NEXT: ImplicitCastExpr {{.*}} '__global int *' +// CHECK-NEXT: DeclRefExpr {{.*}} '__global int *' lvalue ParmVar {{.*}} '__arg_ptr' '__global int *' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} '__arg_start' 'int' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} '__arg_end' 'int' + + +// Templated free function definition. +template +__attribute__((sycl_device)) +[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]] + void ff_3(T* ptr, T start, int end) { + for (int i = start; i <= end; i++) + ptr[i] = start; +} + +// Explicit instantiation with "int*" +template void ff_3(int* ptr, int start, int end); + +// CHECK: FunctionDecl {{.*}}__sycl_kernel_{{.*}} 'void (__global int *, int, int)' +// CHECK-NEXT: ParmVarDecl {{.*}} __arg_ptr '__global int *' +// CHECK-NEXT: ParmVarDecl {{.*}} __arg_start 'int' +// CHECK-NEXT: ParmVarDecl {{.*}} __arg_end 'int' +// CHECK-NEXT: CompoundStmt +// CHECK-NEXT: CallExpr {{.*}} 'void' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(int *, int, int)' +// CHECK-NEXT: DeclRefExpr {{.*}} 'void (int *, int, int)' lvalue Function {{.*}} 'ff_3' 'void (int *, int, int)' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int *' +// CHECK-NEXT: ImplicitCastExpr {{.*}} '__global int *' +// CHECK-NEXT: DeclRefExpr {{.*}} '__global int *' lvalue ParmVar {{.*}} '__arg_ptr' '__global int *' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} '__arg_start' 'int' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} '__arg_end' 'int' diff --git a/sycl/include/sycl/ext/oneapi/kernel_properties/properties.hpp b/sycl/include/sycl/ext/oneapi/kernel_properties/properties.hpp index 2643fa75fadb1..e46ab88c43172 100644 --- a/sycl/include/sycl/ext/oneapi/kernel_properties/properties.hpp +++ b/sycl/include/sycl/ext/oneapi/kernel_properties/properties.hpp @@ -61,6 +61,16 @@ struct device_has_key std::integral_constant...>; }; +struct nd_range_kernel_key { + template + using value_t = + property_value>; +}; + +struct single_task_kernel_key { + using value_t = property_value; +}; + template struct property_value, std::integral_constant...> { @@ -113,6 +123,21 @@ struct property_value value{Aspects...}; }; +template +struct property_value> { + static_assert( + Dims >= 1 && Dims <= 3, + "nd_range_kernel_key property must use dimension of 1, 2 or 3."); + + using key_t = nd_range_kernel_key; + using value_t = int; + static constexpr int dimensions = Dims; +}; + +template <> struct property_value { + using key_t = single_task_kernel_key; +}; + template inline constexpr work_group_size_key::value_t work_group_size; @@ -126,6 +151,11 @@ inline constexpr sub_group_size_key::value_t sub_group_size; template inline constexpr device_has_key::value_t device_has; +template +inline constexpr nd_range_kernel_key::value_t nd_range_kernel; + +inline constexpr single_task_kernel_key::value_t single_task_kernel; + struct work_group_progress_key : detail::compile_time_property_key { template struct is_property_key : std::true_type {}; template <> struct is_property_key : std::true_type {}; namespace detail { + template struct PropertyMetaInfo> { static constexpr const char *name = "sycl-work-group-size"; @@ -230,6 +261,15 @@ struct PropertyMetaInfo> { static constexpr const char *value = SizeListToStr(Aspects)...>::value; }; +template +struct PropertyMetaInfo> { + static constexpr const char *name = "sycl-nd-range-kernel"; + static constexpr int value = Dims; +}; +template <> struct PropertyMetaInfo { + static constexpr const char *name = "sycl-single-task-kernel"; + static constexpr int value = 0; +}; template struct HasKernelPropertiesGetMethod : std::false_type {}; @@ -251,7 +291,6 @@ struct HasKernelPropertiesGetMethod>>::name, \ sycl::ext::oneapi::experimental::detail::PropertyMetaInfo< \ diff --git a/sycl/include/sycl/ext/oneapi/properties/property.hpp b/sycl/include/sycl/ext/oneapi/properties/property.hpp index 5ca6a1d5d4944..89d7dd7852a8a 100644 --- a/sycl/include/sycl/ext/oneapi/properties/property.hpp +++ b/sycl/include/sycl/ext/oneapi/properties/property.hpp @@ -203,8 +203,10 @@ enum PropKind : uint32_t { WorkGroupProgress = 62, SubGroupProgress = 63, WorkItemProgress = 64, + NDRangeKernel = 65, + SingleTaskKernel = 66, // PropKindSize must always be the last value. - PropKindSize = 65, + PropKindSize = 67, }; struct property_key_base_tag {}; diff --git a/sycl/source/feature_test.hpp.in b/sycl/source/feature_test.hpp.in index 9a6b72b7d4a48..1a53c9abd0270 100644 --- a/sycl/source/feature_test.hpp.in +++ b/sycl/source/feature_test.hpp.in @@ -105,6 +105,7 @@ inline namespace _V1 { #define SYCL_EXT_INTEL_FPGA_TASK_SEQUENCE 1 #define SYCL_EXT_ONEAPI_PRIVATE_ALLOCA 1 #define SYCL_EXT_ONEAPI_FORWARD_PROGRESS 1 +#define SYCL_EXT_ONEAPI_FREE_FUNCTION_KERNELS 1 #ifndef __has_include #define __has_include(x) 0 diff --git a/sycl/test-e2e/KernelAndProgram/free_function_kernels.cpp b/sycl/test-e2e/KernelAndProgram/free_function_kernels.cpp new file mode 100755 index 0000000000000..d8478d3f56e35 --- /dev/null +++ b/sycl/test-e2e/KernelAndProgram/free_function_kernels.cpp @@ -0,0 +1,277 @@ +// REQUIRES: aspect-usm_shared_allocations +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// The name mangling for free function kernels currently does not work with PTX. +// UNSUPPORTED: cuda + +// This test tests free function kernel code generation and execution. + +#include +#include + +using namespace sycl; + +// Kernel finder. +class KernelFinder { + queue &Queue; + std::vector AllKernelIDs; + +public: + KernelFinder(queue &Q) : Queue(Q) { + // Obtain kernel bundle + kernel_bundle Bundle = + get_kernel_bundle(Queue.get_context()); + std::cout << "Bundle obtained\n"; + AllKernelIDs = get_kernel_ids(); + std::cout << "Number of kernels = " << AllKernelIDs.size() << std::endl; + for (auto K : AllKernelIDs) { + std::cout << "Kernel obtained: " << K.get_name() << std::endl; + } + std::cout << std::endl; + } + + kernel get_kernel(const char *name) { + kernel_bundle Bundle = + get_kernel_bundle(Queue.get_context()); + for (auto K : AllKernelIDs) { + auto Kname = K.get_name(); + if (strcmp(name, Kname) == 0) { + kernel Kernel = Bundle.get_kernel(K); + return Kernel; + } + } + std::cout << "No kernel named " << name << " found\n"; + exit(1); + } +}; + +void printUSM(int *usmPtr, int size) { + std::cout << "usmPtr[] = {"; + for (int i = 0; i < size; i++) { + std::cout << usmPtr[i] << ", "; + } + std::cout << "}\n"; +} + +bool checkUSM(int *usmPtr, int size, int *Result) { + bool Pass = true; + for (int i = 0; i < size; i++) { + if (usmPtr[i] != Result[i]) { + Pass = false; + break; + } + } + if (Pass) + return true; + + std::cout << "Expected = {"; + for (int i = 0; i < size; i++) { + std::cout << Result[i] << ", "; + } + std::cout << "}\n"; + std::cout << "Result = {"; + for (int i = 0; i < size; i++) { + std::cout << usmPtr[i] << ", "; + } + std::cout << "}\n"; + return false; +} + +SYCL_EXTERNAL +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( + (ext::oneapi::experimental::single_task_kernel)) +void ff_0(int *ptr, int start, int end) { + for (int i = start; i <= end; i++) + ptr[i] = start + end; +} + +bool test_0(queue Queue, KernelFinder &KF) { + constexpr int Range = 10; + int *usmPtr = malloc_shared(Range, Queue); + int start = 3; + int end = 5; + int Result[Range] = {0, 0, 0, 8, 8, 8, 0, 0, 0, 0}; + range<1> R1{Range}; + + memset(usmPtr, 0, Range * sizeof(int)); + Queue.submit([&](handler &Handler) { + Handler.single_task([=]() { + for (int i = start; i <= end; i++) + usmPtr[i] = start + end; + }); + }); + Queue.wait(); + bool PassA = checkUSM(usmPtr, Range, Result); + std::cout << "Test 0a: " << (PassA ? "PASS" : "FAIL") << std::endl; + + kernel Kernel = KF.get_kernel("_Z18__sycl_kernel_ff_0Piii"); + memset(usmPtr, 0, Range * sizeof(int)); + Queue.submit([&](handler &Handler) { + Handler.set_arg(0, usmPtr); + Handler.set_arg(1, start); + Handler.set_arg(2, end); + Handler.single_task(Kernel); + }); + Queue.wait(); + bool PassB = checkUSM(usmPtr, Range, Result); + std::cout << "Test 0b: " << (PassB ? "PASS" : "FAIL") << std::endl; + + free(usmPtr, Queue); + return PassA && PassB; +} + +// Overloaded free function definition. +SYCL_EXTERNAL +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( + (ext::oneapi::experimental::nd_range_kernel<1>)) +void ff_1(int *ptr, int start, int end) { + nd_item<1> Item = ext::oneapi::this_work_item::get_nd_item<1>(); + id<1> GId = Item.get_global_id(); + ptr[GId.get(0)] = GId.get(0) + start + end; +} + +bool test_1(queue Queue, KernelFinder &KF) { + constexpr int Range = 10; + int *usmPtr = malloc_shared(Range, Queue); + int start = 3; + int Result[Range] = {13, 14, 15, 16, 17, 18, 19, 20, 21, 22}; + nd_range<1> R1{{Range}, {1}}; + + memset(usmPtr, 0, Range * sizeof(int)); + Queue.submit([&](handler &Handler) { + Handler.parallel_for(R1, [=](nd_item<1> Item) { + id<1> GId = Item.get_global_id(); + usmPtr[GId.get(0)] = GId.get(0) + start + Range; + }); + }); + Queue.wait(); + bool PassA = checkUSM(usmPtr, Range, Result); + std::cout << "Test 1a: " << (PassA ? "PASS" : "FAIL") << std::endl; + + kernel Kernel = KF.get_kernel("_Z18__sycl_kernel_ff_1Piii"); + memset(usmPtr, 0, Range * sizeof(int)); + Queue.submit([&](handler &Handler) { + Handler.set_arg(0, usmPtr); + Handler.set_arg(1, start); + Handler.set_arg(2, Range); + Handler.parallel_for(R1, Kernel); + }); + Queue.wait(); + bool PassB = checkUSM(usmPtr, Range, Result); + std::cout << "Test 1b: " << (PassB ? "PASS" : "FAIL") << std::endl; + + free(usmPtr, Queue); + return PassA && PassB; +} + +// Overloaded free function definition. +SYCL_EXTERNAL +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( + (ext::oneapi::experimental::nd_range_kernel<2>)) +void ff_1(int *ptr, int start) { + int(&ptr2D)[4][4] = *reinterpret_cast(ptr); + nd_item<2> Item = ext::oneapi::this_work_item::get_nd_item<2>(); + id<2> GId = Item.get_global_id(); + id<2> LId = Item.get_local_id(); + ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + start; +} + +bool test_2(queue Queue, KernelFinder &KF) { + constexpr int Range = 16; + int *usmPtr = malloc_shared(Range, Queue); + int value = 55; + int Result[Range] = {55, 56, 55, 56, 56, 57, 56, 57, + 55, 56, 55, 56, 56, 57, 56, 57}; + nd_range<2> R2{range<2>{4, 4}, range<2>{2, 2}}; + + memset(usmPtr, 0, Range * sizeof(int)); + Queue.submit([&](handler &Handler) { + Handler.parallel_for(R2, [=](nd_item<2> Item) { + int(&ptr2D)[4][4] = *reinterpret_cast(usmPtr); + id<2> GId = Item.get_global_id(); + id<2> LId = Item.get_local_id(); + ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + value; + }); + }); + Queue.wait(); + bool PassA = checkUSM(usmPtr, Range, Result); + std::cout << "Test 2a: " << (PassA ? "PASS" : "FAIL") << std::endl; + + kernel Kernel = KF.get_kernel("_Z18__sycl_kernel_ff_1Pii"); + memset(usmPtr, 0, Range * sizeof(int)); + Queue.submit([&](handler &Handler) { + Handler.set_arg(0, usmPtr); + Handler.set_arg(1, value); + Handler.parallel_for(R2, Kernel); + }); + Queue.wait(); + bool PassB = checkUSM(usmPtr, Range, Result); + std::cout << "Test 2b: " << (PassB ? "PASS" : "FAIL") << std::endl; + + free(usmPtr, Queue); + return PassA && PassB; +} + +// Templated free function definition. +template +SYCL_EXTERNAL SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(( + ext::oneapi::experimental::nd_range_kernel<2>)) void ff_3(T *ptr, T start) { + int(&ptr2D)[4][4] = *reinterpret_cast(ptr); + nd_item<2> Item = ext::oneapi::this_work_item::get_nd_item<2>(); + id<2> GId = Item.get_global_id(); + id<2> LId = Item.get_local_id(); + ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + start; +} + +// Explicit instantiation with �int*�. +template void ff_3(int *ptr, int start); + +bool test_3(queue Queue, KernelFinder &KF) { + constexpr int Range = 16; + int *usmPtr = malloc_shared(Range, Queue); + int value = 55; + int Result[Range] = {55, 56, 55, 56, 56, 57, 56, 57, + 55, 56, 55, 56, 56, 57, 56, 57}; + nd_range<2> R2{range<2>{4, 4}, range<2>{2, 2}}; + + memset(usmPtr, 0, Range * sizeof(int)); + Queue.submit([&](handler &Handler) { + Handler.parallel_for(R2, [=](nd_item<2> Item) { + int(&ptr2D)[4][4] = *reinterpret_cast(usmPtr); + id<2> GId = Item.get_global_id(); + id<2> LId = Item.get_local_id(); + ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + value; + }); + }); + Queue.wait(); + bool PassA = checkUSM(usmPtr, Range, Result); + std::cout << "Test 3a: " << (PassA ? "PASS" : "FAIL") << std::endl; + + kernel Kernel = KF.get_kernel("_Z18__sycl_kernel_ff_3IiEvPT_S0_"); + memset(usmPtr, 0, Range * sizeof(int)); + Queue.submit([&](handler &Handler) { + Handler.set_arg(0, usmPtr); + Handler.set_arg(1, value); + Handler.parallel_for(R2, Kernel); + }); + Queue.wait(); + bool PassB = checkUSM(usmPtr, Range, Result); + std::cout << "Test 3b: " << (PassB ? "PASS" : "FAIL") << std::endl; + + free(usmPtr, Queue); + return PassA && PassB; +} + +int main() { + queue Queue; + KernelFinder KF(Queue); + + bool Pass = true; + Pass &= test_0(Queue, KF); + Pass &= test_1(Queue, KF); + Pass &= test_2(Queue, KF); + Pass &= test_3(Queue, KF); + + return Pass ? 0 : 1; +}