-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][OpenMP] Introduce overlapped record type map support #119588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/agozillo/declare-target-to-1
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -2979,39 +2979,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) { | |||||||
return std::distance(mapData.MapClause.begin(), res); | ||||||||
} | ||||||||
|
||||||||
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, | ||||||||
bool first) { | ||||||||
ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); | ||||||||
// Only 1 member has been mapped, we can return it. | ||||||||
if (indexAttr.size() == 1) | ||||||||
return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()); | ||||||||
static void sortMapIndices(llvm::SmallVector<size_t> &indices, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
mlir::omp::MapInfoOp mapInfo, | ||||||||
bool ascending = true) { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems a bit overkill to introduce this argument and allow sorting the list in reverse order just so that we can get the first or the last element in |
||||||||
mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); | ||||||||
if (indexAttr.empty() || indexAttr.size() == 1 || indices.empty() || | ||||||||
indices.size() == 1) | ||||||||
return; | ||||||||
Comment on lines
+2986
to
+2988
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I think this isn't necessary. |
||||||||
|
||||||||
llvm::SmallVector<size_t> indices(indexAttr.size()); | ||||||||
std::iota(indices.begin(), indices.end(), 0); | ||||||||
llvm::sort( | ||||||||
indices.begin(), indices.end(), [&](const size_t a, const size_t b) { | ||||||||
auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]); | ||||||||
auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]); | ||||||||
|
||||||||
size_t smallestMember = memberIndicesA.size() < memberIndicesB.size() | ||||||||
? memberIndicesA.size() | ||||||||
: memberIndicesB.size(); | ||||||||
|
||||||||
llvm::sort(indices.begin(), indices.end(), | ||||||||
[&](const size_t a, const size_t b) { | ||||||||
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); | ||||||||
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); | ||||||||
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { | ||||||||
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); | ||||||||
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); | ||||||||
for (size_t i = 0; i < smallestMember; ++i) { | ||||||||
Comment on lines
+2995
to
+2999
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: |
||||||||
int64_t aIndex = | ||||||||
mlir::cast<mlir::IntegerAttr>(memberIndicesA.getValue()[i]) | ||||||||
.getInt(); | ||||||||
int64_t bIndex = | ||||||||
mlir::cast<mlir::IntegerAttr>(memberIndicesB.getValue()[i]) | ||||||||
.getInt(); | ||||||||
|
||||||||
if (aIndex == bIndex) | ||||||||
continue; | ||||||||
if (aIndex == bIndex) | ||||||||
continue; | ||||||||
|
||||||||
if (aIndex < bIndex) | ||||||||
return first; | ||||||||
if (aIndex < bIndex) | ||||||||
return ascending; | ||||||||
|
||||||||
if (aIndex > bIndex) | ||||||||
return !first; | ||||||||
} | ||||||||
if (aIndex > bIndex) | ||||||||
return !ascending; | ||||||||
} | ||||||||
|
||||||||
// Iterated the up until the end of the smallest member and | ||||||||
// they were found to be equal up to that point, so select | ||||||||
// the member with the lowest index count, so the "parent" | ||||||||
return memberIndicesA.size() < memberIndicesB.size(); | ||||||||
}); | ||||||||
// Iterated up until the end of the smallest member and | ||||||||
// they were found to be equal up to that point, so select | ||||||||
// the member with the lowest index count, so the "parent" | ||||||||
return memberIndicesA.size() < memberIndicesB.size(); | ||||||||
}); | ||||||||
} | ||||||||
|
||||||||
static mlir::omp::MapInfoOp | ||||||||
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) { | ||||||||
mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); | ||||||||
// Only 1 member has been mapped, we can return it. | ||||||||
if (indexAttr.size() == 1) | ||||||||
if (auto mapOp = | ||||||||
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp())) | ||||||||
Comment on lines
+3029
to
+3030
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me know if I understood this wrong, but it seems like there is nothing preventing the I don't know whether this function can be expected to return |
||||||||
return mapOp; | ||||||||
|
||||||||
llvm::SmallVector<size_t> indices; | ||||||||
indices.resize(indexAttr.size()); | ||||||||
Comment on lines
+3033
to
+3034
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
std::iota(indices.begin(), indices.end(), 0); | ||||||||
sortMapIndices(indices, mapInfo, first); | ||||||||
|
||||||||
return llvm::cast<omp::MapInfoOp>( | ||||||||
mapInfo.getMembers()[indices.front()].getDefiningOp()); | ||||||||
|
@@ -3110,6 +3132,91 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, | |||||||
return idx; | ||||||||
} | ||||||||
|
||||||||
// Gathers members that are overlapping in the parent, excluding members that | ||||||||
// themselves overlap, keeping the top-most (closest to parents level) map. | ||||||||
static void getOverlappedMembers(llvm::SmallVector<size_t> &overlapMapDataIdxs, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
MapInfoData &mapData, | ||||||||
omp::MapInfoOp parentOp) { | ||||||||
// No members mapped, no overlaps. | ||||||||
if (parentOp.getMembers().empty()) | ||||||||
return; | ||||||||
|
||||||||
// Single member, we can insert and return early. | ||||||||
if (parentOp.getMembers().size() == 1) { | ||||||||
overlapMapDataIdxs.push_back(0); | ||||||||
return; | ||||||||
} | ||||||||
|
||||||||
// 1) collect list of top-level overlapping members from MemberOp | ||||||||
llvm::SmallVector<std::pair<int, mlir::ArrayAttr>> memberByIndex; | ||||||||
mlir::ArrayAttr indexAttr = parentOp.getMembersIndexAttr(); | ||||||||
for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr)) | ||||||||
memberByIndex.push_back( | ||||||||
std::make_pair(memIndex, mlir::cast<mlir::ArrayAttr>(indicesAttr))); | ||||||||
|
||||||||
// Sort the smallest first (higher up the parent -> member chain), so that | ||||||||
// when we remove members, we remove as much as we can in the initial | ||||||||
// iterations, shortening the number of passes required. | ||||||||
llvm::sort(memberByIndex.begin(), memberByIndex.end(), | ||||||||
[&](auto a, auto b) { return a.second.size() < b.second.size(); }); | ||||||||
|
||||||||
auto getAsIntegers = [](mlir::ArrayAttr values) { | ||||||||
llvm::SmallVector<int64_t> ints; | ||||||||
ints.reserve(values.size()); | ||||||||
llvm::transform(values, std::back_inserter(ints), | ||||||||
[](mlir::Attribute value) { | ||||||||
return mlir::cast<mlir::IntegerAttr>(value).getInt(); | ||||||||
}); | ||||||||
return ints; | ||||||||
}; | ||||||||
|
||||||||
// Remove elements from the vector if there is a parent element that | ||||||||
// supersedes it. i.e. if member [0] is mapped, we can remove members [0,1], | ||||||||
// [0,2].. etc. | ||||||||
for (auto v : make_early_inc_range(memberByIndex)) { | ||||||||
auto vArr = getAsIntegers(v.second); | ||||||||
memberByIndex.erase( | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we know for sure this always works? Reading the documentation for Maybe it would be safer to just create an integer set of to-be-skipped elements and only add to |
||||||||
std::remove_if(memberByIndex.begin(), memberByIndex.end(), | ||||||||
[&](auto x) { | ||||||||
if (v == x) | ||||||||
return false; | ||||||||
|
||||||||
auto xArr = getAsIntegers(x.second); | ||||||||
return std::equal(vArr.begin(), vArr.end(), | ||||||||
xArr.begin()) && | ||||||||
xArr.size() >= vArr.size(); | ||||||||
}), | ||||||||
memberByIndex.end()); | ||||||||
} | ||||||||
|
||||||||
// Collect the indices from mapData that we need, as we technically need the | ||||||||
// base pointer etc. info, which is stored in there and primarily accessible | ||||||||
// via index at the moment. | ||||||||
for (auto v : memberByIndex) | ||||||||
overlapMapDataIdxs.push_back(v.first); | ||||||||
} | ||||||||
|
||||||||
// The intent is to verify if the mapped data being passed is a | ||||||||
// pointer -> pointee that requires special handling in certain cases, | ||||||||
// e.g. applying the OMP_MAP_PTR_AND_OBJ map type. | ||||||||
// | ||||||||
// There may be a better way to verify this, but unfortunately with | ||||||||
// opaque pointers we lose the ability to easily check if something is | ||||||||
// a pointer whilst maintaining access to the underlying type. | ||||||||
static bool checkIfPointerMap(omp::MapInfoOp mapOp) { | ||||||||
// If we have a varPtrPtr field assigned then the underlying type is a pointer | ||||||||
if (mapOp.getVarPtrPtr()) | ||||||||
return true; | ||||||||
|
||||||||
// If the map data is declare target with a link clause, then it's represented | ||||||||
// as a pointer when we lower it to LLVM-IR even if at the MLIR level it has | ||||||||
// no relation to pointers. | ||||||||
if (isDeclareTargetLink(mapOp.getVarPtr())) | ||||||||
return true; | ||||||||
|
||||||||
return false; | ||||||||
} | ||||||||
|
||||||||
// This creates two insertions into the MapInfosTy data structure for the | ||||||||
// "parent" of a set of members, (usually a container e.g. | ||||||||
// class/structure/derived type) when subsequent members have also been | ||||||||
|
@@ -3150,7 +3257,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( | |||||||
// runtime information on the dynamically allocated data). | ||||||||
auto parentClause = | ||||||||
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]); | ||||||||
|
||||||||
llvm::Value *lowAddr, *highAddr; | ||||||||
if (!parentClause.getPartialMap()) { | ||||||||
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex], | ||||||||
|
@@ -3197,37 +3303,77 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( | |||||||
// what we support as expected. | ||||||||
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex]; | ||||||||
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); | ||||||||
combinedInfo.Types.emplace_back(mapFlag); | ||||||||
combinedInfo.DevicePointers.emplace_back( | ||||||||
llvm::OpenMPIRBuilder::DeviceInfoTy::None); | ||||||||
combinedInfo.Names.emplace_back(LLVM::createMappingInformation( | ||||||||
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); | ||||||||
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); | ||||||||
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]); | ||||||||
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]); | ||||||||
} | ||||||||
return memberOfFlag; | ||||||||
} | ||||||||
|
||||||||
// The intent is to verify if the mapped data being passed is a | ||||||||
// pointer -> pointee that requires special handling in certain cases, | ||||||||
// e.g. applying the OMP_MAP_PTR_AND_OBJ map type. | ||||||||
// | ||||||||
// There may be a better way to verify this, but unfortunately with | ||||||||
// opaque pointers we lose the ability to easily check if something is | ||||||||
// a pointer whilst maintaining access to the underlying type. | ||||||||
static bool checkIfPointerMap(omp::MapInfoOp mapOp) { | ||||||||
// If we have a varPtrPtr field assigned then the underlying type is a pointer | ||||||||
if (mapOp.getVarPtrPtr()) | ||||||||
return true; | ||||||||
|
||||||||
// If the map data is declare target with a link clause, then it's represented | ||||||||
// as a pointer when we lower it to LLVM-IR even if at the MLIR level it has | ||||||||
// no relation to pointers. | ||||||||
if (isDeclareTargetLink(mapOp.getVarPtr())) | ||||||||
return true; | ||||||||
if (targetDirective == TargetDirective::TargetUpdate) { | ||||||||
combinedInfo.Types.emplace_back(mapFlag); | ||||||||
combinedInfo.DevicePointers.emplace_back( | ||||||||
mapData.DevicePointers[mapDataIndex]); | ||||||||
combinedInfo.Names.emplace_back(LLVM::createMappingInformation( | ||||||||
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); | ||||||||
combinedInfo.BasePointers.emplace_back( | ||||||||
mapData.BasePointers[mapDataIndex]); | ||||||||
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]); | ||||||||
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]); | ||||||||
} else { | ||||||||
llvm::SmallVector<size_t> overlapIdxs; | ||||||||
// Find all of the members that "overlap", i.e. occlude other members that | ||||||||
// were mapped alongside the parent, e.g. member [0], occludes | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: This comment seems to be incomplete. |
||||||||
getOverlappedMembers(overlapIdxs, mapData, parentClause); | ||||||||
// We need to make sure the overlapped members are sorted in order of | ||||||||
// lowest address to highest address | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
sortMapIndices(overlapIdxs, parentClause); | ||||||||
|
||||||||
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex], | ||||||||
builder.getPtrTy()); | ||||||||
highAddr = builder.CreatePointerCast( | ||||||||
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex], | ||||||||
mapData.Pointers[mapDataIndex], 1), | ||||||||
builder.getPtrTy()); | ||||||||
|
||||||||
// TODO: We may want to skip arrays/array sections in this as Clang does | ||||||||
// so it appears to be an optimisation rather than a neccessity though, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
// but this requires further investigation. However, we would have to make | ||||||||
// sure to not exclude maps with bounds that ARE pointers, as these are | ||||||||
// processed as seperate components, i.e. pointer + data. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
for (auto v : overlapIdxs) { | ||||||||
auto mapDataOverlapIdx = getMapDataMemberIdx( | ||||||||
mapData, | ||||||||
cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp())); | ||||||||
combinedInfo.Types.emplace_back(mapFlag); | ||||||||
combinedInfo.DevicePointers.emplace_back( | ||||||||
mapData.DevicePointers[mapDataOverlapIdx]); | ||||||||
combinedInfo.Names.emplace_back(LLVM::createMappingInformation( | ||||||||
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); | ||||||||
combinedInfo.BasePointers.emplace_back( | ||||||||
mapData.BasePointers[mapDataIndex]); | ||||||||
combinedInfo.Pointers.emplace_back(lowAddr); | ||||||||
combinedInfo.Sizes.emplace_back(builder.CreateIntCast( | ||||||||
builder.CreatePtrDiff(builder.getInt8Ty(), | ||||||||
mapData.OriginalValue[mapDataOverlapIdx], | ||||||||
lowAddr), | ||||||||
builder.getInt64Ty(), /*isSigned=*/true)); | ||||||||
lowAddr = builder.CreateConstGEP1_32( | ||||||||
checkIfPointerMap(llvm::cast<omp::MapInfoOp>( | ||||||||
mapData.MapClause[mapDataOverlapIdx])) | ||||||||
? builder.getPtrTy() | ||||||||
: mapData.BaseType[mapDataOverlapIdx], | ||||||||
mapData.BasePointers[mapDataOverlapIdx], 1); | ||||||||
} | ||||||||
|
||||||||
return false; | ||||||||
combinedInfo.Types.emplace_back(mapFlag); | ||||||||
combinedInfo.DevicePointers.emplace_back( | ||||||||
mapData.DevicePointers[mapDataIndex]); | ||||||||
combinedInfo.Names.emplace_back(LLVM::createMappingInformation( | ||||||||
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); | ||||||||
combinedInfo.BasePointers.emplace_back( | ||||||||
mapData.BasePointers[mapDataIndex]); | ||||||||
combinedInfo.Pointers.emplace_back(lowAddr); | ||||||||
combinedInfo.Sizes.emplace_back(builder.CreateIntCast( | ||||||||
builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr), | ||||||||
builder.getInt64Ty(), true)); | ||||||||
} | ||||||||
} | ||||||||
return memberOfFlag; | ||||||||
} | ||||||||
|
||||||||
// This function is intended to add explicit mappings of members | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General nit for changes in this file: There's a
using namespace mlir
, so we can removemlir::
. Same forllvm::
cast-style functions, which are present in themlir
namespace as well.