Skip to content

Commit d4bb985

Browse files
authored
Merge pull request #9771 from yosefe/topic/ucp-mm-check-registration-flags-after-taking
UCP/MM: Check registration flags after taking memh from rcache
2 parents 34d9a6f + 6c849cf commit d4bb985

File tree

4 files changed

+132
-30
lines changed

4 files changed

+132
-30
lines changed

src/ucp/core/ucp_mm.c

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
362362

363363
memh->uct[md_index] = NULL;
364364
}
365+
memh->md_map &= ~md_map;
365366

366367
ucs_assert(comp.count == 1);
367368
}
@@ -473,7 +474,8 @@ ucp_memh_register_internal(ucp_context_h context, ucp_mem_h memh,
473474
uct_flags |= UCT_MD_MEM_FLAG_NONBLOCK;
474475
}
475476

476-
reg_params.flags = uct_flags;
477+
/* When adding registrations, existing access flags must be supported */
478+
reg_params.flags = uct_flags | memh->uct_flags;
477479
reg_params.dmabuf_fd = UCT_DMABUF_FD_INVALID;
478480
reg_params.dmabuf_offset = 0;
479481

@@ -576,27 +578,38 @@ static size_t ucp_memh_size(ucp_context_h context)
576578
return sizeof(ucp_mem_t) + (sizeof(uct_mem_h) * context->num_mds);
577579
}
578580

579-
static void ucp_memh_set(ucp_mem_h memh, ucp_context_h context, void* address,
580-
size_t length, ucs_memory_type_t mem_type,
581-
uint8_t memh_flags, uct_alloc_method_t method)
581+
static void ucp_memh_set_uct_flags(ucp_mem_h memh, unsigned uct_flags)
582+
{
583+
/* When changing memh->uct_flags, must not have any existing registrations,
584+
since those may not support the new flags */
585+
ucs_assertv(memh->md_map == 0,
586+
"memh=%p memh->md_map=0x%" PRIx64
587+
" memh->uct_flags=0x%x uct_flags=0x%x",
588+
memh, memh->md_map, memh->uct_flags, uct_flags);
589+
memh->uct_flags = uct_flags & UCP_MM_UCT_ACCESS_MASK;
590+
}
591+
592+
static void ucp_memh_init(ucp_mem_h memh, ucp_context_h context,
593+
uint8_t memh_flags, unsigned uct_flags,
594+
uct_alloc_method_t method, ucs_memory_type_t mem_type)
582595
{
583596
ucp_memory_info_t info;
584597

585-
ucp_memory_detect(context, address, length, &info);
586-
memh->super.super.start = (uintptr_t)address;
587-
memh->super.super.end = (uintptr_t)address + length;
588-
memh->flags = memh_flags;
598+
ucp_memory_detect(context, ucp_memh_address(memh), ucp_memh_length(memh),
599+
&info);
600+
ucp_memh_set_uct_flags(memh, uct_flags);
589601
memh->context = context;
602+
memh->flags = memh_flags;
603+
memh->alloc_md_index = UCP_NULL_RESOURCE;
604+
memh->alloc_method = method;
590605
memh->mem_type = mem_type;
591606
memh->sys_dev = info.sys_dev;
592-
memh->alloc_method = method;
593-
memh->alloc_md_index = UCP_NULL_RESOURCE;
594607
}
595608

596609
static ucs_status_t
597610
ucp_memh_create(ucp_context_h context, void *address, size_t length,
598611
ucs_memory_type_t mem_type, uct_alloc_method_t method,
599-
uint8_t memh_flags, ucp_mem_h *memh_p)
612+
uint8_t memh_flags, unsigned uct_flags, ucp_mem_h *memh_p)
600613
{
601614
ucp_mem_h memh;
602615

@@ -605,7 +618,9 @@ ucp_memh_create(ucp_context_h context, void *address, size_t length,
605618
return UCS_ERR_NO_MEMORY;
606619
}
607620

608-
ucp_memh_set(memh, context, address, length, mem_type, memh_flags, method);
621+
memh->super.super.start = (uintptr_t)address;
622+
memh->super.super.end = (uintptr_t)address + length;
623+
ucp_memh_init(memh, context, memh_flags, uct_flags, method, mem_type);
609624

610625
*memh_p = memh;
611626
return UCS_OK;
@@ -658,13 +673,14 @@ static ucp_md_index_t ucp_mem_get_md_index(ucp_context_h context,
658673

659674
static ucs_status_t ucp_memh_create_from_mem(ucp_context_h context,
660675
const uct_allocated_memory_t *mem,
676+
unsigned uct_flags,
661677
ucp_mem_h *memh_p)
662678
{
663679
ucs_status_t status;
664680
ucp_mem_h memh;
665681

666682
status = ucp_memh_create(context, mem->address, mem->length, mem->mem_type,
667-
mem->method, 0, &memh);
683+
mem->method, 0, uct_flags, &memh);
668684
if (status != UCS_OK) {
669685
return status;
670686
}
@@ -787,21 +803,30 @@ ucs_status_t ucp_memh_get_slow(ucp_context_h context, void *address,
787803
UCP_THREAD_CS_ENTER(&context->mt_lock);
788804
if (context->rcache == NULL) {
789805
status = ucp_memh_create(context, reg_address, reg_length, mem_type,
790-
UCT_ALLOC_METHOD_LAST, 0, &memh);
806+
UCT_ALLOC_METHOD_LAST, 0, uct_flags, &memh);
807+
if (status != UCS_OK) {
808+
goto out;
809+
}
791810
} else {
792811
status = ucp_memh_rcache_get(context->rcache, reg_address, reg_length,
793812
reg_align, mem_type, reg_md_map, uct_flags,
794813
alloc_name, &memh);
814+
if (status != UCS_OK) {
815+
goto out;
816+
}
817+
818+
if (!ucs_test_all_flags(memh->uct_flags,
819+
uct_flags & UCP_MM_UCT_ACCESS_MASK)) {
820+
reg_md_map |= memh->md_map; /* Re-register previous MDs */
821+
ucp_memh_dereg(context, memh, memh->md_map);
822+
ucp_memh_set_uct_flags(memh, uct_flags);
823+
}
795824

796825
ucs_assert(memh->mem_type == mem_type);
797826
ucs_assert(ucs_padding((intptr_t)ucp_memh_address(memh), reg_align) == 0);
798827
ucs_assert(ucs_padding(ucp_memh_length(memh), reg_align) == 0);
799828
}
800829

801-
if (status != UCS_OK) {
802-
goto out;
803-
}
804-
805830
ucs_trace(
806831
"memh_get_slow: %s address %p/%p length %zu/%zu %s md_map %" PRIx64
807832
" flags 0x%x",
@@ -847,7 +872,7 @@ ucp_memh_alloc(ucp_context_h context, void *address, size_t length,
847872
goto out;
848873
}
849874

850-
status = ucp_memh_create_from_mem(context, &mem, &memh);
875+
status = ucp_memh_create_from_mem(context, &mem, uct_flags, &memh);
851876
if (status != UCS_OK) {
852877
goto err_dealloc;
853878
}
@@ -974,7 +999,7 @@ ucs_status_t ucp_mem_map(ucp_context_h context, const ucp_mem_map_params_t *para
974999
alloc_name, &memh);
9751000
} else {
9761001
status = ucp_memh_create(context, address, length, mem_type,
977-
UCT_ALLOC_METHOD_LAST, 0, &memh);
1002+
UCT_ALLOC_METHOD_LAST, 0, uct_flags, &memh);
9781003
if (status != UCS_OK) {
9791004
goto out;
9801005
}
@@ -1412,15 +1437,9 @@ ucp_mem_rcache_mem_reg_cb(void *ctx, ucs_rcache_t *rcache, void *arg,
14121437
ucp_context_h context = (ucp_context_h)ctx;
14131438
ucp_mem_rcache_reg_ctx_t *reg_ctx = arg;
14141439
ucp_mem_h memh = ucs_derived_of(rregion, ucp_mem_t);
1415-
ucp_memory_info_t info;
14161440

1417-
ucp_memory_detect(context, (void*)memh->super.super.start,
1418-
memh->super.super.end - memh->super.super.start, &info);
1419-
memh->context = context;
1420-
memh->alloc_md_index = UCP_NULL_RESOURCE;
1421-
memh->alloc_method = UCT_ALLOC_METHOD_LAST;
1422-
memh->mem_type = reg_ctx->mem_type;
1423-
memh->sys_dev = info.sys_dev;
1441+
ucp_memh_init(memh, context, 0, reg_ctx->uct_flags, UCT_ALLOC_METHOD_LAST,
1442+
reg_ctx->mem_type);
14241443

14251444
if (rcache_mem_reg_flags & UCS_RCACHE_MEM_REG_HIDE_ERRORS) {
14261445
/* Hide errors during registration but fail if any memory domain failed
@@ -1730,7 +1749,7 @@ ucp_memh_import(ucp_context_h context, const void *export_mkey_buffer,
17301749

17311750
status = ucp_memh_create(context, unpacked_memh.address,
17321751
unpacked_memh.length, unpacked_memh.mem_type,
1733-
UCT_ALLOC_METHOD_LAST, 0, &memh);
1752+
UCT_ALLOC_METHOD_LAST, 0, 0, &memh);
17341753
if (status != UCS_OK) {
17351754
goto out;
17361755
}

src/ucp/core/ucp_mm.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
#define UCP_RCACHE_LOOKUP_FUNC ucs_linear_func_make(50.0e-9, 0)
2424

2525

26+
/* Mask of UCT memory flags that need make sure are present when reusing an
27+
existing region */
28+
#define UCP_MM_UCT_ACCESS_MASK UCT_MD_MEM_ACCESS_ALL
29+
30+
2631
/**
2732
* Memory handle flags.
2833
*/
@@ -51,6 +56,7 @@ enum {
5156
typedef struct ucp_mem {
5257
ucs_rcache_region_t super;
5358
uint8_t flags; /* Memory handle flags */
59+
unsigned uct_flags; /* UCT memory registration flags */
5460
ucp_context_h context; /* UCP context that owns a memory handle */
5561
uct_alloc_method_t alloc_method; /* Method used to allocate the memory */
5662
ucs_sys_device_t sys_dev; /* System device index */

src/ucp/core/ucp_mm.inl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ ucp_memh_get(ucp_context_h context, void *address, size_t length,
5757
}
5858

5959
memh = ucs_derived_of(rregion, ucp_mem_t);
60-
if (ucs_likely(ucs_test_all_flags(memh->md_map, reg_md_map))) {
60+
if (ucs_likely(ucs_test_all_flags(memh->md_map, reg_md_map)) &&
61+
ucs_likely(ucs_test_all_flags(
62+
memh->uct_flags,
63+
uct_flags & UCP_MM_UCT_ACCESS_MASK))) {
6164
ucp_memh_rcache_print(memh, address, length);
6265
*memh_p = memh;
6366
UCP_THREAD_CS_EXIT(&context->mt_lock);

test/gtest/ucp/test_ucp_mmap.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,80 @@ UCS_TEST_P(test_ucp_mmap, fixed) {
905905

906906
UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_mmap)
907907

908+
class test_ucp_mmap_atomic : public test_ucp_mmap {
909+
public:
910+
static void get_test_variants(std::vector<ucp_test_variant> &variants)
911+
{
912+
test_ucp_mmap::get_test_variants(variants,
913+
UCP_FEATURE_TAG | UCP_FEATURE_AMO64);
914+
}
915+
};
916+
917+
/* Use a buffer for send/recv, and then reuse it for atomic operations */
918+
UCS_TEST_P(test_ucp_mmap_atomic, reuse_buffer)
919+
{
920+
mem_buffer sbuf(UCS_MBYTE, UCS_MEMORY_TYPE_HOST, 1);
921+
mem_buffer rbuf(UCS_MBYTE, UCS_MEMORY_TYPE_HOST);
922+
923+
/* Send/receive from buffers to trigger adding them to registration cache */
924+
{
925+
static constexpr uint64_t TAG = 0xdeadbeef;
926+
ucp_request_param_t param;
927+
928+
param.op_attr_mask = 0;
929+
auto sreq = ucp_tag_send_nbx(sender().ep(), sbuf.ptr(), sbuf.size(),
930+
TAG, &param);
931+
auto rreq = ucp_tag_recv_nbx(receiver().worker(), rbuf.ptr(),
932+
rbuf.size(), TAG, 0, &param);
933+
934+
ASSERT_UCS_OK(requests_wait({sreq, rreq}));
935+
}
936+
937+
/* Map the receive buffer for atomic operations */
938+
ucp_mem_h memh;
939+
ucp_rkey_h rkey;
940+
{
941+
ucp_mem_map_params_t params;
942+
params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
943+
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
944+
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
945+
params.address = rbuf.ptr();
946+
params.length = rbuf.size();
947+
params.flags = mem_map_flags();
948+
949+
ASSERT_UCS_OK(ucp_mem_map(receiver().ucph(), &params, &memh));
950+
951+
void *rkey_buffer;
952+
size_t rkey_size;
953+
ASSERT_UCS_OK(ucp_rkey_pack(receiver().ucph(), memh, &rkey_buffer,
954+
&rkey_size));
955+
ASSERT_UCS_OK(ucp_ep_rkey_unpack(sender().ep(), rkey_buffer, &rkey));
956+
957+
ucp_rkey_buffer_release(rkey_buffer);
958+
}
959+
960+
/* Perform atomic operation */
961+
{
962+
uint64_t value = 1;
963+
ucp_request_param_t param;
964+
965+
param.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE;
966+
param.datatype = ucp_dt_make_contig(sizeof(value));
967+
auto sreq = ucp_atomic_op_nbx(sender().ep(), UCP_ATOMIC_OP_ADD, &value,
968+
1, (uintptr_t)rbuf.ptr(), rkey, &param);
969+
970+
param.op_attr_mask = 0;
971+
auto freq = ucp_ep_flush_nbx(sender().ep(), &param);
972+
973+
ASSERT_UCS_OK(requests_wait({sreq, freq}));
974+
}
975+
976+
/* Unmap the buffer */
977+
ucp_rkey_destroy(rkey);
978+
ASSERT_UCS_OK(ucp_mem_unmap(receiver().ucph(), memh));
979+
}
980+
981+
UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_mmap_atomic)
908982

909983
class test_ucp_rkey_compare : public test_ucp_mmap {
910984
public:

0 commit comments

Comments
 (0)