Skip to content

[MLIR][OpenMP] Introduce overlapped record type map support #119588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/agozillo/declare-target-to-1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 77 additions & 53 deletions flang/test/Integration/OpenMP/map-types-and-sizes.f90

Large diffs are not rendered by default.

260 changes: 203 additions & 57 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2979,39 +2979,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
return std::distance(mapData.MapClause.begin(), res);
}

static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
bool first) {
ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
// Only 1 member has been mapped, we can return it.
if (indexAttr.size() == 1)
return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
static void sortMapIndices(llvm::SmallVector<size_t> &indices,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General nit for changes in this file: There's a using namespace mlir, so we can remove mlir::. Same for llvm:: cast-style functions, which are present in the mlir namespace as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static void sortMapIndices(llvm::SmallVector<size_t> &indices,
static void sortMapIndices(llvm::SmallVectorImpl<size_t> &indices,

mlir::omp::MapInfoOp mapInfo,
bool ascending = true) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit overkill to introduce this argument and allow sorting the list in reverse order just so that we can get the first or the last element in getFirstOrLastMappedMemberPtr. Wouldn't it be simpler to just update the mapInfo.getMembers()[indices.front()].getDefiningOp()); expression to take indices.front() or indices.back() based on the first argument?

mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
if (indexAttr.empty() || indexAttr.size() == 1 || indices.empty() ||
indices.size() == 1)
return;
Comment on lines +2986 to +2988
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think this isn't necessary. std::sort, in which llvm::sort seems to be based, already returns early in these cases.


llvm::SmallVector<size_t> indices(indexAttr.size());
std::iota(indices.begin(), indices.end(), 0);
llvm::sort(
indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);

size_t smallestMember = memberIndicesA.size() < memberIndicesB.size()
? memberIndicesA.size()
: memberIndicesB.size();

llvm::sort(indices.begin(), indices.end(),
[&](const size_t a, const size_t b) {
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
for (size_t i = 0; i < smallestMember; ++i) {
Comment on lines +2995 to +2999
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: llvm::zip already iterates as long as both ranges have elements, so it stops at the shortest. I think it's better to use it in this case.

int64_t aIndex =
mlir::cast<mlir::IntegerAttr>(memberIndicesA.getValue()[i])
.getInt();
int64_t bIndex =
mlir::cast<mlir::IntegerAttr>(memberIndicesB.getValue()[i])
.getInt();

if (aIndex == bIndex)
continue;
if (aIndex == bIndex)
continue;

if (aIndex < bIndex)
return first;
if (aIndex < bIndex)
return ascending;

if (aIndex > bIndex)
return !first;
}
if (aIndex > bIndex)
return !ascending;
}

// Iterated the up until the end of the smallest member and
// they were found to be equal up to that point, so select
// the member with the lowest index count, so the "parent"
return memberIndicesA.size() < memberIndicesB.size();
});
// Iterated up until the end of the smallest member and
// they were found to be equal up to that point, so select
// the member with the lowest index count, so the "parent"
return memberIndicesA.size() < memberIndicesB.size();
});
}

static mlir::omp::MapInfoOp
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
// Only 1 member has been mapped, we can return it.
if (indexAttr.size() == 1)
if (auto mapOp =
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
Comment on lines +3029 to +3030
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if I understood this wrong, but it seems like there is nothing preventing the llvm::cast call at the end of this function to trigger an assert if there was a single member mapped that wasn't defined by an omp.map.info.

I don't know whether this function can be expected to return null, in which case we could replace the cast below with a dyn_cast, or if this check here should be replaced with return cast<omp::MapInfoOp>(...).

return mapOp;

llvm::SmallVector<size_t> indices;
indices.resize(indexAttr.size());
Comment on lines +3033 to +3034
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::SmallVector<size_t> indices;
indices.resize(indexAttr.size());
llvm::SmallVector<size_t> indices(indexAttr.size());

std::iota(indices.begin(), indices.end(), 0);
sortMapIndices(indices, mapInfo, first);

return llvm::cast<omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
Expand Down Expand Up @@ -3110,6 +3132,91 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
return idx;
}

// Gathers members that are overlapping in the parent, excluding members that
// themselves overlap, keeping the top-most (closest to parents level) map.
static void getOverlappedMembers(llvm::SmallVector<size_t> &overlapMapDataIdxs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static void getOverlappedMembers(llvm::SmallVector<size_t> &overlapMapDataIdxs,
static void getOverlappedMembers(llvm::SmallVectorImpl<size_t> &overlapMapDataIdxs,

MapInfoData &mapData,
omp::MapInfoOp parentOp) {
// No members mapped, no overlaps.
if (parentOp.getMembers().empty())
return;

// Single member, we can insert and return early.
if (parentOp.getMembers().size() == 1) {
overlapMapDataIdxs.push_back(0);
return;
}

// 1) collect list of top-level overlapping members from MemberOp
llvm::SmallVector<std::pair<int, mlir::ArrayAttr>> memberByIndex;
mlir::ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
memberByIndex.push_back(
std::make_pair(memIndex, mlir::cast<mlir::ArrayAttr>(indicesAttr)));

// Sort the smallest first (higher up the parent -> member chain), so that
// when we remove members, we remove as much as we can in the initial
// iterations, shortening the number of passes required.
llvm::sort(memberByIndex.begin(), memberByIndex.end(),
[&](auto a, auto b) { return a.second.size() < b.second.size(); });

auto getAsIntegers = [](mlir::ArrayAttr values) {
llvm::SmallVector<int64_t> ints;
ints.reserve(values.size());
llvm::transform(values, std::back_inserter(ints),
[](mlir::Attribute value) {
return mlir::cast<mlir::IntegerAttr>(value).getInt();
});
return ints;
};

// Remove elements from the vector if there is a parent element that
// supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
// [0,2].. etc.
for (auto v : make_early_inc_range(memberByIndex)) {
auto vArr = getAsIntegers(v.second);
memberByIndex.erase(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know for sure this always works? Reading the documentation for make_early_inc_range, my understanding is that we're allowed to mutate the underlying range as long as we don't invalidate the next iterator. But, if we try to delete elements which could be anywhere in the range, it seems possible that we would end up doing just that.

Maybe it would be safer to just create an integer set of to-be-skipped elements and only add to overlapMapDataIdxs elements in memberByIndex which are not part of that set.

std::remove_if(memberByIndex.begin(), memberByIndex.end(),
[&](auto x) {
if (v == x)
return false;

auto xArr = getAsIntegers(x.second);
return std::equal(vArr.begin(), vArr.end(),
xArr.begin()) &&
xArr.size() >= vArr.size();
}),
memberByIndex.end());
}

// Collect the indices from mapData that we need, as we technically need the
// base pointer etc. info, which is stored in there and primarily accessible
// via index at the moment.
for (auto v : memberByIndex)
overlapMapDataIdxs.push_back(v.first);
}

// The intent is to verify if the mapped data being passed is a
// pointer -> pointee that requires special handling in certain cases,
// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
//
// There may be a better way to verify this, but unfortunately with
// opaque pointers we lose the ability to easily check if something is
// a pointer whilst maintaining access to the underlying type.
static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
// If we have a varPtrPtr field assigned then the underlying type is a pointer
if (mapOp.getVarPtrPtr())
return true;

// If the map data is declare target with a link clause, then it's represented
// as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
// no relation to pointers.
if (isDeclareTargetLink(mapOp.getVarPtr()))
return true;

return false;
}

// This creates two insertions into the MapInfosTy data structure for the
// "parent" of a set of members, (usually a container e.g.
// class/structure/derived type) when subsequent members have also been
Expand Down Expand Up @@ -3150,7 +3257,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// runtime information on the dynamically allocated data).
auto parentClause =
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);

llvm::Value *lowAddr, *highAddr;
if (!parentClause.getPartialMap()) {
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
Expand Down Expand Up @@ -3197,37 +3303,77 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// what we support as expected.
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
}
return memberOfFlag;
}

// The intent is to verify if the mapped data being passed is a
// pointer -> pointee that requires special handling in certain cases,
// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
//
// There may be a better way to verify this, but unfortunately with
// opaque pointers we lose the ability to easily check if something is
// a pointer whilst maintaining access to the underlying type.
static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
// If we have a varPtrPtr field assigned then the underlying type is a pointer
if (mapOp.getVarPtrPtr())
return true;

// If the map data is declare target with a link clause, then it's represented
// as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
// no relation to pointers.
if (isDeclareTargetLink(mapOp.getVarPtr()))
return true;
if (targetDirective == TargetDirective::TargetUpdate) {
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
mapData.DevicePointers[mapDataIndex]);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(
mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
} else {
llvm::SmallVector<size_t> overlapIdxs;
// Find all of the members that "overlap", i.e. occlude other members that
// were mapped alongside the parent, e.g. member [0], occludes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This comment seems to be incomplete.

getOverlappedMembers(overlapIdxs, mapData, parentClause);
// We need to make sure the overlapped members are sorted in order of
// lowest address to highest address
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// lowest address to highest address
// lowest address to highest address.

sortMapIndices(overlapIdxs, parentClause);

lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
builder.getPtrTy());
highAddr = builder.CreatePointerCast(
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
mapData.Pointers[mapDataIndex], 1),
builder.getPtrTy());

// TODO: We may want to skip arrays/array sections in this as Clang does
// so it appears to be an optimisation rather than a neccessity though,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// so it appears to be an optimisation rather than a neccessity though,
// so. It appears to be an optimisation rather than a necessity though,

// but this requires further investigation. However, we would have to make
// sure to not exclude maps with bounds that ARE pointers, as these are
// processed as seperate components, i.e. pointer + data.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// processed as seperate components, i.e. pointer + data.
// processed as separate components, i.e. pointer + data.

for (auto v : overlapIdxs) {
auto mapDataOverlapIdx = getMapDataMemberIdx(
mapData,
cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
mapData.DevicePointers[mapDataOverlapIdx]);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(
mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(lowAddr);
combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
builder.CreatePtrDiff(builder.getInt8Ty(),
mapData.OriginalValue[mapDataOverlapIdx],
lowAddr),
builder.getInt64Ty(), /*isSigned=*/true));
lowAddr = builder.CreateConstGEP1_32(
checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
mapData.MapClause[mapDataOverlapIdx]))
? builder.getPtrTy()
: mapData.BaseType[mapDataOverlapIdx],
mapData.BasePointers[mapDataOverlapIdx], 1);
}

return false;
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
mapData.DevicePointers[mapDataIndex]);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(
mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(lowAddr);
combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
builder.getInt64Ty(), true));
}
}
return memberOfFlag;
}

// This function is intended to add explicit mappings of members
Expand Down
20 changes: 10 additions & 10 deletions mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,18 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a

// CHECK: define void @mix_use_device_ptr_and_addr_and_map_(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) {
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
// CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_0_GEP]], align 8
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
// CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8
// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8
// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8

// CHECK: call void @__tgt_target_data_begin_mapper({{.*}})
// CHECK: %[[LOAD_BASEPTR_0:.*]] = load ptr, ptr %[[BASEPTR_0_GEP]], align 8
// store ptr %[[LOAD_BASEPTR_0]], ptr %[[ALLOCA]], align 8
// CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8
// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8
// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8
// CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0
// CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0
// CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4
Expand All @@ -93,17 +93,17 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a

// CHECK: define void @mix_use_device_ptr_and_addr_and_map_2(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) {
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
// CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_1_GEP]], align 8
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
// CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8
// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8
// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8
// CHECK: call void @__tgt_target_data_begin_mapper({{.*}})
// CHECK: %[[LOAD_BASEPTR_1:.*]] = load ptr, ptr %[[BASEPTR_1_GEP]], align 8
// store ptr %[[LOAD_BASEPTR_1]], ptr %[[ALLOCA]], align 8
// CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8
// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8
// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8
// CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0
// CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0
// CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4
Expand Down
Loading
Loading