From 1c6a9651beffd9cbbb3641179f3a738d5555d3c9 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 23 Oct 2020 17:32:48 +0800 Subject: make teal-macro look better. --- spec/inputs/macro-teal.mp | 34 +- spec/inputs/teal-lang.mp | 22 +- spec/lib/tl.lua | 6801 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 6844 insertions(+), 13 deletions(-) create mode 100644 spec/lib/tl.lua diff --git a/spec/inputs/macro-teal.mp b/spec/inputs/macro-teal.mp index 20444e1..9ce1bcd 100644 --- a/spec/inputs/macro-teal.mp +++ b/spec/inputs/macro-teal.mp @@ -2,6 +2,7 @@ $ -> import "moonp" as {:options} if options.tl_enabled options.target_extension = "tl" + package.path ..= "?.lua;./spec/lib/?.lua" macro expr to_lua = (codes)-> "require('moonp').to_lua(#{codes}, reserve_line_number:false, same_module:true)" @@ -9,19 +10,31 @@ macro expr to_lua = (codes)-> macro expr trim = (name)-> "if result = #{name}\\match '[\\'\"](.*)[\\'\"]' then result else #{name}" -export macro text var = (name, type, value = nil)-> +export macro text local = (decl, value = nil)-> import "moonp" as {options:{:tl_enabled}} + name, type = ($trim decl)\match "(.-):(.*)" + if not (name and type) + error "invalid local varaible declaration for \"#{decl}\"" value = $to_lua(value)\gsub "^return ", "" if tl_enabled "local #{name}:#{$trim type} = #{value}", {name} else "local #{name} = #{value}", {name} -export macro text def = (name, type, value)-> +export macro text function = (decl, value)-> import "moonp" as {options:{:tl_enabled}} + import "tl" + decl = $trim decl + name, type = decl\match "(.-)(%(.*)" + if not (name and type) + error "invalid function declaration for \"#{decl}\"" + tokens = tl.lex "function #{decl}" + _, node = tl.parse_program tokens,{},"macro-function" + args = table.concat [arg.tk for arg in *node[1].args],", " + value = "(#{args})#{value}" if tl_enabled - value = $to_lua(value)\match "function%(.*%)(.*)end" - "local function #{name}#{$trim type}\n#{value}\nend", {name} + value = $to_lua(value)\match "function%([^\n]*%)(.*)end" + "local function #{name}#{type}\n#{value}\nend", {name} else value = $to_lua(value)\gsub "^return ", "" "local #{name} = #{value}", {name} @@ -35,11 +48,20 @@ end", {name} else "local #{name} = {}", {name} -export macro text field = (tab, sym, func, type, value)-> +export macro text method = (decl, value)-> import "moonp" as {options:{:tl_enabled}} + import "tl" + decl = $trim decl + tab, sym, func, type = decl\match "(.-)([%.:])(.-)(%(.*)" + if not (tab and sym and func and type) + error "invalid method declaration for \"#{decl}\"" + tokens = tl.lex "function #{decl}" + _, node = tl.parse_program tokens,{},"macro-function" + args = table.concat [arg.tk for arg in *node[1].args],", " + value = "(#{args})->#{value\match "[%-=]>(.*)"}" if tl_enabled value = $to_lua(value)\match "^return function%(.-%)\n(.*)end" - "function #{tab}#{$trim sym}#{func}#{$trim type}\n#{value}\nend" + "function #{tab}#{sym}#{func}#{type}\n#{value}\nend" else value = $to_lua(value)\gsub "^return ", "" "#{tab}.#{func} = #{value}" diff --git a/spec/inputs/teal-lang.mp b/spec/inputs/teal-lang.mp index 3c9c79b..29769d5 100644 --- a/spec/inputs/teal-lang.mp +++ b/spec/inputs/teal-lang.mp @@ -3,10 +3,10 @@ $ -> import "macro-teal" as {$} -$var a, "{string:number}", {value:123} -$var b, "number", a.value +$local "a:{string:number}", {value:123} +$local "b:number", a.value -$def add, "(a:number,b:number):number", (a, b)-> a + b +$function "add(a:number, b:number):number", -> a + b s = add(a.value, b) print(s) @@ -16,17 +16,25 @@ $record Point, [[ y: number ]] -$field Point, '.', new, "(x: number, y: number):Point", (x, y)-> - $var point, "Point", setmetatable {}, __index: Point +$method "Point.new(x:number, y:number):Point", -> + $local "point:Point", setmetatable {}, __index: Point point.x = x or 0 point.y = y or 0 point -$field Point, ":", move, "(dx: number, dy: number)", (dx, dy)=> +$method "Point:move(dx:number, dy:number)", -> @x += dx @y += dy -$var p, "Point", Point.new 100, 100 +$local "p:Point", Point.new 100, 100 p\move 50, 50 +$function "filter(tab:{string}, handler:function(item:string):boolean):{string}", -> + [item for item in *tab when handler item] + +$function "cond(item:string):boolean", -> item ~= "a" + +res = filter {"a", "b", "c", "a"}, cond +for s in *res + print s diff --git a/spec/lib/tl.lua b/spec/lib/tl.lua new file mode 100644 index 0000000..aca748c --- /dev/null +++ b/spec/lib/tl.lua @@ -0,0 +1,6801 @@ +local _tl_compat53 = ((tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3) and require('compat53.module'); local assert = _tl_compat53 and _tl_compat53.assert or assert; local io = _tl_compat53 and _tl_compat53.io or io; local ipairs = _tl_compat53 and _tl_compat53.ipairs or ipairs; local load = _tl_compat53 and _tl_compat53.load or load; local math = _tl_compat53 and _tl_compat53.math or math; local os = _tl_compat53 and _tl_compat53.os or os; local package = _tl_compat53 and _tl_compat53.package or package; local pairs = _tl_compat53 and _tl_compat53.pairs or pairs; local string = _tl_compat53 and _tl_compat53.string or string; local table = _tl_compat53 and _tl_compat53.table or table; local _tl_table_unpack = unpack or table.unpack; local Env = {} + + + + + +local TypeCheckOptions = {} + + + + + + + +local LoadMode = {} + + + + +local LoadFunction = {} + +local tl = { + load = nil, + process = nil, + process_string = nil, + gen = nil, + type_check = nil, + init_env = nil, +} + + + + + + + +local inspect = function(x) + return tostring(x) +end + +local keywords = { + ["and"] = true, + ["break"] = true, + ["do"] = true, + ["else"] = true, + ["elseif"] = true, + ["end"] = true, + ["false"] = true, + ["for"] = true, + ["function"] = true, + ["goto"] = true, + ["if"] = true, + ["in"] = true, + ["local"] = true, + ["nil"] = true, + ["not"] = true, + ["or"] = true, + ["repeat"] = true, + ["return"] = true, + ["then"] = true, + ["true"] = true, + ["until"] = true, + ["while"] = true, + + +} + +local TokenKind = {} + + + + + + + + + + + + +local Token = {} + + + + + + + +local lex_word_start = {} +for c = string.byte("a"), string.byte("z") do + lex_word_start[string.char(c)] = true +end +for c = string.byte("A"), string.byte("Z") do + lex_word_start[string.char(c)] = true +end +lex_word_start["_"] = true + +local lex_word = {} +for c = string.byte("a"), string.byte("z") do + lex_word[string.char(c)] = true +end +for c = string.byte("A"), string.byte("Z") do + lex_word[string.char(c)] = true +end +for c = string.byte("0"), string.byte("9") do + lex_word[string.char(c)] = true +end +lex_word["_"] = true + +local lex_decimal_start = {} +for c = string.byte("1"), string.byte("9") do + lex_decimal_start[string.char(c)] = true +end + +local lex_decimals = {} +for c = string.byte("0"), string.byte("9") do + lex_decimals[string.char(c)] = true +end + +local lex_hexadecimals = {} +for c = string.byte("0"), string.byte("9") do + lex_hexadecimals[string.char(c)] = true +end +for c = string.byte("a"), string.byte("f") do + lex_hexadecimals[string.char(c)] = true +end +for c = string.byte("A"), string.byte("F") do + lex_hexadecimals[string.char(c)] = true +end + +local lex_char_symbols = {} +for _, c in ipairs({ "[", "]", "(", ")", "{", "}", ",", "#", "`", ";" }) do + lex_char_symbols[c] = true +end + +local lex_op_start = {} +for _, c in ipairs({ "+", "*", "/", "|", "&", "%", "^" }) do + lex_op_start[c] = true +end + +local lex_space = {} +for _, c in ipairs({ " ", "\t", "\v", "\n", "\r" }) do + lex_space[c] = true +end + +local LexState = {} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +function tl.lex(input) + local tokens = {} + + local state = "start" + local fwd = true + local y = 1 + local x = 0 + local i = 0 + local lc_open_lvl = 0 + local lc_close_lvl = 0 + local ls_open_lvl = 0 + local ls_close_lvl = 0 + local errs = {} + + local tx + local ty + local ti + local in_token = false + + local function begin_token() + tx = x + ty = y + ti = i + in_token = true + end + + local function end_token(kind, last, t) + local tk = t or input:sub(ti, last or i) or "" + if keywords[tk] then + kind = "keyword" + end + table.insert(tokens, { + x = tx, + y = ty, + i = ti, + tk = tk, + kind = kind, + }) + in_token = false + end + + local function drop_token() + in_token = false + end + + while i <= #input do + if fwd then + i = i + 1 + if i > #input then + break + end + end + + local c = input:sub(i, i) + + if fwd then + if c == "\n" then + y = y + 1 + x = 0 + else + x = x + 1 + end + else + fwd = true + end + + if state == "start" then + if input:sub(1, 2) == "#!" then + i = input:find("\n") + if not i then + break + end + c = "\n" + y = 2 + x = 0 + end + state = "any" + end + + if state == "any" then + if c == "-" then + state = "maybecomment" + begin_token() + elseif c == "." then + state = "maybedotdot" + begin_token() + elseif c == "\"" then + state = "dblquote_string" + begin_token() + elseif c == "'" then + state = "singlequote_string" + begin_token() + elseif lex_word_start[c] then + state = "identifier" + begin_token() + elseif c == "0" then + state = "decimal_or_hex" + begin_token() + elseif lex_decimal_start[c] then + state = "decimal_number" + begin_token() + elseif c == "<" then + state = "lt" + begin_token() + elseif c == ":" then + state = "colon" + begin_token() + elseif c == ">" then + state = "gt" + begin_token() + elseif c == "=" or c == "~" then + state = "maybeequals" + begin_token() + elseif c == "[" then + state = "maybelongstring" + begin_token() + elseif lex_char_symbols[c] then + begin_token() + end_token(c) + elseif lex_op_start[c] then + begin_token() + end_token("op") + elseif lex_space[c] then + + else + begin_token() + end_token("$invalid$") + table.insert(errs, tokens[#tokens]) + end + elseif state == "maybecomment" then + if c == "-" then + state = "maybecomment2" + else + end_token("op", nil, "-") + fwd = false + state = "any" + end + elseif state == "maybecomment2" then + if c == "[" then + state = "maybelongcomment" + else + fwd = false + state = "comment" + drop_token() + end + elseif state == "maybelongcomment" then + if c == "[" then + state = "longcomment" + elseif c == "=" then + lc_open_lvl = lc_open_lvl + 1 + else + fwd = false + state = "comment" + drop_token() + lc_open_lvl = 0 + end + elseif state == "longcomment" then + if c == "]" then + state = "maybelongcommentend" + end + elseif state == "maybelongcommentend" then + if c == "]" and lc_close_lvl == lc_open_lvl then + drop_token() + state = "any" + lc_open_lvl = 0 + lc_close_lvl = 0 + elseif c == "=" then + lc_close_lvl = lc_close_lvl + 1 + else + state = "longcomment" + lc_close_lvl = 0 + end + elseif state == "dblquote_string" then + if c == "\\" then + state = "escape_dblquote_string" + elseif c == "\"" then + end_token("string") + state = "any" + end + elseif state == "escape_dblquote_string" then + state = "dblquote_string" + elseif state == "singlequote_string" then + if c == "\\" then + state = "escape_singlequote_string" + elseif c == "'" then + end_token("string") + state = "any" + end + elseif state == "escape_singlequote_string" then + state = "singlequote_string" + elseif state == "maybeequals" then + if c == "=" then + end_token("op") + state = "any" + else + end_token("op", i - 1) + fwd = false + state = "any" + end + elseif state == "lt" then + if c == "=" or c == "<" then + end_token("op") + state = "any" + else + end_token("op", i - 1) + fwd = false + state = "any" + end + elseif state == "colon" then + if c == ":" then + end_token("::") + state = "any" + else + end_token(":", i - 1) + fwd = false + state = "any" + end + elseif state == "gt" then + if c == "=" or c == ">" then + end_token("op") + state = "any" + else + end_token("op", i - 1) + fwd = false + state = "any" + end + elseif state == "maybelongstring" then + if c == "[" then + state = "longstring" + elseif c == "=" then + ls_open_lvl = ls_open_lvl + 1 + else + end_token("[", i - 1) + fwd = false + state = "any" + ls_open_lvl = 0 + end + elseif state == "longstring" then + if c == "]" then + state = "maybelongstringend" + end + elseif state == "maybelongstringend" then + if c == "]" then + if ls_close_lvl == ls_open_lvl then + end_token("string") + state = "any" + ls_open_lvl = 0 + ls_close_lvl = 0 + end + elseif c == "=" then + ls_close_lvl = ls_close_lvl + 1 + else + state = "longstring" + ls_close_lvl = 0 + end + elseif state == "maybedotdot" then + if c == "." then + state = "maybedotdotdot" + elseif lex_decimals[c] then + state = "decimal_float" + else + end_token(".", i - 1) + fwd = false + state = "any" + end + elseif state == "maybedotdotdot" then + if c == "." then + end_token("...") + state = "any" + else + end_token("op", i - 1) + fwd = false + state = "any" + end + elseif state == "comment" then + if c == "\n" then + state = "any" + end + elseif state == "identifier" then + if not lex_word[c] then + end_token("identifier", i - 1) + fwd = false + state = "any" + end + elseif state == "decimal_or_hex" then + if c == "x" or c == "X" then + state = "hex_number" + elseif c == "e" or c == "E" then + state = "power_sign" + elseif lex_decimals[c] then + state = "decimal_number" + elseif c == "." then + state = "decimal_float" + else + end_token("number", i - 1) + fwd = false + state = "any" + end + elseif state == "hex_number" then + if c == "." then + state = "hex_float" + elseif c == "p" or c == "P" then + state = "power_sign" + elseif not lex_hexadecimals[c] then + end_token("number", i - 1) + fwd = false + state = "any" + end + elseif state == "hex_float" then + if c == "p" or c == "P" then + state = "power_sign" + elseif not lex_hexadecimals[c] then + end_token("number", i - 1) + fwd = false + state = "any" + end + elseif state == "decimal_number" then + if c == "." then + state = "decimal_float" + elseif c == "e" or c == "E" then + state = "power_sign" + elseif not lex_decimals[c] then + end_token("number", i - 1) + fwd = false + state = "any" + end + elseif state == "decimal_float" then + if c == "e" or c == "E" then + state = "power_sign" + elseif not lex_decimals[c] then + end_token("number", i - 1) + fwd = false + state = "any" + end + elseif state == "power_sign" then + if c == "-" or c == "+" then + state = "power" + elseif lex_decimals[c] then + state = "power" + else + end_token("$invalid$") + table.insert(errs, tokens[#tokens]) + state = "any" + end + elseif state == "power" then + if not lex_decimals[c] then + end_token("number", i - 1) + fwd = false + state = "any" + end + end + end + + local terminals = { + ["identifier"] = "identifier", + ["decimal_or_hex"] = "number", + ["decimal_number"] = "number", + ["decimal_float"] = "number", + ["hex_number"] = "number", + ["hex_float"] = "number", + ["power"] = "number", + } + + if in_token then + if terminals[state] then + end_token(terminals[state], i - 1) + else + drop_token() + end + end + + return tokens, (#errs > 0) and errs +end + + + + + +local add_space = { + ["word:keyword"] = true, + ["word:word"] = true, + ["word:string"] = true, + ["word:="] = true, + ["word:op"] = true, + + ["keyword:word"] = true, + ["keyword:keyword"] = true, + ["keyword:string"] = true, + ["keyword:number"] = true, + ["keyword:="] = true, + ["keyword:op"] = true, + ["keyword:{"] = true, + ["keyword:("] = true, + ["keyword:#"] = true, + + ["=:word"] = true, + ["=:keyword"] = true, + ["=:string"] = true, + ["=:number"] = true, + ["=:{"] = true, + ["=:("] = true, + ["op:("] = true, + ["op:{"] = true, + ["op:#"] = true, + + [",:word"] = true, + [",:keyword"] = true, + [",:string"] = true, + [",:{"] = true, + + ["):op"] = true, + ["):word"] = true, + ["):keyword"] = true, + + ["op:string"] = true, + ["op:number"] = true, + ["op:word"] = true, + ["op:keyword"] = true, + + ["]:word"] = true, + ["]:keyword"] = true, + ["]:="] = true, + ["]:op"] = true, + + ["string:op"] = true, + ["string:word"] = true, + ["string:keyword"] = true, + + ["number:word"] = true, + ["number:keyword"] = true, +} + +local should_unindent = { + ["end"] = true, + ["elseif"] = true, + ["else"] = true, + ["}"] = true, +} + +local should_indent = { + ["{"] = true, + ["for"] = true, + ["if"] = true, + ["while"] = true, + ["elseif"] = true, + ["else"] = true, + ["function"] = true, +} + +function tl.pretty_print_tokens(tokens) + local y = 1 + local out = {} + local indent = 0 + local newline = false + local kind = "" + for _, t in ipairs(tokens) do + while t.y > y do + table.insert(out, "\n") + y = y + 1 + newline = true + kind = "" + end + if should_unindent[t.tk] then + indent = indent - 1 + if indent < 0 then + indent = 0 + end + end + if newline then + for _ = 1, indent do + table.insert(out, " ") + end + newline = false + end + if should_indent[t.tk] then + indent = indent + 1 + end + if add_space[(kind or "") .. ":" .. t.kind] then + table.insert(out, " ") + end + table.insert(out, t.tk) + kind = t.kind or "" + end + return table.concat(out) +end + + + + + +local last_typeid = 0 + +local function new_typeid() + last_typeid = last_typeid + 1 + return last_typeid +end + +local ParseError = {} + + + + + + +local TypeName = {} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +local table_types = { + ["array"] = true, + ["map"] = true, + ["arrayrecord"] = true, + ["record"] = true, + ["emptytable"] = true, +} + +local Type = {} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +local Operator = {} + + + + + + + +local NodeKind = {} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +local FactType = {} + + + +local Fact = {} + + + + + +local KeyParsed = {} + + + + + +local Node = {} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +local function is_array_type(t) + return t.typename == "array" or t.typename == "arrayrecord" +end + +local function is_record_type(t) + return t.typename == "record" or t.typename == "arrayrecord" +end + +local function is_type(t) + return t.typename == "typetype" or t.typename == "nestedtype" +end + +local ParseState = {} + + + + + +local ParseTypeListMode = {} + + + + + +local parse_type_list +local parse_expression +local parse_statements +local parse_argument_list +local parse_argument_type_list +local parse_type +local parse_newtype + + +local function fail(ps, i, msg) + if not ps.tokens[i] then + local eof = ps.tokens[#ps.tokens] + table.insert(ps.errs, { y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) + return #ps.tokens + end + table.insert(ps.errs, { y = ps.tokens[i].y, x = ps.tokens[i].x, msg = msg or "syntax error" }) + return math.min(#ps.tokens, i + 1) +end + +local function verify_tk(ps, i, tk) + if ps.tokens[i].tk == tk then + return i + 1 + end + return fail(ps, i, "syntax error, expected '" .. tk .. "'") +end + +local function new_node(tokens, i, kind) + local t = tokens[i] + return { y = t.y, x = t.x, tk = t.tk, kind = kind or t.kind } +end + +local function a_type(t) + t.typeid = new_typeid() + return t +end + +local function new_type(ps, i, typename) + local token = ps.tokens[i] + return a_type({ + typename = assert(typename), + filename = ps.filename, + y = token.y, + x = token.x, + tk = token.tk, + }) +end + +local function verify_kind(ps, i, kind, node_kind) + if ps.tokens[i].kind == kind then + return i + 1, new_node(ps.tokens, i, node_kind) + end + return fail(ps, i, "syntax error, expected " .. kind) +end + +local is_newtype = { + ["enum"] = true, + ["record"] = true, +} + +local function parse_table_value(ps, i) + if is_newtype[ps.tokens[i].tk] then + return parse_newtype(ps, i) + else + local i, node, _ = parse_expression(ps, i) + return i, node + end +end + +local function parse_table_item(ps, i, n) + local node = new_node(ps.tokens, i, "table_item") + if ps.tokens[i].kind == "$EOF$" then + return fail(ps, i) + end + + if ps.tokens[i].tk == "[" then + node.key_parsed = "long" + i = i + 1 + i, node.key = parse_expression(ps, i) + i = verify_tk(ps, i, "]") + i = verify_tk(ps, i, "=") + i, node.value = parse_table_value(ps, i) + return i, node, n + elseif ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == "=" then + node.key_parsed = "short" + i, node.key = verify_kind(ps, i, "identifier", "string") + node.key.conststr = node.key.tk + node.key.tk = '"' .. node.key.tk .. '"' + i = verify_tk(ps, i, "=") + i, node.value = parse_table_value(ps, i) + return i, node, n + elseif ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then + node.key_parsed = "short" + local orig_i = i + local try_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + } + i, node.key = verify_kind(try_ps, i, "identifier", "string") + node.key.conststr = node.key.tk + node.key.tk = '"' .. node.key.tk .. '"' + i = verify_tk(try_ps, i, ":") + i, node.decltype = parse_type(try_ps, i) + if node.decltype and ps.tokens[i].tk == "=" then + i = verify_tk(try_ps, i, "=") + i, node.value = parse_table_value(try_ps, i) + if node.value then + for _, e in ipairs(try_ps.errs) do + table.insert(ps.errs, e) + end + return i, node, n + end + end + + node.decltype = nil + i = orig_i + end + + node.key = new_node(ps.tokens, i, "number") + node.key_parsed = "implicit" + node.key.constnum = n + node.key.tk = tostring(n) + i, node.value = parse_expression(ps, i) + return i, node, n + 1 +end + +local ParseItem = {} + +local SeparatorMode = {} + + + + +local function parse_list(ps, i, list, close, sep, parse_item) + local n = 1 + while ps.tokens[i].kind ~= "$EOF$" do + if close[ps.tokens[i].tk] then + (list).yend = ps.tokens[i].y + break + end + local item + i, item, n = parse_item(ps, i, n) + table.insert(list, item) + if ps.tokens[i].tk == "," then + i = i + 1 + if sep == "sep" and close[ps.tokens[i].tk] then + return fail(ps, i) + end + elseif sep == "term" and ps.tokens[i].tk == ";" then + i = i + 1 + elseif not close[ps.tokens[i].tk] then + return fail(ps, i) + end + end + return i, list +end + +local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) + i = verify_tk(ps, i, open) + i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) + i = verify_tk(ps, i, close) + return i, list +end + +local function parse_table_literal(ps, i) + local node = new_node(ps.tokens, i, "table_literal") + return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) +end + +local function parse_trying_list(ps, i, list, parse_item) + local try_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + } + local tryi, item = parse_item(try_ps, i) + if not item then + return i, list + end + for _, e in ipairs(try_ps.errs) do + table.insert(ps.errs, e) + end + i = tryi + table.insert(list, item) + if ps.tokens[i].tk == "," then + while ps.tokens[i].tk == "," do + i = i + 1 + i, item = parse_item(ps, i) + table.insert(list, item) + end + end + return i, list +end + +local function parse_typearg_type(ps, i) + local backtick = false + if ps.tokens[i].tk == "`" then + i = verify_tk(ps, i, "`") + backtick = true + end + i = verify_kind(ps, i, "identifier") + return i, a_type({ + y = ps.tokens[i - 2].y, + x = ps.tokens[i - 2].x, + typename = "typearg", + typearg = (backtick and "`" or "") .. ps.tokens[i - 1].tk, + }) +end + +local function parse_typevar_type(ps, i) + i = verify_tk(ps, i, "`") + i = verify_kind(ps, i, "identifier") + return i, a_type({ + y = ps.tokens[i - 2].y, + x = ps.tokens[i - 2].x, + typename = "typevar", + typevar = "`" .. ps.tokens[i - 1].tk, + }) +end + +local function parse_typearg_list(ps, i) + local typ = new_type(ps, i, "tuple") + return parse_bracket_list(ps, i, typ, "<", ">", "sep", parse_typearg_type) +end + +local function parse_typeval_list(ps, i) + local typ = new_type(ps, i, "tuple") + return parse_bracket_list(ps, i, typ, "<", ">", "sep", parse_type) +end + +local function parse_return_types(ps, i) + return parse_type_list(ps, i, "rets") +end + +local function parse_function_type(ps, i) + local node = new_type(ps, i, "function") + node.args = {} + node.rets = {} + i = i + 1 + if ps.tokens[i].tk == "<" then + i, node.typeargs = parse_typearg_list(ps, i) + end + if ps.tokens[i].tk == "(" then + i, node.args = parse_argument_type_list(ps, i) + i, node.rets = parse_return_types(ps, i) + else + node.args = { a_type({ typename = "any", is_va = true }) } + node.rets = { a_type({ typename = "any", is_va = true }) } + end + return i, node +end + +local function parse_base_type(ps, i) + if ps.tokens[i].tk == "string" or + ps.tokens[i].tk == "boolean" or + ps.tokens[i].tk == "nil" or + ps.tokens[i].tk == "number" or + ps.tokens[i].tk == "thread" then + local typ = new_type(ps, i, ps.tokens[i].tk) + typ.tk = nil + return i + 1, typ + elseif ps.tokens[i].tk == "table" then + local typ = new_type(ps, i, "map") + typ.keys = a_type({ typename = "any" }) + typ.values = a_type({ typename = "any" }) + return i + 1, typ + elseif ps.tokens[i].tk == "function" then + return parse_function_type(ps, i) + elseif ps.tokens[i].tk == "{" then + i = i + 1 + local decl = new_type(ps, i, "array") + local t + i, t = parse_type(ps, i) + if ps.tokens[i].tk == "}" then + decl.elements = t + decl.yend = ps.tokens[i].y + i = verify_tk(ps, i, "}") + elseif ps.tokens[i].tk == ":" then + decl.typename = "map" + i = i + 1 + decl.keys = t + i, decl.values = parse_type(ps, i) + decl.yend = ps.tokens[i].y + i = verify_tk(ps, i, "}") + end + return i, decl + elseif ps.tokens[i].tk == "`" then + return parse_typevar_type(ps, i) + elseif ps.tokens[i].kind == "identifier" then + local typ = new_type(ps, i, "nominal") + typ.names = { ps.tokens[i].tk } + i = i + 1 + while ps.tokens[i].tk == "." do + i = i + 1 + if ps.tokens[i].kind == "identifier" then + table.insert(typ.names, ps.tokens[i].tk) + i = i + 1 + else + return fail(ps, i, "syntax error, expected identifier") + end + end + + if ps.tokens[i].tk == "<" then + i, typ.typevals = parse_typeval_list(ps, i) + end + return i, typ + end + return fail(ps, i) +end + +parse_type = function(ps, i) + if ps.tokens[i].tk == "(" then + i = i + 1 + local t + i, t = parse_type(ps, i) + i = verify_tk(ps, i, ")") + return i, t + end + + local bt + local istart = i + i, bt = parse_base_type(ps, i) + if not bt then + return i + end + if ps.tokens[i].tk == "|" then + local u = new_type(ps, istart, "union") + u.types = { bt } + while ps.tokens[i].tk == "|" do + i = i + 1 + i, bt = parse_base_type(ps, i) + if not bt then + return i + end + table.insert(u.types, bt) + end + bt = u + end + return i, bt +end + +parse_type_list = function(ps, i, mode) + local list = new_type(ps, i, "tuple") + + local first_token = ps.tokens[i].tk + if mode == "rets" or mode == "decltype" then + if first_token == ":" then + i = i + 1 + else + return i, list + end + end + + local optional_paren = false + if ps.tokens[i].tk == "(" then + optional_paren = true + i = i + 1 + end + + local prev_i = i + i = parse_trying_list(ps, i, list, parse_type) + if i == prev_i and ps.tokens[i].tk ~= ")" then + fail(ps, i - 1, "expected a type list") + end + + if mode == "rets" and ps.tokens[i].tk == "..." then + i = i + 1 + local nrets = #list + if nrets > 0 then + list[nrets].is_va = true + else + return fail(ps, i, "unexpected '...'") + end + end + + if optional_paren then + i = verify_tk(ps, i, ")") + end + + return i, list +end + +local function parse_function_args_rets_body(ps, i, node) + if ps.tokens[i].tk == "<" then + i, node.typeargs = parse_typearg_list(ps, i) + end + i, node.args = parse_argument_list(ps, i) + i, node.rets = parse_return_types(ps, i) + i, node.body = parse_statements(ps, i) + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "end") + return i, node +end + +local function parse_function_value(ps, i) + local node = new_node(ps.tokens, i, "function") + i = verify_tk(ps, i, "function") + return parse_function_args_rets_body(ps, i, node) +end + +local function unquote(str) + local f = str:sub(1, 1) + if f == '"' or f == "'" then + return str:sub(2, -2) + end + f = str:match("^%[=*%[") + local l = #f + 1 + return str:sub(l, -l) +end + +local function parse_literal(ps, i) + if ps.tokens[i].tk == "{" then + return parse_table_literal(ps, i) + elseif ps.tokens[i].kind == "..." then + return verify_kind(ps, i, "...") + elseif ps.tokens[i].kind == "string" then + local tk = unquote(ps.tokens[i].tk) + local node + i, node = verify_kind(ps, i, "string") + node.conststr = tk + return i, node + elseif ps.tokens[i].kind == "identifier" then + return verify_kind(ps, i, "identifier", "variable") + elseif ps.tokens[i].kind == "number" then + local n = tonumber(ps.tokens[i].tk) + local node + i, node = verify_kind(ps, i, "number") + node.constnum = n + return i, node + elseif ps.tokens[i].tk == "true" then + return verify_kind(ps, i, "keyword", "boolean") + elseif ps.tokens[i].tk == "false" then + return verify_kind(ps, i, "keyword", "boolean") + elseif ps.tokens[i].tk == "nil" then + return verify_kind(ps, i, "keyword", "nil") + elseif ps.tokens[i].tk == "function" then + return parse_function_value(ps, i) + end + return fail(ps, i) +end + +do + local precedences = { + [1] = { + ["not"] = 11, + ["#"] = 11, + ["-"] = 11, + ["~"] = 11, + }, + [2] = { + ["or"] = 1, + ["and"] = 2, + ["is"] = 3, + ["<"] = 3, + [">"] = 3, + ["<="] = 3, + [">="] = 3, + ["~="] = 3, + ["=="] = 3, + ["|"] = 4, + ["~"] = 5, + ["&"] = 6, + ["<<"] = 7, + [">>"] = 7, + [".."] = 8, + ["+"] = 8, + ["-"] = 9, + ["*"] = 10, + ["/"] = 10, + ["//"] = 10, + ["%"] = 10, + ["^"] = 12, + ["as"] = 50, + ["@funcall"] = 100, + ["@index"] = 100, + ["."] = 100, + [":"] = 100, + }, + } + + local is_right_assoc = { + ["^"] = true, + [".."] = true, + } + + local function new_operator(tk, arity, op) + op = op or tk.tk + return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } + end + + local E + + local function P(ps, i) + if ps.tokens[i].kind == "$EOF$" then + return i + end + local e1 + local t1 = ps.tokens[i] + if precedences[1][ps.tokens[i].tk] ~= nil then + local op = new_operator(ps.tokens[i], 1) + i = i + 1 + i, e1 = P(ps, i) + e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } + elseif ps.tokens[i].tk == "(" then + i = i + 1 + i, e1 = parse_expression(ps, i) + e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } + i = verify_tk(ps, i, ")") + else + i, e1 = parse_literal(ps, i) + end + + while true do + if ps.tokens[i].kind == "string" or ps.tokens[i].kind == "{" then + local op = new_operator(ps.tokens[i], 2, "@funcall") + local args = new_node(ps.tokens, i, "expression_list") + local arg + if ps.tokens[i].kind == "string" then + arg = new_node(ps.tokens, i) + arg.conststr = unquote(ps.tokens[i].tk) + i = i + 1 + else + i, arg = parse_table_literal(ps, i) + end + table.insert(args, arg) + e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = args } + elseif ps.tokens[i].tk == "(" then + local op = new_operator(ps.tokens[i], 2, "@funcall") + + local args = new_node(ps.tokens, i, "expression_list") + i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) + + e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = args } + elseif ps.tokens[i].tk == "[" then + local op = new_operator(ps.tokens[i], 2, "@index") + + local idx + i = i + 1 + i, idx = parse_expression(ps, i) + i = verify_tk(ps, i, "]") + + e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = idx } + elseif ps.tokens[i].tk == "." or ps.tokens[i].tk == ":" then + local op = new_operator(ps.tokens[i], 2) + + local key + i = i + 1 + i, key = verify_kind(ps, i, "identifier") + + e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = key } + elseif ps.tokens[i].tk == "as" or ps.tokens[i].tk == "is" then + local op = new_operator(ps.tokens[i], 2, ps.tokens[i].tk) + + i = i + 1 + local cast = new_node(ps.tokens, i, "cast") + if ps.tokens[i].tk == "(" then + i, cast.casttype = parse_type_list(ps, i, "casttype") + else + i, cast.casttype = parse_type(ps, i) + end + e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } + else + break + end + end + + return i, e1 + end + + local function E(ps, i, lhs, min_precedence) + local lookahead = ps.tokens[i].tk + while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do + local t1 = ps.tokens[i] + local op = new_operator(t1, 2) + i = i + 1 + local rhs + i, rhs = P(ps, i) + lookahead = ps.tokens[i].tk + while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or + (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do + i, rhs = E(ps, i, rhs, precedences[2][lookahead]) + lookahead = ps.tokens[i].tk + end + lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } + end + return i, lhs + end + + parse_expression = function(ps, i) + local lhs + i, lhs = P(ps, i) + i, lhs = E(ps, i, lhs, 0) + if lhs then + return i, lhs, 0 + else + return fail(ps, i, "expected an expression") + end + end +end + +local function parse_variable_name(ps, i) + local is_const = false + local node + i, node = verify_kind(ps, i, "identifier") + if not node then + return i + end + if ps.tokens[i].tk == "<" then + i = i + 1 + local annotation + i, annotation = verify_kind(ps, i, "identifier") + if annotation and annotation.tk == "const" then + is_const = true + end + i = verify_tk(ps, i, ">") + end + node.is_const = is_const + return i, node +end + +local function parse_argument(ps, i) + local node + if ps.tokens[i].tk == "..." then + i, node = verify_kind(ps, i, "...") + else + i, node = verify_kind(ps, i, "identifier", "argument") + end + if ps.tokens[i].tk == ":" then + i = i + 1 + local decltype + + i, decltype = parse_type(ps, i) + + if node then + i, node.decltype = i, decltype + end + end + return i, node, 0 +end + +parse_argument_list = function(ps, i) + local node = new_node(ps.tokens, i, "argument_list") + return parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) +end + +local function parse_argument_type(ps, i) + local is_va = false + if ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then + i = i + 2 + elseif ps.tokens[i].tk == "..." then + if ps.tokens[i + 1].tk == ":" then + i = i + 2 + is_va = true + else + return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") + end + end + + local i, typ = parse_type(ps, i) + if typ then + typ.is_va = is_va + end + + return i, typ, 0 +end + +parse_argument_type_list = function(ps, i) + local list = new_type(ps, i, "tuple") + return parse_bracket_list(ps, i, list, "(", ")", "sep", parse_argument_type) +end + +local function parse_local_function(ps, i) + local node = new_node(ps.tokens, i, "local_function") + i = verify_tk(ps, i, "local") + i = verify_tk(ps, i, "function") + i, node.name = verify_kind(ps, i, "identifier") + return parse_function_args_rets_body(ps, i, node) +end + +local function parse_function(ps, i) + local orig_i = i + local fn = new_node(ps.tokens, i, "global_function") + local node = fn + i = verify_tk(ps, i, "function") + local names = {} + i, names[1] = verify_kind(ps, i, "identifier", "variable") + while ps.tokens[i].tk == "." do + i = i + 1 + i, names[#names + 1] = verify_kind(ps, i, "identifier") + end + if ps.tokens[i].tk == ":" then + i = i + 1 + i, names[#names + 1] = verify_kind(ps, i, "identifier") + fn.is_method = true + end + + if #names > 1 then + fn.kind = "record_function" + local owner = names[1] + for i = 2, #names - 1 do + local dot = { y = names[i].y, x = names[i].x - 1, arity = 2, op = "." } + names[i].kind = "identifier" + local op = { y = names[i].y, x = names[i].x, kind = "op", op = dot, e1 = owner, e2 = names[i] } + owner = op + end + fn.fn_owner = owner + end + fn.name = names[#names] + + local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y + i = parse_function_args_rets_body(ps, i, fn) + if fn.is_method then + table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "variable" }) + end + + if not fn.name then + return orig_i + end + + return i, node +end + +local function parse_if(ps, i) + local node = new_node(ps.tokens, i, "if") + i = verify_tk(ps, i, "if") + i, node.exp = parse_expression(ps, i) + i = verify_tk(ps, i, "then") + i, node.thenpart = parse_statements(ps, i) + node.elseifs = {} + local n = 0 + while ps.tokens[i].tk == "elseif" do + n = n + 1 + local subnode = new_node(ps.tokens, i, "elseif") + subnode.parent_if = node + subnode.elseif_n = n + i = i + 1 + i, subnode.exp = parse_expression(ps, i) + i = verify_tk(ps, i, "then") + i, subnode.thenpart = parse_statements(ps, i) + table.insert(node.elseifs, subnode) + end + if ps.tokens[i].tk == "else" then + local subnode = new_node(ps.tokens, i, "else") + subnode.parent_if = node + i = i + 1 + i, subnode.elsepart = parse_statements(ps, i) + node.elsepart = subnode + end + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "end") + return i, node +end + +local function parse_while(ps, i) + local node = new_node(ps.tokens, i, "while") + i = verify_tk(ps, i, "while") + i, node.exp = parse_expression(ps, i) + i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "end") + return i, node +end + +local function parse_fornum(ps, i) + local node = new_node(ps.tokens, i, "fornum") + i = i + 1 + i, node.var = verify_kind(ps, i, "identifier") + i = verify_tk(ps, i, "=") + i, node.from = parse_expression(ps, i) + i = verify_tk(ps, i, ",") + i, node.to = parse_expression(ps, i) + if ps.tokens[i].tk == "," then + i = i + 1 + i, node.step = parse_expression(ps, i) + end + i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "end") + return i, node +end + +local function parse_forin(ps, i) + local node = new_node(ps.tokens, i, "forin") + i = i + 1 + node.vars = new_node(ps.tokens, i, "variables") + i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_variable_name) + i = verify_tk(ps, i, "in") + node.exps = new_node(ps.tokens, i, "expression_list") + i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) + if #node.exps < 1 then + return fail(ps, i, "missing iterator expression in generic for") + elseif #node.exps > 3 then + return fail(ps, i, "too many expressions in generic for") + end + i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "end") + return i, node +end + +local function parse_for(ps, i) + if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then + return parse_fornum(ps, i) + else + return parse_forin(ps, i) + end +end + +local function parse_repeat(ps, i) + local node = new_node(ps.tokens, i, "repeat") + i = verify_tk(ps, i, "repeat") + i, node.body = parse_statements(ps, i) + node.body.is_repeat = true + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "until") + i, node.exp = parse_expression(ps, i) + return i, node +end + +local function parse_do(ps, i) + local node = new_node(ps.tokens, i, "do") + i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "end") + return i, node +end + +local function parse_break(ps, i) + local node = new_node(ps.tokens, i, "break") + i = verify_tk(ps, i, "break") + return i, node +end + +local function parse_goto(ps, i) + local node = new_node(ps.tokens, i, "goto") + i = verify_tk(ps, i, "goto") + node.label = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + return i, node +end + +local function parse_label(ps, i) + local node = new_node(ps.tokens, i, "label") + i = verify_tk(ps, i, "::") + node.label = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + i = verify_tk(ps, i, "::") + return i, node +end + +local stop_statement_list = { + ["end"] = true, + ["else"] = true, + ["elseif"] = true, + ["until"] = true, +} + +local stop_return_list = { + [";"] = true, + ["$EOF$"] = true, +} + +for k, v in pairs(stop_statement_list) do + stop_return_list[k] = v +end + +local function parse_return(ps, i) + local node = new_node(ps.tokens, i, "return") + i = verify_tk(ps, i, "return") + node.exps = new_node(ps.tokens, i, "expression_list") + i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) + if ps.tokens[i].kind == ";" then + i = i + 1 + end + return i, node +end + +local function store_field_in_record(name, def, nt) + if def.fields[name] then + return false + end + def.fields[name] = nt.newtype + table.insert(def.field_order, name) + return true +end + +local ParseBody = {} + +local function parse_nested_type(ps, i, def, typename, parse_body) + i = i + 1 + + local v + i, v = verify_kind(ps, i, "identifier", "variable") + if not v then + return fail(ps, i, "expected a variable name") + end + + local nt = new_node(ps.tokens, i, "newtype") + nt.newtype = new_type(ps, i, "typetype") + local rdef = new_type(ps, i, typename) + local iok = parse_body(ps, i, rdef, nt) + if iok then + i = iok + nt.newtype.def = rdef + end + + local ok = store_field_in_record(v.tk, def, nt) + if not ok then + fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)") + end + return i +end + +local function parse_enum_body(ps, i, def, node) + def.enumset = {} + while not ((not ps.tokens[i]) or ps.tokens[i].tk == "end") do + local item + i, item = verify_kind(ps, i, "string", "enum_item") + if item then + table.insert(node, item) + def.enumset[unquote(item.tk)] = true + end + end + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "end") + return i, node +end + +local function parse_record_body(ps, i, def, node) + def.fields = {} + def.field_order = {} + if ps.tokens[i].tk == "<" then + i, def.typeargs = parse_typearg_list(ps, i) + end + while not ((not ps.tokens[i]) or ps.tokens[i].tk == "end") do + if ps.tokens[i].tk == "{" then + if def.typename == "arrayrecord" then + return fail(ps, i, "duplicated declaration of array element type in record") + end + i = i + 1 + local t + i, t = parse_type(ps, i) + if ps.tokens[i].tk == "}" then + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "}") + else + return fail(ps, i, "expected an array declaration") + end + def.typename = "arrayrecord" + def.elements = t + elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then + i = i + 1 + local v + i, v = verify_kind(ps, i, "identifier", "variable") + if not v then + return fail(ps, i, "expected a variable name") + end + i = verify_tk(ps, i, "=") + local nt + i, nt = parse_newtype(ps, i) + if not nt or not nt.newtype then + return fail(ps, i, "expected a type definition") + end + + local ok = store_field_in_record(v.tk, def, nt) + if not ok then + return fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)") + end + elseif ps.tokens[i].tk == "record" and ps.tokens[i + 1].tk ~= ":" then + i = parse_nested_type(ps, i, def, "record", parse_record_body) + elseif ps.tokens[i].tk == "enum" and ps.tokens[i + 1].tk ~= ":" then + i = parse_nested_type(ps, i, def, "enum", parse_enum_body) + else + local v + i, v = verify_kind(ps, i, "identifier", "variable") + local iv = i + if not v then + return fail(ps, i, "expected a variable name") + end + if ps.tokens[i].tk == ":" then + i = verify_tk(ps, i, ":") + local t + i, t = parse_type(ps, i) + if not t then + return fail(ps, i, "expected a type") + end + if not def.fields[v.tk] then + def.fields[v.tk] = t + table.insert(def.field_order, v.tk) + else + local prev_t = def.fields[v.tk] + if t.typename == "function" and prev_t.typename == "function" then + def.fields[v.tk] = new_type(ps, iv, "poly") + def.fields[v.tk].types = { prev_t, t } + elseif t.typename == "function" and prev_t.typename == "poly" then + table.insert(prev_t.types, t) + else + return fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)") + end + end + elseif ps.tokens[i].tk == "=" then + local next_word = ps.tokens[i + 1].tk + if next_word == "record" or next_word == "enum" then + return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") + elseif next_word == "functiontype" then + return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") + else + return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") + end + end + end + end + node.yend = ps.tokens[i].y + i = verify_tk(ps, i, "end") + return i, node +end + +parse_newtype = function(ps, i) + local node = new_node(ps.tokens, i, "newtype") + node.newtype = new_type(ps, i, "typetype") + if ps.tokens[i].tk == "record" then + local def = new_type(ps, i, "record") + i = i + 1 + i = parse_record_body(ps, i, def, node) + node.newtype.def = def + return i, node + elseif ps.tokens[i].tk == "enum" then + local def = new_type(ps, i, "enum") + i = i + 1 + i = parse_enum_body(ps, i, def, node) + node.newtype.def = def + return i, node + else + i, node.newtype.def = parse_type(ps, i) + return i, node + end + return fail(ps, i) +end + +local function parse_call_or_assignment(ps, i) + local asgn = new_node(ps.tokens, i, "assignment") + + local tryi = i + asgn.vars = new_node(ps.tokens, i, "variables") + i = parse_trying_list(ps, i, asgn.vars, parse_expression) + if #asgn.vars < 1 then + return fail(ps, i) + end + local lhs = asgn.vars[1] + + if ps.tokens[i].tk == "=" then + asgn.exps = new_node(ps.tokens, i, "values") + repeat + i = i + 1 + local val + i, val = parse_expression(ps, i) + table.insert(asgn.exps, val) + until ps.tokens[i].tk ~= "," + return i, asgn + end + if #asgn.vars > 1 then + local err_ps = { + tokens = ps.tokens, + errs = {}, + } + local expi = parse_expression(err_ps, tryi) + return fail(ps, expi or i) + end + if lhs.op and lhs.op.op == "@funcall" and #asgn.vars == 1 then + return i, lhs + end + return fail(ps, i) +end + +local function parse_variable_declarations(ps, i, node_name) + local asgn = new_node(ps.tokens, i, node_name) + + asgn.vars = new_node(ps.tokens, i, "variables") + i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) + if #asgn.vars == 0 then + return fail(ps, i, "expected a local variable definition") + end + local lhs = asgn.vars[1] + + i, asgn.decltype = parse_type_list(ps, i, "decltype") + + if ps.tokens[i].tk == "=" then + + if ps.tokens[i + 1].tk == "record" or + ps.tokens[i + 1].tk == "enum" then + + local scope = node_name == "local_declaration" and "local" or "global" + fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. ps.tokens[i + 1].tk .. " " .. asgn.vars[1].tk .. "'") + elseif ps.tokens[i + 1].tk == "functiontype" then + local scope = node_name == "local_declaration" and "local" or "global" + fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...") + end + + asgn.exps = new_node(ps.tokens, i, "values") + local v = 1 + repeat + i = i + 1 + local val + i, val = parse_expression(ps, i) + table.insert(asgn.exps, val) + v = v + 1 + until ps.tokens[i].tk ~= "," + end + return i, asgn +end + +local function parse_type_declaration(ps, i, node_name) + i = i + 2 + + local asgn = new_node(ps.tokens, i, node_name) + i, asgn.var = parse_variable_name(ps, i) + if not asgn.var then + return fail(ps, i, "expected a type name") + end + i = verify_tk(ps, i, "=") + i, asgn.value = parse_newtype(ps, i) + if asgn.value then + asgn.value.newtype.def.names = { asgn.var.tk } + else + return i + end + + return i, asgn +end + +local ParseBody = {} + +local function parse_type_constructor(ps, i, node_name, type_name, parse_body) + local asgn = new_node(ps.tokens, i, node_name) + local nt = new_node(ps.tokens, i, "newtype") + asgn.value = nt + nt.newtype = new_type(ps, i, "typetype") + local def = new_type(ps, i, type_name) + nt.newtype.def = def + + i = i + 2 + + i, asgn.var = verify_kind(ps, i, "identifier") + if not asgn.var then + return fail(ps, i, "expected a type name") + end + nt.newtype.def.names = { asgn.var.tk } + + i = parse_body(ps, i, def, nt) + return i, asgn +end + +local function parse_statement(ps, i) + if ps.tokens[i].tk == "local" then + if ps.tokens[i + 1].tk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i, "local_type") + elseif ps.tokens[i + 1].tk == "function" then + return parse_local_function(ps, i) + elseif ps.tokens[i + 1].tk == "record" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "local_type", "record", parse_record_body) + elseif ps.tokens[i + 1].tk == "enum" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "local_type", "enum", parse_enum_body) + else + i = i + 1 + return parse_variable_declarations(ps, i, "local_declaration") + end + elseif ps.tokens[i].tk == "global" then + if ps.tokens[i + 1].tk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i, "global_type") + elseif ps.tokens[i + 1].tk == "record" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "global_type", "record", parse_record_body) + elseif ps.tokens[i + 1].tk == "enum" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "global_type", "enum", parse_enum_body) + elseif ps.tokens[i + 1].tk == "function" then + i = i + 1 + return parse_function(ps, i) + else + i = i + 1 + return parse_variable_declarations(ps, i, "global_declaration") + end + elseif ps.tokens[i].tk == "function" then + return parse_function(ps, i) + elseif ps.tokens[i].tk == "if" then + return parse_if(ps, i) + elseif ps.tokens[i].tk == "while" then + return parse_while(ps, i) + elseif ps.tokens[i].tk == "repeat" then + return parse_repeat(ps, i) + elseif ps.tokens[i].tk == "for" then + return parse_for(ps, i) + elseif ps.tokens[i].tk == "do" then + return parse_do(ps, i) + elseif ps.tokens[i].tk == "break" then + return parse_break(ps, i) + elseif ps.tokens[i].tk == "return" then + return parse_return(ps, i) + elseif ps.tokens[i].tk == "goto" then + return parse_goto(ps, i) + elseif ps.tokens[i].tk == "::" then + return parse_label(ps, i) + else + return parse_call_or_assignment(ps, i) + end +end + +parse_statements = function(ps, i, filename, toplevel) + local node = new_node(ps.tokens, i, "statements") + while true do + while ps.tokens[i].kind == ";" do + i = i + 1 + end + if ps.tokens[i].kind == "$EOF$" then + break + end + if (not toplevel) and stop_statement_list[ps.tokens[i].tk] then + break + end + local item + i, item = parse_statement(ps, i) + if filename then + for j = 1, #ps.errs do + if not ps.errs[j].filename then + ps.errs[j].filename = filename + end + end + end + if not item then + break + end + table.insert(node, item) + end + return i, node +end + +function tl.parse_program(tokens, errs, filename) + errs = errs or {} + local ps = { + tokens = tokens, + errs = errs, + filename = filename, + } + local last = ps.tokens[#ps.tokens] or { y = 1, x = 1, tk = "" } + table.insert(ps.tokens, { y = last.y, x = last.x + #last.tk, tk = "$EOF$", kind = "$EOF$" }) + return parse_statements(ps, 1, filename, true) +end + + + + + +local VisitorCallbacks = {} + + + + + + +local Visitor = {} + + + + +local function visit_before(ast, kind, visit) + assert(visit.cbs[kind], "no visitor for " .. (kind)) + if visit.cbs[kind].before then + visit.cbs[kind].before(ast) + end +end + +local function visit_after(ast, kind, visit, xs) + if visit.after and visit.after.before then + visit.after.before(ast, xs) + end + local ret + if visit.cbs[kind].after then + ret = visit.cbs[kind].after(ast, xs) + end + if visit.after and visit.after.after then + ret = visit.after.after(ast, xs, ret) + end + return ret +end + +local function recurse_type(ast, visit) + visit_before(ast, ast.typename, visit) + local xs = {} + + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(child, visit)) + end + end + + for i, child in ipairs(ast) do + xs[i] = recurse_type(child, visit) + end + + if ast.types then + for i, child in ipairs(ast.types) do + table.insert(xs, recurse_type(child, visit)) + end + end + if ast.def then + table.insert(xs, recurse_type(ast.def, visit)) + end + if ast.keys then + table.insert(xs, recurse_type(ast.keys, visit)) + end + if ast.values then + table.insert(xs, recurse_type(ast.values, visit)) + end + if ast.elements then + table.insert(xs, recurse_type(ast.elements, visit)) + end + if ast.fields then + for _, child in pairs(ast.fields) do + table.insert(xs, recurse_type(child, visit)) + end + end + if ast.args then + for i, child in ipairs(ast.args) do + if i > 1 or not ast.is_method then + table.insert(xs, recurse_type(child, visit)) + end + end + end + if ast.rets then + for _, child in ipairs(ast.rets) do + table.insert(xs, recurse_type(child, visit)) + end + end + if ast.typevals then + for _, child in ipairs(ast.typevals) do + table.insert(xs, recurse_type(child, visit)) + end + end + if ast.ktype then + table.insert(xs, recurse_type(ast.ktype, visit)) + end + if ast.vtype then + table.insert(xs, recurse_type(ast.vtype, visit)) + end + + return visit_after(ast, ast.typename, visit, xs) +end + +local function recurse_node(ast, +visit_node, +visit_type) + if not ast then + + return + end + + visit_before(ast, ast.kind, visit_node) + local xs = {} + local cbs = visit_node.cbs[ast.kind] + if ast.kind == "statements" or + ast.kind == "variables" or + ast.kind == "values" or + ast.kind == "argument_list" or + ast.kind == "expression_list" or + ast.kind == "table_literal" then + for i, child in ipairs(ast) do + xs[i] = recurse_node(child, visit_node, visit_type) + end + elseif ast.kind == "local_declaration" or + ast.kind == "global_declaration" or + ast.kind == "assignment" then + xs[1] = recurse_node(ast.vars, visit_node, visit_type) + if ast.exps then + xs[2] = recurse_node(ast.exps, visit_node, visit_type) + end + if ast.decltype then + xs[3] = recurse_type(ast.decltype, visit_type) + end + elseif ast.kind == "local_type" or + ast.kind == "global_type" then + xs[1] = recurse_node(ast.var, visit_node, visit_type) + xs[2] = recurse_node(ast.value, visit_node, visit_type) + elseif ast.kind == "table_item" then + xs[1] = recurse_node(ast.key, visit_node, visit_type) + xs[2] = recurse_node(ast.value, visit_node, visit_type) + elseif ast.kind == "if" then + xs[1] = recurse_node(ast.exp, visit_node, visit_type) + if cbs.before_statements then + cbs.before_statements(ast, xs) + end + xs[2] = recurse_node(ast.thenpart, visit_node, visit_type) + for i, e in ipairs(ast.elseifs) do + table.insert(xs, recurse_node(e, visit_node, visit_type)) + end + if ast.elsepart then + table.insert(xs, recurse_node(ast.elsepart, visit_node, visit_type)) + end + elseif ast.kind == "while" then + xs[1] = recurse_node(ast.exp, visit_node, visit_type) + if cbs.before_statements then + cbs.before_statements(ast, xs) + end + xs[2] = recurse_node(ast.body, visit_node, visit_type) + elseif ast.kind == "repeat" then + xs[1] = recurse_node(ast.body, visit_node, visit_type) + xs[2] = recurse_node(ast.exp, visit_node, visit_type) + elseif ast.kind == "function" then + xs[1] = recurse_node(ast.args, visit_node, visit_type) + xs[2] = recurse_type(ast.rets, visit_type) + xs[3] = recurse_node(ast.body, visit_node, visit_type) + elseif ast.kind == "forin" then + xs[1] = recurse_node(ast.vars, visit_node, visit_type) + xs[2] = recurse_node(ast.exps, visit_node, visit_type) + if cbs.before_statements then + cbs.before_statements(ast) + end + xs[3] = recurse_node(ast.body, visit_node, visit_type) + elseif ast.kind == "fornum" then + xs[1] = recurse_node(ast.var, visit_node, visit_type) + xs[2] = recurse_node(ast.from, visit_node, visit_type) + xs[3] = recurse_node(ast.to, visit_node, visit_type) + xs[4] = ast.step and recurse_node(ast.step, visit_node, visit_type) + xs[5] = recurse_node(ast.body, visit_node, visit_type) + elseif ast.kind == "elseif" then + xs[1] = recurse_node(ast.exp, visit_node, visit_type) + if cbs.before_statements then + cbs.before_statements(ast, xs) + end + xs[2] = recurse_node(ast.thenpart, visit_node, visit_type) + elseif ast.kind == "else" then + xs[1] = recurse_node(ast.elsepart, visit_node, visit_type) + elseif ast.kind == "return" then + xs[1] = recurse_node(ast.exps, visit_node, visit_type) + elseif ast.kind == "do" then + xs[1] = recurse_node(ast.body, visit_node, visit_type) + elseif ast.kind == "cast" then + elseif ast.kind == "local_function" or + ast.kind == "global_function" then + xs[1] = recurse_node(ast.name, visit_node, visit_type) + xs[2] = recurse_node(ast.args, visit_node, visit_type) + xs[3] = recurse_type(ast.rets, visit_type) + xs[4] = recurse_node(ast.body, visit_node, visit_type) + elseif ast.kind == "record_function" then + xs[1] = recurse_node(ast.fn_owner, visit_node, visit_type) + xs[2] = recurse_node(ast.name, visit_node, visit_type) + xs[3] = recurse_node(ast.args, visit_node, visit_type) + xs[4] = recurse_type(ast.rets, visit_type) + if cbs.before_statements then + cbs.before_statements(ast, xs) + end + xs[5] = recurse_node(ast.body, visit_node, visit_type) + elseif ast.kind == "paren" then + xs[1] = recurse_node(ast.e1, visit_node, visit_type) + elseif ast.kind == "op" then + xs[1] = recurse_node(ast.e1, visit_node, visit_type) + local p1 = ast.e1.op and ast.e1.op.prec or nil + if ast.op.op == ":" and ast.e1.kind == "string" then + p1 = -999 + end + xs[2] = p1 + if ast.op.arity == 2 then + if cbs.before_e2 then + cbs.before_e2(ast, xs) + end + if ast.op.op == "is" or ast.op.op == "as" then + xs[3] = recurse_type(ast.e2.casttype, visit_type) + else + xs[3] = recurse_node(ast.e2, visit_node, visit_type) + end + xs[4] = (ast.e2.op and ast.e2.op.prec) + end + elseif ast.kind == "newtype" then + xs[1] = recurse_type(ast.newtype, visit_type) + elseif ast.kind == "variable" or + ast.kind == "argument" or + ast.kind == "identifier" or + ast.kind == "string" or + ast.kind == "number" or + ast.kind == "break" or + ast.kind == "goto" or + ast.kind == "label" or + ast.kind == "nil" or + ast.kind == "..." or + ast.kind == "boolean" then + if ast.decltype then + xs[1] = recurse_type(ast.decltype, visit_type) + end + else + if not ast.kind then + error("wat: " .. inspect(ast)) + end + error("unknown node kind " .. ast.kind) + end + return visit_after(ast, ast.kind, visit_node, xs) +end + + + + + +local tight_op = { + [1] = { + ["-"] = true, + ["~"] = true, + ["#"] = true, + }, + [2] = { + ["."] = true, + [":"] = true, + }, +} + +local spaced_op = { + [1] = { + ["not"] = true, + }, + [2] = { + ["or"] = true, + ["and"] = true, + ["<"] = true, + [">"] = true, + ["<="] = true, + [">="] = true, + ["~="] = true, + ["=="] = true, + ["|"] = true, + ["~"] = true, + ["&"] = true, + ["<<"] = true, + [">>"] = true, + [".."] = true, + ["+"] = true, + ["-"] = true, + ["*"] = true, + ["/"] = true, + ["//"] = true, + ["%"] = true, + ["^"] = true, + }, +} + +local PrettyPrintOpts = {} + + + + +local default_pretty_print_ast_opts = { + preserve_indent = true, + preserve_newlines = true, +} + +local fast_pretty_print_ast_opts = { + preserve_indent = false, + preserve_newlines = true, +} + +function tl.pretty_print_ast(ast, mode) + local indent = 0 + + local opts + if type(mode) == "table" then + opts = mode + elseif mode == true then + opts = fast_pretty_print_ast_opts + else + opts = default_pretty_print_ast_opts + end + + local Output = {} + + + + + + local function increment_indent() + indent = indent + 1 + end + + if not opts.preserve_indent then + increment_indent = nil + end + + local function add(out, s) + table.insert(out, s) + end + + local function add_string(out, s) + table.insert(out, s) + if string.find(s, "\n", 1, true) then + for nl in s:gmatch("\n") do + out.h = out.h + 1 + end + end + end + + local function add_child(out, child, space, indent) + if #child == 0 then + return + end + + if child.y < out.y then + out.y = child.y + end + + if child.y > out.y + out.h and opts.preserve_newlines then + local delta = child.y - (out.y + out.h) + out.h = out.h + delta + table.insert(out, ("\n"):rep(delta)) + else + if space then + table.insert(out, space) + indent = nil + end + end + if indent and opts.preserve_indent then + table.insert(out, (" "):rep(indent)) + end + table.insert(out, child) + out.h = out.h + child.h + end + + local function concat_output(out) + for i, s in ipairs(out) do + if type(s) == "table" then + out[i] = concat_output(s) + end + end + return table.concat(out) + end + + local function print_record_def(typ) + local out = { "{" } + for name, field in pairs(typ.fields) do + if field.typename == "typetype" and is_record_type(field.def) then + table.insert(out, name) + table.insert(out, " = ") + table.insert(out, print_record_def(field.def)) + table.insert(out, ", ") + end + end + table.insert(out, "}") + return table.concat(out) + end + + local visit_node = {} + + visit_node.cbs = { + ["statements"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + local space + for i, child in ipairs(children) do + add_child(out, children[i], space, indent) + space = "; " + end + return out + end, + }, + ["local_declaration"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "local") + add_child(out, children[1], " ") + if children[2] then + table.insert(out, " =") + add_child(out, children[2], " ") + end + return out + end, + }, + ["local_type"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "local") + add_child(out, children[1], " ") + table.insert(out, " =") + add_child(out, children[2], " ") + return out + end, + }, + ["global_type"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + add_child(out, children[1], " ") + table.insert(out, " =") + add_child(out, children[2], " ") + return out + end, + }, + ["global_declaration"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + if children[2] then + add_child(out, children[1]) + table.insert(out, " =") + add_child(out, children[2], " ") + end + return out + end, + }, + ["assignment"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + add_child(out, children[1]) + table.insert(out, " =") + add_child(out, children[2], " ") + return out + end, + }, + ["if"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "if") + add_child(out, children[1], " ") + table.insert(out, " then") + add_child(out, children[2], " ") + indent = indent - 1 + for i = 3, #children do + add_child(out, children[i], " ", indent) + end + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["while"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "while") + add_child(out, children[1], " ") + table.insert(out, " do") + add_child(out, children[2], " ") + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["repeat"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "repeat") + add_child(out, children[1], " ") + if opts.preserve_indent then + indent = indent - 1 + end + add_child(out, { y = node.yend, h = 0, [1] = "until " }, " ", indent) + add_child(out, children[2]) + return out + end, + }, + ["do"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "do") + add_child(out, children[1], " ") + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["forin"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "for") + add_child(out, children[1], " ") + table.insert(out, " in") + add_child(out, children[2], " ") + table.insert(out, " do") + add_child(out, children[3], " ") + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["fornum"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "for") + add_child(out, children[1], " ") + table.insert(out, " =") + add_child(out, children[2], " ") + table.insert(out, ",") + add_child(out, children[3], " ") + if children[4] then + table.insert(out, ",") + add_child(out, children[4], " ") + end + table.insert(out, " do") + add_child(out, children[5], " ") + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["return"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "return") + if #children[1] > 0 then + add_child(out, children[1], " ") + end + return out + end, + }, + ["break"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "break") + return out + end, + }, + ["elseif"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "elseif") + add_child(out, children[1], " ") + table.insert(out, " then") + add_child(out, children[2], " ") + return out + end, + }, + ["else"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "else") + add_child(out, children[1], " ") + return out + end, + }, + ["variables"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + local space + for i, child in ipairs(children) do + if i > 1 then + table.insert(out, ",") + space = " " + end + add_child(out, child, space) + end + return out + end, + }, + ["table_literal"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + if #children == 0 then + indent = indent - 1 + table.insert(out, "{}") + return out + end + table.insert(out, "{") + local n = #children + for i, child in ipairs(children) do + add_child(out, child, " ", child.y ~= node.y and indent) + if i < n or node.yend ~= node.y then + table.insert(out, ",") + end + end + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "}" }, " ", indent) + return out + end, + }, + ["table_item"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + if node.key_parsed ~= "implicit" then + if node.key_parsed == "short" then + children[1][1] = children[1][1]:sub(2, -2) + add_child(out, children[1]) + table.insert(out, " = ") + else + table.insert(out, "[") + add_child(out, children[1]) + table.insert(out, "] = ") + end + end + add_child(out, children[2]) + return out + end, + }, + ["local_function"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "local function") + add_child(out, children[1], " ") + table.insert(out, "(") + add_child(out, children[2]) + table.insert(out, ")") + add_child(out, children[4], " ") + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["global_function"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "function") + add_child(out, children[1], " ") + table.insert(out, "(") + add_child(out, children[2]) + table.insert(out, ")") + add_child(out, children[4], " ") + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["record_function"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "function") + add_child(out, children[1], " ") + table.insert(out, node.is_method and ":" or ".") + add_child(out, children[2]) + table.insert(out, "(") + if node.is_method then + + table.remove(children[3], 1) + if children[3][1] == "," then + table.remove(children[3], 1) + table.remove(children[3], 1) + end + end + add_child(out, children[3]) + table.insert(out, ")") + add_child(out, children[5], " ") + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["function"] = { + before = increment_indent, + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "function(") + add_child(out, children[1]) + table.insert(out, ")") + add_child(out, children[3], " ") + indent = indent - 1 + add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) + return out + end, + }, + ["cast"] = {}, + + ["paren"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "(") + add_child(out, children[1], "", indent) + table.insert(out, ")") + return out + end, + }, + ["op"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + if node.op.op == "@funcall" then + add_child(out, children[1], "", indent) + table.insert(out, "(") + add_child(out, children[3], "", indent) + table.insert(out, ")") + elseif node.op.op == "@index" then + add_child(out, children[1], "", indent) + table.insert(out, "[") + add_child(out, children[3], "", indent) + table.insert(out, "]") + elseif node.op.op == "as" then + add_child(out, children[1], "", indent) + elseif node.op.op == "is" then + table.insert(out, "type(") + add_child(out, children[1], "", indent) + table.insert(out, ") == \"") + add_child(out, children[3], "", indent) + table.insert(out, "\"") + elseif spaced_op[node.op.arity][node.op.op] or tight_op[node.op.arity][node.op.op] then + local space = spaced_op[node.op.arity][node.op.op] and " " or "" + if children[2] and node.op.prec > tonumber(children[2]) then + table.insert(children[1], 1, "(") + table.insert(children[1], ")") + end + if node.op.arity == 1 then + table.insert(out, node.op.op) + add_child(out, children[1], space, indent) + elseif node.op.arity == 2 then + add_child(out, children[1], "", indent) + if space == " " then + table.insert(out, " ") + end + table.insert(out, node.op.op) + if children[4] and node.op.prec > tonumber(children[4]) then + table.insert(children[3], 1, "(") + table.insert(children[3], ")") + end + add_child(out, children[3], space, indent) + end + else + error("unknown node op " .. node.op.op) + end + return out + end, + }, + ["variable"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + add_string(out, node.tk) + return out + end, + }, + ["newtype"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + if is_record_type(node.newtype.def) then + table.insert(out, print_record_def(node.newtype.def)) + else + table.insert(out, "{}") + end + return out + end, + }, + ["goto"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "goto ") + table.insert(out, node.label) + return out + end, + }, + ["label"] = { + after = function(node, children) + local out = { y = node.y, h = 0 } + table.insert(out, "::") + table.insert(out, node.label) + table.insert(out, "::") + return out + end, + }, + } + + local primitive = { + ["function"] = "function", + ["enum"] = "string", + ["boolean"] = "boolean", + ["string"] = "string", + ["nil"] = "nil", + ["number"] = "number", + ["thread"] = "thread", + } + + local visit_type = {} + visit_type.cbs = { + ["string"] = { + after = function(typ, children) + local out = { y = typ.y, h = 0 } + table.insert(out, primitive[typ.typename] or "table") + return out + end, + }, + } + visit_type.cbs["typetype"] = visit_type.cbs["string"] + visit_type.cbs["typevar"] = visit_type.cbs["string"] + visit_type.cbs["typearg"] = visit_type.cbs["string"] + visit_type.cbs["function"] = visit_type.cbs["string"] + visit_type.cbs["thread"] = visit_type.cbs["string"] + visit_type.cbs["array"] = visit_type.cbs["string"] + visit_type.cbs["map"] = visit_type.cbs["string"] + visit_type.cbs["arrayrecord"] = visit_type.cbs["string"] + visit_type.cbs["record"] = visit_type.cbs["string"] + visit_type.cbs["enum"] = visit_type.cbs["string"] + visit_type.cbs["boolean"] = visit_type.cbs["string"] + visit_type.cbs["nil"] = visit_type.cbs["string"] + visit_type.cbs["number"] = visit_type.cbs["string"] + visit_type.cbs["union"] = visit_type.cbs["string"] + visit_type.cbs["nominal"] = visit_type.cbs["string"] + visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] + visit_type.cbs["emptytable"] = visit_type.cbs["string"] + visit_type.cbs["table_item"] = visit_type.cbs["string"] + visit_type.cbs["unknown_emptytable_value"] = visit_type.cbs["string"] + visit_type.cbs["tuple"] = visit_type.cbs["string"] + visit_type.cbs["poly"] = visit_type.cbs["string"] + visit_type.cbs["any"] = visit_type.cbs["string"] + visit_type.cbs["unknown"] = visit_type.cbs["string"] + visit_type.cbs["invalid"] = visit_type.cbs["string"] + visit_type.cbs["unresolved"] = visit_type.cbs["string"] + visit_type.cbs["none"] = visit_type.cbs["string"] + + visit_node.cbs["values"] = visit_node.cbs["variables"] + visit_node.cbs["expression_list"] = visit_node.cbs["variables"] + visit_node.cbs["argument_list"] = visit_node.cbs["variables"] + visit_node.cbs["identifier"] = visit_node.cbs["variable"] + visit_node.cbs["string"] = visit_node.cbs["variable"] + visit_node.cbs["number"] = visit_node.cbs["variable"] + visit_node.cbs["nil"] = visit_node.cbs["variable"] + visit_node.cbs["boolean"] = visit_node.cbs["variable"] + visit_node.cbs["..."] = visit_node.cbs["variable"] + visit_node.cbs["argument"] = visit_node.cbs["variable"] + + local out = recurse_node(ast, visit_node, visit_type) + local code + if opts.preserve_newlines then + code = { y = 1, h = 0 } + add_child(code, out) + else + code = out + end + return concat_output(code) +end + + + + + +local ANY = a_type({ typename = "any" }) +local NONE = a_type({ typename = "none" }) + +local NIL = a_type({ typename = "nil" }) +local NUMBER = a_type({ typename = "number" }) +local STRING = a_type({ typename = "string" }) +local OPT_NUMBER = a_type({ typename = "number" }) +local OPT_STRING = a_type({ typename = "string" }) +local VARARG_ANY = a_type({ typename = "any", is_va = true }) +local VARARG_STRING = a_type({ typename = "string", is_va = true }) +local VARARG_NUMBER = a_type({ typename = "number", is_va = true }) +local VARARG_UNKNOWN = a_type({ typename = "unknown", is_va = true }) +local VARARG_ALPHA = a_type({ typename = "typevar", typevar = "@a", is_va = true }) +local BOOLEAN = a_type({ typename = "boolean" }) +local ARG_ALPHA = a_type({ typename = "typearg", typearg = "@a" }) +local ARG_BETA = a_type({ typename = "typearg", typearg = "@b" }) +local ALPHA = a_type({ typename = "typevar", typevar = "@a" }) +local BETA = a_type({ typename = "typevar", typevar = "@b" }) +local ARRAY_OF_STRING = a_type({ typename = "array", elements = STRING }) +local ARRAY_OF_ALPHA = a_type({ typename = "array", elements = ALPHA }) +local MAP_OF_ALPHA_TO_BETA = a_type({ typename = "map", keys = ALPHA, values = BETA }) +local TABLE = a_type({ typename = "map", keys = ANY, values = ANY }) +local FUNCTION = a_type({ typename = "function", args = { a_type({ typename = "any", is_va = true }) }, rets = { a_type({ typename = "any", is_va = true }) } }) +local THREAD = a_type({ typename = "thread" }) +local INVALID = a_type({ typename = "invalid" }) +local UNKNOWN = a_type({ typename = "unknown" }) +local NOMINAL_FILE = a_type({ typename = "nominal", names = { "FILE" } }) +local NOMINAL_METATABLE = a_type({ typename = "nominal", names = { "METATABLE" } }) + +local OS_DATE_TABLE = a_type({ + typename = "record", + fields = { + ["year"] = NUMBER, + ["month"] = NUMBER, + ["day"] = NUMBER, + ["hour"] = NUMBER, + ["min"] = NUMBER, + ["sec"] = NUMBER, + ["wday"] = NUMBER, + ["yday"] = NUMBER, + ["isdst"] = BOOLEAN, + }, +}) + +local DEBUG_GETINFO_TABLE = a_type({ + typename = "record", + fields = { + ["name"] = STRING, + ["namewhat"] = STRING, + ["source"] = STRING, + ["short_src"] = STRING, + ["linedefined"] = NUMBER, + ["lastlinedefined"] = NUMBER, + ["what"] = STRING, + ["currentline"] = NUMBER, + ["istailcall"] = BOOLEAN, + ["nups"] = NUMBER, + ["nparams"] = NUMBER, + ["isvararg"] = BOOLEAN, + ["func"] = ANY, + ["activelines"] = a_type({ typename = "map", keys = NUMBER, values = BOOLEAN }), + }, +}) + +local numeric_binop = { + ["number"] = { + ["number"] = NUMBER, + }, +} + +local relational_binop = { + ["number"] = { + ["number"] = BOOLEAN, + }, + ["string"] = { + ["string"] = BOOLEAN, + }, + ["boolean"] = { + ["boolean"] = BOOLEAN, + }, +} + +local equality_binop = { + ["number"] = { + ["number"] = BOOLEAN, + ["nil"] = BOOLEAN, + }, + ["string"] = { + ["string"] = BOOLEAN, + ["nil"] = BOOLEAN, + }, + ["boolean"] = { + ["boolean"] = BOOLEAN, + ["nil"] = BOOLEAN, + }, + ["record"] = { + ["emptytable"] = BOOLEAN, + ["arrayrecord"] = BOOLEAN, + ["record"] = BOOLEAN, + ["nil"] = BOOLEAN, + }, + ["array"] = { + ["emptytable"] = BOOLEAN, + ["arrayrecord"] = BOOLEAN, + ["array"] = BOOLEAN, + ["nil"] = BOOLEAN, + }, + ["arrayrecord"] = { + ["emptytable"] = BOOLEAN, + ["arrayrecord"] = BOOLEAN, + ["record"] = BOOLEAN, + ["array"] = BOOLEAN, + ["nil"] = BOOLEAN, + }, + ["map"] = { + ["emptytable"] = BOOLEAN, + ["map"] = BOOLEAN, + ["nil"] = BOOLEAN, + }, + ["thread"] = { + ["thread"] = BOOLEAN, + ["nil"] = BOOLEAN, + }, +} + +local unop_types = { + ["#"] = { + ["arrayrecord"] = NUMBER, + ["string"] = NUMBER, + ["array"] = NUMBER, + ["map"] = NUMBER, + ["emptytable"] = NUMBER, + }, + ["-"] = { + ["number"] = NUMBER, + }, + ["not"] = { + ["string"] = BOOLEAN, + ["number"] = BOOLEAN, + ["boolean"] = BOOLEAN, + ["record"] = BOOLEAN, + ["arrayrecord"] = BOOLEAN, + ["array"] = BOOLEAN, + ["map"] = BOOLEAN, + ["emptytable"] = BOOLEAN, + ["thread"] = BOOLEAN, + }, +} + +local binop_types = { + ["+"] = numeric_binop, + ["-"] = { + ["number"] = { + ["number"] = NUMBER, + }, + }, + ["*"] = numeric_binop, + ["%"] = numeric_binop, + ["/"] = numeric_binop, + ["^"] = numeric_binop, + ["&"] = numeric_binop, + ["|"] = numeric_binop, + ["<<"] = numeric_binop, + [">>"] = numeric_binop, + ["=="] = equality_binop, + ["~="] = equality_binop, + ["<="] = relational_binop, + [">="] = relational_binop, + ["<"] = relational_binop, + [">"] = relational_binop, + ["or"] = { + ["boolean"] = { + ["boolean"] = BOOLEAN, + ["function"] = FUNCTION, + }, + ["number"] = { + ["number"] = NUMBER, + ["boolean"] = BOOLEAN, + }, + ["string"] = { + ["string"] = STRING, + ["boolean"] = BOOLEAN, + ["enum"] = STRING, + }, + ["function"] = { + ["function"] = FUNCTION, + ["boolean"] = BOOLEAN, + }, + ["array"] = { + ["boolean"] = BOOLEAN, + }, + ["record"] = { + ["boolean"] = BOOLEAN, + }, + ["arrayrecord"] = { + ["boolean"] = BOOLEAN, + }, + ["map"] = { + ["boolean"] = BOOLEAN, + }, + ["enum"] = { + ["string"] = STRING, + }, + ["thread"] = { + ["boolean"] = BOOLEAN, + }, + }, + [".."] = { + ["string"] = { + ["string"] = STRING, + ["enum"] = STRING, + ["number"] = STRING, + }, + ["number"] = { + ["number"] = STRING, + ["string"] = STRING, + ["enum"] = STRING, + }, + ["enum"] = { + ["number"] = STRING, + ["string"] = STRING, + ["enum"] = STRING, + }, + }, +} + +local show_type + +local function is_unknown(t) + return t.typename == "unknown" or + t.typename == "unknown_emptytable_value" +end + +local show_type + +local function show_type_base(t, seen) + + if seen[t] then + return "..." + end + seen[t] = true + + local function show(t) + return show_type(t, seen) + end + + if t.typename == "nominal" then + if t.typevals then + local out = { table.concat(t.names, "."), "<" } + local vals = {} + for _, v in ipairs(t.typevals) do + table.insert(vals, show(v)) + end + table.insert(out, table.concat(vals, ", ")) + table.insert(out, ">") + return table.concat(out) + else + return table.concat(t.names, ".") + end + elseif t.typename == "tuple" then + local out = {} + for _, v in ipairs(t) do + table.insert(out, show(v)) + end + return "(" .. table.concat(out, ", ") .. ")" + elseif t.typename == "poly" then + local out = {} + for _, v in ipairs(t.types) do + table.insert(out, show(v)) + end + return table.concat(out, " or ") + elseif t.typename == "union" then + local out = {} + for _, v in ipairs(t.types) do + table.insert(out, show(v)) + end + return table.concat(out, " | ") + elseif t.typename == "emptytable" then + return "{}" + elseif t.typename == "map" then + return "{" .. show(t.keys) .. " : " .. show(t.values) .. "}" + elseif t.typename == "array" then + return "{" .. show(t.elements) .. "}" + elseif t.typename == "enum" then + return t.names and table.concat(t.names, ".") or "enum" + elseif is_record_type(t) then + local out = {} + for _, k in ipairs(t.field_order) do + local v = t.fields[k] + table.insert(out, k .. ": " .. show(v)) + end + return "{" .. table.concat(out, ", ") .. "}" + elseif t.typename == "function" then + local out = {} + table.insert(out, "function(") + local args = {} + if t.is_method then + table.insert(args, "self") + end + for i, v in ipairs(t.args) do + if not t.is_method or i > 1 then + table.insert(args, show(v)) + end + end + table.insert(out, table.concat(args, ",")) + table.insert(out, ")") + if #t.rets > 0 then + table.insert(out, ":") + local rets = {} + for _, v in ipairs(t.rets) do + table.insert(rets, show(v)) + end + table.insert(out, table.concat(rets, ",")) + end + return table.concat(out) + elseif t.typename == "number" or + t.typename == "boolean" or + t.typename == "thread" then + return t.typename + elseif t.typename == "string" then + return t.typename .. + (t.tk and " " .. t.tk or "") + elseif t.typename == "typevar" then + return t.typevar + elseif t.typename == "typearg" then + return t.typearg + elseif is_unknown(t) then + return "" + elseif t.typename == "invalid" then + return "" + elseif t.typename == "any" then + return "" + elseif t.typename == "nil" then + return "nil" + elseif t.typename == "typetype" then + return "type " .. show(t.def) + elseif t.typename == "bad_nominal" then + return table.concat(t.names, ".") .. " (an unknown type)" + else + return inspect(t) + end +end + +show_type = function(t, seen) + local ret = show_type_base(t, seen or {}) + if t.inferred_at then + ret = ret .. " (inferred at " .. t.inferred_at_file .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ": )" + end + return ret +end + +local Error = {} + + + + + + +local Result = {} + + + + + + + + +local function search_for(module_name, suffix, path, tried) + for entry in path:gmatch("[^;]+") do + local slash_name = module_name:gsub("%.", "/") + local filename = entry:gsub("?", slash_name) + local tl_filename = filename:gsub("%.lua$", suffix) + local fd = io.open(tl_filename, "r") + if fd then + return tl_filename, fd, tried + end + table.insert(tried, "no file '" .. tl_filename .. "'") + end + return nil, nil, tried +end + +function tl.search_module(module_name, search_dtl) + local found + local tried = {} + local path = os.getenv("TL_PATH") or package.path + if search_dtl then + local found, fd, tried = search_for(module_name, ".d.tl", path, tried) + if found then + return found, fd + end + end + local found, fd, tried = search_for(module_name, ".tl", path, tried) + if found then + return found, fd + end + local found, fd, tried = search_for(module_name, ".lua", path, tried) + if found then + return found, fd + end + return nil, nil, tried +end + +local Variable = {} + + + + + + + +local function fill_field_order(t) + if t.typename == "record" then + t.field_order = {} + for k, v in pairs(t.fields) do + table.insert(t.field_order, k) + end + table.sort(t.field_order) + end +end + +local function require_module(module_name, lax, env, result) + local modules = env.modules + + if modules[module_name] then + return modules[module_name], true + end + modules[module_name] = UNKNOWN + + local found, fd, tried = tl.search_module(module_name, true) + if found and (lax or found:match("tl$")) then + fd:close() + local _result, err = tl.process(found, env, result) + assert(_result, err) + + if not _result.type then + _result.type = BOOLEAN + end + + modules[module_name] = _result.type + + return _result.type, true + end + + return UNKNOWN, found ~= nil +end + +local standard_library = { + ["..."] = a_type({ typename = "tuple", STRING, STRING, STRING, STRING, STRING }), + ["@return"] = a_type({ typename = "tuple", ANY }), + ["any"] = a_type({ typename = "typetype", def = ANY }), + ["arg"] = ARRAY_OF_STRING, + ["assert"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ALPHA }, rets = { ALPHA } }), + a_type({ typename = "function", typeargs = { ARG_ALPHA, ARG_BETA }, args = { ALPHA, BETA }, rets = { ALPHA } }), + }, + }), + ["collectgarbage"] = a_type({ typename = "function", args = { STRING }, rets = { a_type({ typename = "union", types = { BOOLEAN, NUMBER } }), NUMBER, NUMBER } }), + ["dofile"] = a_type({ typename = "function", args = { OPT_STRING }, rets = { VARARG_ANY } }), + ["error"] = a_type({ typename = "function", args = { STRING, NUMBER }, rets = {} }), + ["getmetatable"] = a_type({ typename = "function", args = { ANY }, rets = { NOMINAL_METATABLE } }), + ["ipairs"] = a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA }, rets = { + a_type({ typename = "function", args = {}, rets = { NUMBER, ALPHA } }), + }, }), + ["load"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { STRING }, rets = { FUNCTION, STRING } }), + a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION, STRING } }), + a_type({ typename = "function", args = { STRING, STRING, STRING }, rets = { FUNCTION, STRING } }), + a_type({ typename = "function", args = { STRING, STRING, STRING, TABLE }, rets = { FUNCTION, STRING } }), + }, + }), + ["loadfile"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = {}, rets = { FUNCTION, ANY } }), + a_type({ typename = "function", args = { STRING }, rets = { FUNCTION, ANY } }), + a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION, ANY } }), + a_type({ typename = "function", args = { STRING, STRING, TABLE }, rets = { FUNCTION, ANY } }), + }, + }), + ["next"] = a_type({ + typename = "poly", + types = { + a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA }, rets = { ALPHA, BETA } }), + a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA, ALPHA }, rets = { ALPHA, BETA } }), + a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA }, rets = { NUMBER, ALPHA } }), + a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA, ALPHA }, rets = { NUMBER, ALPHA } }), + }, + }), + ["pairs"] = a_type({ typename = "function", typeargs = { ARG_ALPHA, ARG_BETA }, args = { a_type({ typename = "map", keys = ALPHA, values = BETA }) }, rets = { + a_type({ typename = "function", args = {}, rets = { ALPHA, BETA } }), + }, }), + ["pcall"] = a_type({ typename = "function", args = { FUNCTION, VARARG_ANY }, rets = { BOOLEAN, ANY } }), + ["xpcall"] = a_type({ typename = "function", args = { FUNCTION, FUNCTION, VARARG_ANY }, rets = { BOOLEAN, ANY } }), + ["print"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = {} }), + ["rawequal"] = a_type({ typename = "function", args = { ANY, ANY }, rets = { BOOLEAN } }), + ["rawget"] = a_type({ typename = "function", args = { TABLE, ANY }, rets = { ANY } }), + ["rawlen"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { TABLE }, rets = { NUMBER } }), + a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }), + }, + }), + ["rawset"] = a_type({ + typename = "poly", + types = { + a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA, ALPHA, BETA }, rets = {} }), + a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA, NUMBER, ALPHA }, rets = {} }), + a_type({ typename = "function", args = { TABLE, ANY, ANY }, rets = {} }), + }, + }), + ["require"] = a_type({ typename = "function", args = { STRING }, rets = {} }), + ["select"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { NUMBER, VARARG_ALPHA }, rets = { ALPHA } }), + a_type({ typename = "function", args = { NUMBER, VARARG_ANY }, rets = { ANY } }), + a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { NUMBER } }), + }, + }), + ["setmetatable"] = a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ALPHA, NOMINAL_METATABLE }, rets = { ALPHA } }), + ["tonumber"] = a_type({ typename = "function", args = { ANY, NUMBER }, rets = { NUMBER } }), + ["tostring"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), + ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), + ["FILE"] = a_type({ + typename = "typetype", + def = a_type({ + typename = "record", + fields = { + ["close"] = a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING } }), + ["flush"] = a_type({ typename = "function", args = { NOMINAL_FILE }, rets = {} }), + ["lines"] = a_type({ typename = "function", args = { NOMINAL_FILE, a_type({ typename = "union", types = { STRING, NUMBER }, is_va = true }) }, rets = { + a_type({ typename = "function", args = {}, rets = { VARARG_STRING } }), + }, }), + ["read"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { STRING, STRING } }), + a_type({ typename = "function", args = { NOMINAL_FILE, NUMBER }, rets = { STRING, STRING } }), + }, + }), + ["seek"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NUMBER, STRING } }), + a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { NUMBER, STRING } }), + a_type({ typename = "function", args = { NOMINAL_FILE, STRING, NUMBER }, rets = { NUMBER, STRING } }), + }, + }), + ["setvbuf"] = a_type({ typename = "function", args = { NOMINAL_FILE, STRING, OPT_NUMBER }, rets = {} }), + ["write"] = a_type({ typename = "function", args = { NOMINAL_FILE, VARARG_STRING }, rets = { NOMINAL_FILE, STRING } }), + + }, + }), + }), + ["METATABLE"] = a_type({ + typename = "typetype", + def = a_type({ + typename = "record", + fields = { + ["__call"] = FUNCTION, + ["__gc"] = a_type({ typename = "function", args = { ANY }, rets = {} }), + ["__index"] = ANY, + ["__len"] = a_type({ typename = "function", args = { ANY }, rets = { NUMBER } }), + ["__mode"] = a_type({ typename = "enum", enumset = { ["k"] = true, ["v"] = true, ["kv"] = true } }), + ["__newindex"] = ANY, + ["__pairs"] = a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { a_type({ typename = "map", keys = ALPHA, values = BETA }) }, rets = { + a_type({ typename = "function", args = {}, rets = { ALPHA, BETA } }), + }, }), + ["__tostring"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), + ["__name"] = STRING, + + + ["__add"] = FUNCTION, + ["__sub"] = FUNCTION, + ["__mul"] = FUNCTION, + ["__div"] = FUNCTION, + ["__idiv"] = FUNCTION, + ["__mod"] = FUNCTION, + ["__pow"] = FUNCTION, + ["__unm"] = FUNCTION, + ["__band"] = FUNCTION, + ["__bor"] = FUNCTION, + ["__bxor"] = FUNCTION, + ["__bnot"] = FUNCTION, + ["__shl"] = FUNCTION, + ["__shr"] = FUNCTION, + ["__concat"] = FUNCTION, + ["__eq"] = FUNCTION, + ["__lt"] = FUNCTION, + ["__le"] = FUNCTION, + }, + }), + }), + ["coroutine"] = a_type({ + typename = "record", + fields = { + ["create"] = a_type({ typename = "function", args = { FUNCTION }, rets = { THREAD } }), + ["close"] = a_type({ typename = "function", args = { THREAD }, rets = { BOOLEAN, STRING } }), + ["isyieldable"] = a_type({ typename = "function", args = {}, rets = { BOOLEAN } }), + ["resume"] = a_type({ typename = "function", args = { THREAD, VARARG_ANY }, rets = { BOOLEAN, VARARG_ANY } }), + ["running"] = a_type({ typename = "function", args = {}, rets = { THREAD, BOOLEAN } }), + ["status"] = a_type({ typename = "function", args = { THREAD }, rets = { STRING } }), + ["wrap"] = a_type({ typename = "function", args = { FUNCTION }, rets = { FUNCTION } }), + ["yield"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = { VARARG_ANY } }), + }, + }), + ["debug"] = a_type({ + typename = "record", + fields = { + ["traceback"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { THREAD, STRING, NUMBER }, rets = { STRING } }), + a_type({ typename = "function", args = { STRING, NUMBER }, rets = { STRING } }), + }, + }), + ["getinfo"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { ANY }, rets = { DEBUG_GETINFO_TABLE } }), + a_type({ typename = "function", args = { ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }), + a_type({ typename = "function", args = { ANY, ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }), + }, + }), + }, + }), + ["io"] = a_type({ + typename = "record", + fields = { + ["close"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = {}, rets = { BOOLEAN, STRING } }), + a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING } }), + }, + }), + ["flush"] = a_type({ typename = "function", args = {}, rets = {} }), + ["input"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }), + a_type({ typename = "function", args = { STRING }, rets = { NOMINAL_FILE } }), + a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NOMINAL_FILE } }), + }, + }), + ["lines"] = a_type({ typename = "function", args = { OPT_STRING, a_type({ typename = "union", types = { STRING, NUMBER }, is_va = true }) }, rets = { + a_type({ typename = "function", args = {}, rets = { VARARG_STRING } }), + }, }), + ["open"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { NOMINAL_FILE, STRING } }), + ["output"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }), + a_type({ typename = "function", args = { STRING }, rets = { NOMINAL_FILE } }), + a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NOMINAL_FILE } }), + }, + }), + ["popen"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { NOMINAL_FILE, STRING } }), + ["read"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { STRING, STRING } }), + a_type({ typename = "function", args = { NOMINAL_FILE, NUMBER }, rets = { STRING, STRING } }), + }, + }), + ["stderr"] = NOMINAL_FILE, + ["stdin"] = NOMINAL_FILE, + ["stdout"] = NOMINAL_FILE, + ["tmpfile"] = a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }), + ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), + ["write"] = a_type({ typename = "function", args = { VARARG_STRING }, rets = { NOMINAL_FILE, STRING } }), + }, + }), + ["math"] = a_type({ + typename = "record", + fields = { + ["abs"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["acos"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["asin"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["atan"] = a_type({ + typename = "poly", + a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), + }), + ["atan2"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["ceil"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["cos"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["cosh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["deg"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["exp"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["floor"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["fmod"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["frexp"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER, NUMBER } }), + ["huge"] = NUMBER, + ["ldexp"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["log"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["log10"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["max"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { NUMBER } }), + ["maxinteger"] = NUMBER, + ["min"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { NUMBER } }), + ["mininteger"] = NUMBER, + ["modf"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER, NUMBER } }), + ["pi"] = NUMBER, + ["pow"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["rad"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["random"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["randomseed"] = a_type({ typename = "function", args = { NUMBER }, rets = {} }), + ["sin"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["sinh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["sqrt"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["tan"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["tanh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), + ["tointeger"] = a_type({ typename = "function", args = { ANY }, rets = { NUMBER } }), + ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), + ["ult"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { BOOLEAN } }), + }, + }), + ["os"] = a_type({ + typename = "record", + fields = { + ["clock"] = a_type({ typename = "function", args = {}, rets = { NUMBER } }), + ["date"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = {}, rets = { STRING } }), + a_type({ typename = "function", args = { STRING, OPT_STRING }, rets = { a_type({ typename = "union", types = { STRING, OS_DATE_TABLE } }) } }), + }, + }), + ["difftime"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["execute"] = a_type({ typename = "function", args = { STRING }, rets = { BOOLEAN, STRING, NUMBER } }), + ["exit"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { NUMBER, BOOLEAN }, rets = {} }), + a_type({ typename = "function", args = { BOOLEAN, BOOLEAN }, rets = {} }), + }, + }), + ["getenv"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }), + ["remove"] = a_type({ typename = "function", args = { STRING }, rets = { BOOLEAN, STRING } }), + ["rename"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { BOOLEAN, STRING } }), + ["setlocale"] = a_type({ typename = "function", args = { STRING, OPT_STRING }, rets = { STRING } }), + ["time"] = a_type({ typename = "function", args = {}, rets = { NUMBER } }), + ["tmpname"] = a_type({ typename = "function", args = {}, rets = { STRING } }), + }, + }), + ["package"] = a_type({ + typename = "record", + fields = { + ["config"] = STRING, + ["cpath"] = STRING, + ["loaded"] = a_type({ + typename = "map", + keys = STRING, + values = ANY, + }), + ["loaders"] = a_type({ + typename = "array", + elements = a_type({ typename = "function", args = { STRING }, rets = { ANY } }), + }), + ["loadlib"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION } }), + ["path"] = STRING, + ["preload"] = TABLE, + ["searchers"] = a_type({ + typename = "array", + elements = a_type({ typename = "function", args = { STRING }, rets = { ANY } }), + }), + ["searchpath"] = a_type({ typename = "function", args = { STRING, STRING, OPT_STRING, OPT_STRING }, rets = { STRING, STRING } }), + }, + }), + ["string"] = a_type({ + typename = "record", + fields = { + ["byte"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }), + a_type({ typename = "function", args = { STRING, NUMBER }, rets = { NUMBER } }), + a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { VARARG_NUMBER } }), + }, + }), + ["char"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { STRING } }), + ["dump"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { FUNCTION }, rets = { STRING } }), + a_type({ typename = "function", args = { FUNCTION, BOOLEAN }, rets = { STRING } }), + }, + }), + ["find"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { STRING, STRING }, rets = { NUMBER, NUMBER, VARARG_STRING } }), + a_type({ typename = "function", args = { STRING, STRING, NUMBER }, rets = { NUMBER, NUMBER, VARARG_STRING } }), + a_type({ typename = "function", args = { STRING, STRING, NUMBER, BOOLEAN }, rets = { NUMBER, NUMBER, VARARG_STRING } }), + + }, + }), + ["format"] = a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { STRING } }), + ["gmatch"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { + a_type({ typename = "function", args = {}, rets = { STRING } }), + }, }), + ["gsub"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", args = { STRING, STRING, STRING, NUMBER }, rets = { STRING, NUMBER } }), + a_type({ typename = "function", args = { STRING, STRING, a_type({ typename = "map", keys = STRING, values = STRING }), NUMBER }, rets = { STRING, NUMBER } }), + a_type({ typename = "function", args = { STRING, STRING, a_type({ typename = "function", args = { VARARG_STRING }, rets = { STRING } }) }, rets = { STRING, NUMBER } }), + + }, + }), + ["len"] = a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }), + ["lower"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }), + ["match"] = a_type({ typename = "function", args = { STRING, STRING, NUMBER }, rets = { VARARG_STRING } }), + ["pack"] = a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { STRING } }), + ["packsize"] = a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }), + ["rep"] = a_type({ typename = "function", args = { STRING, NUMBER }, rets = { STRING } }), + ["reverse"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }), + ["sub"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { STRING } }), + ["unpack"] = a_type({ typename = "function", args = { STRING, STRING, OPT_NUMBER }, rets = { VARARG_ANY } }), + ["upper"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }), + }, + }), + ["table"] = a_type({ + typename = "record", + fields = { + ["concat"] = a_type({ typename = "function", args = { ARRAY_OF_STRING, OPT_STRING, OPT_NUMBER, OPT_NUMBER }, rets = { STRING } }), + ["insert"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, ALPHA }, rets = {} }), + a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, ALPHA }, rets = {} }), + }, + }), + ["move"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, NUMBER, NUMBER }, rets = { ARRAY_OF_ALPHA } }), + a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, NUMBER, NUMBER, ARRAY_OF_ALPHA }, rets = { ARRAY_OF_ALPHA } }), + }, + }), + ["pack"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = { TABLE } }), + ["remove"] = a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, OPT_NUMBER }, rets = { ALPHA } }), + ["sort"] = a_type({ + typename = "poly", + types = { + a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA }, rets = {} }), + a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, a_type({ typename = "function", args = { ALPHA, ALPHA }, rets = { BOOLEAN } }) }, rets = {} }), + }, + }), + ["unpack"] = a_type({ + typename = "function", + needs_compat53 = true, + typeargs = { ARG_ALPHA }, + args = { ARRAY_OF_ALPHA, NUMBER, NUMBER }, + rets = { VARARG_ALPHA }, + }), + }, + }), + ["utf8"] = a_type({ + typename = "record", + fields = { + ["char"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { STRING } }), + ["charpattern"] = STRING, + ["codepoint"] = a_type({ typename = "function", args = { STRING, OPT_NUMBER, OPT_NUMBER }, rets = { VARARG_NUMBER } }), + ["codes"] = a_type({ typename = "function", args = { STRING }, rets = { + a_type({ typename = "function", args = {}, rets = { NUMBER, STRING } }), + }, }), + ["len"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { NUMBER } }), + ["offset"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { NUMBER } }), + }, + }), +} + +for _, t in pairs(standard_library) do + fill_field_order(t) + if t.typename == "typetype" then + fill_field_order(t.def) + end +end +fill_field_order(OS_DATE_TABLE) +fill_field_order(DEBUG_GETINFO_TABLE) + +NOMINAL_FILE.found = standard_library["FILE"] +NOMINAL_METATABLE.found = standard_library["METATABLE"] + +local compat53_code_cache = {} + +local function add_compat53_entries(program, used_set) + if not next(used_set) then + return + end + + local used_list = {} + for name, _ in pairs(used_set) do + table.insert(used_list, name) + end + table.sort(used_list) + + local compat53_loaded = false + + local n = 1 + local function load_code(name, text) + local code = compat53_code_cache[name] + if not code then + local tokens = tl.lex(text) + local _ + _, code = tl.parse_program(tokens, {}, "@internal") + tl.type_check(code, { lax = false, skip_compat53 = true }) + code = code[1] + compat53_code_cache[name] = code + end + table.insert(program, n, code) + n = n + 1 + end + + for i, name in ipairs(used_list) do + local mod, fn = name:match("([^.]*)%.(.*)") + local errs = {} + local text + local code = compat53_code_cache[name] + if not code then + + if name == "table.unpack" then + load_code(name, "local _tl_table_unpack = unpack or table.unpack") + else + if not compat53_loaded then + load_code("compat53", "local _tl_compat53 = ((tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3) and require('compat53.module')") + compat53_loaded = true + end + load_code(name, (("local $NAME = _tl_compat53 and _tl_compat53.$NAME or $NAME"):gsub("$NAME", name))) + end + end + end + program.y = 1 +end + +local function get_stdlib_compat53(lax) + if lax then + return { + ["utf8"] = true, + } + else + return { + ["io"] = true, + ["math"] = true, + ["string"] = true, + ["table"] = true, + ["utf8"] = true, + ["coroutine"] = true, + ["os"] = true, + ["package"] = true, + ["debug"] = true, + ["load"] = true, + ["loadfile"] = true, + ["assert"] = true, + ["pairs"] = true, + ["ipairs"] = true, + ["pcall"] = true, + ["xpcall"] = true, + ["rawlen"] = true, + } + end +end + +local function init_globals(lax) + local globals = {} + local stdlib_compat53 = get_stdlib_compat53(lax) + + for name, typ in pairs(standard_library) do + globals[name] = { t = typ, needs_compat53 = stdlib_compat53[name], is_const = true } + end + + + + + globals["@is_va"] = { t = VARARG_ANY } + + return globals +end + +function tl.init_env(lax, skip_compat53) + local env = { + modules = {}, + globals = init_globals(lax), + skip_compat53 = skip_compat53, + } + + + for name, var in pairs(standard_library) do + if var.typename == "record" then + env.modules[name] = var + end + end + + return env +end + +function tl.type_check(ast, opts) + opts = opts or {} + opts.env = opts.env or tl.init_env(opts.lax, opts.skip_compat53) + local lax = opts.lax + local filename = opts.filename + local result = opts.result or { + syntax_errors = {}, + type_errors = {}, + unknowns = {}, + } + + local stdlib_compat53 = get_stdlib_compat53(lax) + + local st = { opts.env.globals } + + local all_needs_compat53 = {} + + local errors = result.type_errors or {} + local unknowns = result.unknowns or {} + local module_type + + local function find_var(name) + if name == "_G" then + + local globals = {} + for k, v in pairs(st[1]) do + if k:sub(1, 1) ~= "@" then + globals[k] = v.t + end + end + local field_order = {} + for k, _ in pairs(globals) do + table.insert(field_order, k) + end + return a_type({ + typename = "record", + field_order = field_order, + fields = globals, + }), false + end + for i = #st, 1, -1 do + local scope = st[i] + if scope[name] then + if i == 1 and scope[name].needs_compat53 then + all_needs_compat53[name] = true + end + local typ = scope[name].t + + return typ, scope[name].is_const + end + end + end + + local function resolve_typevars(t, seen) + seen = seen or {} + if seen[t] then + return seen[t] + end + + local orig_t = t + local clear_tk = false + if t.typename == "typevar" then + local tv = find_var(t.typevar) + if tv then + t = tv + clear_tk = true + else + t = UNKNOWN + end + end + + local copy = {} + seen[orig_t] = copy + + for k, v in pairs(t) do + local cp = copy + if type(v) == "table" then + cp[k] = resolve_typevars(v, seen) + else + cp[k] = v + end + end + + if clear_tk then + copy.tk = nil + end + + return copy + end + + local function find_type(names, accept_typearg) + local typ = find_var(names[1]) + if not typ then + return nil + end + for i = 2, #names do + local nested = typ.fields or (typ.def and typ.def.fields) + if nested then + typ = nested[names[i]] + if typ == nil then + return nil + end + else + break + end + end + if typ then + if accept_typearg and typ.typename == "typearg" then + return typ + end + if is_type(typ) then + return typ + end + end + return nil + end + + local function infer_var(emptytable, t, node) + local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") + local nst = is_global and 1 or #st + for i = nst, 1, -1 do + local scope = st[i] + if scope[emptytable.assigned_to] then + scope[emptytable.assigned_to] = { + t = t, + is_const = false, + } + t.inferred_at = node + t.inferred_at_file = filename + end + end + end + + local function find_global(name) + local scope = st[1] + if scope[name] then + return scope[name].t, scope[name].is_const + end + end + + local function resolve_tuple(t) + if t.typename == "tuple" then + t = t[1] + end + if t == nil then + return NIL + end + return t + end + + local function error_in_type(where, msg, ...) + local n = select("#", ...) + if n > 0 then + local showt = {} + for i = 1, n do + local t = select(i, ...) + if t.typename == "invalid" then + return nil + end + showt[i] = show_type(t) + end + msg = msg:format(_tl_table_unpack(showt)) + end + + return { + y = where.y, + x = where.x, + msg = msg, + filename = where.filename or filename, + } + end + + local function type_error(t, msg, ...) + local e = error_in_type(t, msg, ...) + if e then + table.insert(errors, e) + return true + else + return false + end + end + + local function node_error(node, msg, ...) + local ok = type_error(node, msg, ...) + node.type = INVALID + return node.type + end + + local function terr(t, s, ...) + return { error_in_type(t, s, ...) } + end + + local function add_unknown(node, name) + table.insert(unknowns, { y = node.y, x = node.x, msg = name, filename = filename }) + end + + local function add_var(node, var, valtype, is_const, is_narrowing) + if lax and node and is_unknown(valtype) and (var ~= "self" and var ~= "...") then + add_unknown(node, var) + end + if st[#st][var] and is_narrowing then + if not st[#st][var].is_narrowed then + st[#st][var].narrowed_from = st[#st][var].t + end + st[#st][var].is_narrowed = true + st[#st][var].t = valtype + else + st[#st][var] = { t = valtype, is_const = is_const, is_narrowed = is_narrowing } + end + end + + local CompareTypes = {} + + local function compare_typevars(t1, t2, comp) + local tv1 = find_var(t1.typevar) + local tv2 = find_var(t2.typevar) + if t1.typevar == t2.typevar then + local has_t1 = not not tv1 + local has_t2 = not not tv2 + if has_t1 == has_t2 then + return true + end + end + local function cmp(k, v, a, b) + if find_var(k) then + return comp(a, b) + else + add_var(nil, k, resolve_typevars(v)) + return true + end + end + if t2.typename == "typevar" then + return cmp(t2.typevar, t1, t1, tv2) + else + return cmp(t1.typevar, t2, tv1, t2) + end + end + + local function add_errs_prefixing(src, dst, prefix, node) + if not src then + return + end + for i, err in ipairs(src) do + err.msg = prefix .. err.msg + + + if node and node.y and ( + (err.filename ~= filename) or + (not err.y) or + (node.y > err.y or (node.y == err.y and node.x > err.x))) then + + err.y = node.y + err.x = node.x + err.filename = filename + end + + table.insert(dst, err) + end + end + + local is_a + + local TypeGetter = {} + + local function match_record_fields(t1, t2, cmp) + cmp = cmp or is_a + local fielderrs = {} + for _, k in ipairs(t1.field_order) do + local f = t1.fields[k] + local t2k = t2(k) + if t2k == nil then + if not lax then + table.insert(fielderrs, error_in_type(f, "unknown field " .. k)) + end + else + local match, errs = is_a(f, t2k) + add_errs_prefixing(errs, fielderrs, "record field doesn't match: " .. k .. ": ") + end + end + if #fielderrs > 0 then + return false, fielderrs + end + return true + end + + local function match_fields_to_record(t1, t2, cmp) + return match_record_fields(t1, function(k) return t2.fields[k] end, cmp) + end + + local function match_fields_to_map(t1, t2) + if not match_record_fields(t1, function(_) return t2.values end) then + return false, { error_in_type(t1, "not all fields have type %s", t2.values) } + end + return true + end + + local function arg_check(cmp, a, b, at, n, errs) + local matches, match_errs = cmp(a, b) + if not matches then + add_errs_prefixing(match_errs, errs, "argument " .. n .. ": ", at) + return false + end + return true + end + + local same_type + + local function has_all_types_of(t1s, t2s) + for _, t1 in ipairs(t1s) do + local found = false + for _, t2 in ipairs(t2s) do + if is_a(t2, t1) then + found = true + break + end + end + if not found then + return false + end + end + return true + end + + local function any_errors(all_errs) + if #all_errs == 0 then + return true + else + return false, all_errs + end + end + + local function are_same_nominals(t1, t2) + local same_names + if t1.found and t2.found then + same_names = t1.found.typeid == t2.found.typeid + else + local ft1 = t1.found or find_type(t1.names) + local ft2 = t2.found or find_type(t2.names) + if ft1 and ft2 then + same_names = ft1.typeid == ft2.typeid + else + if not ft1 then + type_error(t1, "unknown type %s", t1) + end + if not ft2 then + type_error(t2, "unknown type %s", t2) + end + return false, {} + end + end + + if same_names then + if t1.typevals == nil and t2.typevals == nil then + return true + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local all_errs = {} + for i = 1, #t1.typevals do + local ok, errs = same_type(t2.typevals[i], t1.typevals[i]) + add_errs_prefixing(errs, all_errs, "type parameter <" .. show_type(t1.typevals[i]) .. ">: ", t1) + end + if #all_errs == 0 then + return true + else + return false, all_errs + end + end + else + return false, terr(t1, "%s is not a %s", t1, t2) + end + end + + same_type = function(t1, t2) + assert(type(t1) == "table") + assert(type(t2) == "table") + + if t1.typename == "typevar" or t2.typename == "typevar" then + return compare_typevars(t1, t2, same_type) + end + + if t1.typename ~= t2.typename then + return false, terr(t1, "got %s, expected %s", t1, t2) + end + if t1.typename == "array" then + return same_type(t1.elements, t2.elements) + elseif t1.typename == "map" then + local all_errs = {} + local k_ok, k_errs = same_type(t1.keys, t2.keys) + if not k_ok then + add_errs_prefixing(k_errs, all_errs, "keys", t1) + end + local v_ok, v_errs = same_type(t1.values, t2.values) + if not v_ok then + add_errs_prefixing(v_errs, all_errs, "values", t1) + end + return any_errors(all_errs) + elseif t1.typename == "union" then + if has_all_types_of(t1.types, t2.types) and + has_all_types_of(t2.types, t1.types) then + return true + else + return false, terr(t1, "got %s, expected %s", t1, t2) + end + elseif t1.typename == "nominal" then + return are_same_nominals(t1, t2) + elseif t1.typename == "record" then + return match_fields_to_record(t1, t2, same_type) + elseif t1.typename == "function" then + if #t1.args ~= #t2.args then + return false, terr(t1, "different number of input arguments: got " .. #t1.args .. ", expected " .. #t2.args) + end + if #t1.rets ~= #t2.rets then + return false, terr(t1, "different number of return values: got " .. #t1.args .. ", expected " .. #t2.args) + end + local all_errs = {} + for i = 1, #t1.args do + arg_check(same_type, t1.args[i], t2.args[i], t1, i, all_errs) + end + for i = 1, #t1.rets do + local ok, errs = same_type(t1.rets[i], t2.rets[i]) + add_errs_prefixing(errs, all_errs, "return " .. i, t1) + end + return any_errors(all_errs) + elseif t1.typename == "arrayrecord" then + local ok, errs = same_type(t1.elements, t2.elements) + if not ok then + return ok, errs + end + return match_fields_to_record(t1, t2, same_type) + end + return true + end + + local function a_union(types) + local ts = {} + local stack = {} + local i = 1 + while types[i] or stack[1] do + local t + if stack[1] then + t = table.remove(stack) + else + t = types[i] + i = i + 1 + end + if t.typename == "union" then + for _, s in ipairs(t.types) do + table.insert(stack, s) + end + else + table.insert(ts, t) + end + end + return a_type({ + typename = "union", + types = ts, + }) + end + + local function is_vararg(t) + return t.args and #t.args > 0 and t.args[#t.args].is_va + end + + local function combine_errs(...) + local errs + for i = 1, select("#", ...) do + local e = select(i, ...) + if e then + errs = errs or {} + for _, err in ipairs(e) do + table.insert(errs, err) + end + end + end + if not errs then + return true + else + return false, errs + end + end + + local resolve_unary = nil + + local function is_known_table_type(t) + return (t.typename == "array" or t.typename == "map" or t.typename == "record" or t.typename == "arrayrecord") + end + + is_a = function(t1, t2, for_equality) + assert(type(t1) == "table") + assert(type(t2) == "table") + + if lax and (is_unknown(t1) or is_unknown(t2)) then + return true + end + + if t1.typename == "nil" then + return true + end + + if t2.typename ~= "tuple" then + t1 = resolve_tuple(t1) + end + if t2.typename == "tuple" and t1.typename ~= "tuple" then + t1 = a_type({ + typename = "tuple", + [1] = t1, + }) + end + + if t1.typename == "typevar" or t2.typename == "typevar" then + return compare_typevars(t1, t2, is_a) + end + + if t2.typename == "any" then + return true + elseif t2.typename == "poly" then + for _, t in ipairs(t2.types) do + if is_a(t1, t, for_equality) then + return true + end + end + return false, terr(t1, "cannot match against any alternatives of the polymorphic type") + elseif t1.typename == "union" and t2.typename == "union" then + if has_all_types_of(t1.types, t2.types) then + return true + else + return false, terr(t1, "got %s, expected %s", t1, t2) + end + elseif t2.typename == "union" then + for _, t in ipairs(t2.types) do + if is_a(t1, t, for_equality) then + return true + end + end + elseif t1.typename == "poly" then + for _, t in ipairs(t1.types) do + if is_a(t, t2, for_equality) then + return true + end + end + return false, terr(t1, "cannot match against any alternatives of the polymorphic type") + elseif t1.typename == "nominal" and t2.typename == "nominal" and #t2.names == 1 and t2.names[1] == "any" then + return true + elseif t1.typename == "nominal" and t2.typename == "nominal" then + return are_same_nominals(t1, t2) + elseif t1.typename == "enum" and t2.typename == "string" then + local ok + if for_equality then + ok = t2.tk and t1.enumset[unquote(t2.tk)] + else + ok = true + end + if ok then + return true + else + return false, terr(t1, "enum is incompatible with %s", t2) + end + elseif t1.typename == "string" and t2.typename == "enum" then + local ok = t1.tk and t2.enumset[unquote(t1.tk)] + if ok then + return true + else + if t1.tk then + return false, terr(t1, "%s is not a member of %s", t1, t2) + else + return false, terr(t1, "string is not a %s", t2) + end + end + elseif t1.typename == "nominal" or t2.typename == "nominal" then + local t1u = resolve_unary(t1) + local t2u = resolve_unary(t2) + local ok, errs = is_a(t1u, t2u, for_equality) + if errs and #errs == 1 then + if errs[1].msg:match("^got ") then + + + errs = terr(t1, "got %s, expected %s", t1, t2) + end + end + return ok, errs + elseif t1.typename == "emptytable" and is_known_table_type(t2) then + return true + elseif t2.typename == "array" then + if is_array_type(t1) then + if is_a(t1.elements, t2.elements) then + return true + end + elseif t1.typename == "map" then + local _, errs_keys = is_a(t1.keys, NUMBER) + local _, errs_values = is_a(t1.values, t2.elements) + return combine_errs(errs_keys, errs_values) + end + elseif t2.typename == "record" then + if is_record_type(t1) then + return match_fields_to_record(t1, t2) + elseif t1.typename == "typetype" and t1.def.typename == "record" then + return is_a(t1.def, t2, for_equality) + end + elseif t2.typename == "arrayrecord" then + if t1.typename == "array" then + return is_a(t1.elements, t2.elements) + elseif t1.typename == "record" then + return match_fields_to_record(t1, t2) + elseif t1.typename == "arrayrecord" then + if not is_a(t1.elements, t2.elements) then + return false, terr(t1, "array parts have incompatible element types") + end + return match_fields_to_record(t1, t2) + end + elseif t2.typename == "map" then + if t1.typename == "map" then + local _, errs_keys = is_a(t1.keys, t2.keys) + local _, errs_values = is_a(t2.values, t1.values) + if t2.values.typename == "any" then + errs_values = {} + end + return combine_errs(errs_keys, errs_values) + elseif t1.typename == "array" then + local _, errs_keys = is_a(NUMBER, t2.keys) + local _, errs_values = is_a(t1.elements, t2.values) + return combine_errs(errs_keys, errs_values) + elseif is_record_type(t1) then + if not is_a(t2.keys, STRING) then + return false, terr(t1, "can't match a record to a map with non-string keys") + end + if t2.keys.typename == "enum" then + for _, k in ipairs(t1.field_order) do + if not t2.keys.enumset[k] then + return false, terr(t1, "key is not an enum value: " .. k) + end + end + end + return match_fields_to_map(t1, t2) + end + elseif t1.typename == "function" and t2.typename == "function" then + local all_errs = {} + if (not is_vararg(t2)) and #t1.args > #t2.args then + t1.args.typename = "tuple" + t2.args.typename = "tuple" + table.insert(all_errs, error_in_type(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) + else + for i = (t1.is_method and 2 or 1), #t1.args do + arg_check(is_a, t1.args[i], t2.args[i] or ANY, nil, i, all_errs) + end + end + local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets[#t2.rets].is_va + if #t1.rets < #t2.rets and not diff_by_va then + t1.rets.typename = "tuple" + t2.rets.typename = "tuple" + table.insert(all_errs, error_in_type(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) + else + local nrets = #t2.rets + if diff_by_va then + nrets = nrets - 1 + end + for i = 1, nrets do + local ok, errs = is_a(t1.rets[i], t2.rets[i]) + add_errs_prefixing(errs, all_errs, "return " .. i .. ": ") + end + end + if #all_errs == 0 then + return true + else + return false, all_errs + end + elseif lax and ((not for_equality) and t2.typename == "boolean") then + + return true + elseif t1.typename == t2.typename then + return true + end + + return false, terr(t1, "got %s, expected %s", t1, t2) + end + + local function assert_is_a(node, t1, t2, context, name) + t1 = resolve_tuple(t1) + t2 = resolve_tuple(t2) + if lax and (is_unknown(t1) or is_unknown(t2)) then + return + end + + if t2.typename == "unknown_emptytable_value" then + if same_type(t2.emptytable_type.keys, NUMBER) then + infer_var(t2.emptytable_type, a_type({ typename = "array", elements = t1 }), node) + else + infer_var(t2.emptytable_type, a_type({ typename = "map", keys = t2.emptytable_type.keys, values = t1 }), node) + end + return + elseif t2.typename == "emptytable" then + if is_known_table_type(t1) then + infer_var(t2, t1, node) + elseif t1.typename ~= "emptytable" then + node_error(node, "in " .. context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + end + return + end + + local match, match_errs = is_a(t1, t2) + add_errs_prefixing(match_errs, errors, "in " .. context .. ": " .. (name and (name .. ": ") or ""), node) + end + + local function close_types(vars) + for name, var in pairs(vars) do + if var.t.typename == "typetype" then + var.t.closed = true + end + end + end + + local function begin_scope() + table.insert(st, {}) + end + + local function end_scope() + local unresolved = st[#st]["@unresolved"] + if unresolved then + local upper = st[#st - 1]["@unresolved"] + if upper then + for name, nodes in pairs(unresolved.t.labels) do + for _, node in ipairs(nodes) do + upper.t.labels[name] = upper.t.labels[name] or {} + table.insert(upper.t.labels[name], node) + end + end + for name, types in pairs(unresolved.t.nominals) do + for _, typ in ipairs(types) do + upper.t.nominals[name] = upper.t.nominals[name] or {} + table.insert(upper.t.nominals[name], typ) + end + end + else + st[#st - 1]["@unresolved"] = unresolved + end + end + close_types(st[#st]) + table.remove(st) + end + + local type_check_function_call + do + local function try_match_func_args(node, f, args, is_method, argdelta) + local ok = true + local errs = {} + + if is_method then + argdelta = -1 + elseif not argdelta then + argdelta = 0 + end + + if f.is_method and not is_method and not (args[1] and is_a(args[1], f.args[1])) then + table.insert(errs, { y = node.y, x = node.x, msg = "invoked method as a regular function: use ':' instead of '.'", filename = filename }) + return nil, errs + end + + local va = is_vararg(f) + local nargs = va and + math.max(#args, #f.args) or + math.min(#args, #f.args) + + for a = 1, nargs do + local arg = args[a] + local farg = f.args[a] or (va and f.args[#f.args]) + if arg == nil then + if farg.is_va then + break + end + else + local at = node.e2 and node.e2[a] or node + if not arg_check(is_a, arg, farg, at, (a + argdelta), errs) then + ok = false + break + end + end + end + if ok == true then + f.rets.typename = "tuple" + + + for a = 1, #args do + local arg = args[a] + local farg = f.args[a] or (va and f.args[#f.args]) + if arg.typename == "emptytable" then + infer_var(arg, resolve_typevars(farg), node.e2[a]) + end + end + + return resolve_typevars(f.rets) + end + return nil, errs + end + + local function revert_typeargs(func) + if func.typeargs then + for _, arg in ipairs(func.typeargs) do + if st[#st][arg.typearg] then + st[#st][arg.typearg] = nil + end + end + end + end + + local function remove_sorted_duplicates(t) + local prev = nil + for i = #t, 1, -1 do + if t[i] == prev then + table.remove(t, i) + else + prev = t[i] + end + end + end + + local function check_call(node, func, args, is_method, argdelta) + assert(type(func) == "table") + assert(type(args) == "table") + + if lax and is_unknown(func) then + func = a_type({ typename = "function", args = { VARARG_UNKNOWN }, rets = { VARARG_UNKNOWN } }) + end + + func = resolve_unary(func) + + args = args or {} + local poly = func.typename == "poly" and func or { types = { func } } + local first_errs + local expects = {} + + local tried = {} + for i, f in ipairs(poly.types) do + if not tried[i] then + if f.typename ~= "function" then + if lax and is_unknown(f) then + return UNKNOWN + end + return node_error(node, "not a function: %s", f) + end + table.insert(expects, tostring(#f.args or 0)) + local va = is_vararg(f) + if #args == (#f.args or 0) or (va and #args > #f.args) then + tried[i] = true + local matched, errs = try_match_func_args(node, f, args, is_method, argdelta) + if matched then + return matched + else + revert_typeargs(f) + end + first_errs = first_errs or errs + end + end + end + + for i, f in ipairs(poly.types) do + if not tried[i] then + tried[i] = true + if #args < (#f.args or 0) then + tried[i] = true + local matched, errs = try_match_func_args(node, f, args, is_method, argdelta) + if matched then + return matched + else + revert_typeargs(f) + end + first_errs = first_errs or errs + end + end + end + + for i, f in ipairs(poly.types) do + if not tried[i] then + if is_vararg(f) and #args > (#f.args or 0) then + tried[i] = true + local matched, errs = try_match_func_args(node, f, args, is_method, argdelta) + if matched then + return matched + else + revert_typeargs(f) + end + first_errs = first_errs or errs + end + end + end + + if not first_errs then + table.sort(expects) + remove_sorted_duplicates(expects) + node_error(node, "wrong number of arguments (given " .. #args .. ", expects " .. table.concat(expects, " or ") .. ")") + else + for _, err in ipairs(first_errs) do + table.insert(errors, err) + end + end + + poly.types[1].rets.typename = "tuple" + return resolve_typevars(poly.types[1].rets) + end + + type_check_function_call = function(node, func, args, is_method, argdelta) + begin_scope() + local ret = check_call(node, func, args, is_method, argdelta) + end_scope() + return ret + end + end + + local unknown_dots = {} + + local function add_unknown_dot(node, name) + if not unknown_dots[name] then + unknown_dots[name] = true + add_unknown(node, name) + end + end + + local function get_self_type(t) + if t.typename == "typetype" then + return t.def + else + return t + end + end + + local function match_record_key(node, tbl, key, orig_tbl) + assert(type(tbl) == "table") + assert(type(key) == "table") + + tbl = resolve_unary(tbl) + local type_description = tbl.typename + if tbl.typename == "string" or tbl.typename == "enum" then + tbl = find_var("string") + end + + if lax and (is_unknown(tbl) or tbl.typename == "typevar") then + if node.e1.kind == "variable" and node.op.op ~= "@funcall" then + add_unknown_dot(node, node.e1.tk .. "." .. key.tk) + end + return UNKNOWN + end + + tbl = get_self_type(tbl) + + if tbl.typename == "emptytable" then + elseif is_record_type(tbl) then + assert(tbl.fields, "record has no fields!?") + + if key.kind == "string" or key.kind == "identifier" then + if tbl.fields[key.tk] then + return tbl.fields[key.tk] + end + end + else + if is_unknown(tbl) then + if not lax then + node_error(node, "cannot index a value of unknown type") + end + else + node_error(node, "cannot index something that is not a record: %s", tbl) + end + return INVALID + end + + if lax then + if node.e1.kind == "variable" and node.op.op ~= "@funcall" then + add_unknown_dot(node, node.e1.tk .. "." .. key.tk) + end + return UNKNOWN + end + + local description + if node.e1.kind == "variable" then + description = type_description .. " '" .. node.e1.tk .. "' of type " .. show_type(resolve_tuple(orig_tbl)) + else + description = "type " .. show_type(resolve_tuple(orig_tbl)) + end + + return node_error(key, "invalid key '" .. key.tk .. "' in " .. description) + end + + local function widen_in_scope(scope, var) + if scope[var].is_narrowed then + if scope[var].narrowed_from then + scope[var].t = scope[var].narrowed_from + scope[var].narrowed_from = nil + scope[var].is_narrowed = false + else + scope[var] = nil + end + return true + end + return false + end + + local function widen_back_var(var) + local widened = false + for i = #st, 1, -1 do + if st[i][var] then + if widen_in_scope(st[i], var) then + widened = true + else + break + end + end + end + return widened + end + + local function widen_all_unions() + for i = #st, 1, -1 do + for var, _ in pairs(st[i]) do + widen_in_scope(st[i], var) + end + end + end + + local function add_global(node, var, valtype, is_const) + if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then + add_unknown(node, var) + end + st[1][var] = { t = valtype, is_const = is_const } + end + + local check_typevars + + local function check_all_typevars(node, ts) + if ts ~= nil then + for _, arg in ipairs(ts) do + check_typevars(node, arg) + end + end + end + + check_typevars = function(node, t) + if t == nil then + return + end + if t.typename == "typevar" then + if not find_var(t.typevar) then + node_error(node, "unknown type variable " .. t.typevar) + end + return + end + check_typevars(node, t.elements) + check_typevars(node, t.keys) + check_typevars(node, t.values) + check_all_typevars(node, t.typeargs) + check_all_typevars(node, t.args) + check_all_typevars(node, t.rets) + end + + local function get_rets(rets) + if lax and (#rets == 0) then + return { a_type({ typename = "unknown", is_va = true }) } + end + return rets + end + + local function begin_function_scope(node, recurse) + begin_scope() + local args = {} + if node.typeargs then + for i, arg in ipairs(node.typeargs) do + add_var(nil, arg.typearg, arg) + end + end + local is_va = false + for i, arg in ipairs(node.args) do + local t = arg.decltype + if not t then + t = a_type({ typename = "unknown" }) + end + if arg.tk == "..." then + is_va = true + t.is_va = true + if i ~= #node.args then + node_error(node, "'...' can only be last argument") + end + end + check_typevars(arg, t) + table.insert(args, t) + add_var(arg, arg.tk, t) + end + + add_var(nil, "@is_va", is_va and VARARG_ANY or NIL) + + add_var(nil, "@return", node.rets or a_type({ typename = "tuple" })) + if recurse then + add_var(nil, node.name.tk, a_type({ + typename = "function", + args = args, + rets = get_rets(node.rets), + })) + end + end + + local function fail_unresolved() + local unresolved = st[#st]["@unresolved"] + if unresolved then + st[#st]["@unresolved"] = nil + for name, nodes in pairs(unresolved.t.labels) do + for _, node in ipairs(nodes) do + node_error(node, "no visible label '" .. name .. "' for goto") + end + end + for name, types in pairs(unresolved.t.nominals) do + for _, typ in ipairs(types) do + assert(typ.x) + assert(typ.y) + type_error(typ, "unknown type %s", typ) + end + end + end + end + + local function end_function_scope() + fail_unresolved() + end_scope() + end + + local function match_typevals(t, def) + if t.typevals and def.typeargs then + if #t.typevals ~= #def.typeargs then + type_error(t, "mismatch in number of type arguments") + return nil + end + + begin_scope() + for i, tt in ipairs(t.typevals) do + add_var(nil, def.typeargs[i].typearg, tt) + end + local ret = resolve_typevars(def) + end_scope() + return ret + elseif t.typevals then + type_error(t, "spurious type arguments") + return nil + elseif def.typeargs then + type_error(t, "missing type arguments in %s", def) + return nil + else + return def + end + end + + local function resolve_nominal(t) + if t.resolved then + return t.resolved + end + + local resolved + + local typetype = t.found or find_type(t.names) + if not typetype then + type_error(t, "unknown type %s", t) + elseif is_type(typetype) then + resolved = match_typevals(t, typetype.def) + else + type_error(t, table.concat(t.names, ".") .. " is not a type") + end + + if not resolved then + resolved = a_type({ typename = "bad_nominal", names = t.names }) + end + + t.found = typetype + t.resolved = resolved + return resolved + end + + resolve_unary = function(t) + t = resolve_tuple(t) + if t.typename == "nominal" then + return resolve_nominal(t) + end + return t + end + + local function flatten_list(list) + local exps = {} + for i = 1, #list - 1 do + table.insert(exps, resolve_unary(list[i])) + end + if #list > 0 then + local last = list[#list] + if last.typename == "tuple" then + for _, val in ipairs(last) do + table.insert(exps, val) + end + else + table.insert(exps, last) + end + end + return exps + end + + local function get_assignment_values(vals, wanted) + local ret = {} + if vals == nil then + return ret + end + + for i = 1, #vals - 1 do + ret[i] = vals[i] + end + local last = vals[#vals] + + if last.typename == "tuple" then + for _, v in ipairs(last) do + table.insert(ret, v) + end + + elseif last.is_va and #ret < wanted then + while #ret < wanted do + table.insert(ret, last) + end + + else + table.insert(ret, last) + end + return ret + end + + local function match_all_record_field_names(node, a, field_names, errmsg) + local t + for _, k in ipairs(field_names) do + local f = a.fields[k] + if not t then + t = f + else + if not same_type(f, t) then + t = nil + break + end + end + end + if t then + return t + else + return node_error(node, errmsg) + end + end + + local function type_check_index(node, idxnode, a, b) + local orig_a = a + local orig_b = b + a = resolve_unary(a) + b = resolve_unary(b) + + if is_array_type(a) and is_a(b, NUMBER) then + return a.elements + elseif a.typename == "emptytable" then + if a.keys == nil then + a.keys = b + a.keys_inferred_at = node + a.keys_inferred_at_file = filename + else + if not is_a(b, a.keys) then + local inferred = " (type of keys inferred at " .. a.keys_inferred_at_file .. ":" .. a.keys_inferred_at.y .. ":" .. a.keys_inferred_at.x .. ": )" + return node_error(idxnode, "inconsistent index type: %s, expected %s" .. inferred, b, a.keys) + end + end + return a_type({ y = node.y, x = node.x, typename = "unknown_emptytable_value", emptytable_type = a }) + elseif a.typename == "map" then + if is_a(b, a.keys) then + return a.values + else + return node_error(idxnode, "wrong index type: %s, expected %s", orig_b, a.keys) + end + elseif node.e2.kind == "string" or node.e2.kind == "enum_item" then + return match_record_key(node, a, { y = node.e2.y, x = node.e2.x, kind = "string", tk = assert(node.e2.conststr) }, orig_a) + elseif is_record_type(a) and b.typename == "enum" then + local field_names = {} + for k, _ in pairs(b.enumset) do + table.insert(field_names, k) + end + table.sort(field_names) + for _, k in ipairs(field_names) do + if not a.fields[k] then + return node_error(idxnode, "enum value '" .. k .. "' is not a field in %s", a) + end + end + return match_all_record_field_names(idxnode, a, field_names, +"cannot index, not all enum values map to record fields of the same type") + elseif lax and is_unknown(a) then + return UNKNOWN + else + if is_a(b, STRING) then + return node_error(idxnode, "cannot index object of type %s with a string, consider using an enum", orig_a) + end + return node_error(idxnode, "cannot index object of type %s with %s", orig_a, orig_b) + end + end + + local function expand_type(where, old, new) + if not old then + return new + else + if not is_a(new, old) then + if old.typename == "map" and is_record_type(new) then + if old.keys.typename == "string" then + for _, ftype in pairs(new.fields) do + old.values = expand_type(where, old.values, ftype) + end + else + node_error(where, "cannot determine table literal type") + end + elseif is_record_type(old) and is_record_type(new) then + old.typename = "map" + old.keys = STRING + for _, ftype in pairs(old.fields) do + if not old.values then + old.values = ftype + else + old.values = expand_type(where, old.values, ftype) + end + end + for _, ftype in pairs(new.fields) do + if not old.values then + new.values = ftype + else + new.values = expand_type(where, old.values, ftype) + end + end + old.fields = nil + old.field_order = nil + elseif old.typename == "union" then + new.tk = nil + table.insert(old.types, new) + else + old.tk = nil + new.tk = nil + return a_union({ old, new }) + end + end + end + return old + end + + local function find_in_scope(exp) + if exp.kind == "variable" then + local t = find_var(exp.tk) + if t.def then + if not t.def.closed and not t.closed then + return t.def + end + end + if not t.closed then + return t + end + elseif exp.kind == "op" and exp.op.op == "." then + local t = find_in_scope(exp.e1) + if not t then + return nil + end + while exp.e2.kind == "op" and exp.e2.op.op == "." do + t = t.fields[exp.e2.e1.tk] + if not t then + return nil + end + exp = exp.e2 + end + t = t.fields[exp.e2.tk] + return t + end + end + + local facts_and + local facts_or + local facts_not + do + local function join_facts(fss) + local vars = {} + + for _, fs in ipairs(fss) do + for _, f in ipairs(fs) do + if not vars[f.var] then + vars[f.var] = {} + end + table.insert(vars[f.var], f) + end + end + return vars + end + + local function intersect(xs, ys, same) + local rs = {} + for i = #xs, 1, -1 do + local x = xs[i] + for _, y in ipairs(ys) do + if same(x, y) then + table.insert(rs, x) + break + end + end + end + return rs + end + + local function same_type_for_intersect(t, u) + return (same_type(t, u)) + end + + local function intersect_facts(fs, errnode) + local all_is = true + local types = {} + for i, f in ipairs(fs) do + if f.fact ~= "is" then + all_is = false + break + end + if f.typ.typename == "union" then + if i == 1 then + types = f.typ.types + else + types = intersect(types, f.typ.types, same_type_for_intersect) + end + else + if i == 1 then + types = { f.typ } + else + types = intersect(types, { f.typ }, same_type_for_intersect) + end + end + end + + if #types == 0 then + node_error(errnode, "branch is always false") + return false + end + + if all_is then + if #types == 1 then + return true, types[1] + else + return true, a_union(types) + end + else + return false + end + end + + local function sum_facts(fs) + local all_is = true + local types = {} + for _, f in ipairs(fs) do + if f.fact ~= "is" then + all_is = false + break + end + table.insert(types, f.typ) + end + + if all_is then + if #types == 1 then + return true, types[1] + else + return true, a_union(types) + end + else + return false + end + end + + local function subtract_types(u1, u2, errt) + local types = {} + for _, rt in ipairs(u1.types or { u1 }) do + local not_present = true + for _, ft in ipairs(u2.types or { u2 }) do + if same_type(rt, ft) then + not_present = false + break + end + end + if not_present then + table.insert(types, rt) + end + end + + if #types == 0 then + type_error(errt, "branch is always false") + return INVALID + end + + if #types == 1 then + return types[1] + else + return a_union(types) + end + end + + facts_and = function(f1, f2, errnode) + if not f1 then + return f2 + end + if not f2 then + return f1 + end + + local out = {} + for v, fs in pairs(join_facts({ f1, f2 })) do + local ok, u = intersect_facts(fs, errnode) + + if ok then + table.insert(out, { fact = "is", var = v, typ = u }) + else + + for _, f in ipairs(fs) do + table.insert(out, f) + end + end + end + return out + end + + facts_or = function(f1, f2) + if not f1 or not f2 then + return nil + end + + local out = {} + for v, fs in pairs(join_facts({ f1, f2 })) do + local ok, u = sum_facts(fs) + if ok then + table.insert(out, { fact = "is", var = v, typ = u }) + else + + for _, f in ipairs(fs) do + table.insert(out, f) + end + end + end + return out + end + + facts_not = function(f1) + if not f1 then + return nil + end + + local out = {} + for v, fs in pairs(join_facts({ f1 })) do + local realtype = find_var(v) + if realtype then + local ok, u = sum_facts(fs) + if ok then + local not_typ = subtract_types(realtype, u, fs[1].typ) + table.insert(out, { fact = "is", var = v, typ = not_typ }) + end + end + end + return out + end + end + + local function apply_facts(where, facts) + if not facts then + return + end + for _, f in ipairs(facts) do + if f.fact == "is" then + local t = resolve_typevars(f.typ) + t.inferred_at = where + t.inferred_at_file = filename + add_var(nil, f.var, t, nil, true) + end + end + end + + local function dismiss_unresolved(name) + local unresolved = st[#st]["@unresolved"] + if unresolved then + if unresolved.t.nominals[name] then + for _, t in ipairs(unresolved.t.nominals[name]) do + resolve_nominal(t) + end + end + unresolved.t.nominals[name] = nil + end + end + + local function type_check_funcall(node, a, b, argdelta) + argdelta = argdelta or 0 + if node.e1.tk == "rawget" then + if #b == 2 then + local b1 = resolve_unary(b[1]) + local b2 = resolve_unary(b[2]) + local knode = node.e2[2] + if is_record_type(b1) and knode.conststr then + return match_record_key(node, b1, { y = knode.y, x = knode.x, kind = "string", tk = assert(knode.conststr) }, b1) + else + return type_check_index(node, knode, b1, b2) + end + else + node_error(node, "rawget expects two arguments") + end + elseif node.e1.tk == "print_type" then + print(show_type(b)) + return BOOLEAN + elseif node.e1.tk == "require" then + if #b == 1 then + if node.e2[1].kind == "string" then + local module_name = assert(node.e2[1].conststr) + local t, found = require_module(module_name, lax, opts.env, result) + if not found then + node_error(node, "module not found: '" .. module_name .. "'") + elseif not lax and is_unknown(t) then + node_error(node, "no type information for required module: '" .. module_name .. "'") + end + return t + else + node_error(node, "don't know how to resolve a dynamic require") + end + else + node_error(node, "require expects one literal argument") + end + elseif node.e1.tk == "pcall" then + local ftype = table.remove(b, 1) + local fe2 = {} + for i = 2, #node.e2 do + table.insert(fe2, node.e2[i]) + end + local fnode = { + y = node.y, + x = node.x, + typename = "op", + op = { op = "@funcall" }, + e1 = node.e2[1], + e2 = fe2, + } + local rets = type_check_funcall(fnode, ftype, b, argdelta + 1) + if rets.typename ~= "tuple" then + rets = a_type({ typename = "tuple", rets }) + end + table.insert(rets, 1, BOOLEAN) + return rets + elseif node.e1.op and node.e1.op.op == ":" then + local func = node.e1.type + if func.typename == "function" or func.typename == "poly" then + table.insert(b, 1, node.e1.e1.type) + return type_check_function_call(node, func, b, true) + else + if lax and (is_unknown(func)) then + if node.e1.e1.kind == "variable" then + add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) + end + return VARARG_UNKNOWN + else + return INVALID + end + end + else + return type_check_function_call(node, a, b, false, argdelta) + end + return UNKNOWN + end + + local visit_node = {} + + visit_node.cbs = { + ["statements"] = { + before = function() + begin_scope() + end, + after = function(node, children) + + if #st == 2 then + fail_unresolved() + end + + if not node.is_repeat then + end_scope() + end + + node.type = NONE + end, + }, + ["local_type"] = { + before = function(node) + add_var(node.var, node.var.tk, node.value.newtype, node.var.is_const) + end, + after = function(node, children) + dismiss_unresolved(node.var.tk) + node.type = NONE + end, + }, + ["global_type"] = { + before = function(node) + add_global(node.var, node.var.tk, node.value.newtype, node.var.is_const) + end, + after = function(node, children) + local existing, existing_is_const = find_global(node.var.tk) + local var = node.var + if existing then + if existing_is_const == true and not var.is_const then + node_error(var, "global was previously declared as : " .. var.tk) + end + if existing_is_const == false and var.is_const then + node_error(var, "global was previously declared as not : " .. var.tk) + end + if not same_type(existing, node.value.newtype) then + node_error(var, "cannot redeclare global with a different type: previous type of " .. var.tk .. " is %s", existing) + end + end + dismiss_unresolved(var.tk) + node.type = NONE + end, + }, + ["local_declaration"] = { + after = function(node, children) + local vals = get_assignment_values(children[2], #node.vars) + for i, var in ipairs(node.vars) do + local decltype = node.decltype and node.decltype[i] + local infertype = vals and vals[i] + if lax and infertype and infertype.typename == "nil" then + infertype = nil + end + if decltype and infertype then + assert_is_a(node.vars[i], infertype, decltype, "local declaration", var.tk) + end + local t = decltype or infertype + if t == nil then + t = a_type({ typename = "unknown" }) + if not lax then + if node.exps then + node_error(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. var.tk .. "'") + else + node_error(node.vars[i], "variable '" .. var.tk .. "' has no type or initial value") + end + end + elseif t.typename == "emptytable" then + t.declared_at = node + t.assigned_to = var.tk + end + assert(var) + add_var(var, var.tk, t, var.is_const) + + dismiss_unresolved(var.tk) + end + node.type = NONE + end, + }, + ["global_declaration"] = { + after = function(node, children) + local vals = get_assignment_values(children[2], #node.vars) + for i, var in ipairs(node.vars) do + local decltype = node.decltype and node.decltype[i] + local infertype = vals and vals[i] + if lax and infertype and infertype.typename == "nil" then + infertype = nil + end + if decltype and infertype then + assert_is_a(node.vars[i], infertype, decltype, "global declaration", var.tk) + end + local t = decltype or infertype + local existing, existing_is_const = find_global(var.tk) + if existing then + if infertype and existing_is_const then + node_error(var, "cannot reassign to global: " .. var.tk) + end + if existing_is_const == true and not var.is_const then + node_error(var, "global was previously declared as : " .. var.tk) + end + if existing_is_const == false and var.is_const then + node_error(var, "global was previously declared as not : " .. var.tk) + end + if not same_type(existing, t) then + node_error(var, "cannot redeclare global with a different type: previous type of " .. var.tk .. " is %s", existing) + end + else + if t == nil then + t = a_type({ typename = "unknown" }) + elseif t.typename == "emptytable" then + t.declared_at = node + t.assigned_to = var.tk + end + add_global(var, var.tk, t, var.is_const) + + dismiss_unresolved(var.tk) + end + end + node.type = NONE + end, + }, + ["assignment"] = { + after = function(node, children) + local vals = get_assignment_values(children[2], #children[1]) + local exps = flatten_list(vals) + for i, vartype in ipairs(children[1]) do + local varnode = node.vars[i] + if varnode.is_const then + node_error(varnode, "cannot assign to variable") + end + if varnode.kind == "variable" then + if widen_back_var(varnode.tk) then + vartype = find_var(varnode.tk) + end + end + if vartype then + local val = exps[i] + if resolve_unary(vartype).typename == "typetype" then + node_error(varnode, "cannot reassign a type") + elseif val then + assert_is_a(varnode, val, vartype, "assignment") + if varnode.kind == "variable" and vartype.typename == "union" then + + add_var(varnode, varnode.tk, val, false, true) + end + else + node_error(varnode, "variable is not being assigned a value") + end + else + node_error(varnode, "unknown variable") + end + end + node.type = NONE + end, + }, + ["do"] = { + after = function(node, children) + node.type = NONE + end, + }, + ["if"] = { + before_statements = function(node) + begin_scope() + apply_facts(node.exp, node.exp.facts) + end, + after = function(node, children) + end_scope() + node.type = NONE + end, + }, + ["elseif"] = { + before = function(node) + end_scope() + begin_scope() + end, + before_statements = function(node) + local f = facts_not(node.parent_if.exp.facts) + for e = 1, node.elseif_n - 1 do + f = facts_and(f, facts_not(node.parent_if.elseifs[e].exp.facts), node) + end + f = facts_and(f, node.exp.facts, node) + apply_facts(node.exp, f) + end, + after = function(node, children) + node.type = NONE + end, + }, + ["else"] = { + before = function(node) + end_scope() + begin_scope() + local f = facts_not(node.parent_if.exp.facts) + for _, elseifnode in ipairs(node.parent_if.elseifs) do + f = facts_and(f, facts_not(elseifnode.exp.facts), node) + end + apply_facts(node, f) + end, + after = function(node, children) + node.type = NONE + end, + }, + ["while"] = { + before = function() + + widen_all_unions() + end, + before_statements = function(node) + begin_scope() + apply_facts(node.exp, node.exp.facts) + end, + after = function(node, children) + end_scope() + node.type = NONE + end, + }, + ["label"] = { + before = function(node) + + widen_all_unions() + local label_id = "::" .. node.label .. "::" + if st[#st][label_id] then + node_error(node, "label '" .. node.label .. "' already defined at " .. filename) + end + local unresolved = st[#st]["@unresolved"] + if unresolved then + unresolved.t.labels[node.label] = nil + end + node.type = a_type({ y = node.y, x = node.x, typename = "none" }) + add_var(node, label_id, node.type) + end, + }, + ["goto"] = { + after = function(node, children) + if not find_var("::" .. node.label .. "::") then + local unresolved = st[#st]["@unresolved"] and st[#st]["@unresolved"].t + if not unresolved then + unresolved = { typename = "unresolved", labels = {}, nominals = {} } + add_var(node, "@unresolved", unresolved) + end + unresolved.labels[node.label] = unresolved.labels[node.label] or {} + table.insert(unresolved.labels[node.label], node) + end + node.type = NONE + end, + }, + ["repeat"] = { + before = function() + + widen_all_unions() + end, + after = function(node, children) + + end_scope() + node.type = NONE + end, + }, + ["forin"] = { + before = function() + begin_scope() + end, + before_statements = function(node) + local exp1 = node.exps[1] + local exp1type = resolve_tuple(exp1.type) + if exp1type.typename == "function" then + + if exp1.op and exp1.op.op == "@funcall" then + local t = resolve_unary(exp1.e2.type) + if exp1.e1.tk == "pairs" and not (t.typename == "map" or t.typename == "record") then + if not (lax and is_unknown(t)) then + node_error(exp1, "attempting pairs loop on something that's not a map or record: %s", exp1.e2.type) + end + elseif exp1.e1.tk == "ipairs" and not is_array_type(t) then + if not (lax and (is_unknown(t) or t.typename == "emptytable")) then + node_error(exp1, "attempting ipairs loop on something that's not an array: %s", exp1.e2.type) + end + end + end + local last + for i, v in ipairs(node.vars) do + local r = exp1type.rets[i] + if not r then + if last and last.is_va then + r = last + else + r = UNKNOWN + end + end + add_var(v, v.tk, r) + last = r + end + else + if not (lax and is_unknown(exp1type)) then + node_error(exp1, "expression in for loop does not return an iterator") + end + end + end, + after = function(node, children) + end_scope() + node.type = NONE + end, + }, + ["fornum"] = { + before = function(node) + begin_scope() + add_var(nil, node.var.tk, NUMBER) + end, + after = function(node, children) + end_scope() + node.type = NONE + end, + }, + ["return"] = { + after = function(node, children) + local rets = assert(find_var("@return")) + local nrets = #rets + local vatype + if nrets > 0 then + vatype = rets[nrets].is_va and rets[nrets] + end + + if #children[1] > nrets and (not lax) and not vatype then + rets.typename = "tuple" + children[1].typename = "tuple" + node_error(node, "excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) + end + + for i = 1, #children[1] do + local expected = rets[i] or vatype + if expected then + expected = resolve_unary(expected) + local where = (node.exps[i] and node.exps[i].x) and + node.exps[i] or + node.exps + assert(where and where.x) + assert_is_a(where, children[1][i], expected, "return value") + end + end + + + if #st == 2 then + module_type = resolve_unary(children[1]) + end + + node.type = NONE + end, + }, + ["variables"] = { + after = function(node, children) + node.type = children + + + local n = #children + if n > 0 and children[n].typename == "tuple" then + local tuple = children[n] + for i, c in ipairs(tuple) do + children[n + i - 1] = c + end + end + + node.type.typename = "tuple" + end, + }, + ["table_literal"] = { + after = function(node, children) + node.type = a_type({ + y = node.y, + x = node.x, + typename = "emptytable", + }) + local is_record = false + local is_array = false + local is_map = false + for i, child in ipairs(children) do + assert(child.typename == "table_item") + if child.kname then + is_record = true + if not node.type.fields then + node.type.fields = {} + node.type.field_order = {} + end + node.type.fields[child.kname] = child.vtype + table.insert(node.type.field_order, child.kname) + elseif child.ktype.typename == "number" then + is_array = true + if i == #children and node[i].key_parsed == "implicit" and child.vtype.typename == "tuple" then + + for _, c in ipairs(child.vtype) do + node.type.elements = expand_type(node, node.type.elements, c) + end + else + node.type.elements = expand_type(node, node.type.elements, child.vtype) + end + if not node.type.elements then + node_error(node, "cannot determine type of array elements") + is_array = false + end + else + is_map = true + node.type.keys = expand_type(node, node.type.keys, child.ktype) + node.type.values = expand_type(node, node.type.values, child.vtype) + end + end + if is_array and is_map then + node_error(node, "cannot determine type of table literal") + elseif is_record and is_array then + node.type.typename = "arrayrecord" + elseif is_record and is_map then + if node.type.keys.typename == "string" then + node.type.typename = "map" + for _, ftype in pairs(node.type.fields) do + node.type.values = expand_type(node, node.type.values, ftype) + end + node.type.fields = nil + node.type.field_order = nil + else + node_error(node, "cannot determine type of table literal") + end + elseif is_array then + node.type.typename = "array" + elseif is_record then + node.type.typename = "record" + elseif is_map then + node.type.typename = "map" + end + end, + }, + ["table_item"] = { + after = function(node, children) + local kname = node.key.conststr + local ktype = children[1] + local vtype = children[2] + if node.decltype then + vtype = node.decltype + assert_is_a(node.value, children[2], node.decltype, "table item") + end + node.type = a_type({ + y = node.y, + x = node.x, + typename = "table_item", + kname = kname, + ktype = ktype, + vtype = vtype, + }) + end, + }, + ["local_function"] = { + before = function(node) + begin_function_scope(node, true) + end, + after = function(node, children) + end_function_scope() + local rets = get_rets(children[3]) + + add_var(nil, node.name.tk, a_type({ + typename = "function", + args = children[2], + rets = rets, + })) + node.type = NONE + end, + }, + ["global_function"] = { + before = function(node) + begin_function_scope(node, true) + end, + after = function(node, children) + end_function_scope() + add_global(nil, node.name.tk, a_type({ + typename = "function", + args = children[2], + rets = get_rets(children[3]), + })) + node.type = NONE + end, + }, + ["record_function"] = { + before = function(node) + begin_function_scope(node) + end, + before_statements = function(node, children) + if node.is_method then + local rtype = get_self_type(children[1]) + children[3][1] = rtype + add_var(nil, "self", rtype) + end + + local rtype = resolve_unary(get_self_type(children[1])) + if rtype.typename == "emptytable" then + rtype.typename = "record" + end + if is_record_type(rtype) then + local fn_type = a_type({ + y = node.y, + x = node.x, + typename = "function", + is_method = node.is_method, + args = children[3], + rets = get_rets(children[4]), + }) + + local ok = false + if lax then + ok = true + elseif rtype.fields and rtype.fields[node.name.tk] and is_a(fn_type, rtype.fields[node.name.tk]) then + ok = true + elseif find_in_scope(node.fn_owner) == rtype then + ok = true + end + + if ok then + rtype.fields = rtype.fields or {} + rtype.field_order = rtype.field_order or {} + rtype.fields[node.name.tk] = fn_type + table.insert(rtype.field_order, node.name.tk) + else + local name = tl.pretty_print_ast(node.fn_owner, { preserve_indent = true, preserve_newlines = false }) + node_error(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. name .. "' was originally declared") + end + else + if (not lax) or (rtype.typename ~= "unknown") then + node_error(node, "not a module: %s", rtype) + end + end + end, + after = function(node, children) + end_function_scope() + + node.type = NONE + end, + }, + ["function"] = { + before = function(node) + begin_function_scope(node) + end, + after = function(node, children) + end_function_scope() + + + node.type = a_type({ + y = node.y, + x = node.x, + typename = "function", + args = children[1], + rets = children[2], + }) + end, + }, + ["cast"] = { + after = function(node, children) + node.type = node.casttype + end, + }, + ["paren"] = { + after = function(node, children) + node.type = resolve_unary(children[1]) + end, + }, + ["op"] = { + before = function(node) + begin_scope() + end, + before_e2 = function(node) + if node.op.op == "and" then + apply_facts(node, node.e1.facts) + elseif node.op.op == "or" then + apply_facts(node, facts_not(node.e1.facts)) + end + end, + after = function(node, children) + end_scope() + + local a = children[1] + local b = children[3] + + local orig_a = a + local orig_b = b + local ua = a and resolve_unary(a) + local ub = b and resolve_unary(b) + if node.op.op == "@funcall" then + node.type = type_check_funcall(node, a, b) + elseif node.op.op == "@index" then + node.type = type_check_index(node, node.e2, a, b) + elseif node.op.op == "as" then + node.type = b + elseif node.op.op == "is" then + if node.e1.kind == "variable" then + node.facts = { { fact = "is", var = node.e1.tk, typ = b } } + else + node_error(node, "can only use 'is' on variables") + end + node.type = BOOLEAN + elseif node.op.op == "." then + a = ua + if a.typename == "map" then + if is_a(a.keys, STRING) or is_a(a.keys, ANY) then + node.type = a.values + else + node_error(node, "cannot use . index, expects keys of type %s", a.keys) + end + else + node.type = match_record_key(node, a, { y = node.e2.y, x = node.e2.x, kind = "string", tk = node.e2.tk }, orig_a) + if node.type.needs_compat53 and not opts.skip_compat53 then + local key = node.e1.tk .. "." .. node.e2.tk + node.kind = "variable" + node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk + all_needs_compat53[key] = true + end + end + elseif node.op.op == ":" then + node.type = match_record_key(node, node.e1.type, node.e2, orig_a) + elseif node.op.op == "not" then + node.facts = facts_not(node.e1.facts) + node.type = BOOLEAN + elseif node.op.op == "and" then + node.facts = facts_and(node.e1.facts, node.e2.facts, node) + node.type = resolve_tuple(b) + elseif node.op.op == "or" and b.typename == "emptytable" then + node.facts = nil + node.type = resolve_tuple(a) + elseif node.op.op == "or" and same_type(ua, ub) then + node.facts = facts_or(node.e1.facts, node.e2.facts) + node.type = resolve_tuple(a) + elseif node.op.op == "or" and b.typename == "nil" then + node.facts = nil + node.type = resolve_tuple(a) + elseif node.op.op == "or" and + ((ua.typename == "enum" and ub.typename == "string" and is_a(ub, ua)) or + (ua.typename == "string" and ub.typename == "enum" and is_a(ua, ub))) then + node.facts = nil + node.type = (ua.typename == "enum" and ua or ub) + elseif node.op.op == "or" and + (a.typename == "nominal" or a.typename == "map") and + is_record_type(b) and + is_a(b, a) then + node.facts = nil + node.type = resolve_tuple(a) + elseif node.op.op == "==" or node.op.op == "~=" then + if is_a(a, b, true) or is_a(b, a, true) then + node.type = BOOLEAN + else + if lax and (is_unknown(a) or is_unknown(b)) then + node.type = UNKNOWN + else + node_error(node, "types are not comparable for equality: %s and %s", a, b) + end + end + elseif node.op.arity == 1 and unop_types[node.op.op] then + a = ua + local types_op = unop_types[node.op.op] + node.type = types_op[a.typename] + if not node.type then + if lax and is_unknown(a) then + node.type = UNKNOWN + else + node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", orig_a) + end + end + elseif node.op.arity == 2 and binop_types[node.op.op] then + if node.op.op == "or" then + node.facts = facts_or(node.e1.facts, node.e2.facts) + end + + a = ua + b = ub + local types_op = binop_types[node.op.op] + node.type = types_op[a.typename] and types_op[a.typename][b.typename] + if not node.type then + if lax and (is_unknown(a) or is_unknown(b)) then + node.type = UNKNOWN + else + node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", orig_a, orig_b) + end + end + else + error("unknown node op " .. node.op.op) + end + end, + }, + ["variable"] = { + after = function(node, children) + if node.tk == "..." then + local va_sentinel = find_var("@is_va") + if not va_sentinel or va_sentinel.typename == "nil" then + node.type = UNKNOWN + node_error(node, "cannot use '...' outside a vararg function") + end + end + + node.type, node.is_const = find_var(node.tk) + if node.type == nil then + node.type = a_type({ typename = "unknown" }) + if lax then + add_unknown(node, node.tk) + else + node_error(node, "unknown variable: " .. node.tk) + end + end + end, + }, + ["identifier"] = { + after = function(node, children) + node.type = NONE + end, + }, + ["newtype"] = { + after = function(node, children) + node.type = node.newtype + end, + }, + } + + visit_node.cbs["break"] = visit_node.cbs["do"] + + visit_node.cbs["values"] = visit_node.cbs["variables"] + visit_node.cbs["expression_list"] = visit_node.cbs["variables"] + visit_node.cbs["argument_list"] = visit_node.cbs["variables"] + visit_node.cbs["argument"] = visit_node.cbs["variable"] + + visit_node.cbs["string"] = { + after = function(node, children) + node.type = a_type({ + y = node.y, + x = node.x, + typename = node.kind, + tk = node.tk, + }) + return node.type + end, + } + visit_node.cbs["number"] = visit_node.cbs["string"] + visit_node.cbs["nil"] = visit_node.cbs["string"] + visit_node.cbs["boolean"] = visit_node.cbs["string"] + visit_node.cbs["..."] = visit_node.cbs["variable"] + + visit_node.after = { + after = function(node, children) + assert(type(node.type) == "table", node.kind .. " did not produce a type") + assert(type(node.type.typename) == "string", node.kind .. " type does not have a typename") + return node.type + end, + } + + local visit_type = { + cbs = { + ["string"] = { + after = function(typ, children) + return typ + end, + }, + ["function"] = { + before = function(typ, children) + begin_scope() + end, + after = function(typ, children) + end_scope() + return typ + end, + }, + ["record"] = { + before = function(typ, children) + begin_scope() + for name, typ in pairs(typ.fields) do + if typ.typename == "typetype" then + typ.typename = "nestedtype" + add_var(nil, name, typ) + end + end + end, + after = function(typ, children) + end_scope() + for name, typ in pairs(typ.fields) do + if typ.typename == "nestedtype" then + typ.typename = "typetype" + end + end + return typ + end, + }, + ["typearg"] = { + after = function(typ, children) + add_var(nil, typ.typearg, a_type({ + y = typ.y, + x = typ.x, + typename = "typearg", + typearg = typ.typearg, + })) + return typ + end, + }, + ["nominal"] = { + after = function(typ, children) + local t = find_type(typ.names, true) + if t then + if t.typename == "typearg" then + + typ.names = nil + typ.typename = "typevar" + typ.typevar = t.typearg + else + typ.found = t + end + else + local name = typ.names[1] + local unresolved = find_var("@unresolved") + if not unresolved then + unresolved = { typename = "unresolved", labels = {}, nominals = {} } + add_var(nil, "@unresolved", unresolved) + end + unresolved.nominals[name] = unresolved.nominals[name] or {} + table.insert(unresolved.nominals[name], typ) + end + return typ + end, + }, + ["union"] = { + after = function(typ, children) + + + local n_table_types = 0 + local n_function_types = 0 + local n_string_enum = 0 + for _, t in ipairs(typ.types) do + t = resolve_unary(t) + if table_types[t.typename] then + n_table_types = n_table_types + 1 + if n_table_types > 1 then + type_error(typ, "cannot discriminate a union between multiple table types: %s", typ) + break + end + elseif t.typename == "function" then + n_function_types = n_function_types + 1 + if n_function_types > 1 then + type_error(typ, "cannot discriminate a union between multiple function types: %s", typ) + break + end + elseif t.typename == "string" or t.typename == "enum" then + n_string_enum = n_string_enum + 1 + if n_string_enum > 1 then + type_error(typ, "cannot discriminate a union between multiple string/enum types: %s", typ) + break + end + end + end + return typ + end, + }, + }, + after = { + after = function(typ, children, ret) + assert(type(ret) == "table", typ.typename .. " did not produce a type") + assert(type(ret.typename) == "string", "type node does not have a typename") + return ret + end, + }, + } + + visit_type.cbs["typetype"] = visit_type.cbs["string"] + visit_type.cbs["nestedtype"] = visit_type.cbs["string"] + visit_type.cbs["typevar"] = visit_type.cbs["string"] + visit_type.cbs["array"] = visit_type.cbs["string"] + visit_type.cbs["map"] = visit_type.cbs["string"] + visit_type.cbs["arrayrecord"] = visit_type.cbs["string"] + visit_type.cbs["enum"] = visit_type.cbs["string"] + visit_type.cbs["boolean"] = visit_type.cbs["string"] + visit_type.cbs["nil"] = visit_type.cbs["string"] + visit_type.cbs["number"] = visit_type.cbs["string"] + visit_type.cbs["thread"] = visit_type.cbs["string"] + visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] + visit_type.cbs["emptytable"] = visit_type.cbs["string"] + visit_type.cbs["table_item"] = visit_type.cbs["string"] + visit_type.cbs["unknown_emptytable_value"] = visit_type.cbs["string"] + visit_type.cbs["tuple"] = visit_type.cbs["string"] + visit_type.cbs["poly"] = visit_type.cbs["string"] + visit_type.cbs["any"] = visit_type.cbs["string"] + visit_type.cbs["unknown"] = visit_type.cbs["string"] + visit_type.cbs["invalid"] = visit_type.cbs["string"] + visit_type.cbs["unresolved"] = visit_type.cbs["string"] + visit_type.cbs["none"] = visit_type.cbs["string"] + + recurse_node(ast, visit_node, visit_type) + + close_types(st[1]) + + local redundant = {} + local lastx, lasty = 0, 0 + table.sort(errors, function(a, b) + return ((a.filename and b.filename) and a.filename < b.filename) or + (a.filename == b.filename and ((a.y < b.y) or (a.y == b.y and a.x < b.x))) + end) + for i, err in ipairs(errors) do + if err.x == lastx and err.y == lasty then + table.insert(redundant, i) + end + lastx, lasty = err.x, err.y + end + for i = #redundant, 1, -1 do + table.remove(errors, redundant[i]) + end + + if not opts.skip_compat53 then + add_compat53_entries(ast, all_needs_compat53) + end + + return errors, unknowns, module_type +end + +function tl.process(filename, env, result, preload_modules) + local fd, err = io.open(filename, "r") + if not fd then + return nil, "could not open " .. filename .. ": " .. err + end + + local input, err = fd:read("*a") + fd:close() + if not input then + return nil, "could not read " .. filename .. ": " .. err + end + + local basename, extension = filename:match("(.*)%.([a-z]+)$") + extension = extension and extension:lower() + + local is_lua + if extension == "tl" then + is_lua = false + elseif extension == "lua" then + is_lua = true + else + is_lua = input:match("^#![^\n]*lua[^\n]*\n") + end + + result, err = tl.process_string(input, is_lua, env, result, preload_modules, filename) + + if err then + return nil, err + end + + return result +end + +function tl.process_string(input, is_lua, env, result, preload_modules, +filename) + + env = env or tl.init_env(is_lua) + result = result or { + syntax_errors = {}, + type_errors = {}, + unknowns = {}, + } + preload_modules = preload_modules or {} + filename = filename or "" + + local tokens, errs = tl.lex(input) + if errs then + for i, err in ipairs(errs) do + table.insert(result.syntax_errors, { + y = err.y, + x = err.x, + msg = "invalid token '" .. err.tk .. "'", + filename = filename, + }) + end + end + + local i, program = tl.parse_program(tokens, result.syntax_errors, filename) + if #result.syntax_errors > 0 then + return result + end + + + for _, name in ipairs(preload_modules) do + local module_type = require_module(name, is_lua, env, result) + + if module_type == UNKNOWN then + return nil, string.format("Error: could not preload module '%s'", name) + end + end + + local error, unknown + local opts = { + lax = is_lua, + filename = filename, + env = env, + result = result, + skip_compat53 = env.skip_compat53, + } + error, unknown, result.type = tl.type_check(program, opts) + + result.ast = program + result.env = env + + return result +end + +function tl.gen(input, env) + env = env or tl.init_env() + local result, err = tl.process_string(input, false, env) + + if err then + return nil, nil + end + + if not result.ast then + return nil, result + end + + return tl.pretty_print_ast(result.ast), result +end + +local function tl_package_loader(module_name) + local found_filename, fd, tried = tl.search_module(module_name, false) + if found_filename then + local input = fd:read("*a") + fd:close() + local errs = {} + local _, program = tl.parse_program(tl.lex(input), errs, module_name) + if #errs > 0 then + error(module_name .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) + end + local code = tl.pretty_print_ast(program, true) + local chunk, err = load(code, module_name, "t") + if chunk then + return function() + local ret = chunk() + package.loaded[module_name] = ret + return ret + end + else + error("Internal Compiler Error: Teal generator produced invalid Lua. Please report a bug at https://github.com/teal-language/tl") + end + end + return table.concat(tried, "\n\t") +end + +function tl.loader() + if package.searchers then + table.insert(package.searchers, 2, tl_package_loader) + else + table.insert(package.loaders, 2, tl_package_loader) + end +end + +function tl.load(input, chunkname, mode, env) + local tokens = tl.lex(input) + local errs = {} + local i, program = tl.parse_program(tokens, errs, chunkname) + if #errs > 0 then + return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg + end + local code = tl.pretty_print_ast(program, true) + return load(code, chunkname, mode, env) +end + +return tl -- cgit v1.2.3-55-g6feb