Skip to content

Commit 24e4c25

Browse files
committed
add for SC
1 parent 30b2cf6 commit 24e4c25

File tree

4 files changed

+135
-22
lines changed

4 files changed

+135
-22
lines changed

integration_tests/test_dict_keys_values.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def test_dict_keys_values():
88
v1_copy: list[i32] = []
99
i: i32
1010
j: i32
11+
s: str
1112
key_count: i32
1213

1314
for i in range(105, 115):
@@ -27,4 +28,27 @@ def test_dict_keys_values():
2728
assert v1_copy[j] == d1[i]
2829
assert key_count == 1
2930

31+
d2: dict[str, str] = {}
32+
k2: list[str]
33+
k2_copy: list[str] = []
34+
v2: list[str]
35+
v2_copy: list[str] = []
36+
37+
for i in range(105, 115):
38+
d2[str(i)] = str(i + 1)
39+
k2 = d2.keys()
40+
for s in k2:
41+
k2_copy.append(s)
42+
v2 = d2.values()
43+
for s in v2:
44+
v2_copy.append(s)
45+
assert len(k2) == 10
46+
for i in range(105, 115):
47+
key_count = 0
48+
for j in range(len(k2)):
49+
if k2_copy[j] == str(i):
50+
key_count += 1
51+
assert v2_copy[j] == d2[str(i)]
52+
assert key_count == 1
53+
3054
test_dict_keys_values()

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,7 +1684,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16841684
tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx);
16851685
}
16861686

1687-
void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value, const Location &loc) {
1687+
void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value) {
16881688
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
16891689
ASRUtils::expr_type(m_arg));
16901690
ASR::ttype_t* el_type = key_or_value == 0 ?
@@ -1695,11 +1695,6 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16951695
this->visit_expr(*m_arg);
16961696
llvm::Value* pdict = tmp;
16971697

1698-
llvm_utils->set_dict_api(dict_type);
1699-
if(llvm_utils->dict_api == dict_api_sc.get()) {
1700-
throw CodeGenError("dict.keys and dict.values are only implemented "
1701-
"for linear probing for now", loc);
1702-
}
17031698
ptr_loads = ptr_loads_copy;
17041699

17051700
bool is_array_type_local = false, is_malloc_array_type_local = false;
@@ -1725,7 +1720,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17251720
"keys_list" : "values_list");
17261721
list_api->list_init(type_code, el_list, *module, 0, 0);
17271722

1728-
llvm_utils->dict_api->get_elements_list(pdict, el_list, el_type, *module,
1723+
llvm_utils->set_dict_api(dict_type);
1724+
llvm_utils->dict_api->get_elements_list(pdict, el_list, dict_type->m_key_type,
1725+
dict_type->m_value_type, *module,
17291726
name2memidx, key_or_value);
17301727
tmp = el_list;
17311728
}
@@ -1802,11 +1799,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
18021799
break;
18031800
}
18041801
case ASRUtils::IntrinsicFunctions::DictKeys: {
1805-
generate_DictElems(x.m_args[0], 0, x.base.base.loc);
1802+
generate_DictElems(x.m_args[0], 0);
18061803
break;
18071804
}
18081805
case ASRUtils::IntrinsicFunctions::DictValues: {
1809-
generate_DictElems(x.m_args[0], 1, x.base.base.loc);
1806+
generate_DictElems(x.m_args[0], 1);
18101807
break;
18111808
}
18121809
case ASRUtils::IntrinsicFunctions::SetAdd: {

src/libasr/codegen/llvm_utils.cpp

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3872,34 +3872,38 @@ namespace LCompilers {
38723872
}
38733873

38743874
void LLVMDict::get_elements_list(llvm::Value* dict,
3875-
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
3875+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
3876+
ASR::ttype_t* value_asr_type, llvm::Module& module,
38763877
std::map<std::string, std::map<std::string, int>>& name2memidx,
38773878
bool key_or_value) {
38783879

38793880
/**
38803881
* C++ equivalent:
3881-
*
3882+
*
3883+
* // key_or_value = 0 for keys, 1 for values
3884+
*
38823885
* idx = 0;
3883-
*
3886+
*
38843887
* while( capacity > idx ) {
38853888
* el = key_or_value_list[idx];
38863889
* key_mask_value = key_mask[idx];
3887-
*
3890+
*
38883891
* is_key_skip = key_mask_value == 3; // tombstone
38893892
* is_key_set = key_mask_value != 0;
38903893
* add_el = is_key_set && !is_key_skip;
38913894
* if( add_el ) {
38923895
* elements_list.append(el);
38933896
* }
3894-
*
3897+
*
38953898
* idx++;
38963899
* }
3897-
*
3900+
*
38983901
*/
38993902

39003903
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
39013904
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
39023905
llvm::Value* el_list = key_or_value == 0 ? get_key_list(dict) : get_value_list(dict);
3906+
ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type;
39033907
if( !are_iterators_set ) {
39043908
idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
39053909
}
@@ -3949,10 +3953,95 @@ namespace LCompilers {
39493953
llvm_utils->start_new_block(loopend);
39503954
}
39513955

3952-
void LLVMDictSeparateChaining::get_elements_list(llvm::Value* /*dict*/,
3953-
llvm::Value* /*elements_list*/, ASR::ttype_t* /*el_asr_type*/, llvm::Module& /*module*/,
3954-
std::map<std::string, std::map<std::string, int>>& /*name2memidx*/,
3955-
bool /*key_or_value*/) {}
3956+
void LLVMDictSeparateChaining::get_elements_list(llvm::Value* dict,
3957+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
3958+
ASR::ttype_t* value_asr_type, llvm::Module& module,
3959+
std::map<std::string, std::map<std::string, int>>& name2memidx,
3960+
bool key_or_value) {
3961+
if( !are_iterators_set ) {
3962+
idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
3963+
chain_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
3964+
}
3965+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
3966+
llvm::APInt(32, 0)), idx_ptr);
3967+
3968+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
3969+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
3970+
llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict));
3971+
llvm::Type* kv_pair_type = get_key_value_pair_type(key_asr_type, value_asr_type);
3972+
ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type;
3973+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
3974+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
3975+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
3976+
3977+
// head
3978+
llvm_utils->start_new_block(loophead);
3979+
{
3980+
llvm::Value *cond = builder->CreateICmpSGT(
3981+
capacity,
3982+
LLVM::CreateLoad(*builder, idx_ptr));
3983+
builder->CreateCondBr(cond, loopbody, loopend);
3984+
}
3985+
3986+
// body
3987+
llvm_utils->start_new_block(loopbody);
3988+
{
3989+
llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr);
3990+
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
3991+
llvm_utils->create_ptr_gep(key_mask, idx));
3992+
llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value,
3993+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
3994+
3995+
llvm_utils->create_if_else(is_key_set, [&]() {
3996+
llvm::Value* dict_i = llvm_utils->create_ptr_gep(key_value_pairs, idx);
3997+
llvm::Value* kv_ll_i8 = builder->CreateBitCast(dict_i, llvm::Type::getInt8PtrTy(context));
3998+
LLVM::CreateStore(*builder, kv_ll_i8, chain_itr);
3999+
4000+
llvm::BasicBlock *loop2head = llvm::BasicBlock::Create(context, "loop2.head");
4001+
llvm::BasicBlock *loop2body = llvm::BasicBlock::Create(context, "loop2.body");
4002+
llvm::BasicBlock *loop2end = llvm::BasicBlock::Create(context, "loop2.end");
4003+
4004+
// head
4005+
llvm_utils->start_new_block(loop2head);
4006+
{
4007+
llvm::Value *cond = builder->CreateICmpNE(
4008+
LLVM::CreateLoad(*builder, chain_itr),
4009+
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))
4010+
);
4011+
builder->CreateCondBr(cond, loop2body, loop2end);
4012+
}
4013+
4014+
// body
4015+
llvm_utils->start_new_block(loop2body);
4016+
{
4017+
llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
4018+
llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_pair_type->getPointerTo());
4019+
llvm::Value* kv_el = llvm_utils->create_gep(kv_struct, key_or_value);
4020+
if( !LLVM::is_llvm_struct(el_asr_type) ) {
4021+
kv_el = LLVM::CreateLoad(*builder, kv_el);
4022+
}
4023+
llvm_utils->list_api->append(elements_list, kv_el,
4024+
el_asr_type, &module, name2memidx);
4025+
llvm::Value* next_kv_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 2));
4026+
LLVM::CreateStore(*builder, next_kv_struct, chain_itr);
4027+
}
4028+
4029+
builder->CreateBr(loop2head);
4030+
4031+
// end
4032+
llvm_utils->start_new_block(loop2end);
4033+
}, [=]() {
4034+
});
4035+
llvm::Value* tmp = builder->CreateAdd(idx,
4036+
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
4037+
LLVM::CreateStore(*builder, tmp, idx_ptr);
4038+
}
4039+
4040+
builder->CreateBr(loophead);
4041+
4042+
// end
4043+
llvm_utils->start_new_block(loopend);
4044+
}
39564045

39574046
llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
39584047
bool enable_bounds_checking,

src/libasr/codegen/llvm_utils.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,8 @@ namespace LCompilers {
623623

624624
virtual
625625
void get_elements_list(llvm::Value* dict,
626-
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
626+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
627+
ASR::ttype_t* value_asr_type, llvm::Module& module,
627628
std::map<std::string, std::map<std::string, int>>& name2memidx,
628629
bool key_or_value) = 0;
629630

@@ -720,7 +721,8 @@ namespace LCompilers {
720721
llvm::Value* len(llvm::Value* dict);
721722

722723
void get_elements_list(llvm::Value* dict,
723-
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
724+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
725+
ASR::ttype_t* value_asr_type, llvm::Module& module,
724726
std::map<std::string, std::map<std::string, int>>& name2memidx,
725727
bool key_or_value);
726728

@@ -872,7 +874,8 @@ namespace LCompilers {
872874
llvm::Value* len(llvm::Value* dict);
873875

874876
void get_elements_list(llvm::Value* dict,
875-
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
877+
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
878+
ASR::ttype_t* value_asr_type, llvm::Module& module,
876879
std::map<std::string, std::map<std::string, int>>& name2memidx,
877880
bool key_or_value);
878881

0 commit comments

Comments
 (0)