Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions integration_tests/test_set_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
def test_set():
s: set[i32]
s = {1, 2, 22, 2, -1, 1}
assert len(s) == 4
s2: set[str]
s2 = {'a', 'b', 'cd', 'b', 'abc', 'a'}
assert len(s2) == 4

test_set()
test_set()
38 changes: 32 additions & 6 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
std::unique_ptr<LLVMTuple> tuple_api;
std::unique_ptr<LLVMDictInterface> dict_api_lp;
std::unique_ptr<LLVMDictInterface> dict_api_sc;
std::unique_ptr<LLVMSetInterface> set_api; // linear probing
std::unique_ptr<LLVMSetInterface> set_api_lp;
std::unique_ptr<LLVMSetInterface> set_api_sc;
std::unique_ptr<LLVMArrUtils::Descriptor> arr_descr;

ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, std::string infile,
Expand All @@ -200,18 +201,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tuple_api(std::make_unique<LLVMTuple>(context, llvm_utils.get(), builder.get())),
dict_api_lp(std::make_unique<LLVMDictOptimizedLinearProbing>(context, llvm_utils.get(), builder.get())),
dict_api_sc(std::make_unique<LLVMDictSeparateChaining>(context, llvm_utils.get(), builder.get())),
set_api(std::make_unique<LLVMSetLinearProbing>(context, llvm_utils.get(), builder.get())),
set_api_lp(std::make_unique<LLVMSetLinearProbing>(context, llvm_utils.get(), builder.get())),
set_api_sc(std::make_unique<LLVMSetSeparateChaining>(context, llvm_utils.get(), builder.get())),
arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context,
builder.get(), llvm_utils.get(),
LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor))
{
llvm_utils->tuple_api = tuple_api.get();
llvm_utils->list_api = list_api.get();
llvm_utils->dict_api = nullptr;
llvm_utils->set_api = set_api.get();
llvm_utils->set_api = nullptr;
llvm_utils->arr_api = arr_descr.get();
llvm_utils->dict_api_lp = dict_api_lp.get();
llvm_utils->dict_api_sc = dict_api_sc.get();
llvm_utils->set_api_lp = set_api_lp.get();
llvm_utils->set_api_sc = set_api_sc.get();
}

llvm::Value* CreateLoad(llvm::Value *x) {
Expand Down Expand Up @@ -1152,12 +1156,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Type* const_set_type = llvm_utils->get_set_type(x.m_type, module.get());
llvm::Value* const_set = builder->CreateAlloca(const_set_type, nullptr, "const_set");
ASR::Set_t* x_set = ASR::down_cast<ASR::Set_t>(x.m_type);
llvm_utils->set_set_api(x_set);
std::string el_type_code = ASRUtils::get_type_code(x_set->m_type);
llvm_utils->set_api->set_init(el_type_code, const_set, module.get(), x.n_elements);
int64_t ptr_loads_el = !LLVM::is_llvm_struct(x_set->m_type);
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = ptr_loads_el;
for( size_t i = 0; i < x.n_elements; i++ ) {
ptr_loads = ptr_loads_el;
visit_expr_wrapper(x.m_elements[i], true);
llvm::Value* element = tmp;
llvm_utils->set_api->write_item(const_set, element, module.get(),
Expand Down Expand Up @@ -1516,6 +1521,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr(*x.m_arg);
ptr_loads = ptr_loads_copy;
llvm::Value* pset = tmp;
ASR::Set_t* x_set = ASR::down_cast<ASR::Set_t>(ASRUtils::expr_type(x.m_arg));
llvm_utils->set_set_api(x_set);
tmp = llvm_utils->set_api->len(pset);
}

Expand Down Expand Up @@ -1724,6 +1731,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}

void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
ASRUtils::expr_type(m_arg));
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
Expand All @@ -1734,10 +1743,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr_wrapper(m_ele, true);
ptr_loads = ptr_loads_copy;
llvm::Value *el = tmp;
set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
llvm_utils->set_set_api(set_type);
llvm_utils->set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
}

void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
ASRUtils::expr_type(m_arg));
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
Expand All @@ -1748,7 +1760,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr_wrapper(m_ele, true);
ptr_loads = ptr_loads_copy;
llvm::Value *el = tmp;
set_api->remove_item(pset, el, *module, asr_el_type);
llvm_utils->set_set_api(set_type);
llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type);
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
Expand Down Expand Up @@ -2773,6 +2786,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
bool is_dict_present_copy_sc = dict_api_sc->is_dict_present();
dict_api_lp->set_is_dict_present(false);
dict_api_sc->set_is_dict_present(false);
bool is_set_present_copy_lp = set_api_lp->is_set_present();
bool is_set_present_copy_sc = set_api_sc->is_set_present();
set_api_lp->set_is_set_present(false);
set_api_sc->set_is_set_present(false);
llvm_goto_targets.clear();
// Generate code for nested subroutines and functions first:
for (auto &item : x.m_symtab->get_scope()) {
Expand Down Expand Up @@ -2832,6 +2849,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
builder->CreateRet(ret_val2);
dict_api_lp->set_is_dict_present(is_dict_present_copy_lp);
dict_api_sc->set_is_dict_present(is_dict_present_copy_sc);
set_api_lp->set_is_set_present(is_set_present_copy_lp);
set_api_sc->set_is_set_present(is_set_present_copy_sc);

// Finalize the debug info.
if (compiler_options.emit_debug_info) DBuilder->finalize();
Expand Down Expand Up @@ -3323,6 +3342,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
bool is_dict_present_copy_sc = dict_api_sc->is_dict_present();
dict_api_lp->set_is_dict_present(false);
dict_api_sc->set_is_dict_present(false);
bool is_set_present_copy_lp = set_api_lp->is_set_present();
bool is_set_present_copy_sc = set_api_sc->is_set_present();
set_api_lp->set_is_set_present(false);
set_api_sc->set_is_set_present(false);
llvm_goto_targets.clear();
instantiate_function(x);
if (ASRUtils::get_FunctionType(x)->m_deftype == ASR::deftypeType::Interface) {
Expand All @@ -3335,6 +3358,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
parent_function = nullptr;
dict_api_lp->set_is_dict_present(is_dict_present_copy_lp);
dict_api_sc->set_is_dict_present(is_dict_present_copy_sc);
set_api_lp->set_is_set_present(is_set_present_copy_lp);
set_api_sc->set_is_set_present(is_set_present_copy_sc);

// Finalize the debug info.
if (compiler_options.emit_debug_info) DBuilder->finalize();
Expand Down Expand Up @@ -4187,6 +4212,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value* target_set = tmp;
ptr_loads = ptr_loads_copy;
ASR::Set_t* value_set_type = ASR::down_cast<ASR::Set_t>(asr_value_type);
llvm_utils->set_set_api(value_set_type);
llvm_utils->set_api->set_deepcopy(value_set, target_set,
value_set_type, module.get(), name2memidx);
return ;
Expand Down
Loading