From 9f88a3f1d9535154a5f6479f3ab8cf46cdebef2a Mon Sep 17 00:00:00 2001
From: Li Jin <dragon-fly@qq.com>
Date: Tue, 28 Jan 2020 03:06:59 +0800
Subject: fix some special cases for backcall.

---
 src/MoonP/moon_compiler.cpp | 157 ++++++++++++++++++++++++++------------------
 1 file changed, 93 insertions(+), 64 deletions(-)

(limited to 'src')

diff --git a/src/MoonP/moon_compiler.cpp b/src/MoonP/moon_compiler.cpp
index 0028f22..58b3bb0 100644
--- a/src/MoonP/moon_compiler.cpp
+++ b/src/MoonP/moon_compiler.cpp
@@ -337,7 +337,7 @@ private:
 		return _converter.to_bytes(std::wstring(begin, end));
 	}
 
-	Value_t* singleValueFrom(ast_node* item) {
+	Value_t* singleValueFrom(ast_node* item) const {
 		Exp_t* exp = nullptr;
 		switch (item->getId()) {
 			case "Exp"_id:
@@ -491,7 +491,7 @@ private:
 		return Empty;
 	}
 
-	bool isAssignable(const node_container& chainItems) {
+	bool isAssignable(const node_container& chainItems) const {
 		if (chainItems.size() == 1) {
 			 auto firstItem = chainItems.back();
 			 if (auto callable = ast_cast<Callable_t>(firstItem)) {
@@ -514,7 +514,7 @@ private:
 		return false;
 	}
 
-	bool isAssignable(Exp_t* exp) {
+	bool isAssignable(Exp_t* exp) const {
 		if (auto value = singleValueFrom(exp)) {
 			auto item = value->item.get();
 			switch (item->getId()) {
@@ -522,7 +522,7 @@ private:
 					return true;
 				case "SimpleValue"_id: {
 					auto simpleValue = static_cast<SimpleValue_t*>(item);
-					if (simpleValue->value.is<TableLit_t>()) {
+					if (simpleValue->value. is<TableLit_t>()) {
 						return true;
 					}
 					return false;
@@ -536,14 +536,14 @@ private:
 		return false;
 	}
 
-	bool isAssignable(Assignable_t* assignable) {
+	bool isAssignable(Assignable_t* assignable) const {
 		if (auto assignableChain = ast_cast<AssignableChain_t>(assignable->item)) {
 			return isAssignable(assignableChain->items.objects());
 		}
 		return true;
 	}
 
-	void checkAssignable(ExpList_t* expList) {
+	void checkAssignable(ExpList_t* expList) const {
 		for (auto exp_ : expList->exprs.objects()) {
 			Exp_t* exp = static_cast<Exp_t*>(exp_);
 			if (!isAssignable(exp)) {
@@ -552,7 +552,22 @@ private:
 		}
 	}
 
-	std::string debugInfo(std::string_view msg, const input_range* loc) {
+	bool isPureBackcall(Exp_t* exp) const {
+		if (exp->opValues.empty()) {
+			return false;
+		}
+		bool backcall = true;
+		for (auto _opValue : exp->opValues.objects()) {
+			auto opValue = static_cast<exp_op_value_t*>(_opValue);
+			if (!opValue->op.is<BackcallOperator_t>()) {
+				backcall = false;
+				break;
+			}
+		}
+		return backcall;
+	}
+
+	std::string debugInfo(std::string_view msg, const input_range* loc) const {
 		const int ASCII = 255;
 		int length = loc->m_begin.m_line;
 		auto begin = _input.begin();
@@ -569,7 +584,7 @@ private:
 				count++;
 			}
 		}
-		auto line = _converter.to_bytes(std::wstring(begin, end));
+		auto line = Converter{}.to_bytes(std::wstring(begin, end));
 		int oldCol = loc->m_begin.m_col;
 		int col = std::max(0, oldCol - 1);
 		auto it = begin;
@@ -741,20 +756,9 @@ private:
 						}
 					} else if (expList->exprs.size() == 1){
 						auto exp = static_cast<Exp_t*>(expList->exprs.back());
-						if (exp->opValues.size() > 0) {
-							bool backcall = true;
-							for (auto _opValue : exp->opValues.objects()) {
-								auto opValue = static_cast<exp_op_value_t*>(_opValue);
-								if (!opValue->op.is<BackcallOperator_t>()) {
-									backcall = false;
-									break;
-								}
-							}
-							if (backcall) {
-								transformExp(exp, out);
-								out.back().append(nll(exp));
-								break;
-							}
+						if (isPureBackcall(exp)) {
+							transformExp(exp, out, ExpUsage::Common);
+							break;
 						}
 					}
 					throw std::logic_error(debugInfo("Expression list must appear at the end of body block."sv, expList));
@@ -1046,6 +1050,11 @@ private:
 					break;
 			}
 		}
+		if (isPureBackcall(exp)) {
+			auto expList = assignment->expList.get();
+			transformExp(exp, out, ExpUsage::Assignment, expList);
+			return;
+		}
 		BLOCK_END
 		auto info = extractDestructureInfo(assignment);
 		if (info.first.empty()) {
@@ -1116,7 +1125,7 @@ private:
 			case "If"_id: transformIf(static_cast<If_t*>(value), out, ExpUsage::Closure); break;
 			case "Switch"_id: transformSwitchClosure(static_cast<Switch_t*>(value), out); break;
 			case "TableBlock"_id: transformTableBlock(static_cast<TableBlock_t*>(value), out); break;
-			case "Exp"_id: transformExp(static_cast<Exp_t*>(value), out); break;
+			case "Exp"_id: transformExp(static_cast<Exp_t*>(value), out, ExpUsage::Closure); break;
 			default: break;
 		}
 	}
@@ -1162,7 +1171,7 @@ private:
 						bool isVariable = !varName.empty();
 						if (!isVariable) {
 							str_list temp;
-							transformExp(exp, temp);
+							transformExp(exp, temp, ExpUsage::Closure);
 							varName = std::move(temp.back());
 						}
 						_config.lintGlobalVariable = lintGlobal;
@@ -1211,7 +1220,7 @@ private:
 							bool isVariable = !varName.empty();
 							if (!isVariable) {
 								str_list temp;
-								transformExp(exp, temp);
+								transformExp(exp, temp, ExpUsage::Closure);
 								varName = std::move(temp.back());
 							}
 							_config.lintGlobalVariable = lintGlobal;
@@ -1342,7 +1351,7 @@ private:
 				transformValue(leftValue, temp);
 				auto left = std::move(temp.back());
 				temp.pop_back();
-				transformExp(update->value, temp);
+				transformExp(update->value, temp, ExpUsage::Closure);
 				auto right = std::move(temp.back());
 				temp.pop_back();
 				if (!singleValueFrom(update->value)) {
@@ -1543,13 +1552,13 @@ private:
 					if (auto value = singleValueFrom(condition)) {
 						transformValue(value, tmp);
 					} else {
-						transformExp(condition, tmp);
+						transformExp(condition, tmp, ExpUsage::Closure);
 						tmp.back() = s("("sv) + tmp.back() + s(")"sv);
 					}
 					tmp.back().insert(0, s("not "sv));
 					unless = false;
 				} else {
-					transformExp(condition, tmp);
+					transformExp(condition, tmp, ExpUsage::Closure);
 				}
 				_buf << indent();
 				if (pair != ifCondPairs.front()) {
@@ -1596,7 +1605,7 @@ private:
 	void transformExpList(ExpList_t* expList, str_list& out) {
 		str_list temp;
 		for (auto exp : expList->exprs.objects()) {
-			transformExp(static_cast<Exp_t*>(exp), temp);
+			transformExp(static_cast<Exp_t*>(exp), temp, ExpUsage::Closure);
 		}
 		out.push_back(join(temp, ", "sv));
 	}
@@ -1604,12 +1613,12 @@ private:
 	void transformExpListLow(ExpListLow_t* expListLow, str_list& out) {
 		str_list temp;
 		for (auto exp : expListLow->exprs.objects()) {
-			transformExp(static_cast<Exp_t*>(exp), temp);
+			transformExp(static_cast<Exp_t*>(exp), temp, ExpUsage::Closure);
 		}
 		out.push_back(join(temp, ", "sv));
 	}
 
-	void transformExp(Exp_t* exp, str_list& out) {
+	void transformExp(Exp_t* exp, str_list& out, ExpUsage usage, ExpList_t* assignList = nullptr) {
 		auto x = exp;
 		const auto& opValues = exp->opValues.objects();
 		for (auto it = opValues.begin(); it != opValues.end(); ++it) {
@@ -1643,7 +1652,20 @@ private:
 						value->item.set(chainValue);
 						newExp->value.set(value);
 					}
-					transformExp(newExp, out);
+					if (newExp->opValues.size() == 0) {
+						if (usage == ExpUsage::Assignment) {
+							auto assign = x->new_ptr<Assign_t>();
+							assign->values.push_back(newExp);
+							auto assignment = x->new_ptr<ExpListAssign_t>();
+							assignment->expList.set(assignList);
+							assignment->action.set(assign);
+							transformAssignment(assignment, out);
+						} else {
+							transformChainValue(chainValue, out, usage);
+						}
+					} else {
+						transformExp(newExp, out, usage, assignList);
+					}
 					return;
 				} else {
 					throw std::logic_error(debugInfo("Backcall operator must be followed by chain value."sv, opValue->value));
@@ -1711,7 +1733,7 @@ private:
 
 	void transformParens(Parens_t* parans, str_list& out) {
 		str_list temp;
-		transformExp(parans->expr, temp);
+		transformExp(parans->expr, temp, ExpUsage::Closure);
 		out.push_back(s("("sv) + temp.front() + s(")"sv));
 	}
 
@@ -1963,6 +1985,13 @@ private:
 
 	void transformReturn(Return_t* returnNode, str_list& out) {
 		if (auto valueList = returnNode->valueList.get()) {
+			if (valueList->exprs.size() == 1) {
+				auto exp = static_cast<Exp_t*>(valueList->exprs.back());
+				if (isPureBackcall(exp)) {
+					transformExp(exp, out, ExpUsage::Return);
+					return;
+				}
+			}
 			if (auto singleValue = singleValueFrom(valueList)) {
 				if (auto simpleValue = singleValue->item.as<SimpleValue_t>()) {
 					auto value = simpleValue->value.get();
@@ -2640,7 +2669,7 @@ private:
 					temp.back() = s("("sv) + temp.back() + s(")"sv);
 					break;
 				case "Exp"_id:
-					transformExp(static_cast<Exp_t*>(item), temp);
+					transformExp(static_cast<Exp_t*>(item), temp, ExpUsage::Closure);
 					temp.back() = s("["sv) + temp.back() + s("]"sv);
 					break;
 				case "InvokeArgs"_id: transformInvokeArgs(static_cast<InvokeArgs_t*>(item), temp); break;
@@ -2700,7 +2729,7 @@ private:
 		str_list temp;
 		for (auto arg : invoke->args.objects()) {
 			switch (arg->getId()) {
-				case "Exp"_id: transformExp(static_cast<Exp_t*>(arg), temp); break;
+				case "Exp"_id: transformExp(static_cast<Exp_t*>(arg), temp, ExpUsage::Closure); break;
 				case "SingleString"_id: transformSingleString(static_cast<SingleString_t*>(arg), temp); break;
 				case "DoubleString"_id: transformDoubleString(static_cast<DoubleString_t*>(arg), temp); break;
 				case "LuaString"_id: transformLuaString(static_cast<LuaString_t*>(arg), temp); break;
@@ -2713,7 +2742,7 @@ private:
 	void transform_unary_exp(unary_exp_t* unary_exp, str_list& out) {
 		std::string op = toString(unary_exp->m_begin.m_it, unary_exp->item->m_begin.m_it);
 		str_list temp{op + (op == "not"sv ? s(" "sv) : Empty)};
-		transformExp(unary_exp->item, temp);
+		transformExp(unary_exp->item, temp, ExpUsage::Closure);
 		out.push_back(join(temp));
 	}
 
@@ -2742,7 +2771,7 @@ private:
 					transformCompFor(static_cast<CompFor_t*>(item), temp);
 					break;
 				case "Exp"_id:
-					transformExp(static_cast<Exp_t*>(item), temp);
+					transformExp(static_cast<Exp_t*>(item), temp, ExpUsage::Closure);
 					temp.back() = indent() + s("if "sv) + temp.back() + s(" then"sv) + nll(item);
 					pushScope();
 					break;
@@ -2794,7 +2823,7 @@ private:
 					transformCompFor(static_cast<CompFor_t*>(item), temp);
 					break;
 				case "Exp"_id:
-					transformExp(static_cast<Exp_t*>(item), temp);
+					transformExp(static_cast<Exp_t*>(item), temp, ExpUsage::Closure);
 					temp.back() = indent() + s("if "sv) + temp.back() + s(" then"sv) + nll(item);
 					pushScope();
 					break;
@@ -2908,19 +2937,19 @@ private:
 				}
 				std::string startValue("1"sv);
 				if (auto exp = slice->startValue.as<Exp_t>()) {
-					transformExp(exp, temp);
+					transformExp(exp, temp, ExpUsage::Closure);
 					startValue = temp.back();
 					temp.pop_back();
 				}
 				std::string stopValue;
 				if (auto exp = slice->stopValue.as<Exp_t>()) {
-					transformExp(exp, temp);
+					transformExp(exp, temp, ExpUsage::Closure);
 					stopValue = temp.back();
 					temp.pop_back();
 				}
 				std::string stepValue;
 				if (auto exp = slice->stepValue.as<Exp_t>()) {
-					transformExp(exp, temp);
+					transformExp(exp, temp, ExpUsage::Closure);
 					stepValue = temp.back();
 					temp.pop_back();
 				}
@@ -2957,7 +2986,7 @@ private:
 					varBefore.push_back(listVar);
 				}
 				if (!endWithSlice) {
-					transformExp(star_exp->value, temp);
+					transformExp(star_exp->value, temp, ExpUsage::Closure);
 					if (newListVal) _buf << indent() << "local "sv << listVar << " = "sv << temp.back() << nll(nameList);
 					_buf << indent() << "for "sv << indexVar << " = 1, #"sv << listVar << " do"sv << nlr(loopTarget);
 					_buf << indent(1) << "local "sv << join(vars) << " = "sv << listVar << "["sv << indexVar << "]"sv << nll(nameList);
@@ -2966,7 +2995,7 @@ private:
 				break;
 			}
 			case "Exp"_id:
-				transformExp(static_cast<Exp_t*>(loopTarget), temp);
+				transformExp(static_cast<Exp_t*>(loopTarget), temp, ExpUsage::Closure);
 				_buf << indent() << "for "sv << join(vars, ", "sv) << " in "sv << temp.back() << " do"sv << nlr(loopTarget);
 				out.push_back(clearBuf());
 				break;
@@ -3010,7 +3039,7 @@ private:
 		str_list temp;
 		for (auto arg : invokeArgs->args.objects()) {
 			switch (arg->getId()) {
-				case "Exp"_id: transformExp(static_cast<Exp_t*>(arg), temp); break;
+				case "Exp"_id: transformExp(static_cast<Exp_t*>(arg), temp, ExpUsage::Closure); break;
 				case "TableBlock"_id: transformTableBlock(static_cast<TableBlock_t*>(arg), temp); break;
 				default: break;
 			}
@@ -3021,10 +3050,10 @@ private:
 	void transformForHead(For_t* forNode, str_list& out) {
 		str_list temp;
 		std::string varName = toString(forNode->varName);
-		transformExp(forNode->startValue, temp);
-		transformExp(forNode->stopValue, temp);
+		transformExp(forNode->startValue, temp, ExpUsage::Closure);
+		transformExp(forNode->stopValue, temp, ExpUsage::Closure);
 		if (forNode->stepValue) {
-			transformExp(forNode->stepValue->value, temp);
+			transformExp(forNode->stepValue->value, temp, ExpUsage::Closure);
 		} else {
 			temp.emplace_back();
 		}
@@ -3229,7 +3258,7 @@ private:
 				break;
 			}
 			case "Exp"_id:
-				transformExp(static_cast<Exp_t*>(key), temp);
+				transformExp(static_cast<Exp_t*>(key), temp, ExpUsage::Closure);
 				temp.back() = s("["sv) + temp.back() + s("]"sv);
 				break;
 			case "DoubleString"_id:
@@ -3243,7 +3272,7 @@ private:
 		}
 		auto value = pair->value.get();
 		switch (value->getId()) {
-			case "Exp"_id: transformExp(static_cast<Exp_t*>(value), temp); break;
+			case "Exp"_id: transformExp(static_cast<Exp_t*>(value), temp, ExpUsage::Closure); break;
 			case "TableBlock"_id: transformTableBlock(static_cast<TableBlock_t*>(value), temp); break;
 			default: break;
 		}
@@ -3259,7 +3288,7 @@ private:
 		}
 	}
 
-	void replace(std::string& str, std::string_view from, std::string_view to) {
+	void replace(std::string& str, std::string_view from, std::string_view to) const {
 		size_t start_pos = 0;
 		while((start_pos = str.find(from, start_pos)) != std::string::npos) {
 			str.replace(start_pos, from.size(), to);
@@ -3295,7 +3324,7 @@ private:
 					break;
 				}
 				case "Exp"_id:
-					transformExp(static_cast<Exp_t*>(content), temp);
+					transformExp(static_cast<Exp_t*>(content), temp, ExpUsage::Closure);
 					temp.back() = s("tostring("sv) + temp.back() + s(")"sv);
 					break;
 				default: break;
@@ -3428,7 +3457,7 @@ private:
 		if (extend) {
 			parentVar = getUnusedName("_parent_"sv);
 			addToScope(parentVar);
-			transformExp(extend, temp);
+			transformExp(extend, temp, ExpUsage::Closure);
 			parent = temp.back();
 			temp.pop_back();
 			temp.push_back(indent() + s("local "sv) + parentVar + s(" = "sv) + parent + nll(classDecl));
@@ -3904,7 +3933,7 @@ private:
 		pushScope();
 		for (auto pair : pairs) {
 			switch (pair->getId()) {
-				case "Exp"_id: transformExp(static_cast<Exp_t*>(pair), temp); break;
+				case "Exp"_id: transformExp(static_cast<Exp_t*>(pair), temp, ExpUsage::Closure); break;
 				case "variable_pair"_id: transform_variable_pair(static_cast<variable_pair_t*>(pair), temp); break;
 				case "normal_pair"_id: transform_normal_pair(static_cast<normal_pair_t*>(pair), temp); break;
 			}
@@ -3943,15 +3972,15 @@ private:
 					transformCompFor(static_cast<CompFor_t*>(item), temp);
 					break;
 				case "Exp"_id:
-					transformExp(static_cast<Exp_t*>(item), temp);
+					transformExp(static_cast<Exp_t*>(item), temp, ExpUsage::Closure);
 					temp.back() = indent() + s("if "sv) + temp.back() + s(" then"sv) + nll(item);
 					pushScope();
 					break;
 			}
 		}
-		transformExp(comp->key, kv);
+		transformExp(comp->key, kv, ExpUsage::Closure);
 		if (comp->value) {
-			transformExp(comp->value->value, kv);
+			transformExp(comp->value->value, kv, ExpUsage::Closure);
 		}
 		for (size_t i = 0; i < compInner->items.objects().size(); ++i) {
 			popScope();
@@ -4004,10 +4033,10 @@ private:
 	void transformCompFor(CompFor_t* comp, str_list& out) {
 		str_list temp;
 		std::string varName = toString(comp->varName);
-		transformExp(comp->startValue, temp);
-		transformExp(comp->stopValue, temp);
+		transformExp(comp->startValue, temp, ExpUsage::Closure);
+		transformExp(comp->stopValue, temp, ExpUsage::Closure);
 		if (comp->stepValue) {
-			transformExp(comp->stepValue->value, temp);
+			transformExp(comp->stepValue->value, temp, ExpUsage::Closure);
 		} else {
 			temp.emplace_back();
 		}
@@ -4200,7 +4229,7 @@ private:
 		addToScope(lenVar);
 		temp.push_back(indent() + s("local "sv) + accumVar + s(" = { }"sv) + nll(whileNode));
 		temp.push_back(indent() + s("local "sv) + lenVar + s(" = 1"sv) + nll(whileNode));
-		transformExp(whileNode->condition, temp);
+		transformExp(whileNode->condition, temp, ExpUsage::Closure);
 		temp.back() = indent() + s("while "sv) + temp.back() + s(" do"sv) + nll(whileNode);
 		pushScope();
 		auto assignLeft = toAst<ExpList_t>(accumVar + s("["sv) + lenVar + s("]"sv), ExpList, x);
@@ -4237,7 +4266,7 @@ private:
 		addToScope(lenVar);
 		temp.push_back(indent() + s("local "sv) + accumVar + s(" = { }"sv) + nll(whileNode));
 		temp.push_back(indent() + s("local "sv) + lenVar + s(" = 1"sv) + nll(whileNode));
-		transformExp(whileNode->condition, temp);
+		transformExp(whileNode->condition, temp, ExpUsage::Closure);
 		temp.back() = indent() + s("while "sv) + temp.back() + s(" do"sv) + nll(whileNode);
 		pushScope();
 		auto assignLeft = toAst<ExpList_t>(accumVar + s("["sv) + lenVar + s("]"sv), ExpList, x);
@@ -4255,7 +4284,7 @@ private:
 	void transformWhile(While_t* whileNode, str_list& out) {
 		str_list temp;
 		pushScope();
-		transformExp(whileNode->condition, temp);
+		transformExp(whileNode->condition, temp, ExpUsage::Closure);
 		transformLoopBody(whileNode->body, temp, Empty);
 		popScope();
 		_buf << indent() << "while "sv << temp.front() << " do"sv << nll(whileNode);
@@ -4280,7 +4309,7 @@ private:
 		if (objVar.empty()) {
 			objVar = getUnusedName("_exp_"sv);
 			addToScope(objVar);
-			transformExp(switchNode->target, temp);
+			transformExp(switchNode->target, temp, ExpUsage::Closure);
 			_buf << indent() << "local "sv << objVar << " = "sv << temp.back() << nll(switchNode);
 			temp.back() = clearBuf();
 		}
@@ -4292,7 +4321,7 @@ private:
 			const auto& exprs = branch->valueList->exprs.objects();
 			for (auto exp_ : exprs) {
 				auto exp = static_cast<Exp_t*>(exp_);
-				transformExp(exp, tmp);
+				transformExp(exp, tmp, ExpUsage::Closure);
 				if (!singleValueFrom(exp)) {
 					tmp.back() = s("("sv) + tmp.back() + s(")"sv);
 				}
-- 
cgit v1.2.3-55-g6feb