@@ -362,6 +362,7 @@ static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
362
362
363
363
memh -> uct [md_index ] = NULL ;
364
364
}
365
+ memh -> md_map &= ~md_map ;
365
366
366
367
ucs_assert (comp .count == 1 );
367
368
}
@@ -473,7 +474,8 @@ ucp_memh_register_internal(ucp_context_h context, ucp_mem_h memh,
473
474
uct_flags |= UCT_MD_MEM_FLAG_NONBLOCK ;
474
475
}
475
476
476
- reg_params .flags = uct_flags ;
477
+ /* When adding registrations, existing access flags must be supported */
478
+ reg_params .flags = uct_flags | memh -> uct_flags ;
477
479
reg_params .dmabuf_fd = UCT_DMABUF_FD_INVALID ;
478
480
reg_params .dmabuf_offset = 0 ;
479
481
@@ -576,27 +578,38 @@ static size_t ucp_memh_size(ucp_context_h context)
576
578
return sizeof (ucp_mem_t ) + (sizeof (uct_mem_h ) * context -> num_mds );
577
579
}
578
580
579
- static void ucp_memh_set (ucp_mem_h memh , ucp_context_h context , void * address ,
580
- size_t length , ucs_memory_type_t mem_type ,
581
- uint8_t memh_flags , uct_alloc_method_t method )
581
+ static void ucp_memh_set_uct_flags (ucp_mem_h memh , unsigned uct_flags )
582
+ {
583
+ /* When changing memh->uct_flags, must not have any existing registrations,
584
+ since those may not support the new flags */
585
+ ucs_assertv (memh -> md_map == 0 ,
586
+ "memh=%p memh->md_map=0x%" PRIx64
587
+ " memh->uct_flags=0x%x uct_flags=0x%x" ,
588
+ memh , memh -> md_map , memh -> uct_flags , uct_flags );
589
+ memh -> uct_flags = uct_flags & UCP_MM_UCT_ACCESS_MASK ;
590
+ }
591
+
592
+ static void ucp_memh_init (ucp_mem_h memh , ucp_context_h context ,
593
+ uint8_t memh_flags , unsigned uct_flags ,
594
+ uct_alloc_method_t method , ucs_memory_type_t mem_type )
582
595
{
583
596
ucp_memory_info_t info ;
584
597
585
- ucp_memory_detect (context , address , length , & info );
586
- memh -> super .super .start = (uintptr_t )address ;
587
- memh -> super .super .end = (uintptr_t )address + length ;
588
- memh -> flags = memh_flags ;
598
+ ucp_memory_detect (context , ucp_memh_address (memh ), ucp_memh_length (memh ),
599
+ & info );
600
+ ucp_memh_set_uct_flags (memh , uct_flags );
589
601
memh -> context = context ;
602
+ memh -> flags = memh_flags ;
603
+ memh -> alloc_md_index = UCP_NULL_RESOURCE ;
604
+ memh -> alloc_method = method ;
590
605
memh -> mem_type = mem_type ;
591
606
memh -> sys_dev = info .sys_dev ;
592
- memh -> alloc_method = method ;
593
- memh -> alloc_md_index = UCP_NULL_RESOURCE ;
594
607
}
595
608
596
609
static ucs_status_t
597
610
ucp_memh_create (ucp_context_h context , void * address , size_t length ,
598
611
ucs_memory_type_t mem_type , uct_alloc_method_t method ,
599
- uint8_t memh_flags , ucp_mem_h * memh_p )
612
+ uint8_t memh_flags , unsigned uct_flags , ucp_mem_h * memh_p )
600
613
{
601
614
ucp_mem_h memh ;
602
615
@@ -605,7 +618,9 @@ ucp_memh_create(ucp_context_h context, void *address, size_t length,
605
618
return UCS_ERR_NO_MEMORY ;
606
619
}
607
620
608
- ucp_memh_set (memh , context , address , length , mem_type , memh_flags , method );
621
+ memh -> super .super .start = (uintptr_t )address ;
622
+ memh -> super .super .end = (uintptr_t )address + length ;
623
+ ucp_memh_init (memh , context , memh_flags , uct_flags , method , mem_type );
609
624
610
625
* memh_p = memh ;
611
626
return UCS_OK ;
@@ -658,13 +673,14 @@ static ucp_md_index_t ucp_mem_get_md_index(ucp_context_h context,
658
673
659
674
static ucs_status_t ucp_memh_create_from_mem (ucp_context_h context ,
660
675
const uct_allocated_memory_t * mem ,
676
+ unsigned uct_flags ,
661
677
ucp_mem_h * memh_p )
662
678
{
663
679
ucs_status_t status ;
664
680
ucp_mem_h memh ;
665
681
666
682
status = ucp_memh_create (context , mem -> address , mem -> length , mem -> mem_type ,
667
- mem -> method , 0 , & memh );
683
+ mem -> method , 0 , uct_flags , & memh );
668
684
if (status != UCS_OK ) {
669
685
return status ;
670
686
}
@@ -787,21 +803,30 @@ ucs_status_t ucp_memh_get_slow(ucp_context_h context, void *address,
787
803
UCP_THREAD_CS_ENTER (& context -> mt_lock );
788
804
if (context -> rcache == NULL ) {
789
805
status = ucp_memh_create (context , reg_address , reg_length , mem_type ,
790
- UCT_ALLOC_METHOD_LAST , 0 , & memh );
806
+ UCT_ALLOC_METHOD_LAST , 0 , uct_flags , & memh );
807
+ if (status != UCS_OK ) {
808
+ goto out ;
809
+ }
791
810
} else {
792
811
status = ucp_memh_rcache_get (context -> rcache , reg_address , reg_length ,
793
812
reg_align , mem_type , reg_md_map , uct_flags ,
794
813
alloc_name , & memh );
814
+ if (status != UCS_OK ) {
815
+ goto out ;
816
+ }
817
+
818
+ if (!ucs_test_all_flags (memh -> uct_flags ,
819
+ uct_flags & UCP_MM_UCT_ACCESS_MASK )) {
820
+ reg_md_map |= memh -> md_map ; /* Re-register previous MDs */
821
+ ucp_memh_dereg (context , memh , memh -> md_map );
822
+ ucp_memh_set_uct_flags (memh , uct_flags );
823
+ }
795
824
796
825
ucs_assert (memh -> mem_type == mem_type );
797
826
ucs_assert (ucs_padding ((intptr_t )ucp_memh_address (memh ), reg_align ) == 0 );
798
827
ucs_assert (ucs_padding (ucp_memh_length (memh ), reg_align ) == 0 );
799
828
}
800
829
801
- if (status != UCS_OK ) {
802
- goto out ;
803
- }
804
-
805
830
ucs_trace (
806
831
"memh_get_slow: %s address %p/%p length %zu/%zu %s md_map %" PRIx64
807
832
" flags 0x%x" ,
@@ -847,7 +872,7 @@ ucp_memh_alloc(ucp_context_h context, void *address, size_t length,
847
872
goto out ;
848
873
}
849
874
850
- status = ucp_memh_create_from_mem (context , & mem , & memh );
875
+ status = ucp_memh_create_from_mem (context , & mem , uct_flags , & memh );
851
876
if (status != UCS_OK ) {
852
877
goto err_dealloc ;
853
878
}
@@ -974,7 +999,7 @@ ucs_status_t ucp_mem_map(ucp_context_h context, const ucp_mem_map_params_t *para
974
999
alloc_name , & memh );
975
1000
} else {
976
1001
status = ucp_memh_create (context , address , length , mem_type ,
977
- UCT_ALLOC_METHOD_LAST , 0 , & memh );
1002
+ UCT_ALLOC_METHOD_LAST , 0 , uct_flags , & memh );
978
1003
if (status != UCS_OK ) {
979
1004
goto out ;
980
1005
}
@@ -1412,15 +1437,9 @@ ucp_mem_rcache_mem_reg_cb(void *ctx, ucs_rcache_t *rcache, void *arg,
1412
1437
ucp_context_h context = (ucp_context_h )ctx ;
1413
1438
ucp_mem_rcache_reg_ctx_t * reg_ctx = arg ;
1414
1439
ucp_mem_h memh = ucs_derived_of (rregion , ucp_mem_t );
1415
- ucp_memory_info_t info ;
1416
1440
1417
- ucp_memory_detect (context , (void * )memh -> super .super .start ,
1418
- memh -> super .super .end - memh -> super .super .start , & info );
1419
- memh -> context = context ;
1420
- memh -> alloc_md_index = UCP_NULL_RESOURCE ;
1421
- memh -> alloc_method = UCT_ALLOC_METHOD_LAST ;
1422
- memh -> mem_type = reg_ctx -> mem_type ;
1423
- memh -> sys_dev = info .sys_dev ;
1441
+ ucp_memh_init (memh , context , 0 , reg_ctx -> uct_flags , UCT_ALLOC_METHOD_LAST ,
1442
+ reg_ctx -> mem_type );
1424
1443
1425
1444
if (rcache_mem_reg_flags & UCS_RCACHE_MEM_REG_HIDE_ERRORS ) {
1426
1445
/* Hide errors during registration but fail if any memory domain failed
@@ -1730,7 +1749,7 @@ ucp_memh_import(ucp_context_h context, const void *export_mkey_buffer,
1730
1749
1731
1750
status = ucp_memh_create (context , unpacked_memh .address ,
1732
1751
unpacked_memh .length , unpacked_memh .mem_type ,
1733
- UCT_ALLOC_METHOD_LAST , 0 , & memh );
1752
+ UCT_ALLOC_METHOD_LAST , 0 , 0 , & memh );
1734
1753
if (status != UCS_OK ) {
1735
1754
goto out ;
1736
1755
}
0 commit comments