Skip to content

[SYCL] Add support for __registered_kernels__ #16485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -2141,6 +2141,22 @@ def SYCLAddIRAnnotationsMember : InheritableAttr {
let Documentation = [SYCLAddIRAnnotationsMemberDocs];
}

def SYCLRegisteredKernels : InheritableAttr {
let Spellings = [CXX11<"__sycl_detail__", "__registered_kernels__">];
let Args = [VariadicExprArgument<"Args">];
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
let Subjects = SubjectList<[Empty], ErrorDiag, "Translation Unit Scope">;
let AdditionalMembers = SYCLAddIRAttrCommonMembers.MemberCode;
let Documentation = [SYCLAddIRAnnotationsMemberDocs];
}

def SYCLRegisteredKernelName : InheritableAttr {
let Spellings = [];
let Subjects = SubjectList<[Function]>;
let Args = [StringArgument<"RegName">];
let Documentation = [InternalOnly];
}

def C11NoReturn : InheritableAttr {
let Spellings = [CustomKeyword<"_Noreturn">];
let Subjects = SubjectList<[Function], ErrorDiag>;
Expand Down
14 changes: 14 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12514,6 +12514,20 @@ def err_sycl_special_type_num_init_method : Error<
def warn_launch_bounds_is_cuda_specific : Warning<
"%0 attribute ignored, only applicable when targeting Nvidia devices">,
InGroup<IgnoredAttributes>;
def err_registered_kernels_num_of_args : Error<
"'__registered_kernels__' attribute must have at least one argument">;
def err_registered_kernels_init_list : Error<
"argument to the '__registered_kernels__' attribute must be an "
"initializer list expression">;
def err_registered_kernels_init_list_pair_values : Error<
"each initializer list argument to the '__registered_kernels__' attribute "
"must contain a pair of values">;
def err_registered_kernels_resolve_function : Error<
"unable to resolve free function kernel '%0'">;
def err_registered_kernels_name_already_registered : Error<
"free function kernel has already been registered with '%0'; cannot register with '%1'">;
def err_not_sycl_free_function : Error<
"attempting to register a function that is not a SYCL free function as '%0'">;

def warn_cuda_maxclusterrank_sm_90 : Warning<
"'maxclusterrank' requires sm_90 or higher, CUDA arch provided: %0, ignoring "
Expand Down
14 changes: 12 additions & 2 deletions clang/include/clang/Sema/SemaSYCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,9 @@ class SemaSYCL : public SemaBase {
// We need to store the list of the sycl_kernel functions and their associated
// generated OpenCL Kernels so we can go back and re-name these after the
// fact.
llvm::SmallVector<std::pair<const FunctionDecl *, FunctionDecl *>>
SyclKernelsToOpenCLKernels;
using KernelFDPairs =
llvm::SmallVector<std::pair<const FunctionDecl *, FunctionDecl *>>;
KernelFDPairs SyclKernelsToOpenCLKernels;

// Used to suppress diagnostics during kernel construction, since these were
// already emitted earlier. Diagnosing during Kernel emissions also skips the
Expand Down Expand Up @@ -296,11 +297,15 @@ class SemaSYCL : public SemaBase {
llvm::DenseSet<QualType> Visited,
ValueDecl *DeclToCheck);

const KernelFDPairs &getKernelFDPairs() { return SyclKernelsToOpenCLKernels; }

void addSyclOpenCLKernel(const FunctionDecl *SyclKernel,
FunctionDecl *OpenCLKernel) {
SyclKernelsToOpenCLKernels.emplace_back(SyclKernel, OpenCLKernel);
}

void constructFreeFunctionKernel(FunctionDecl *FD, StringRef NameStr = "");

void addSyclDeviceDecl(Decl *d) { SyclDeviceDecls.insert(d); }
llvm::SetVector<Decl *> &syclDeviceDecls() { return SyclDeviceDecls; }

Expand Down Expand Up @@ -480,6 +485,7 @@ class SemaSYCL : public SemaBase {
void handleSYCLIntelMaxWorkGroupsPerMultiprocessor(Decl *D,
const ParsedAttr &AL);
void handleSYCLScopeAttr(Decl *D, const ParsedAttr &AL);
void handleSYCLRegisteredKernels(Decl *D, const ParsedAttr &AL);

void checkSYCLAddIRAttributesFunctionAttrConflicts(Decl *D);

Expand Down Expand Up @@ -655,6 +661,10 @@ class SemaSYCL : public SemaBase {
void addIntelReqdSubGroupSizeAttr(Decl *D, const AttributeCommonInfo &CI,
Expr *E);
void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL);

// Used to check whether the function represented by FD is a SYCL
// free function kernel or not.
bool isFreeFunction(const FunctionDecl *FD);
};

} // namespace clang
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,12 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,

llvm::LLVMContext &Context = getLLVMContext();

if (getLangOpts().SYCLIsDevice)
if (FD->hasAttr<SYCLRegisteredKernelNameAttr>())
CGM.SYCLAddRegKernelNamePairs(
FD->getAttr<SYCLRegisteredKernelNameAttr>()->getRegName(),
FD->getNameAsString());

if (FD->hasAttr<OpenCLKernelAttr>() || FD->hasAttr<CUDAGlobalAttr>())
CGM.GenKernelArgMetadata(Fn, FD, this);

Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,19 @@ void CodeGenModule::Release() {
AspectEnumValsMD->addOperand(
getAspectEnumValueMD(Context, TheModule.getContext(), ECD));
}

if (!SYCLRegKernelNames.empty()) {
std::vector<llvm::Metadata *> Nodes;
llvm::LLVMContext &Ctx = TheModule.getContext();
for (auto MDKernelNames : SYCLRegKernelNames) {
llvm::Metadata *Vals[] = {MDKernelNames.first, MDKernelNames.second};
Nodes.push_back(llvm::MDTuple::get(Ctx, Vals));
}

llvm::NamedMDNode *SYCLRegKernelsMD =
TheModule.getOrInsertNamedMetadata("sycl_registered_kernels");
SYCLRegKernelsMD->addOperand(llvm::MDNode::get(Ctx, Nodes));
}
}

// HLSL related end of code gen work items.
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ class CodeGenModule : public CodeGenTypeCache {
/// handled differently than regular annotations so they cannot share map.
llvm::DenseMap<unsigned, llvm::Constant *> SYCLAnnotationArgs;

typedef std::pair<llvm::Metadata *, llvm::Metadata *> MetadataPair;
SmallVector<MetadataPair, 4> SYCLRegKernelNames;

llvm::StringMap<llvm::GlobalVariable *> CFConstantStringMap;

llvm::DenseMap<llvm::Constant *, llvm::GlobalVariable *> ConstantStringMap;
Expand Down Expand Up @@ -1483,6 +1486,12 @@ class CodeGenModule : public CodeGenTypeCache {
llvm::Constant *EmitSYCLAnnotationArgs(
SmallVectorImpl<std::pair<std::string, std::string>> &Pairs);

void SYCLAddRegKernelNamePairs(StringRef First, StringRef Second) {
SYCLRegKernelNames.push_back(
std::make_pair(llvm::MDString::get(getLLVMContext(), First),
llvm::MDString::get(getLLVMContext(), Second)));
}

/// Add attributes from add_ir_attributes_global_variable on TND to GV.
void AddGlobalSYCLIRAttributes(llvm::GlobalVariable *GV,
const RecordDecl *RD);
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7424,6 +7424,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_SYCLAddIRAnnotationsMember:
S.SYCL().handleSYCLAddIRAnnotationsMemberAttr(D, AL);
break;
case ParsedAttr::AT_SYCLRegisteredKernels:
S.SYCL().handleSYCLRegisteredKernels(D, AL);
break;

// Swift attributes.
case ParsedAttr::AT_SwiftAsyncName:
Expand Down
78 changes: 60 additions & 18 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1148,10 +1148,10 @@ static target getAccessTarget(QualType FieldTy,

// FIXME: Free functions must have void return type and be declared at file
// scope, outside any namespaces.
static bool isFreeFunction(SemaSYCL &SemaSYCLRef, const FunctionDecl *FD) {
bool SemaSYCL::isFreeFunction(const FunctionDecl *FD) {
for (auto *IRAttr : FD->specific_attrs<SYCLAddIRAttributesFunctionAttr>()) {
SmallVector<std::pair<std::string, std::string>, 4> NameValuePairs =
IRAttr->getAttributeNameValuePairs(SemaSYCLRef.getASTContext());
IRAttr->getAttributeNameValuePairs(getASTContext());
for (const auto &NameValuePair : NameValuePairs) {
if (NameValuePair.first == "sycl-nd-range-kernel" ||
NameValuePair.first == "sycl-single-task-kernel") {
Expand Down Expand Up @@ -5291,7 +5291,7 @@ void SemaSYCL::SetSYCLKernelNames() {
SyclKernelsToOpenCLKernels) {
std::string CalculatedName, StableName;
StringRef KernelName;
if (isFreeFunction(*this, Pair.first)) {
if (isFreeFunction(Pair.first)) {
std::tie(CalculatedName, StableName) =
constructFreeFunctionKernelName(*this, Pair.first, *MangleCtx);
KernelName = CalculatedName;
Expand Down Expand Up @@ -5414,24 +5414,66 @@ void SemaSYCL::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
}
}

void ConstructFreeFunctionKernel(SemaSYCL &SemaSYCLRef, FunctionDecl *FD) {
SyclKernelArgsSizeChecker argsSizeChecker(SemaSYCLRef, FD->getLocation(),
static void addRegisteredKernelName(SemaSYCL &S, StringRef Str,
FunctionDecl *FD, SourceLocation Loc) {
if (!Str.empty())
FD->addAttr(SYCLRegisteredKernelNameAttr::CreateImplicit(S.getASTContext(),
Str, Loc));
}

static bool checkAndAddRegisteredKernelName(SemaSYCL &S, FunctionDecl *FD,
StringRef Str) {
using KernelPair = std::pair<const FunctionDecl *, FunctionDecl *>;
for (const KernelPair &Pair : S.getKernelFDPairs()) {
if (Pair.first == FD) {
// If the current list of free function entries already contains this
// free function, apply the name Str as an attribute. But if it already
// has an attribute name, issue a diagnostic instead.
if (!Str.empty()) {
if (!Pair.second->hasAttr<SYCLRegisteredKernelNameAttr>())
addRegisteredKernelName(S, Str, Pair.second, FD->getLocation());
else
S.Diag(FD->getLocation(),
diag::err_registered_kernels_name_already_registered)
<< Pair.second->getAttr<SYCLRegisteredKernelNameAttr>()
->getRegName()
<< Str;
}
// An empty name string implies a regular free kernel construction
// call, so simply return.
return false;
}
}
return true;
}

void SemaSYCL::constructFreeFunctionKernel(FunctionDecl *FD,
StringRef NameStr) {
if (!checkAndAddRegisteredKernelName(*this, FD, NameStr))
return;

SyclKernelArgsSizeChecker argsSizeChecker(*this, FD->getLocation(),
false /*IsSIMDKernel*/);
SyclKernelDeclCreator kernel_decl(SemaSYCLRef, FD->getLocation(),
FD->isInlined(), false /*IsSIMDKernel */,
FD);
SyclKernelDeclCreator kernel_decl(*this, FD->getLocation(), FD->isInlined(),
false /*IsSIMDKernel */, FD);

FreeFunctionKernelBodyCreator kernel_body(SemaSYCLRef, kernel_decl, FD);
FreeFunctionKernelBodyCreator kernel_body(*this, kernel_decl, FD);

SyclKernelIntHeaderCreator int_header(
SemaSYCLRef, SemaSYCLRef.getSyclIntegrationHeader(), FD->getType(), FD);
SyclKernelIntHeaderCreator int_header(*this, getSyclIntegrationHeader(),
FD->getType(), FD);

SyclKernelIntFooterCreator int_footer(SemaSYCLRef,
SemaSYCLRef.getSyclIntegrationFooter());
KernelObjVisitor Visitor{SemaSYCLRef};
SyclKernelIntFooterCreator int_footer(*this, getSyclIntegrationFooter());
KernelObjVisitor Visitor{*this};

Visitor.VisitFunctionParameters(FD, argsSizeChecker, kernel_decl, kernel_body,
int_header, int_footer);

assert(getKernelFDPairs().back().first == FD &&
"OpenCL Kernel not found for free function entry");
// Register the kernel name with the OpenCL kernel generated for the
// free function.
addRegisteredKernelName(*this, NameStr, getKernelFDPairs().back().second,
FD->getLocation());
}

// Figure out the sub-group for the this function. First we check the
Expand Down Expand Up @@ -5717,7 +5759,7 @@ void SemaSYCL::MarkDevices() {
}

void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
if (isFreeFunction(*this, FD)) {
if (isFreeFunction(FD)) {
SyclKernelDecompMarker DecompMarker(*this);
SyclKernelFieldChecker FieldChecker(*this);
SyclKernelUnionChecker UnionChecker(*this);
Expand All @@ -5736,7 +5778,7 @@ void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
if (!FieldChecker.isValid() || !UnionChecker.isValid())
return;

ConstructFreeFunctionKernel(*this, FD);
constructFreeFunctionKernel(FD);
}
}

Expand Down Expand Up @@ -6621,7 +6663,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
unsigned ShimCounter = 1;
int FreeFunctionCount = 0;
for (const KernelDesc &K : KernelDescs) {
if (!isFreeFunction(S, K.SyclKernel))
if (!S.isFreeFunction(K.SyclKernel))
continue;
++FreeFunctionCount;
// Generate forward declaration for free function.
Expand Down Expand Up @@ -6739,7 +6781,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
}
ShimCounter = 1;
for (const KernelDesc &K : KernelDescs) {
if (!isFreeFunction(S, K.SyclKernel))
if (!S.isFreeFunction(K.SyclKernel))
continue;

O << "\n// Definition of kernel_id of " << K.Name << "\n";
Expand Down
65 changes: 65 additions & 0 deletions clang/lib/Sema/SemaSYCLDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3162,3 +3162,68 @@ void SemaSYCL::checkSYCLAddIRAttributesFunctionAttrConflicts(Decl *D) {
Diag(Attr->getLoc(), diag::warn_sycl_old_and_new_kernel_attributes)
<< Attr;
}

void SemaSYCL::handleSYCLRegisteredKernels(Decl *D, const ParsedAttr &A) {
// Check for SYCL device compilation context.
if (!getLangOpts().SYCLIsDevice)
return;

unsigned NumArgs = A.getNumArgs();
// When declared, we expect at least one item in the list.
if (NumArgs == 0) {
Diag(A.getLoc(), diag::err_registered_kernels_num_of_args);
return;
}

// Traverse through the items in the list.
for (unsigned I = 0; I < NumArgs; I++) {
assert(A.isArgExpr(I) && "Expected expression argument");
// Each item in the list must be an initializer list expression.
Expr *ArgExpr = A.getArgAsExpr(I);
if (!isa<InitListExpr>(ArgExpr)) {
Diag(ArgExpr->getExprLoc(), diag::err_registered_kernels_init_list);
return;
}

auto *ArgListE = cast<InitListExpr>(ArgExpr);
unsigned NumInits = ArgListE->getNumInits();
// Each init-list expression must have a pair of values.
if (NumInits != 2) {
Diag(ArgExpr->getExprLoc(),
diag::err_registered_kernels_init_list_pair_values);
return;
}

// The first value of the pair must be a string.
Expr *FirstExpr = ArgListE->getInit(0);
StringRef CurStr;
SourceLocation Loc = FirstExpr->getExprLoc();
if (!SemaRef.checkStringLiteralArgumentAttr(A, FirstExpr, CurStr, &Loc))
return;

// Resolve the FunctionDecl from the second value of the pair.
Expr *SecondE = ArgListE->getInit(1);
FunctionDecl *FD = nullptr;
if (auto *ULE = dyn_cast<UnresolvedLookupExpr>(SecondE)) {
FD = SemaRef.ResolveSingleFunctionTemplateSpecialization(ULE, true);
Loc = ULE->getExprLoc();
} else {
SecondE = SecondE->IgnoreParenCasts();
if (auto *DRE = dyn_cast<DeclRefExpr>(SecondE))
FD = dyn_cast<FunctionDecl>(DRE->getDecl());
Loc = SecondE->getExprLoc();
}
// Issue a diagnostic if we are unable to resolve the FunctionDecl.
if (!FD) {
Diag(Loc, diag::err_registered_kernels_resolve_function) << CurStr;
return;
}
// Issue a diagnostic is the FunctionDecl is not a SYCL free function.
if (!isFreeFunction(FD)) {
Diag(FD->getLocation(), diag::err_not_sycl_free_function) << CurStr;
return;
}
// Construct a free function kernel.
constructFreeFunctionKernel(FD, CurStr);
}
}
Loading
Loading