Skip to content

Commit 3da0e4f

Browse files
committed
add recursive function for nested dict
1 parent 1af3d80 commit 3da0e4f

File tree

2 files changed

+129
-64
lines changed

2 files changed

+129
-64
lines changed

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4096,6 +4096,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
40964096
}
40974097

40984098
void visit_Assignment(const ASR::Assignment_t &x) {
4099+
// std::cout << "in visit_Assignment" << std::endl;
40994100
if (compiler_options.emit_debug_info) debug_emit_loc(x);
41004101
if( x.m_overloaded ) {
41014102
this->visit_stmt(*x.m_overloaded);
@@ -4475,6 +4476,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
44754476
target_type, false, false);
44764477
}
44774478
} else if( ASR::is_a<ASR::DictItem_t>(*x.m_target) ) {
4479+
// std::cout << "is DictItem_t, with write" << std::endl;
44784480
ASR::DictItem_t* dict_item_t = ASR::down_cast<ASR::DictItem_t>(x.m_target);
44794481
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
44804482
ASRUtils::expr_type(dict_item_t->m_a));

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 127 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3749,6 +3749,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
37493749
bool visit_SubscriptIndices(AST::expr_t* m_slice, Vec<ASR::array_index_t>& args,
37503750
ASR::expr_t* value, ASR::ttype_t* type, bool& is_item,
37513751
const Location& loc) {
3752+
// std::cout << "in visit_SubscriptIndices" << std::endl;
37523753
ASR::array_index_t ai;
37533754
ai.loc = loc;
37543755
ai.m_left = nullptr;
@@ -3837,6 +3838,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
38373838
ASRUtils::type_to_str_python(ASRUtils::expr_type(index)) + "'",
38383839
index->base.loc);
38393840
}
3841+
// std::cout << "make_DictItem_t" << std::endl;
38403842
tmp = make_DictItem_t(al, loc, value, index, nullptr,
38413843
ASR::down_cast<ASR::Dict_t>(type)->m_value_type, nullptr);
38423844
return false;
@@ -5027,7 +5029,105 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
50275029
}
50285030
}
50295031

5032+
bool visit_SubscriptUtil(const AST::Subscript_t &x, const AST::Assign_t &assign_node,
5033+
ASR::expr_t *tmp_value, int32_t recursion_level) {
5034+
// std::cout << "in visit_SubscriptUtil" << std::endl;
5035+
if (AST::is_a<AST::Name_t>(*x.m_value)) {
5036+
// std::cout << "is a Name" << std::endl;
5037+
std::string name = AST::down_cast<AST::Name_t>(x.m_value)->m_id;
5038+
ASR::symbol_t *s = current_scope->resolve_symbol(name);
5039+
if (!s) {
5040+
throw SemanticError("Variable: '" + name + "' is not declared",
5041+
x.base.base.loc);
5042+
}
5043+
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(s);
5044+
ASR::ttype_t *type = v->m_type;
5045+
if (ASR::is_a<ASR::Dict_t>(*type)) {
5046+
this->visit_expr(*x.m_slice);
5047+
ASR::expr_t *key = ASRUtils::EXPR(tmp);
5048+
ASR::expr_t* se = ASR::down_cast<ASR::expr_t>(
5049+
ASR::make_Var_t(al, x.base.base.loc, s));
5050+
if( recursion_level == 0 ) {
5051+
// dict insert case;
5052+
ASR::ttype_t *key_type = ASR::down_cast<ASR::Dict_t>(type)->m_key_type;
5053+
ASR::ttype_t *value_type = ASR::down_cast<ASR::Dict_t>(type)->m_value_type;
5054+
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(key), key_type)) {
5055+
std::string ktype = ASRUtils::type_to_str_python(ASRUtils::expr_type(key));
5056+
std::string totype = ASRUtils::type_to_str_python(key_type);
5057+
diag.add(diag::Diagnostic(
5058+
"Type mismatch in dictionary key, the types must be compatible",
5059+
diag::Level::Error, diag::Stage::Semantic, {
5060+
diag::Label("type mismatch (found: '" + ktype + "', expected: '" + totype + "')",
5061+
{key->base.loc})
5062+
})
5063+
);
5064+
throw SemanticAbort();
5065+
}
5066+
// // std::cout << "*x.m_value type code " << ASRUtils::get_type_code(ASRUtils::expr_type(x.m_value)) << std::endl;
5067+
if (tmp_value == nullptr) {
5068+
if (AST::is_a<AST::List_t>(*assign_node.m_value)) {
5069+
LCOMPILERS_ASSERT(AST::down_cast<AST::List_t>(assign_node.m_value)->n_elts == 0);
5070+
Vec<ASR::expr_t*> list_ele;
5071+
list_ele.reserve(al, 1);
5072+
tmp_value = ASRUtils::EXPR(ASR::make_ListConstant_t(al, assign_node.base.base.loc,
5073+
list_ele.p, list_ele.size(), value_type));
5074+
} else if (AST::is_a<AST::Dict_t>(*assign_node.m_value)) {
5075+
LCOMPILERS_ASSERT(AST::down_cast<AST::Dict_t>(assign_node.m_value)->n_keys == 0);
5076+
Vec<ASR::expr_t*> dict_ele;
5077+
dict_ele.reserve(al, 1);
5078+
tmp_value = ASRUtils::EXPR(ASR::make_DictConstant_t(al, assign_node.base.base.loc,
5079+
dict_ele.p, dict_ele.size(), dict_ele.p, dict_ele.size(), value_type));
5080+
}
5081+
// std::cout << "is value dict?; tmp_value is still nullptr " << AST::is_a<AST::Dict_t>(*assign_node.m_value) << " " << (tmp_value == nullptr) << std::endl;
5082+
}
5083+
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(tmp_value), value_type)) {
5084+
std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(tmp_value));
5085+
std::string totype = ASRUtils::type_to_str_python(value_type);
5086+
diag.add(diag::Diagnostic(
5087+
"Type mismatch in dictionary value, the types must be compatible",
5088+
diag::Level::Error, diag::Stage::Semantic, {
5089+
diag::Label("type mismatch (found: '" + vtype + "', expected: '" + totype + "')",
5090+
{tmp_value->base.loc})
5091+
})
5092+
);
5093+
throw SemanticAbort();
5094+
}
5095+
tmp = nullptr;
5096+
// std::cout << "make_DictInsert_t" << std::endl;
5097+
tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, se, key, tmp_value));
5098+
}
5099+
else {
5100+
tmp = make_DictItem_t(al, x.base.base.loc, se, key, nullptr,
5101+
ASR::down_cast<ASR::Dict_t>(type)->m_value_type, nullptr);
5102+
}
5103+
return true;
5104+
} else if (ASRUtils::is_immutable(type)) {
5105+
throw SemanticError("'" + ASRUtils::type_to_str_python(type) + "' object does not support"
5106+
" item assignment", x.base.base.loc);
5107+
}
5108+
} else if( AST::is_a<AST::Subscript_t>(*x.m_value) ) {
5109+
// std::cout << "is a Subscript" << std::endl;
5110+
AST::Subscript_t *sb = AST::down_cast<AST::Subscript_t>(x.m_value);
5111+
bool return_val = visit_SubscriptUtil(*sb, assign_node, tmp_value, recursion_level + 1);
5112+
if( return_val && tmp ) {
5113+
ASR::expr_t *dict = ASRUtils::EXPR(tmp);
5114+
this->visit_expr(*x.m_slice);
5115+
ASR::expr_t *key = ASRUtils::EXPR(tmp);
5116+
if( recursion_level == 0 ) {
5117+
tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, dict, key, tmp_value));
5118+
}
5119+
else {
5120+
tmp = make_DictItem_t(al, x.base.base.loc, dict, key, nullptr,
5121+
ASR::down_cast<ASR::Dict_t>(ASRUtils::expr_type(dict))->m_value_type, nullptr);
5122+
}
5123+
}
5124+
return return_val;
5125+
}
5126+
return false;
5127+
}
5128+
50305129
void visit_Assign(const AST::Assign_t &x) {
5130+
// std::cout << "in visit_Assign" << std::endl;
50315131
ASR::expr_t *target, *assign_value = nullptr, *tmp_value;
50325132
bool is_c_p_pointer_call_copy = is_c_p_pointer_call;
50335133
is_c_p_pointer_call = false;
@@ -5057,65 +5157,13 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
50575157
assign_value = ASRUtils::EXPR(tmp);
50585158
}
50595159
for (size_t i=0; i<x.n_targets; i++) {
5160+
// std::cout << "visit_Assign_i=" << i << std::endl;
50605161
tmp_value = assign_value;
50615162
check_is_assign_to_input_param(x.m_targets[i]);
50625163
if (AST::is_a<AST::Subscript_t>(*x.m_targets[i])) {
50635164
AST::Subscript_t *sb = AST::down_cast<AST::Subscript_t>(x.m_targets[i]);
5064-
if (AST::is_a<AST::Name_t>(*sb->m_value)) {
5065-
std::string name = AST::down_cast<AST::Name_t>(sb->m_value)->m_id;
5066-
ASR::symbol_t *s = current_scope->resolve_symbol(name);
5067-
if (!s) {
5068-
throw SemanticError("Variable: '" + name + "' is not declared",
5069-
x.base.base.loc);
5070-
}
5071-
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(s);
5072-
ASR::ttype_t *type = v->m_type;
5073-
if (ASR::is_a<ASR::Dict_t>(*type)) {
5074-
// dict insert case;
5075-
this->visit_expr(*sb->m_slice);
5076-
ASR::expr_t *key = ASRUtils::EXPR(tmp);
5077-
ASR::ttype_t *key_type = ASR::down_cast<ASR::Dict_t>(type)->m_key_type;
5078-
ASR::ttype_t *value_type = ASR::down_cast<ASR::Dict_t>(type)->m_value_type;
5079-
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(key), key_type)) {
5080-
std::string ktype = ASRUtils::type_to_str_python(ASRUtils::expr_type(key));
5081-
std::string totype = ASRUtils::type_to_str_python(key_type);
5082-
diag.add(diag::Diagnostic(
5083-
"Type mismatch in dictionary key, the types must be compatible",
5084-
diag::Level::Error, diag::Stage::Semantic, {
5085-
diag::Label("type mismatch (found: '" + ktype + "', expected: '" + totype + "')",
5086-
{key->base.loc})
5087-
})
5088-
);
5089-
throw SemanticAbort();
5090-
}
5091-
if (tmp_value == nullptr && AST::is_a<AST::List_t>(*x.m_value)) {
5092-
LCOMPILERS_ASSERT(AST::down_cast<AST::List_t>(x.m_value)->n_elts == 0);
5093-
Vec<ASR::expr_t*> list_ele;
5094-
list_ele.reserve(al, 1);
5095-
tmp_value = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, list_ele.p,
5096-
list_ele.size(), value_type));
5097-
}
5098-
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(tmp_value), value_type)) {
5099-
std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(tmp_value));
5100-
std::string totype = ASRUtils::type_to_str_python(value_type);
5101-
diag.add(diag::Diagnostic(
5102-
"Type mismatch in dictionary value, the types must be compatible",
5103-
diag::Level::Error, diag::Stage::Semantic, {
5104-
diag::Label("type mismatch (found: '" + vtype + "', expected: '" + totype + "')",
5105-
{tmp_value->base.loc})
5106-
})
5107-
);
5108-
throw SemanticAbort();
5109-
}
5110-
ASR::expr_t* se = ASR::down_cast<ASR::expr_t>(
5111-
ASR::make_Var_t(al, x.base.base.loc, s));
5112-
tmp = nullptr;
5113-
tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, se, key, tmp_value));
5114-
continue;
5115-
} else if (ASRUtils::is_immutable(type)) {
5116-
throw SemanticError("'" + ASRUtils::type_to_str_python(type) + "' object does not support"
5117-
" item assignment", x.base.base.loc);
5118-
}
5165+
if( visit_SubscriptUtil(*sb, x, tmp_value, 0) ) {
5166+
continue;
51195167
}
51205168
} else if (AST::is_a<AST::Attribute_t>(*x.m_targets[i])) {
51215169
AST::Attribute_t *attr = AST::down_cast<AST::Attribute_t>(x.m_targets[i]);
@@ -5133,15 +5181,24 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
51335181
}
51345182
}
51355183
}
5184+
// std::cout << "visiting target from here..." << std::endl;
51365185
this->visit_expr(*x.m_targets[i]);
51375186
target = ASRUtils::EXPR(tmp);
51385187
ASR::ttype_t *target_type = ASRUtils::expr_type(target);
5139-
if (tmp_value == nullptr && AST::is_a<AST::List_t>(*x.m_value)) {
5140-
LCOMPILERS_ASSERT(AST::down_cast<AST::List_t>(x.m_value)->n_elts == 0);
5141-
Vec<ASR::expr_t*> list_ele;
5142-
list_ele.reserve(al, 1);
5143-
tmp_value = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, list_ele.p,
5144-
list_ele.size(), target_type));
5188+
if (tmp_value == nullptr) {
5189+
if (AST::is_a<AST::List_t>(*x.m_value)) {
5190+
LCOMPILERS_ASSERT(AST::down_cast<AST::List_t>(x.m_value)->n_elts == 0);
5191+
Vec<ASR::expr_t*> list_ele;
5192+
list_ele.reserve(al, 1);
5193+
tmp_value = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, list_ele.p,
5194+
list_ele.size(), target_type));
5195+
} else if (AST::is_a<AST::Dict_t>(*x.m_value)) {
5196+
LCOMPILERS_ASSERT(AST::down_cast<AST::Dict_t>(x.m_value)->n_keys == 0);
5197+
Vec<ASR::expr_t*> dict_ele;
5198+
dict_ele.reserve(al, 1);
5199+
tmp_value = ASRUtils::EXPR(ASR::make_DictConstant_t(al, x.base.base.loc, dict_ele.p,
5200+
dict_ele.size(), dict_ele.p, dict_ele.size(), target_type));
5201+
}
51455202
}
51465203
if (tmp_value == nullptr && ASR::is_a<ASR::Var_t>(*target)) {
51475204
ASR::Var_t *var_tar = ASR::down_cast<ASR::Var_t>(target);
@@ -6023,9 +6080,15 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
60236080

60246081
void visit_Dict(const AST::Dict_t &x) {
60256082
LCOMPILERS_ASSERT(x.n_keys == x.n_values);
6026-
if( x.n_keys == 0 && ann_assign_target_type != nullptr ) {
6027-
tmp = ASR::make_DictConstant_t(al, x.base.base.loc, nullptr, 0,
6028-
nullptr, 0, ann_assign_target_type);
6083+
if( x.n_keys == 0 ) {
6084+
if( ann_assign_target_type != nullptr ) {
6085+
tmp = ASR::make_DictConstant_t(al, x.base.base.loc, nullptr, 0,
6086+
nullptr, 0, ann_assign_target_type);
6087+
}
6088+
else {
6089+
// std::cout << "here, ann_assign_target_type is nullptr" << std::endl;
6090+
tmp = nullptr;
6091+
}
60296092
return ;
60306093
}
60316094
Vec<ASR::expr_t*> keys;

0 commit comments

Comments
 (0)