summaryrefslogtreecommitdiff
path: root/spec
diff options
context:
space:
mode:
authorLi Jin <dragon-fly@qq.com>2020-10-23 17:32:48 +0800
committerLi Jin <dragon-fly@qq.com>2020-10-23 17:32:48 +0800
commit1c6a9651beffd9cbbb3641179f3a738d5555d3c9 (patch)
tree31f442c3685bf19e851f62c4b2957d445e313eac /spec
parenta51a728d847e790329e41c75928a81630200b63f (diff)
downloadyuescript-1c6a9651beffd9cbbb3641179f3a738d5555d3c9.tar.gz
yuescript-1c6a9651beffd9cbbb3641179f3a738d5555d3c9.tar.bz2
yuescript-1c6a9651beffd9cbbb3641179f3a738d5555d3c9.zip
make teal-macro look better.
Diffstat (limited to 'spec')
-rw-r--r--spec/inputs/macro-teal.mp34
-rw-r--r--spec/inputs/teal-lang.mp22
-rw-r--r--spec/lib/tl.lua6801
3 files changed, 6844 insertions, 13 deletions
diff --git a/spec/inputs/macro-teal.mp b/spec/inputs/macro-teal.mp
index 20444e1..9ce1bcd 100644
--- a/spec/inputs/macro-teal.mp
+++ b/spec/inputs/macro-teal.mp
@@ -2,6 +2,7 @@ $ ->
2 import "moonp" as {:options} 2 import "moonp" as {:options}
3 if options.tl_enabled 3 if options.tl_enabled
4 options.target_extension = "tl" 4 options.target_extension = "tl"
5 package.path ..= "?.lua;./spec/lib/?.lua"
5 6
6macro expr to_lua = (codes)-> 7macro expr to_lua = (codes)->
7 "require('moonp').to_lua(#{codes}, reserve_line_number:false, same_module:true)" 8 "require('moonp').to_lua(#{codes}, reserve_line_number:false, same_module:true)"
@@ -9,19 +10,31 @@ macro expr to_lua = (codes)->
9macro expr trim = (name)-> 10macro expr trim = (name)->
10 "if result = #{name}\\match '[\\'\"](.*)[\\'\"]' then result else #{name}" 11 "if result = #{name}\\match '[\\'\"](.*)[\\'\"]' then result else #{name}"
11 12
12export macro text var = (name, type, value = nil)-> 13export macro text local = (decl, value = nil)->
13 import "moonp" as {options:{:tl_enabled}} 14 import "moonp" as {options:{:tl_enabled}}
15 name, type = ($trim decl)\match "(.-):(.*)"
16 if not (name and type)
17 error "invalid local varaible declaration for \"#{decl}\""
14 value = $to_lua(value)\gsub "^return ", "" 18 value = $to_lua(value)\gsub "^return ", ""
15 if tl_enabled 19 if tl_enabled
16 "local #{name}:#{$trim type} = #{value}", {name} 20 "local #{name}:#{$trim type} = #{value}", {name}
17 else 21 else
18 "local #{name} = #{value}", {name} 22 "local #{name} = #{value}", {name}
19 23
20export macro text def = (name, type, value)-> 24export macro text function = (decl, value)->
21 import "moonp" as {options:{:tl_enabled}} 25 import "moonp" as {options:{:tl_enabled}}
26 import "tl"
27 decl = $trim decl
28 name, type = decl\match "(.-)(%(.*)"
29 if not (name and type)
30 error "invalid function declaration for \"#{decl}\""
31 tokens = tl.lex "function #{decl}"
32 _, node = tl.parse_program tokens,{},"macro-function"
33 args = table.concat [arg.tk for arg in *node[1].args],", "
34 value = "(#{args})#{value}"
22 if tl_enabled 35 if tl_enabled
23 value = $to_lua(value)\match "function%(.*%)(.*)end" 36 value = $to_lua(value)\match "function%([^\n]*%)(.*)end"
24 "local function #{name}#{$trim type}\n#{value}\nend", {name} 37 "local function #{name}#{type}\n#{value}\nend", {name}
25 else 38 else
26 value = $to_lua(value)\gsub "^return ", "" 39 value = $to_lua(value)\gsub "^return ", ""
27 "local #{name} = #{value}", {name} 40 "local #{name} = #{value}", {name}
@@ -35,11 +48,20 @@ end", {name}
35 else 48 else
36 "local #{name} = {}", {name} 49 "local #{name} = {}", {name}
37 50
38export macro text field = (tab, sym, func, type, value)-> 51export macro text method = (decl, value)->
39 import "moonp" as {options:{:tl_enabled}} 52 import "moonp" as {options:{:tl_enabled}}
53 import "tl"
54 decl = $trim decl
55 tab, sym, func, type = decl\match "(.-)([%.:])(.-)(%(.*)"
56 if not (tab and sym and func and type)
57 error "invalid method declaration for \"#{decl}\""
58 tokens = tl.lex "function #{decl}"
59 _, node = tl.parse_program tokens,{},"macro-function"
60 args = table.concat [arg.tk for arg in *node[1].args],", "
61 value = "(#{args})->#{value\match "[%-=]>(.*)"}"
40 if tl_enabled 62 if tl_enabled
41 value = $to_lua(value)\match "^return function%(.-%)\n(.*)end" 63 value = $to_lua(value)\match "^return function%(.-%)\n(.*)end"
42 "function #{tab}#{$trim sym}#{func}#{$trim type}\n#{value}\nend" 64 "function #{tab}#{sym}#{func}#{type}\n#{value}\nend"
43 else 65 else
44 value = $to_lua(value)\gsub "^return ", "" 66 value = $to_lua(value)\gsub "^return ", ""
45 "#{tab}.#{func} = #{value}" 67 "#{tab}.#{func} = #{value}"
diff --git a/spec/inputs/teal-lang.mp b/spec/inputs/teal-lang.mp
index 3c9c79b..29769d5 100644
--- a/spec/inputs/teal-lang.mp
+++ b/spec/inputs/teal-lang.mp
@@ -3,10 +3,10 @@ $ ->
3 3
4import "macro-teal" as {$} 4import "macro-teal" as {$}
5 5
6$var a, "{string:number}", {value:123} 6$local "a:{string:number}", {value:123}
7$var b, "number", a.value 7$local "b:number", a.value
8 8
9$def add, "(a:number,b:number):number", (a, b)-> a + b 9$function "add(a:number, b:number):number", -> a + b
10 10
11s = add(a.value, b) 11s = add(a.value, b)
12print(s) 12print(s)
@@ -16,17 +16,25 @@ $record Point, [[
16 y: number 16 y: number
17]] 17]]
18 18
19$field Point, '.', new, "(x: number, y: number):Point", (x, y)-> 19$method "Point.new(x:number, y:number):Point", ->
20 $var point, "Point", setmetatable {}, __index: Point 20 $local "point:Point", setmetatable {}, __index: Point
21 point.x = x or 0 21 point.x = x or 0
22 point.y = y or 0 22 point.y = y or 0
23 point 23 point
24 24
25$field Point, ":", move, "(dx: number, dy: number)", (dx, dy)=> 25$method "Point:move(dx:number, dy:number)", ->
26 @x += dx 26 @x += dx
27 @y += dy 27 @y += dy
28 28
29$var p, "Point", Point.new 100, 100 29$local "p:Point", Point.new 100, 100
30 30
31p\move 50, 50 31p\move 50, 50
32 32
33$function "filter(tab:{string}, handler:function(item:string):boolean):{string}", ->
34 [item for item in *tab when handler item]
35
36$function "cond(item:string):boolean", -> item ~= "a"
37
38res = filter {"a", "b", "c", "a"}, cond
39for s in *res
40 print s
diff --git a/spec/lib/tl.lua b/spec/lib/tl.lua
new file mode 100644
index 0000000..aca748c
--- /dev/null
+++ b/spec/lib/tl.lua
@@ -0,0 +1,6801 @@
1local _tl_compat53 = ((tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3) and require('compat53.module'); local assert = _tl_compat53 and _tl_compat53.assert or assert; local io = _tl_compat53 and _tl_compat53.io or io; local ipairs = _tl_compat53 and _tl_compat53.ipairs or ipairs; local load = _tl_compat53 and _tl_compat53.load or load; local math = _tl_compat53 and _tl_compat53.math or math; local os = _tl_compat53 and _tl_compat53.os or os; local package = _tl_compat53 and _tl_compat53.package or package; local pairs = _tl_compat53 and _tl_compat53.pairs or pairs; local string = _tl_compat53 and _tl_compat53.string or string; local table = _tl_compat53 and _tl_compat53.table or table; local _tl_table_unpack = unpack or table.unpack; local Env = {}
2
3
4
5
6
7local TypeCheckOptions = {}
8
9
10
11
12
13
14
15local LoadMode = {}
16
17
18
19
20local LoadFunction = {}
21
22local tl = {
23 load = nil,
24 process = nil,
25 process_string = nil,
26 gen = nil,
27 type_check = nil,
28 init_env = nil,
29}
30
31
32
33
34
35
36
37local inspect = function(x)
38 return tostring(x)
39end
40
41local keywords = {
42 ["and"] = true,
43 ["break"] = true,
44 ["do"] = true,
45 ["else"] = true,
46 ["elseif"] = true,
47 ["end"] = true,
48 ["false"] = true,
49 ["for"] = true,
50 ["function"] = true,
51 ["goto"] = true,
52 ["if"] = true,
53 ["in"] = true,
54 ["local"] = true,
55 ["nil"] = true,
56 ["not"] = true,
57 ["or"] = true,
58 ["repeat"] = true,
59 ["return"] = true,
60 ["then"] = true,
61 ["true"] = true,
62 ["until"] = true,
63 ["while"] = true,
64
65
66}
67
68local TokenKind = {}
69
70
71
72
73
74
75
76
77
78
79
80
81local Token = {}
82
83
84
85
86
87
88
89local lex_word_start = {}
90for c = string.byte("a"), string.byte("z") do
91 lex_word_start[string.char(c)] = true
92end
93for c = string.byte("A"), string.byte("Z") do
94 lex_word_start[string.char(c)] = true
95end
96lex_word_start["_"] = true
97
98local lex_word = {}
99for c = string.byte("a"), string.byte("z") do
100 lex_word[string.char(c)] = true
101end
102for c = string.byte("A"), string.byte("Z") do
103 lex_word[string.char(c)] = true
104end
105for c = string.byte("0"), string.byte("9") do
106 lex_word[string.char(c)] = true
107end
108lex_word["_"] = true
109
110local lex_decimal_start = {}
111for c = string.byte("1"), string.byte("9") do
112 lex_decimal_start[string.char(c)] = true
113end
114
115local lex_decimals = {}
116for c = string.byte("0"), string.byte("9") do
117 lex_decimals[string.char(c)] = true
118end
119
120local lex_hexadecimals = {}
121for c = string.byte("0"), string.byte("9") do
122 lex_hexadecimals[string.char(c)] = true
123end
124for c = string.byte("a"), string.byte("f") do
125 lex_hexadecimals[string.char(c)] = true
126end
127for c = string.byte("A"), string.byte("F") do
128 lex_hexadecimals[string.char(c)] = true
129end
130
131local lex_char_symbols = {}
132for _, c in ipairs({ "[", "]", "(", ")", "{", "}", ",", "#", "`", ";" }) do
133 lex_char_symbols[c] = true
134end
135
136local lex_op_start = {}
137for _, c in ipairs({ "+", "*", "/", "|", "&", "%", "^" }) do
138 lex_op_start[c] = true
139end
140
141local lex_space = {}
142for _, c in ipairs({ " ", "\t", "\v", "\n", "\r" }) do
143 lex_space[c] = true
144end
145
146local LexState = {}
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179function tl.lex(input)
180 local tokens = {}
181
182 local state = "start"
183 local fwd = true
184 local y = 1
185 local x = 0
186 local i = 0
187 local lc_open_lvl = 0
188 local lc_close_lvl = 0
189 local ls_open_lvl = 0
190 local ls_close_lvl = 0
191 local errs = {}
192
193 local tx
194 local ty
195 local ti
196 local in_token = false
197
198 local function begin_token()
199 tx = x
200 ty = y
201 ti = i
202 in_token = true
203 end
204
205 local function end_token(kind, last, t)
206 local tk = t or input:sub(ti, last or i) or ""
207 if keywords[tk] then
208 kind = "keyword"
209 end
210 table.insert(tokens, {
211 x = tx,
212 y = ty,
213 i = ti,
214 tk = tk,
215 kind = kind,
216 })
217 in_token = false
218 end
219
220 local function drop_token()
221 in_token = false
222 end
223
224 while i <= #input do
225 if fwd then
226 i = i + 1
227 if i > #input then
228 break
229 end
230 end
231
232 local c = input:sub(i, i)
233
234 if fwd then
235 if c == "\n" then
236 y = y + 1
237 x = 0
238 else
239 x = x + 1
240 end
241 else
242 fwd = true
243 end
244
245 if state == "start" then
246 if input:sub(1, 2) == "#!" then
247 i = input:find("\n")
248 if not i then
249 break
250 end
251 c = "\n"
252 y = 2
253 x = 0
254 end
255 state = "any"
256 end
257
258 if state == "any" then
259 if c == "-" then
260 state = "maybecomment"
261 begin_token()
262 elseif c == "." then
263 state = "maybedotdot"
264 begin_token()
265 elseif c == "\"" then
266 state = "dblquote_string"
267 begin_token()
268 elseif c == "'" then
269 state = "singlequote_string"
270 begin_token()
271 elseif lex_word_start[c] then
272 state = "identifier"
273 begin_token()
274 elseif c == "0" then
275 state = "decimal_or_hex"
276 begin_token()
277 elseif lex_decimal_start[c] then
278 state = "decimal_number"
279 begin_token()
280 elseif c == "<" then
281 state = "lt"
282 begin_token()
283 elseif c == ":" then
284 state = "colon"
285 begin_token()
286 elseif c == ">" then
287 state = "gt"
288 begin_token()
289 elseif c == "=" or c == "~" then
290 state = "maybeequals"
291 begin_token()
292 elseif c == "[" then
293 state = "maybelongstring"
294 begin_token()
295 elseif lex_char_symbols[c] then
296 begin_token()
297 end_token(c)
298 elseif lex_op_start[c] then
299 begin_token()
300 end_token("op")
301 elseif lex_space[c] then
302
303 else
304 begin_token()
305 end_token("$invalid$")
306 table.insert(errs, tokens[#tokens])
307 end
308 elseif state == "maybecomment" then
309 if c == "-" then
310 state = "maybecomment2"
311 else
312 end_token("op", nil, "-")
313 fwd = false
314 state = "any"
315 end
316 elseif state == "maybecomment2" then
317 if c == "[" then
318 state = "maybelongcomment"
319 else
320 fwd = false
321 state = "comment"
322 drop_token()
323 end
324 elseif state == "maybelongcomment" then
325 if c == "[" then
326 state = "longcomment"
327 elseif c == "=" then
328 lc_open_lvl = lc_open_lvl + 1
329 else
330 fwd = false
331 state = "comment"
332 drop_token()
333 lc_open_lvl = 0
334 end
335 elseif state == "longcomment" then
336 if c == "]" then
337 state = "maybelongcommentend"
338 end
339 elseif state == "maybelongcommentend" then
340 if c == "]" and lc_close_lvl == lc_open_lvl then
341 drop_token()
342 state = "any"
343 lc_open_lvl = 0
344 lc_close_lvl = 0
345 elseif c == "=" then
346 lc_close_lvl = lc_close_lvl + 1
347 else
348 state = "longcomment"
349 lc_close_lvl = 0
350 end
351 elseif state == "dblquote_string" then
352 if c == "\\" then
353 state = "escape_dblquote_string"
354 elseif c == "\"" then
355 end_token("string")
356 state = "any"
357 end
358 elseif state == "escape_dblquote_string" then
359 state = "dblquote_string"
360 elseif state == "singlequote_string" then
361 if c == "\\" then
362 state = "escape_singlequote_string"
363 elseif c == "'" then
364 end_token("string")
365 state = "any"
366 end
367 elseif state == "escape_singlequote_string" then
368 state = "singlequote_string"
369 elseif state == "maybeequals" then
370 if c == "=" then
371 end_token("op")
372 state = "any"
373 else
374 end_token("op", i - 1)
375 fwd = false
376 state = "any"
377 end
378 elseif state == "lt" then
379 if c == "=" or c == "<" then
380 end_token("op")
381 state = "any"
382 else
383 end_token("op", i - 1)
384 fwd = false
385 state = "any"
386 end
387 elseif state == "colon" then
388 if c == ":" then
389 end_token("::")
390 state = "any"
391 else
392 end_token(":", i - 1)
393 fwd = false
394 state = "any"
395 end
396 elseif state == "gt" then
397 if c == "=" or c == ">" then
398 end_token("op")
399 state = "any"
400 else
401 end_token("op", i - 1)
402 fwd = false
403 state = "any"
404 end
405 elseif state == "maybelongstring" then
406 if c == "[" then
407 state = "longstring"
408 elseif c == "=" then
409 ls_open_lvl = ls_open_lvl + 1
410 else
411 end_token("[", i - 1)
412 fwd = false
413 state = "any"
414 ls_open_lvl = 0
415 end
416 elseif state == "longstring" then
417 if c == "]" then
418 state = "maybelongstringend"
419 end
420 elseif state == "maybelongstringend" then
421 if c == "]" then
422 if ls_close_lvl == ls_open_lvl then
423 end_token("string")
424 state = "any"
425 ls_open_lvl = 0
426 ls_close_lvl = 0
427 end
428 elseif c == "=" then
429 ls_close_lvl = ls_close_lvl + 1
430 else
431 state = "longstring"
432 ls_close_lvl = 0
433 end
434 elseif state == "maybedotdot" then
435 if c == "." then
436 state = "maybedotdotdot"
437 elseif lex_decimals[c] then
438 state = "decimal_float"
439 else
440 end_token(".", i - 1)
441 fwd = false
442 state = "any"
443 end
444 elseif state == "maybedotdotdot" then
445 if c == "." then
446 end_token("...")
447 state = "any"
448 else
449 end_token("op", i - 1)
450 fwd = false
451 state = "any"
452 end
453 elseif state == "comment" then
454 if c == "\n" then
455 state = "any"
456 end
457 elseif state == "identifier" then
458 if not lex_word[c] then
459 end_token("identifier", i - 1)
460 fwd = false
461 state = "any"
462 end
463 elseif state == "decimal_or_hex" then
464 if c == "x" or c == "X" then
465 state = "hex_number"
466 elseif c == "e" or c == "E" then
467 state = "power_sign"
468 elseif lex_decimals[c] then
469 state = "decimal_number"
470 elseif c == "." then
471 state = "decimal_float"
472 else
473 end_token("number", i - 1)
474 fwd = false
475 state = "any"
476 end
477 elseif state == "hex_number" then
478 if c == "." then
479 state = "hex_float"
480 elseif c == "p" or c == "P" then
481 state = "power_sign"
482 elseif not lex_hexadecimals[c] then
483 end_token("number", i - 1)
484 fwd = false
485 state = "any"
486 end
487 elseif state == "hex_float" then
488 if c == "p" or c == "P" then
489 state = "power_sign"
490 elseif not lex_hexadecimals[c] then
491 end_token("number", i - 1)
492 fwd = false
493 state = "any"
494 end
495 elseif state == "decimal_number" then
496 if c == "." then
497 state = "decimal_float"
498 elseif c == "e" or c == "E" then
499 state = "power_sign"
500 elseif not lex_decimals[c] then
501 end_token("number", i - 1)
502 fwd = false
503 state = "any"
504 end
505 elseif state == "decimal_float" then
506 if c == "e" or c == "E" then
507 state = "power_sign"
508 elseif not lex_decimals[c] then
509 end_token("number", i - 1)
510 fwd = false
511 state = "any"
512 end
513 elseif state == "power_sign" then
514 if c == "-" or c == "+" then
515 state = "power"
516 elseif lex_decimals[c] then
517 state = "power"
518 else
519 end_token("$invalid$")
520 table.insert(errs, tokens[#tokens])
521 state = "any"
522 end
523 elseif state == "power" then
524 if not lex_decimals[c] then
525 end_token("number", i - 1)
526 fwd = false
527 state = "any"
528 end
529 end
530 end
531
532 local terminals = {
533 ["identifier"] = "identifier",
534 ["decimal_or_hex"] = "number",
535 ["decimal_number"] = "number",
536 ["decimal_float"] = "number",
537 ["hex_number"] = "number",
538 ["hex_float"] = "number",
539 ["power"] = "number",
540 }
541
542 if in_token then
543 if terminals[state] then
544 end_token(terminals[state], i - 1)
545 else
546 drop_token()
547 end
548 end
549
550 return tokens, (#errs > 0) and errs
551end
552
553
554
555
556
557local add_space = {
558 ["word:keyword"] = true,
559 ["word:word"] = true,
560 ["word:string"] = true,
561 ["word:="] = true,
562 ["word:op"] = true,
563
564 ["keyword:word"] = true,
565 ["keyword:keyword"] = true,
566 ["keyword:string"] = true,
567 ["keyword:number"] = true,
568 ["keyword:="] = true,
569 ["keyword:op"] = true,
570 ["keyword:{"] = true,
571 ["keyword:("] = true,
572 ["keyword:#"] = true,
573
574 ["=:word"] = true,
575 ["=:keyword"] = true,
576 ["=:string"] = true,
577 ["=:number"] = true,
578 ["=:{"] = true,
579 ["=:("] = true,
580 ["op:("] = true,
581 ["op:{"] = true,
582 ["op:#"] = true,
583
584 [",:word"] = true,
585 [",:keyword"] = true,
586 [",:string"] = true,
587 [",:{"] = true,
588
589 ["):op"] = true,
590 ["):word"] = true,
591 ["):keyword"] = true,
592
593 ["op:string"] = true,
594 ["op:number"] = true,
595 ["op:word"] = true,
596 ["op:keyword"] = true,
597
598 ["]:word"] = true,
599 ["]:keyword"] = true,
600 ["]:="] = true,
601 ["]:op"] = true,
602
603 ["string:op"] = true,
604 ["string:word"] = true,
605 ["string:keyword"] = true,
606
607 ["number:word"] = true,
608 ["number:keyword"] = true,
609}
610
611local should_unindent = {
612 ["end"] = true,
613 ["elseif"] = true,
614 ["else"] = true,
615 ["}"] = true,
616}
617
618local should_indent = {
619 ["{"] = true,
620 ["for"] = true,
621 ["if"] = true,
622 ["while"] = true,
623 ["elseif"] = true,
624 ["else"] = true,
625 ["function"] = true,
626}
627
628function tl.pretty_print_tokens(tokens)
629 local y = 1
630 local out = {}
631 local indent = 0
632 local newline = false
633 local kind = ""
634 for _, t in ipairs(tokens) do
635 while t.y > y do
636 table.insert(out, "\n")
637 y = y + 1
638 newline = true
639 kind = ""
640 end
641 if should_unindent[t.tk] then
642 indent = indent - 1
643 if indent < 0 then
644 indent = 0
645 end
646 end
647 if newline then
648 for _ = 1, indent do
649 table.insert(out, " ")
650 end
651 newline = false
652 end
653 if should_indent[t.tk] then
654 indent = indent + 1
655 end
656 if add_space[(kind or "") .. ":" .. t.kind] then
657 table.insert(out, " ")
658 end
659 table.insert(out, t.tk)
660 kind = t.kind or ""
661 end
662 return table.concat(out)
663end
664
665
666
667
668
669local last_typeid = 0
670
671local function new_typeid()
672 last_typeid = last_typeid + 1
673 return last_typeid
674end
675
676local ParseError = {}
677
678
679
680
681
682
683local TypeName = {}
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714local table_types = {
715 ["array"] = true,
716 ["map"] = true,
717 ["arrayrecord"] = true,
718 ["record"] = true,
719 ["emptytable"] = true,
720}
721
722local Type = {}
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799local Operator = {}
800
801
802
803
804
805
806
807local NodeKind = {}
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851local FactType = {}
852
853
854
855local Fact = {}
856
857
858
859
860
861local KeyParsed = {}
862
863
864
865
866
867local Node = {}
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935local function is_array_type(t)
936 return t.typename == "array" or t.typename == "arrayrecord"
937end
938
939local function is_record_type(t)
940 return t.typename == "record" or t.typename == "arrayrecord"
941end
942
943local function is_type(t)
944 return t.typename == "typetype" or t.typename == "nestedtype"
945end
946
947local ParseState = {}
948
949
950
951
952
953local ParseTypeListMode = {}
954
955
956
957
958
959local parse_type_list
960local parse_expression
961local parse_statements
962local parse_argument_list
963local parse_argument_type_list
964local parse_type
965local parse_newtype
966
967
968local function fail(ps, i, msg)
969 if not ps.tokens[i] then
970 local eof = ps.tokens[#ps.tokens]
971 table.insert(ps.errs, { y = eof.y, x = eof.x, msg = msg or "unexpected end of file" })
972 return #ps.tokens
973 end
974 table.insert(ps.errs, { y = ps.tokens[i].y, x = ps.tokens[i].x, msg = msg or "syntax error" })
975 return math.min(#ps.tokens, i + 1)
976end
977
978local function verify_tk(ps, i, tk)
979 if ps.tokens[i].tk == tk then
980 return i + 1
981 end
982 return fail(ps, i, "syntax error, expected '" .. tk .. "'")
983end
984
985local function new_node(tokens, i, kind)
986 local t = tokens[i]
987 return { y = t.y, x = t.x, tk = t.tk, kind = kind or t.kind }
988end
989
990local function a_type(t)
991 t.typeid = new_typeid()
992 return t
993end
994
995local function new_type(ps, i, typename)
996 local token = ps.tokens[i]
997 return a_type({
998 typename = assert(typename),
999 filename = ps.filename,
1000 y = token.y,
1001 x = token.x,
1002 tk = token.tk,
1003 })
1004end
1005
1006local function verify_kind(ps, i, kind, node_kind)
1007 if ps.tokens[i].kind == kind then
1008 return i + 1, new_node(ps.tokens, i, node_kind)
1009 end
1010 return fail(ps, i, "syntax error, expected " .. kind)
1011end
1012
1013local is_newtype = {
1014 ["enum"] = true,
1015 ["record"] = true,
1016}
1017
1018local function parse_table_value(ps, i)
1019 if is_newtype[ps.tokens[i].tk] then
1020 return parse_newtype(ps, i)
1021 else
1022 local i, node, _ = parse_expression(ps, i)
1023 return i, node
1024 end
1025end
1026
1027local function parse_table_item(ps, i, n)
1028 local node = new_node(ps.tokens, i, "table_item")
1029 if ps.tokens[i].kind == "$EOF$" then
1030 return fail(ps, i)
1031 end
1032
1033 if ps.tokens[i].tk == "[" then
1034 node.key_parsed = "long"
1035 i = i + 1
1036 i, node.key = parse_expression(ps, i)
1037 i = verify_tk(ps, i, "]")
1038 i = verify_tk(ps, i, "=")
1039 i, node.value = parse_table_value(ps, i)
1040 return i, node, n
1041 elseif ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == "=" then
1042 node.key_parsed = "short"
1043 i, node.key = verify_kind(ps, i, "identifier", "string")
1044 node.key.conststr = node.key.tk
1045 node.key.tk = '"' .. node.key.tk .. '"'
1046 i = verify_tk(ps, i, "=")
1047 i, node.value = parse_table_value(ps, i)
1048 return i, node, n
1049 elseif ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then
1050 node.key_parsed = "short"
1051 local orig_i = i
1052 local try_ps = {
1053 filename = ps.filename,
1054 tokens = ps.tokens,
1055 errs = {},
1056 }
1057 i, node.key = verify_kind(try_ps, i, "identifier", "string")
1058 node.key.conststr = node.key.tk
1059 node.key.tk = '"' .. node.key.tk .. '"'
1060 i = verify_tk(try_ps, i, ":")
1061 i, node.decltype = parse_type(try_ps, i)
1062 if node.decltype and ps.tokens[i].tk == "=" then
1063 i = verify_tk(try_ps, i, "=")
1064 i, node.value = parse_table_value(try_ps, i)
1065 if node.value then
1066 for _, e in ipairs(try_ps.errs) do
1067 table.insert(ps.errs, e)
1068 end
1069 return i, node, n
1070 end
1071 end
1072
1073 node.decltype = nil
1074 i = orig_i
1075 end
1076
1077 node.key = new_node(ps.tokens, i, "number")
1078 node.key_parsed = "implicit"
1079 node.key.constnum = n
1080 node.key.tk = tostring(n)
1081 i, node.value = parse_expression(ps, i)
1082 return i, node, n + 1
1083end
1084
1085local ParseItem = {}
1086
1087local SeparatorMode = {}
1088
1089
1090
1091
1092local function parse_list(ps, i, list, close, sep, parse_item)
1093 local n = 1
1094 while ps.tokens[i].kind ~= "$EOF$" do
1095 if close[ps.tokens[i].tk] then
1096 (list).yend = ps.tokens[i].y
1097 break
1098 end
1099 local item
1100 i, item, n = parse_item(ps, i, n)
1101 table.insert(list, item)
1102 if ps.tokens[i].tk == "," then
1103 i = i + 1
1104 if sep == "sep" and close[ps.tokens[i].tk] then
1105 return fail(ps, i)
1106 end
1107 elseif sep == "term" and ps.tokens[i].tk == ";" then
1108 i = i + 1
1109 elseif not close[ps.tokens[i].tk] then
1110 return fail(ps, i)
1111 end
1112 end
1113 return i, list
1114end
1115
1116local function parse_bracket_list(ps, i, list, open, close, sep, parse_item)
1117 i = verify_tk(ps, i, open)
1118 i = parse_list(ps, i, list, { [close] = true }, sep, parse_item)
1119 i = verify_tk(ps, i, close)
1120 return i, list
1121end
1122
1123local function parse_table_literal(ps, i)
1124 local node = new_node(ps.tokens, i, "table_literal")
1125 return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item)
1126end
1127
1128local function parse_trying_list(ps, i, list, parse_item)
1129 local try_ps = {
1130 filename = ps.filename,
1131 tokens = ps.tokens,
1132 errs = {},
1133 }
1134 local tryi, item = parse_item(try_ps, i)
1135 if not item then
1136 return i, list
1137 end
1138 for _, e in ipairs(try_ps.errs) do
1139 table.insert(ps.errs, e)
1140 end
1141 i = tryi
1142 table.insert(list, item)
1143 if ps.tokens[i].tk == "," then
1144 while ps.tokens[i].tk == "," do
1145 i = i + 1
1146 i, item = parse_item(ps, i)
1147 table.insert(list, item)
1148 end
1149 end
1150 return i, list
1151end
1152
1153local function parse_typearg_type(ps, i)
1154 local backtick = false
1155 if ps.tokens[i].tk == "`" then
1156 i = verify_tk(ps, i, "`")
1157 backtick = true
1158 end
1159 i = verify_kind(ps, i, "identifier")
1160 return i, a_type({
1161 y = ps.tokens[i - 2].y,
1162 x = ps.tokens[i - 2].x,
1163 typename = "typearg",
1164 typearg = (backtick and "`" or "") .. ps.tokens[i - 1].tk,
1165 })
1166end
1167
1168local function parse_typevar_type(ps, i)
1169 i = verify_tk(ps, i, "`")
1170 i = verify_kind(ps, i, "identifier")
1171 return i, a_type({
1172 y = ps.tokens[i - 2].y,
1173 x = ps.tokens[i - 2].x,
1174 typename = "typevar",
1175 typevar = "`" .. ps.tokens[i - 1].tk,
1176 })
1177end
1178
1179local function parse_typearg_list(ps, i)
1180 local typ = new_type(ps, i, "tuple")
1181 return parse_bracket_list(ps, i, typ, "<", ">", "sep", parse_typearg_type)
1182end
1183
1184local function parse_typeval_list(ps, i)
1185 local typ = new_type(ps, i, "tuple")
1186 return parse_bracket_list(ps, i, typ, "<", ">", "sep", parse_type)
1187end
1188
1189local function parse_return_types(ps, i)
1190 return parse_type_list(ps, i, "rets")
1191end
1192
1193local function parse_function_type(ps, i)
1194 local node = new_type(ps, i, "function")
1195 node.args = {}
1196 node.rets = {}
1197 i = i + 1
1198 if ps.tokens[i].tk == "<" then
1199 i, node.typeargs = parse_typearg_list(ps, i)
1200 end
1201 if ps.tokens[i].tk == "(" then
1202 i, node.args = parse_argument_type_list(ps, i)
1203 i, node.rets = parse_return_types(ps, i)
1204 else
1205 node.args = { a_type({ typename = "any", is_va = true }) }
1206 node.rets = { a_type({ typename = "any", is_va = true }) }
1207 end
1208 return i, node
1209end
1210
1211local function parse_base_type(ps, i)
1212 if ps.tokens[i].tk == "string" or
1213 ps.tokens[i].tk == "boolean" or
1214 ps.tokens[i].tk == "nil" or
1215 ps.tokens[i].tk == "number" or
1216 ps.tokens[i].tk == "thread" then
1217 local typ = new_type(ps, i, ps.tokens[i].tk)
1218 typ.tk = nil
1219 return i + 1, typ
1220 elseif ps.tokens[i].tk == "table" then
1221 local typ = new_type(ps, i, "map")
1222 typ.keys = a_type({ typename = "any" })
1223 typ.values = a_type({ typename = "any" })
1224 return i + 1, typ
1225 elseif ps.tokens[i].tk == "function" then
1226 return parse_function_type(ps, i)
1227 elseif ps.tokens[i].tk == "{" then
1228 i = i + 1
1229 local decl = new_type(ps, i, "array")
1230 local t
1231 i, t = parse_type(ps, i)
1232 if ps.tokens[i].tk == "}" then
1233 decl.elements = t
1234 decl.yend = ps.tokens[i].y
1235 i = verify_tk(ps, i, "}")
1236 elseif ps.tokens[i].tk == ":" then
1237 decl.typename = "map"
1238 i = i + 1
1239 decl.keys = t
1240 i, decl.values = parse_type(ps, i)
1241 decl.yend = ps.tokens[i].y
1242 i = verify_tk(ps, i, "}")
1243 end
1244 return i, decl
1245 elseif ps.tokens[i].tk == "`" then
1246 return parse_typevar_type(ps, i)
1247 elseif ps.tokens[i].kind == "identifier" then
1248 local typ = new_type(ps, i, "nominal")
1249 typ.names = { ps.tokens[i].tk }
1250 i = i + 1
1251 while ps.tokens[i].tk == "." do
1252 i = i + 1
1253 if ps.tokens[i].kind == "identifier" then
1254 table.insert(typ.names, ps.tokens[i].tk)
1255 i = i + 1
1256 else
1257 return fail(ps, i, "syntax error, expected identifier")
1258 end
1259 end
1260
1261 if ps.tokens[i].tk == "<" then
1262 i, typ.typevals = parse_typeval_list(ps, i)
1263 end
1264 return i, typ
1265 end
1266 return fail(ps, i)
1267end
1268
1269parse_type = function(ps, i)
1270 if ps.tokens[i].tk == "(" then
1271 i = i + 1
1272 local t
1273 i, t = parse_type(ps, i)
1274 i = verify_tk(ps, i, ")")
1275 return i, t
1276 end
1277
1278 local bt
1279 local istart = i
1280 i, bt = parse_base_type(ps, i)
1281 if not bt then
1282 return i
1283 end
1284 if ps.tokens[i].tk == "|" then
1285 local u = new_type(ps, istart, "union")
1286 u.types = { bt }
1287 while ps.tokens[i].tk == "|" do
1288 i = i + 1
1289 i, bt = parse_base_type(ps, i)
1290 if not bt then
1291 return i
1292 end
1293 table.insert(u.types, bt)
1294 end
1295 bt = u
1296 end
1297 return i, bt
1298end
1299
1300parse_type_list = function(ps, i, mode)
1301 local list = new_type(ps, i, "tuple")
1302
1303 local first_token = ps.tokens[i].tk
1304 if mode == "rets" or mode == "decltype" then
1305 if first_token == ":" then
1306 i = i + 1
1307 else
1308 return i, list
1309 end
1310 end
1311
1312 local optional_paren = false
1313 if ps.tokens[i].tk == "(" then
1314 optional_paren = true
1315 i = i + 1
1316 end
1317
1318 local prev_i = i
1319 i = parse_trying_list(ps, i, list, parse_type)
1320 if i == prev_i and ps.tokens[i].tk ~= ")" then
1321 fail(ps, i - 1, "expected a type list")
1322 end
1323
1324 if mode == "rets" and ps.tokens[i].tk == "..." then
1325 i = i + 1
1326 local nrets = #list
1327 if nrets > 0 then
1328 list[nrets].is_va = true
1329 else
1330 return fail(ps, i, "unexpected '...'")
1331 end
1332 end
1333
1334 if optional_paren then
1335 i = verify_tk(ps, i, ")")
1336 end
1337
1338 return i, list
1339end
1340
1341local function parse_function_args_rets_body(ps, i, node)
1342 if ps.tokens[i].tk == "<" then
1343 i, node.typeargs = parse_typearg_list(ps, i)
1344 end
1345 i, node.args = parse_argument_list(ps, i)
1346 i, node.rets = parse_return_types(ps, i)
1347 i, node.body = parse_statements(ps, i)
1348 node.yend = ps.tokens[i].y
1349 i = verify_tk(ps, i, "end")
1350 return i, node
1351end
1352
1353local function parse_function_value(ps, i)
1354 local node = new_node(ps.tokens, i, "function")
1355 i = verify_tk(ps, i, "function")
1356 return parse_function_args_rets_body(ps, i, node)
1357end
1358
1359local function unquote(str)
1360 local f = str:sub(1, 1)
1361 if f == '"' or f == "'" then
1362 return str:sub(2, -2)
1363 end
1364 f = str:match("^%[=*%[")
1365 local l = #f + 1
1366 return str:sub(l, -l)
1367end
1368
1369local function parse_literal(ps, i)
1370 if ps.tokens[i].tk == "{" then
1371 return parse_table_literal(ps, i)
1372 elseif ps.tokens[i].kind == "..." then
1373 return verify_kind(ps, i, "...")
1374 elseif ps.tokens[i].kind == "string" then
1375 local tk = unquote(ps.tokens[i].tk)
1376 local node
1377 i, node = verify_kind(ps, i, "string")
1378 node.conststr = tk
1379 return i, node
1380 elseif ps.tokens[i].kind == "identifier" then
1381 return verify_kind(ps, i, "identifier", "variable")
1382 elseif ps.tokens[i].kind == "number" then
1383 local n = tonumber(ps.tokens[i].tk)
1384 local node
1385 i, node = verify_kind(ps, i, "number")
1386 node.constnum = n
1387 return i, node
1388 elseif ps.tokens[i].tk == "true" then
1389 return verify_kind(ps, i, "keyword", "boolean")
1390 elseif ps.tokens[i].tk == "false" then
1391 return verify_kind(ps, i, "keyword", "boolean")
1392 elseif ps.tokens[i].tk == "nil" then
1393 return verify_kind(ps, i, "keyword", "nil")
1394 elseif ps.tokens[i].tk == "function" then
1395 return parse_function_value(ps, i)
1396 end
1397 return fail(ps, i)
1398end
1399
1400do
1401 local precedences = {
1402 [1] = {
1403 ["not"] = 11,
1404 ["#"] = 11,
1405 ["-"] = 11,
1406 ["~"] = 11,
1407 },
1408 [2] = {
1409 ["or"] = 1,
1410 ["and"] = 2,
1411 ["is"] = 3,
1412 ["<"] = 3,
1413 [">"] = 3,
1414 ["<="] = 3,
1415 [">="] = 3,
1416 ["~="] = 3,
1417 ["=="] = 3,
1418 ["|"] = 4,
1419 ["~"] = 5,
1420 ["&"] = 6,
1421 ["<<"] = 7,
1422 [">>"] = 7,
1423 [".."] = 8,
1424 ["+"] = 8,
1425 ["-"] = 9,
1426 ["*"] = 10,
1427 ["/"] = 10,
1428 ["//"] = 10,
1429 ["%"] = 10,
1430 ["^"] = 12,
1431 ["as"] = 50,
1432 ["@funcall"] = 100,
1433 ["@index"] = 100,
1434 ["."] = 100,
1435 [":"] = 100,
1436 },
1437 }
1438
1439 local is_right_assoc = {
1440 ["^"] = true,
1441 [".."] = true,
1442 }
1443
1444 local function new_operator(tk, arity, op)
1445 op = op or tk.tk
1446 return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] }
1447 end
1448
1449 local E
1450
1451 local function P(ps, i)
1452 if ps.tokens[i].kind == "$EOF$" then
1453 return i
1454 end
1455 local e1
1456 local t1 = ps.tokens[i]
1457 if precedences[1][ps.tokens[i].tk] ~= nil then
1458 local op = new_operator(ps.tokens[i], 1)
1459 i = i + 1
1460 i, e1 = P(ps, i)
1461 e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 }
1462 elseif ps.tokens[i].tk == "(" then
1463 i = i + 1
1464 i, e1 = parse_expression(ps, i)
1465 e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 }
1466 i = verify_tk(ps, i, ")")
1467 else
1468 i, e1 = parse_literal(ps, i)
1469 end
1470
1471 while true do
1472 if ps.tokens[i].kind == "string" or ps.tokens[i].kind == "{" then
1473 local op = new_operator(ps.tokens[i], 2, "@funcall")
1474 local args = new_node(ps.tokens, i, "expression_list")
1475 local arg
1476 if ps.tokens[i].kind == "string" then
1477 arg = new_node(ps.tokens, i)
1478 arg.conststr = unquote(ps.tokens[i].tk)
1479 i = i + 1
1480 else
1481 i, arg = parse_table_literal(ps, i)
1482 end
1483 table.insert(args, arg)
1484 e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = args }
1485 elseif ps.tokens[i].tk == "(" then
1486 local op = new_operator(ps.tokens[i], 2, "@funcall")
1487
1488 local args = new_node(ps.tokens, i, "expression_list")
1489 i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression)
1490
1491 e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = args }
1492 elseif ps.tokens[i].tk == "[" then
1493 local op = new_operator(ps.tokens[i], 2, "@index")
1494
1495 local idx
1496 i = i + 1
1497 i, idx = parse_expression(ps, i)
1498 i = verify_tk(ps, i, "]")
1499
1500 e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = idx }
1501 elseif ps.tokens[i].tk == "." or ps.tokens[i].tk == ":" then
1502 local op = new_operator(ps.tokens[i], 2)
1503
1504 local key
1505 i = i + 1
1506 i, key = verify_kind(ps, i, "identifier")
1507
1508 e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = key }
1509 elseif ps.tokens[i].tk == "as" or ps.tokens[i].tk == "is" then
1510 local op = new_operator(ps.tokens[i], 2, ps.tokens[i].tk)
1511
1512 i = i + 1
1513 local cast = new_node(ps.tokens, i, "cast")
1514 if ps.tokens[i].tk == "(" then
1515 i, cast.casttype = parse_type_list(ps, i, "casttype")
1516 else
1517 i, cast.casttype = parse_type(ps, i)
1518 end
1519 e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr }
1520 else
1521 break
1522 end
1523 end
1524
1525 return i, e1
1526 end
1527
1528 local function E(ps, i, lhs, min_precedence)
1529 local lookahead = ps.tokens[i].tk
1530 while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do
1531 local t1 = ps.tokens[i]
1532 local op = new_operator(t1, 2)
1533 i = i + 1
1534 local rhs
1535 i, rhs = P(ps, i)
1536 lookahead = ps.tokens[i].tk
1537 while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or
1538 (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do
1539 i, rhs = E(ps, i, rhs, precedences[2][lookahead])
1540 lookahead = ps.tokens[i].tk
1541 end
1542 lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs }
1543 end
1544 return i, lhs
1545 end
1546
1547 parse_expression = function(ps, i)
1548 local lhs
1549 i, lhs = P(ps, i)
1550 i, lhs = E(ps, i, lhs, 0)
1551 if lhs then
1552 return i, lhs, 0
1553 else
1554 return fail(ps, i, "expected an expression")
1555 end
1556 end
1557end
1558
1559local function parse_variable_name(ps, i)
1560 local is_const = false
1561 local node
1562 i, node = verify_kind(ps, i, "identifier")
1563 if not node then
1564 return i
1565 end
1566 if ps.tokens[i].tk == "<" then
1567 i = i + 1
1568 local annotation
1569 i, annotation = verify_kind(ps, i, "identifier")
1570 if annotation and annotation.tk == "const" then
1571 is_const = true
1572 end
1573 i = verify_tk(ps, i, ">")
1574 end
1575 node.is_const = is_const
1576 return i, node
1577end
1578
1579local function parse_argument(ps, i)
1580 local node
1581 if ps.tokens[i].tk == "..." then
1582 i, node = verify_kind(ps, i, "...")
1583 else
1584 i, node = verify_kind(ps, i, "identifier", "argument")
1585 end
1586 if ps.tokens[i].tk == ":" then
1587 i = i + 1
1588 local decltype
1589
1590 i, decltype = parse_type(ps, i)
1591
1592 if node then
1593 i, node.decltype = i, decltype
1594 end
1595 end
1596 return i, node, 0
1597end
1598
1599parse_argument_list = function(ps, i)
1600 local node = new_node(ps.tokens, i, "argument_list")
1601 return parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument)
1602end
1603
1604local function parse_argument_type(ps, i)
1605 local is_va = false
1606 if ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then
1607 i = i + 2
1608 elseif ps.tokens[i].tk == "..." then
1609 if ps.tokens[i + 1].tk == ":" then
1610 i = i + 2
1611 is_va = true
1612 else
1613 return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument")
1614 end
1615 end
1616
1617 local i, typ = parse_type(ps, i)
1618 if typ then
1619 typ.is_va = is_va
1620 end
1621
1622 return i, typ, 0
1623end
1624
1625parse_argument_type_list = function(ps, i)
1626 local list = new_type(ps, i, "tuple")
1627 return parse_bracket_list(ps, i, list, "(", ")", "sep", parse_argument_type)
1628end
1629
1630local function parse_local_function(ps, i)
1631 local node = new_node(ps.tokens, i, "local_function")
1632 i = verify_tk(ps, i, "local")
1633 i = verify_tk(ps, i, "function")
1634 i, node.name = verify_kind(ps, i, "identifier")
1635 return parse_function_args_rets_body(ps, i, node)
1636end
1637
1638local function parse_function(ps, i)
1639 local orig_i = i
1640 local fn = new_node(ps.tokens, i, "global_function")
1641 local node = fn
1642 i = verify_tk(ps, i, "function")
1643 local names = {}
1644 i, names[1] = verify_kind(ps, i, "identifier", "variable")
1645 while ps.tokens[i].tk == "." do
1646 i = i + 1
1647 i, names[#names + 1] = verify_kind(ps, i, "identifier")
1648 end
1649 if ps.tokens[i].tk == ":" then
1650 i = i + 1
1651 i, names[#names + 1] = verify_kind(ps, i, "identifier")
1652 fn.is_method = true
1653 end
1654
1655 if #names > 1 then
1656 fn.kind = "record_function"
1657 local owner = names[1]
1658 for i = 2, #names - 1 do
1659 local dot = { y = names[i].y, x = names[i].x - 1, arity = 2, op = "." }
1660 names[i].kind = "identifier"
1661 local op = { y = names[i].y, x = names[i].x, kind = "op", op = dot, e1 = owner, e2 = names[i] }
1662 owner = op
1663 end
1664 fn.fn_owner = owner
1665 end
1666 fn.name = names[#names]
1667
1668 local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y
1669 i = parse_function_args_rets_body(ps, i, fn)
1670 if fn.is_method then
1671 table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "variable" })
1672 end
1673
1674 if not fn.name then
1675 return orig_i
1676 end
1677
1678 return i, node
1679end
1680
1681local function parse_if(ps, i)
1682 local node = new_node(ps.tokens, i, "if")
1683 i = verify_tk(ps, i, "if")
1684 i, node.exp = parse_expression(ps, i)
1685 i = verify_tk(ps, i, "then")
1686 i, node.thenpart = parse_statements(ps, i)
1687 node.elseifs = {}
1688 local n = 0
1689 while ps.tokens[i].tk == "elseif" do
1690 n = n + 1
1691 local subnode = new_node(ps.tokens, i, "elseif")
1692 subnode.parent_if = node
1693 subnode.elseif_n = n
1694 i = i + 1
1695 i, subnode.exp = parse_expression(ps, i)
1696 i = verify_tk(ps, i, "then")
1697 i, subnode.thenpart = parse_statements(ps, i)
1698 table.insert(node.elseifs, subnode)
1699 end
1700 if ps.tokens[i].tk == "else" then
1701 local subnode = new_node(ps.tokens, i, "else")
1702 subnode.parent_if = node
1703 i = i + 1
1704 i, subnode.elsepart = parse_statements(ps, i)
1705 node.elsepart = subnode
1706 end
1707 node.yend = ps.tokens[i].y
1708 i = verify_tk(ps, i, "end")
1709 return i, node
1710end
1711
1712local function parse_while(ps, i)
1713 local node = new_node(ps.tokens, i, "while")
1714 i = verify_tk(ps, i, "while")
1715 i, node.exp = parse_expression(ps, i)
1716 i = verify_tk(ps, i, "do")
1717 i, node.body = parse_statements(ps, i)
1718 node.yend = ps.tokens[i].y
1719 i = verify_tk(ps, i, "end")
1720 return i, node
1721end
1722
1723local function parse_fornum(ps, i)
1724 local node = new_node(ps.tokens, i, "fornum")
1725 i = i + 1
1726 i, node.var = verify_kind(ps, i, "identifier")
1727 i = verify_tk(ps, i, "=")
1728 i, node.from = parse_expression(ps, i)
1729 i = verify_tk(ps, i, ",")
1730 i, node.to = parse_expression(ps, i)
1731 if ps.tokens[i].tk == "," then
1732 i = i + 1
1733 i, node.step = parse_expression(ps, i)
1734 end
1735 i = verify_tk(ps, i, "do")
1736 i, node.body = parse_statements(ps, i)
1737 node.yend = ps.tokens[i].y
1738 i = verify_tk(ps, i, "end")
1739 return i, node
1740end
1741
1742local function parse_forin(ps, i)
1743 local node = new_node(ps.tokens, i, "forin")
1744 i = i + 1
1745 node.vars = new_node(ps.tokens, i, "variables")
1746 i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_variable_name)
1747 i = verify_tk(ps, i, "in")
1748 node.exps = new_node(ps.tokens, i, "expression_list")
1749 i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression)
1750 if #node.exps < 1 then
1751 return fail(ps, i, "missing iterator expression in generic for")
1752 elseif #node.exps > 3 then
1753 return fail(ps, i, "too many expressions in generic for")
1754 end
1755 i = verify_tk(ps, i, "do")
1756 i, node.body = parse_statements(ps, i)
1757 node.yend = ps.tokens[i].y
1758 i = verify_tk(ps, i, "end")
1759 return i, node
1760end
1761
1762local function parse_for(ps, i)
1763 if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then
1764 return parse_fornum(ps, i)
1765 else
1766 return parse_forin(ps, i)
1767 end
1768end
1769
1770local function parse_repeat(ps, i)
1771 local node = new_node(ps.tokens, i, "repeat")
1772 i = verify_tk(ps, i, "repeat")
1773 i, node.body = parse_statements(ps, i)
1774 node.body.is_repeat = true
1775 node.yend = ps.tokens[i].y
1776 i = verify_tk(ps, i, "until")
1777 i, node.exp = parse_expression(ps, i)
1778 return i, node
1779end
1780
1781local function parse_do(ps, i)
1782 local node = new_node(ps.tokens, i, "do")
1783 i = verify_tk(ps, i, "do")
1784 i, node.body = parse_statements(ps, i)
1785 node.yend = ps.tokens[i].y
1786 i = verify_tk(ps, i, "end")
1787 return i, node
1788end
1789
1790local function parse_break(ps, i)
1791 local node = new_node(ps.tokens, i, "break")
1792 i = verify_tk(ps, i, "break")
1793 return i, node
1794end
1795
1796local function parse_goto(ps, i)
1797 local node = new_node(ps.tokens, i, "goto")
1798 i = verify_tk(ps, i, "goto")
1799 node.label = ps.tokens[i].tk
1800 i = verify_kind(ps, i, "identifier")
1801 return i, node
1802end
1803
1804local function parse_label(ps, i)
1805 local node = new_node(ps.tokens, i, "label")
1806 i = verify_tk(ps, i, "::")
1807 node.label = ps.tokens[i].tk
1808 i = verify_kind(ps, i, "identifier")
1809 i = verify_tk(ps, i, "::")
1810 return i, node
1811end
1812
1813local stop_statement_list = {
1814 ["end"] = true,
1815 ["else"] = true,
1816 ["elseif"] = true,
1817 ["until"] = true,
1818}
1819
1820local stop_return_list = {
1821 [";"] = true,
1822 ["$EOF$"] = true,
1823}
1824
1825for k, v in pairs(stop_statement_list) do
1826 stop_return_list[k] = v
1827end
1828
1829local function parse_return(ps, i)
1830 local node = new_node(ps.tokens, i, "return")
1831 i = verify_tk(ps, i, "return")
1832 node.exps = new_node(ps.tokens, i, "expression_list")
1833 i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression)
1834 if ps.tokens[i].kind == ";" then
1835 i = i + 1
1836 end
1837 return i, node
1838end
1839
1840local function store_field_in_record(name, def, nt)
1841 if def.fields[name] then
1842 return false
1843 end
1844 def.fields[name] = nt.newtype
1845 table.insert(def.field_order, name)
1846 return true
1847end
1848
1849local ParseBody = {}
1850
1851local function parse_nested_type(ps, i, def, typename, parse_body)
1852 i = i + 1
1853
1854 local v
1855 i, v = verify_kind(ps, i, "identifier", "variable")
1856 if not v then
1857 return fail(ps, i, "expected a variable name")
1858 end
1859
1860 local nt = new_node(ps.tokens, i, "newtype")
1861 nt.newtype = new_type(ps, i, "typetype")
1862 local rdef = new_type(ps, i, typename)
1863 local iok = parse_body(ps, i, rdef, nt)
1864 if iok then
1865 i = iok
1866 nt.newtype.def = rdef
1867 end
1868
1869 local ok = store_field_in_record(v.tk, def, nt)
1870 if not ok then
1871 fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)")
1872 end
1873 return i
1874end
1875
1876local function parse_enum_body(ps, i, def, node)
1877 def.enumset = {}
1878 while not ((not ps.tokens[i]) or ps.tokens[i].tk == "end") do
1879 local item
1880 i, item = verify_kind(ps, i, "string", "enum_item")
1881 if item then
1882 table.insert(node, item)
1883 def.enumset[unquote(item.tk)] = true
1884 end
1885 end
1886 node.yend = ps.tokens[i].y
1887 i = verify_tk(ps, i, "end")
1888 return i, node
1889end
1890
1891local function parse_record_body(ps, i, def, node)
1892 def.fields = {}
1893 def.field_order = {}
1894 if ps.tokens[i].tk == "<" then
1895 i, def.typeargs = parse_typearg_list(ps, i)
1896 end
1897 while not ((not ps.tokens[i]) or ps.tokens[i].tk == "end") do
1898 if ps.tokens[i].tk == "{" then
1899 if def.typename == "arrayrecord" then
1900 return fail(ps, i, "duplicated declaration of array element type in record")
1901 end
1902 i = i + 1
1903 local t
1904 i, t = parse_type(ps, i)
1905 if ps.tokens[i].tk == "}" then
1906 node.yend = ps.tokens[i].y
1907 i = verify_tk(ps, i, "}")
1908 else
1909 return fail(ps, i, "expected an array declaration")
1910 end
1911 def.typename = "arrayrecord"
1912 def.elements = t
1913 elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then
1914 i = i + 1
1915 local v
1916 i, v = verify_kind(ps, i, "identifier", "variable")
1917 if not v then
1918 return fail(ps, i, "expected a variable name")
1919 end
1920 i = verify_tk(ps, i, "=")
1921 local nt
1922 i, nt = parse_newtype(ps, i)
1923 if not nt or not nt.newtype then
1924 return fail(ps, i, "expected a type definition")
1925 end
1926
1927 local ok = store_field_in_record(v.tk, def, nt)
1928 if not ok then
1929 return fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)")
1930 end
1931 elseif ps.tokens[i].tk == "record" and ps.tokens[i + 1].tk ~= ":" then
1932 i = parse_nested_type(ps, i, def, "record", parse_record_body)
1933 elseif ps.tokens[i].tk == "enum" and ps.tokens[i + 1].tk ~= ":" then
1934 i = parse_nested_type(ps, i, def, "enum", parse_enum_body)
1935 else
1936 local v
1937 i, v = verify_kind(ps, i, "identifier", "variable")
1938 local iv = i
1939 if not v then
1940 return fail(ps, i, "expected a variable name")
1941 end
1942 if ps.tokens[i].tk == ":" then
1943 i = verify_tk(ps, i, ":")
1944 local t
1945 i, t = parse_type(ps, i)
1946 if not t then
1947 return fail(ps, i, "expected a type")
1948 end
1949 if not def.fields[v.tk] then
1950 def.fields[v.tk] = t
1951 table.insert(def.field_order, v.tk)
1952 else
1953 local prev_t = def.fields[v.tk]
1954 if t.typename == "function" and prev_t.typename == "function" then
1955 def.fields[v.tk] = new_type(ps, iv, "poly")
1956 def.fields[v.tk].types = { prev_t, t }
1957 elseif t.typename == "function" and prev_t.typename == "poly" then
1958 table.insert(prev_t.types, t)
1959 else
1960 return fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)")
1961 end
1962 end
1963 elseif ps.tokens[i].tk == "=" then
1964 local next_word = ps.tokens[i + 1].tk
1965 if next_word == "record" or next_word == "enum" then
1966 return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'")
1967 elseif next_word == "functiontype" then
1968 return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...")
1969 else
1970 return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...")
1971 end
1972 end
1973 end
1974 end
1975 node.yend = ps.tokens[i].y
1976 i = verify_tk(ps, i, "end")
1977 return i, node
1978end
1979
1980parse_newtype = function(ps, i)
1981 local node = new_node(ps.tokens, i, "newtype")
1982 node.newtype = new_type(ps, i, "typetype")
1983 if ps.tokens[i].tk == "record" then
1984 local def = new_type(ps, i, "record")
1985 i = i + 1
1986 i = parse_record_body(ps, i, def, node)
1987 node.newtype.def = def
1988 return i, node
1989 elseif ps.tokens[i].tk == "enum" then
1990 local def = new_type(ps, i, "enum")
1991 i = i + 1
1992 i = parse_enum_body(ps, i, def, node)
1993 node.newtype.def = def
1994 return i, node
1995 else
1996 i, node.newtype.def = parse_type(ps, i)
1997 return i, node
1998 end
1999 return fail(ps, i)
2000end
2001
2002local function parse_call_or_assignment(ps, i)
2003 local asgn = new_node(ps.tokens, i, "assignment")
2004
2005 local tryi = i
2006 asgn.vars = new_node(ps.tokens, i, "variables")
2007 i = parse_trying_list(ps, i, asgn.vars, parse_expression)
2008 if #asgn.vars < 1 then
2009 return fail(ps, i)
2010 end
2011 local lhs = asgn.vars[1]
2012
2013 if ps.tokens[i].tk == "=" then
2014 asgn.exps = new_node(ps.tokens, i, "values")
2015 repeat
2016 i = i + 1
2017 local val
2018 i, val = parse_expression(ps, i)
2019 table.insert(asgn.exps, val)
2020 until ps.tokens[i].tk ~= ","
2021 return i, asgn
2022 end
2023 if #asgn.vars > 1 then
2024 local err_ps = {
2025 tokens = ps.tokens,
2026 errs = {},
2027 }
2028 local expi = parse_expression(err_ps, tryi)
2029 return fail(ps, expi or i)
2030 end
2031 if lhs.op and lhs.op.op == "@funcall" and #asgn.vars == 1 then
2032 return i, lhs
2033 end
2034 return fail(ps, i)
2035end
2036
2037local function parse_variable_declarations(ps, i, node_name)
2038 local asgn = new_node(ps.tokens, i, node_name)
2039
2040 asgn.vars = new_node(ps.tokens, i, "variables")
2041 i = parse_trying_list(ps, i, asgn.vars, parse_variable_name)
2042 if #asgn.vars == 0 then
2043 return fail(ps, i, "expected a local variable definition")
2044 end
2045 local lhs = asgn.vars[1]
2046
2047 i, asgn.decltype = parse_type_list(ps, i, "decltype")
2048
2049 if ps.tokens[i].tk == "=" then
2050
2051 if ps.tokens[i + 1].tk == "record" or
2052 ps.tokens[i + 1].tk == "enum" then
2053
2054 local scope = node_name == "local_declaration" and "local" or "global"
2055 fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. ps.tokens[i + 1].tk .. " " .. asgn.vars[1].tk .. "'")
2056 elseif ps.tokens[i + 1].tk == "functiontype" then
2057 local scope = node_name == "local_declaration" and "local" or "global"
2058 fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...")
2059 end
2060
2061 asgn.exps = new_node(ps.tokens, i, "values")
2062 local v = 1
2063 repeat
2064 i = i + 1
2065 local val
2066 i, val = parse_expression(ps, i)
2067 table.insert(asgn.exps, val)
2068 v = v + 1
2069 until ps.tokens[i].tk ~= ","
2070 end
2071 return i, asgn
2072end
2073
2074local function parse_type_declaration(ps, i, node_name)
2075 i = i + 2
2076
2077 local asgn = new_node(ps.tokens, i, node_name)
2078 i, asgn.var = parse_variable_name(ps, i)
2079 if not asgn.var then
2080 return fail(ps, i, "expected a type name")
2081 end
2082 i = verify_tk(ps, i, "=")
2083 i, asgn.value = parse_newtype(ps, i)
2084 if asgn.value then
2085 asgn.value.newtype.def.names = { asgn.var.tk }
2086 else
2087 return i
2088 end
2089
2090 return i, asgn
2091end
2092
2093local ParseBody = {}
2094
2095local function parse_type_constructor(ps, i, node_name, type_name, parse_body)
2096 local asgn = new_node(ps.tokens, i, node_name)
2097 local nt = new_node(ps.tokens, i, "newtype")
2098 asgn.value = nt
2099 nt.newtype = new_type(ps, i, "typetype")
2100 local def = new_type(ps, i, type_name)
2101 nt.newtype.def = def
2102
2103 i = i + 2
2104
2105 i, asgn.var = verify_kind(ps, i, "identifier")
2106 if not asgn.var then
2107 return fail(ps, i, "expected a type name")
2108 end
2109 nt.newtype.def.names = { asgn.var.tk }
2110
2111 i = parse_body(ps, i, def, nt)
2112 return i, asgn
2113end
2114
2115local function parse_statement(ps, i)
2116 if ps.tokens[i].tk == "local" then
2117 if ps.tokens[i + 1].tk == "type" and ps.tokens[i + 2].kind == "identifier" then
2118 return parse_type_declaration(ps, i, "local_type")
2119 elseif ps.tokens[i + 1].tk == "function" then
2120 return parse_local_function(ps, i)
2121 elseif ps.tokens[i + 1].tk == "record" and ps.tokens[i + 2].kind == "identifier" then
2122 return parse_type_constructor(ps, i, "local_type", "record", parse_record_body)
2123 elseif ps.tokens[i + 1].tk == "enum" and ps.tokens[i + 2].kind == "identifier" then
2124 return parse_type_constructor(ps, i, "local_type", "enum", parse_enum_body)
2125 else
2126 i = i + 1
2127 return parse_variable_declarations(ps, i, "local_declaration")
2128 end
2129 elseif ps.tokens[i].tk == "global" then
2130 if ps.tokens[i + 1].tk == "type" and ps.tokens[i + 2].kind == "identifier" then
2131 return parse_type_declaration(ps, i, "global_type")
2132 elseif ps.tokens[i + 1].tk == "record" and ps.tokens[i + 2].kind == "identifier" then
2133 return parse_type_constructor(ps, i, "global_type", "record", parse_record_body)
2134 elseif ps.tokens[i + 1].tk == "enum" and ps.tokens[i + 2].kind == "identifier" then
2135 return parse_type_constructor(ps, i, "global_type", "enum", parse_enum_body)
2136 elseif ps.tokens[i + 1].tk == "function" then
2137 i = i + 1
2138 return parse_function(ps, i)
2139 else
2140 i = i + 1
2141 return parse_variable_declarations(ps, i, "global_declaration")
2142 end
2143 elseif ps.tokens[i].tk == "function" then
2144 return parse_function(ps, i)
2145 elseif ps.tokens[i].tk == "if" then
2146 return parse_if(ps, i)
2147 elseif ps.tokens[i].tk == "while" then
2148 return parse_while(ps, i)
2149 elseif ps.tokens[i].tk == "repeat" then
2150 return parse_repeat(ps, i)
2151 elseif ps.tokens[i].tk == "for" then
2152 return parse_for(ps, i)
2153 elseif ps.tokens[i].tk == "do" then
2154 return parse_do(ps, i)
2155 elseif ps.tokens[i].tk == "break" then
2156 return parse_break(ps, i)
2157 elseif ps.tokens[i].tk == "return" then
2158 return parse_return(ps, i)
2159 elseif ps.tokens[i].tk == "goto" then
2160 return parse_goto(ps, i)
2161 elseif ps.tokens[i].tk == "::" then
2162 return parse_label(ps, i)
2163 else
2164 return parse_call_or_assignment(ps, i)
2165 end
2166end
2167
2168parse_statements = function(ps, i, filename, toplevel)
2169 local node = new_node(ps.tokens, i, "statements")
2170 while true do
2171 while ps.tokens[i].kind == ";" do
2172 i = i + 1
2173 end
2174 if ps.tokens[i].kind == "$EOF$" then
2175 break
2176 end
2177 if (not toplevel) and stop_statement_list[ps.tokens[i].tk] then
2178 break
2179 end
2180 local item
2181 i, item = parse_statement(ps, i)
2182 if filename then
2183 for j = 1, #ps.errs do
2184 if not ps.errs[j].filename then
2185 ps.errs[j].filename = filename
2186 end
2187 end
2188 end
2189 if not item then
2190 break
2191 end
2192 table.insert(node, item)
2193 end
2194 return i, node
2195end
2196
2197function tl.parse_program(tokens, errs, filename)
2198 errs = errs or {}
2199 local ps = {
2200 tokens = tokens,
2201 errs = errs,
2202 filename = filename,
2203 }
2204 local last = ps.tokens[#ps.tokens] or { y = 1, x = 1, tk = "" }
2205 table.insert(ps.tokens, { y = last.y, x = last.x + #last.tk, tk = "$EOF$", kind = "$EOF$" })
2206 return parse_statements(ps, 1, filename, true)
2207end
2208
2209
2210
2211
2212
2213local VisitorCallbacks = {}
2214
2215
2216
2217
2218
2219
2220local Visitor = {}
2221
2222
2223
2224
2225local function visit_before(ast, kind, visit)
2226 assert(visit.cbs[kind], "no visitor for " .. (kind))
2227 if visit.cbs[kind].before then
2228 visit.cbs[kind].before(ast)
2229 end
2230end
2231
2232local function visit_after(ast, kind, visit, xs)
2233 if visit.after and visit.after.before then
2234 visit.after.before(ast, xs)
2235 end
2236 local ret
2237 if visit.cbs[kind].after then
2238 ret = visit.cbs[kind].after(ast, xs)
2239 end
2240 if visit.after and visit.after.after then
2241 ret = visit.after.after(ast, xs, ret)
2242 end
2243 return ret
2244end
2245
2246local function recurse_type(ast, visit)
2247 visit_before(ast, ast.typename, visit)
2248 local xs = {}
2249
2250 if ast.typeargs then
2251 for _, child in ipairs(ast.typeargs) do
2252 table.insert(xs, recurse_type(child, visit))
2253 end
2254 end
2255
2256 for i, child in ipairs(ast) do
2257 xs[i] = recurse_type(child, visit)
2258 end
2259
2260 if ast.types then
2261 for i, child in ipairs(ast.types) do
2262 table.insert(xs, recurse_type(child, visit))
2263 end
2264 end
2265 if ast.def then
2266 table.insert(xs, recurse_type(ast.def, visit))
2267 end
2268 if ast.keys then
2269 table.insert(xs, recurse_type(ast.keys, visit))
2270 end
2271 if ast.values then
2272 table.insert(xs, recurse_type(ast.values, visit))
2273 end
2274 if ast.elements then
2275 table.insert(xs, recurse_type(ast.elements, visit))
2276 end
2277 if ast.fields then
2278 for _, child in pairs(ast.fields) do
2279 table.insert(xs, recurse_type(child, visit))
2280 end
2281 end
2282 if ast.args then
2283 for i, child in ipairs(ast.args) do
2284 if i > 1 or not ast.is_method then
2285 table.insert(xs, recurse_type(child, visit))
2286 end
2287 end
2288 end
2289 if ast.rets then
2290 for _, child in ipairs(ast.rets) do
2291 table.insert(xs, recurse_type(child, visit))
2292 end
2293 end
2294 if ast.typevals then
2295 for _, child in ipairs(ast.typevals) do
2296 table.insert(xs, recurse_type(child, visit))
2297 end
2298 end
2299 if ast.ktype then
2300 table.insert(xs, recurse_type(ast.ktype, visit))
2301 end
2302 if ast.vtype then
2303 table.insert(xs, recurse_type(ast.vtype, visit))
2304 end
2305
2306 return visit_after(ast, ast.typename, visit, xs)
2307end
2308
2309local function recurse_node(ast,
2310visit_node,
2311visit_type)
2312 if not ast then
2313
2314 return
2315 end
2316
2317 visit_before(ast, ast.kind, visit_node)
2318 local xs = {}
2319 local cbs = visit_node.cbs[ast.kind]
2320 if ast.kind == "statements" or
2321 ast.kind == "variables" or
2322 ast.kind == "values" or
2323 ast.kind == "argument_list" or
2324 ast.kind == "expression_list" or
2325 ast.kind == "table_literal" then
2326 for i, child in ipairs(ast) do
2327 xs[i] = recurse_node(child, visit_node, visit_type)
2328 end
2329 elseif ast.kind == "local_declaration" or
2330 ast.kind == "global_declaration" or
2331 ast.kind == "assignment" then
2332 xs[1] = recurse_node(ast.vars, visit_node, visit_type)
2333 if ast.exps then
2334 xs[2] = recurse_node(ast.exps, visit_node, visit_type)
2335 end
2336 if ast.decltype then
2337 xs[3] = recurse_type(ast.decltype, visit_type)
2338 end
2339 elseif ast.kind == "local_type" or
2340 ast.kind == "global_type" then
2341 xs[1] = recurse_node(ast.var, visit_node, visit_type)
2342 xs[2] = recurse_node(ast.value, visit_node, visit_type)
2343 elseif ast.kind == "table_item" then
2344 xs[1] = recurse_node(ast.key, visit_node, visit_type)
2345 xs[2] = recurse_node(ast.value, visit_node, visit_type)
2346 elseif ast.kind == "if" then
2347 xs[1] = recurse_node(ast.exp, visit_node, visit_type)
2348 if cbs.before_statements then
2349 cbs.before_statements(ast, xs)
2350 end
2351 xs[2] = recurse_node(ast.thenpart, visit_node, visit_type)
2352 for i, e in ipairs(ast.elseifs) do
2353 table.insert(xs, recurse_node(e, visit_node, visit_type))
2354 end
2355 if ast.elsepart then
2356 table.insert(xs, recurse_node(ast.elsepart, visit_node, visit_type))
2357 end
2358 elseif ast.kind == "while" then
2359 xs[1] = recurse_node(ast.exp, visit_node, visit_type)
2360 if cbs.before_statements then
2361 cbs.before_statements(ast, xs)
2362 end
2363 xs[2] = recurse_node(ast.body, visit_node, visit_type)
2364 elseif ast.kind == "repeat" then
2365 xs[1] = recurse_node(ast.body, visit_node, visit_type)
2366 xs[2] = recurse_node(ast.exp, visit_node, visit_type)
2367 elseif ast.kind == "function" then
2368 xs[1] = recurse_node(ast.args, visit_node, visit_type)
2369 xs[2] = recurse_type(ast.rets, visit_type)
2370 xs[3] = recurse_node(ast.body, visit_node, visit_type)
2371 elseif ast.kind == "forin" then
2372 xs[1] = recurse_node(ast.vars, visit_node, visit_type)
2373 xs[2] = recurse_node(ast.exps, visit_node, visit_type)
2374 if cbs.before_statements then
2375 cbs.before_statements(ast)
2376 end
2377 xs[3] = recurse_node(ast.body, visit_node, visit_type)
2378 elseif ast.kind == "fornum" then
2379 xs[1] = recurse_node(ast.var, visit_node, visit_type)
2380 xs[2] = recurse_node(ast.from, visit_node, visit_type)
2381 xs[3] = recurse_node(ast.to, visit_node, visit_type)
2382 xs[4] = ast.step and recurse_node(ast.step, visit_node, visit_type)
2383 xs[5] = recurse_node(ast.body, visit_node, visit_type)
2384 elseif ast.kind == "elseif" then
2385 xs[1] = recurse_node(ast.exp, visit_node, visit_type)
2386 if cbs.before_statements then
2387 cbs.before_statements(ast, xs)
2388 end
2389 xs[2] = recurse_node(ast.thenpart, visit_node, visit_type)
2390 elseif ast.kind == "else" then
2391 xs[1] = recurse_node(ast.elsepart, visit_node, visit_type)
2392 elseif ast.kind == "return" then
2393 xs[1] = recurse_node(ast.exps, visit_node, visit_type)
2394 elseif ast.kind == "do" then
2395 xs[1] = recurse_node(ast.body, visit_node, visit_type)
2396 elseif ast.kind == "cast" then
2397 elseif ast.kind == "local_function" or
2398 ast.kind == "global_function" then
2399 xs[1] = recurse_node(ast.name, visit_node, visit_type)
2400 xs[2] = recurse_node(ast.args, visit_node, visit_type)
2401 xs[3] = recurse_type(ast.rets, visit_type)
2402 xs[4] = recurse_node(ast.body, visit_node, visit_type)
2403 elseif ast.kind == "record_function" then
2404 xs[1] = recurse_node(ast.fn_owner, visit_node, visit_type)
2405 xs[2] = recurse_node(ast.name, visit_node, visit_type)
2406 xs[3] = recurse_node(ast.args, visit_node, visit_type)
2407 xs[4] = recurse_type(ast.rets, visit_type)
2408 if cbs.before_statements then
2409 cbs.before_statements(ast, xs)
2410 end
2411 xs[5] = recurse_node(ast.body, visit_node, visit_type)
2412 elseif ast.kind == "paren" then
2413 xs[1] = recurse_node(ast.e1, visit_node, visit_type)
2414 elseif ast.kind == "op" then
2415 xs[1] = recurse_node(ast.e1, visit_node, visit_type)
2416 local p1 = ast.e1.op and ast.e1.op.prec or nil
2417 if ast.op.op == ":" and ast.e1.kind == "string" then
2418 p1 = -999
2419 end
2420 xs[2] = p1
2421 if ast.op.arity == 2 then
2422 if cbs.before_e2 then
2423 cbs.before_e2(ast, xs)
2424 end
2425 if ast.op.op == "is" or ast.op.op == "as" then
2426 xs[3] = recurse_type(ast.e2.casttype, visit_type)
2427 else
2428 xs[3] = recurse_node(ast.e2, visit_node, visit_type)
2429 end
2430 xs[4] = (ast.e2.op and ast.e2.op.prec)
2431 end
2432 elseif ast.kind == "newtype" then
2433 xs[1] = recurse_type(ast.newtype, visit_type)
2434 elseif ast.kind == "variable" or
2435 ast.kind == "argument" or
2436 ast.kind == "identifier" or
2437 ast.kind == "string" or
2438 ast.kind == "number" or
2439 ast.kind == "break" or
2440 ast.kind == "goto" or
2441 ast.kind == "label" or
2442 ast.kind == "nil" or
2443 ast.kind == "..." or
2444 ast.kind == "boolean" then
2445 if ast.decltype then
2446 xs[1] = recurse_type(ast.decltype, visit_type)
2447 end
2448 else
2449 if not ast.kind then
2450 error("wat: " .. inspect(ast))
2451 end
2452 error("unknown node kind " .. ast.kind)
2453 end
2454 return visit_after(ast, ast.kind, visit_node, xs)
2455end
2456
2457
2458
2459
2460
2461local tight_op = {
2462 [1] = {
2463 ["-"] = true,
2464 ["~"] = true,
2465 ["#"] = true,
2466 },
2467 [2] = {
2468 ["."] = true,
2469 [":"] = true,
2470 },
2471}
2472
2473local spaced_op = {
2474 [1] = {
2475 ["not"] = true,
2476 },
2477 [2] = {
2478 ["or"] = true,
2479 ["and"] = true,
2480 ["<"] = true,
2481 [">"] = true,
2482 ["<="] = true,
2483 [">="] = true,
2484 ["~="] = true,
2485 ["=="] = true,
2486 ["|"] = true,
2487 ["~"] = true,
2488 ["&"] = true,
2489 ["<<"] = true,
2490 [">>"] = true,
2491 [".."] = true,
2492 ["+"] = true,
2493 ["-"] = true,
2494 ["*"] = true,
2495 ["/"] = true,
2496 ["//"] = true,
2497 ["%"] = true,
2498 ["^"] = true,
2499 },
2500}
2501
2502local PrettyPrintOpts = {}
2503
2504
2505
2506
2507local default_pretty_print_ast_opts = {
2508 preserve_indent = true,
2509 preserve_newlines = true,
2510}
2511
2512local fast_pretty_print_ast_opts = {
2513 preserve_indent = false,
2514 preserve_newlines = true,
2515}
2516
2517function tl.pretty_print_ast(ast, mode)
2518 local indent = 0
2519
2520 local opts
2521 if type(mode) == "table" then
2522 opts = mode
2523 elseif mode == true then
2524 opts = fast_pretty_print_ast_opts
2525 else
2526 opts = default_pretty_print_ast_opts
2527 end
2528
2529 local Output = {}
2530
2531
2532
2533
2534
2535 local function increment_indent()
2536 indent = indent + 1
2537 end
2538
2539 if not opts.preserve_indent then
2540 increment_indent = nil
2541 end
2542
2543 local function add(out, s)
2544 table.insert(out, s)
2545 end
2546
2547 local function add_string(out, s)
2548 table.insert(out, s)
2549 if string.find(s, "\n", 1, true) then
2550 for nl in s:gmatch("\n") do
2551 out.h = out.h + 1
2552 end
2553 end
2554 end
2555
2556 local function add_child(out, child, space, indent)
2557 if #child == 0 then
2558 return
2559 end
2560
2561 if child.y < out.y then
2562 out.y = child.y
2563 end
2564
2565 if child.y > out.y + out.h and opts.preserve_newlines then
2566 local delta = child.y - (out.y + out.h)
2567 out.h = out.h + delta
2568 table.insert(out, ("\n"):rep(delta))
2569 else
2570 if space then
2571 table.insert(out, space)
2572 indent = nil
2573 end
2574 end
2575 if indent and opts.preserve_indent then
2576 table.insert(out, (" "):rep(indent))
2577 end
2578 table.insert(out, child)
2579 out.h = out.h + child.h
2580 end
2581
2582 local function concat_output(out)
2583 for i, s in ipairs(out) do
2584 if type(s) == "table" then
2585 out[i] = concat_output(s)
2586 end
2587 end
2588 return table.concat(out)
2589 end
2590
2591 local function print_record_def(typ)
2592 local out = { "{" }
2593 for name, field in pairs(typ.fields) do
2594 if field.typename == "typetype" and is_record_type(field.def) then
2595 table.insert(out, name)
2596 table.insert(out, " = ")
2597 table.insert(out, print_record_def(field.def))
2598 table.insert(out, ", ")
2599 end
2600 end
2601 table.insert(out, "}")
2602 return table.concat(out)
2603 end
2604
2605 local visit_node = {}
2606
2607 visit_node.cbs = {
2608 ["statements"] = {
2609 after = function(node, children)
2610 local out = { y = node.y, h = 0 }
2611 local space
2612 for i, child in ipairs(children) do
2613 add_child(out, children[i], space, indent)
2614 space = "; "
2615 end
2616 return out
2617 end,
2618 },
2619 ["local_declaration"] = {
2620 after = function(node, children)
2621 local out = { y = node.y, h = 0 }
2622 table.insert(out, "local")
2623 add_child(out, children[1], " ")
2624 if children[2] then
2625 table.insert(out, " =")
2626 add_child(out, children[2], " ")
2627 end
2628 return out
2629 end,
2630 },
2631 ["local_type"] = {
2632 after = function(node, children)
2633 local out = { y = node.y, h = 0 }
2634 table.insert(out, "local")
2635 add_child(out, children[1], " ")
2636 table.insert(out, " =")
2637 add_child(out, children[2], " ")
2638 return out
2639 end,
2640 },
2641 ["global_type"] = {
2642 after = function(node, children)
2643 local out = { y = node.y, h = 0 }
2644 add_child(out, children[1], " ")
2645 table.insert(out, " =")
2646 add_child(out, children[2], " ")
2647 return out
2648 end,
2649 },
2650 ["global_declaration"] = {
2651 after = function(node, children)
2652 local out = { y = node.y, h = 0 }
2653 if children[2] then
2654 add_child(out, children[1])
2655 table.insert(out, " =")
2656 add_child(out, children[2], " ")
2657 end
2658 return out
2659 end,
2660 },
2661 ["assignment"] = {
2662 after = function(node, children)
2663 local out = { y = node.y, h = 0 }
2664 add_child(out, children[1])
2665 table.insert(out, " =")
2666 add_child(out, children[2], " ")
2667 return out
2668 end,
2669 },
2670 ["if"] = {
2671 before = increment_indent,
2672 after = function(node, children)
2673 local out = { y = node.y, h = 0 }
2674 table.insert(out, "if")
2675 add_child(out, children[1], " ")
2676 table.insert(out, " then")
2677 add_child(out, children[2], " ")
2678 indent = indent - 1
2679 for i = 3, #children do
2680 add_child(out, children[i], " ", indent)
2681 end
2682 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2683 return out
2684 end,
2685 },
2686 ["while"] = {
2687 before = increment_indent,
2688 after = function(node, children)
2689 local out = { y = node.y, h = 0 }
2690 table.insert(out, "while")
2691 add_child(out, children[1], " ")
2692 table.insert(out, " do")
2693 add_child(out, children[2], " ")
2694 indent = indent - 1
2695 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2696 return out
2697 end,
2698 },
2699 ["repeat"] = {
2700 before = increment_indent,
2701 after = function(node, children)
2702 local out = { y = node.y, h = 0 }
2703 table.insert(out, "repeat")
2704 add_child(out, children[1], " ")
2705 if opts.preserve_indent then
2706 indent = indent - 1
2707 end
2708 add_child(out, { y = node.yend, h = 0, [1] = "until " }, " ", indent)
2709 add_child(out, children[2])
2710 return out
2711 end,
2712 },
2713 ["do"] = {
2714 before = increment_indent,
2715 after = function(node, children)
2716 local out = { y = node.y, h = 0 }
2717 table.insert(out, "do")
2718 add_child(out, children[1], " ")
2719 indent = indent - 1
2720 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2721 return out
2722 end,
2723 },
2724 ["forin"] = {
2725 before = increment_indent,
2726 after = function(node, children)
2727 local out = { y = node.y, h = 0 }
2728 table.insert(out, "for")
2729 add_child(out, children[1], " ")
2730 table.insert(out, " in")
2731 add_child(out, children[2], " ")
2732 table.insert(out, " do")
2733 add_child(out, children[3], " ")
2734 indent = indent - 1
2735 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2736 return out
2737 end,
2738 },
2739 ["fornum"] = {
2740 before = increment_indent,
2741 after = function(node, children)
2742 local out = { y = node.y, h = 0 }
2743 table.insert(out, "for")
2744 add_child(out, children[1], " ")
2745 table.insert(out, " =")
2746 add_child(out, children[2], " ")
2747 table.insert(out, ",")
2748 add_child(out, children[3], " ")
2749 if children[4] then
2750 table.insert(out, ",")
2751 add_child(out, children[4], " ")
2752 end
2753 table.insert(out, " do")
2754 add_child(out, children[5], " ")
2755 indent = indent - 1
2756 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2757 return out
2758 end,
2759 },
2760 ["return"] = {
2761 after = function(node, children)
2762 local out = { y = node.y, h = 0 }
2763 table.insert(out, "return")
2764 if #children[1] > 0 then
2765 add_child(out, children[1], " ")
2766 end
2767 return out
2768 end,
2769 },
2770 ["break"] = {
2771 after = function(node, children)
2772 local out = { y = node.y, h = 0 }
2773 table.insert(out, "break")
2774 return out
2775 end,
2776 },
2777 ["elseif"] = {
2778 after = function(node, children)
2779 local out = { y = node.y, h = 0 }
2780 table.insert(out, "elseif")
2781 add_child(out, children[1], " ")
2782 table.insert(out, " then")
2783 add_child(out, children[2], " ")
2784 return out
2785 end,
2786 },
2787 ["else"] = {
2788 after = function(node, children)
2789 local out = { y = node.y, h = 0 }
2790 table.insert(out, "else")
2791 add_child(out, children[1], " ")
2792 return out
2793 end,
2794 },
2795 ["variables"] = {
2796 after = function(node, children)
2797 local out = { y = node.y, h = 0 }
2798 local space
2799 for i, child in ipairs(children) do
2800 if i > 1 then
2801 table.insert(out, ",")
2802 space = " "
2803 end
2804 add_child(out, child, space)
2805 end
2806 return out
2807 end,
2808 },
2809 ["table_literal"] = {
2810 before = increment_indent,
2811 after = function(node, children)
2812 local out = { y = node.y, h = 0 }
2813 if #children == 0 then
2814 indent = indent - 1
2815 table.insert(out, "{}")
2816 return out
2817 end
2818 table.insert(out, "{")
2819 local n = #children
2820 for i, child in ipairs(children) do
2821 add_child(out, child, " ", child.y ~= node.y and indent)
2822 if i < n or node.yend ~= node.y then
2823 table.insert(out, ",")
2824 end
2825 end
2826 indent = indent - 1
2827 add_child(out, { y = node.yend, h = 0, [1] = "}" }, " ", indent)
2828 return out
2829 end,
2830 },
2831 ["table_item"] = {
2832 after = function(node, children)
2833 local out = { y = node.y, h = 0 }
2834 if node.key_parsed ~= "implicit" then
2835 if node.key_parsed == "short" then
2836 children[1][1] = children[1][1]:sub(2, -2)
2837 add_child(out, children[1])
2838 table.insert(out, " = ")
2839 else
2840 table.insert(out, "[")
2841 add_child(out, children[1])
2842 table.insert(out, "] = ")
2843 end
2844 end
2845 add_child(out, children[2])
2846 return out
2847 end,
2848 },
2849 ["local_function"] = {
2850 before = increment_indent,
2851 after = function(node, children)
2852 local out = { y = node.y, h = 0 }
2853 table.insert(out, "local function")
2854 add_child(out, children[1], " ")
2855 table.insert(out, "(")
2856 add_child(out, children[2])
2857 table.insert(out, ")")
2858 add_child(out, children[4], " ")
2859 indent = indent - 1
2860 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2861 return out
2862 end,
2863 },
2864 ["global_function"] = {
2865 before = increment_indent,
2866 after = function(node, children)
2867 local out = { y = node.y, h = 0 }
2868 table.insert(out, "function")
2869 add_child(out, children[1], " ")
2870 table.insert(out, "(")
2871 add_child(out, children[2])
2872 table.insert(out, ")")
2873 add_child(out, children[4], " ")
2874 indent = indent - 1
2875 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2876 return out
2877 end,
2878 },
2879 ["record_function"] = {
2880 before = increment_indent,
2881 after = function(node, children)
2882 local out = { y = node.y, h = 0 }
2883 table.insert(out, "function")
2884 add_child(out, children[1], " ")
2885 table.insert(out, node.is_method and ":" or ".")
2886 add_child(out, children[2])
2887 table.insert(out, "(")
2888 if node.is_method then
2889
2890 table.remove(children[3], 1)
2891 if children[3][1] == "," then
2892 table.remove(children[3], 1)
2893 table.remove(children[3], 1)
2894 end
2895 end
2896 add_child(out, children[3])
2897 table.insert(out, ")")
2898 add_child(out, children[5], " ")
2899 indent = indent - 1
2900 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2901 return out
2902 end,
2903 },
2904 ["function"] = {
2905 before = increment_indent,
2906 after = function(node, children)
2907 local out = { y = node.y, h = 0 }
2908 table.insert(out, "function(")
2909 add_child(out, children[1])
2910 table.insert(out, ")")
2911 add_child(out, children[3], " ")
2912 indent = indent - 1
2913 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
2914 return out
2915 end,
2916 },
2917 ["cast"] = {},
2918
2919 ["paren"] = {
2920 after = function(node, children)
2921 local out = { y = node.y, h = 0 }
2922 table.insert(out, "(")
2923 add_child(out, children[1], "", indent)
2924 table.insert(out, ")")
2925 return out
2926 end,
2927 },
2928 ["op"] = {
2929 after = function(node, children)
2930 local out = { y = node.y, h = 0 }
2931 if node.op.op == "@funcall" then
2932 add_child(out, children[1], "", indent)
2933 table.insert(out, "(")
2934 add_child(out, children[3], "", indent)
2935 table.insert(out, ")")
2936 elseif node.op.op == "@index" then
2937 add_child(out, children[1], "", indent)
2938 table.insert(out, "[")
2939 add_child(out, children[3], "", indent)
2940 table.insert(out, "]")
2941 elseif node.op.op == "as" then
2942 add_child(out, children[1], "", indent)
2943 elseif node.op.op == "is" then
2944 table.insert(out, "type(")
2945 add_child(out, children[1], "", indent)
2946 table.insert(out, ") == \"")
2947 add_child(out, children[3], "", indent)
2948 table.insert(out, "\"")
2949 elseif spaced_op[node.op.arity][node.op.op] or tight_op[node.op.arity][node.op.op] then
2950 local space = spaced_op[node.op.arity][node.op.op] and " " or ""
2951 if children[2] and node.op.prec > tonumber(children[2]) then
2952 table.insert(children[1], 1, "(")
2953 table.insert(children[1], ")")
2954 end
2955 if node.op.arity == 1 then
2956 table.insert(out, node.op.op)
2957 add_child(out, children[1], space, indent)
2958 elseif node.op.arity == 2 then
2959 add_child(out, children[1], "", indent)
2960 if space == " " then
2961 table.insert(out, " ")
2962 end
2963 table.insert(out, node.op.op)
2964 if children[4] and node.op.prec > tonumber(children[4]) then
2965 table.insert(children[3], 1, "(")
2966 table.insert(children[3], ")")
2967 end
2968 add_child(out, children[3], space, indent)
2969 end
2970 else
2971 error("unknown node op " .. node.op.op)
2972 end
2973 return out
2974 end,
2975 },
2976 ["variable"] = {
2977 after = function(node, children)
2978 local out = { y = node.y, h = 0 }
2979 add_string(out, node.tk)
2980 return out
2981 end,
2982 },
2983 ["newtype"] = {
2984 after = function(node, children)
2985 local out = { y = node.y, h = 0 }
2986 if is_record_type(node.newtype.def) then
2987 table.insert(out, print_record_def(node.newtype.def))
2988 else
2989 table.insert(out, "{}")
2990 end
2991 return out
2992 end,
2993 },
2994 ["goto"] = {
2995 after = function(node, children)
2996 local out = { y = node.y, h = 0 }
2997 table.insert(out, "goto ")
2998 table.insert(out, node.label)
2999 return out
3000 end,
3001 },
3002 ["label"] = {
3003 after = function(node, children)
3004 local out = { y = node.y, h = 0 }
3005 table.insert(out, "::")
3006 table.insert(out, node.label)
3007 table.insert(out, "::")
3008 return out
3009 end,
3010 },
3011 }
3012
3013 local primitive = {
3014 ["function"] = "function",
3015 ["enum"] = "string",
3016 ["boolean"] = "boolean",
3017 ["string"] = "string",
3018 ["nil"] = "nil",
3019 ["number"] = "number",
3020 ["thread"] = "thread",
3021 }
3022
3023 local visit_type = {}
3024 visit_type.cbs = {
3025 ["string"] = {
3026 after = function(typ, children)
3027 local out = { y = typ.y, h = 0 }
3028 table.insert(out, primitive[typ.typename] or "table")
3029 return out
3030 end,
3031 },
3032 }
3033 visit_type.cbs["typetype"] = visit_type.cbs["string"]
3034 visit_type.cbs["typevar"] = visit_type.cbs["string"]
3035 visit_type.cbs["typearg"] = visit_type.cbs["string"]
3036 visit_type.cbs["function"] = visit_type.cbs["string"]
3037 visit_type.cbs["thread"] = visit_type.cbs["string"]
3038 visit_type.cbs["array"] = visit_type.cbs["string"]
3039 visit_type.cbs["map"] = visit_type.cbs["string"]
3040 visit_type.cbs["arrayrecord"] = visit_type.cbs["string"]
3041 visit_type.cbs["record"] = visit_type.cbs["string"]
3042 visit_type.cbs["enum"] = visit_type.cbs["string"]
3043 visit_type.cbs["boolean"] = visit_type.cbs["string"]
3044 visit_type.cbs["nil"] = visit_type.cbs["string"]
3045 visit_type.cbs["number"] = visit_type.cbs["string"]
3046 visit_type.cbs["union"] = visit_type.cbs["string"]
3047 visit_type.cbs["nominal"] = visit_type.cbs["string"]
3048 visit_type.cbs["bad_nominal"] = visit_type.cbs["string"]
3049 visit_type.cbs["emptytable"] = visit_type.cbs["string"]
3050 visit_type.cbs["table_item"] = visit_type.cbs["string"]
3051 visit_type.cbs["unknown_emptytable_value"] = visit_type.cbs["string"]
3052 visit_type.cbs["tuple"] = visit_type.cbs["string"]
3053 visit_type.cbs["poly"] = visit_type.cbs["string"]
3054 visit_type.cbs["any"] = visit_type.cbs["string"]
3055 visit_type.cbs["unknown"] = visit_type.cbs["string"]
3056 visit_type.cbs["invalid"] = visit_type.cbs["string"]
3057 visit_type.cbs["unresolved"] = visit_type.cbs["string"]
3058 visit_type.cbs["none"] = visit_type.cbs["string"]
3059
3060 visit_node.cbs["values"] = visit_node.cbs["variables"]
3061 visit_node.cbs["expression_list"] = visit_node.cbs["variables"]
3062 visit_node.cbs["argument_list"] = visit_node.cbs["variables"]
3063 visit_node.cbs["identifier"] = visit_node.cbs["variable"]
3064 visit_node.cbs["string"] = visit_node.cbs["variable"]
3065 visit_node.cbs["number"] = visit_node.cbs["variable"]
3066 visit_node.cbs["nil"] = visit_node.cbs["variable"]
3067 visit_node.cbs["boolean"] = visit_node.cbs["variable"]
3068 visit_node.cbs["..."] = visit_node.cbs["variable"]
3069 visit_node.cbs["argument"] = visit_node.cbs["variable"]
3070
3071 local out = recurse_node(ast, visit_node, visit_type)
3072 local code
3073 if opts.preserve_newlines then
3074 code = { y = 1, h = 0 }
3075 add_child(code, out)
3076 else
3077 code = out
3078 end
3079 return concat_output(code)
3080end
3081
3082
3083
3084
3085
3086local ANY = a_type({ typename = "any" })
3087local NONE = a_type({ typename = "none" })
3088
3089local NIL = a_type({ typename = "nil" })
3090local NUMBER = a_type({ typename = "number" })
3091local STRING = a_type({ typename = "string" })
3092local OPT_NUMBER = a_type({ typename = "number" })
3093local OPT_STRING = a_type({ typename = "string" })
3094local VARARG_ANY = a_type({ typename = "any", is_va = true })
3095local VARARG_STRING = a_type({ typename = "string", is_va = true })
3096local VARARG_NUMBER = a_type({ typename = "number", is_va = true })
3097local VARARG_UNKNOWN = a_type({ typename = "unknown", is_va = true })
3098local VARARG_ALPHA = a_type({ typename = "typevar", typevar = "@a", is_va = true })
3099local BOOLEAN = a_type({ typename = "boolean" })
3100local ARG_ALPHA = a_type({ typename = "typearg", typearg = "@a" })
3101local ARG_BETA = a_type({ typename = "typearg", typearg = "@b" })
3102local ALPHA = a_type({ typename = "typevar", typevar = "@a" })
3103local BETA = a_type({ typename = "typevar", typevar = "@b" })
3104local ARRAY_OF_STRING = a_type({ typename = "array", elements = STRING })
3105local ARRAY_OF_ALPHA = a_type({ typename = "array", elements = ALPHA })
3106local MAP_OF_ALPHA_TO_BETA = a_type({ typename = "map", keys = ALPHA, values = BETA })
3107local TABLE = a_type({ typename = "map", keys = ANY, values = ANY })
3108local FUNCTION = a_type({ typename = "function", args = { a_type({ typename = "any", is_va = true }) }, rets = { a_type({ typename = "any", is_va = true }) } })
3109local THREAD = a_type({ typename = "thread" })
3110local INVALID = a_type({ typename = "invalid" })
3111local UNKNOWN = a_type({ typename = "unknown" })
3112local NOMINAL_FILE = a_type({ typename = "nominal", names = { "FILE" } })
3113local NOMINAL_METATABLE = a_type({ typename = "nominal", names = { "METATABLE" } })
3114
3115local OS_DATE_TABLE = a_type({
3116 typename = "record",
3117 fields = {
3118 ["year"] = NUMBER,
3119 ["month"] = NUMBER,
3120 ["day"] = NUMBER,
3121 ["hour"] = NUMBER,
3122 ["min"] = NUMBER,
3123 ["sec"] = NUMBER,
3124 ["wday"] = NUMBER,
3125 ["yday"] = NUMBER,
3126 ["isdst"] = BOOLEAN,
3127 },
3128})
3129
3130local DEBUG_GETINFO_TABLE = a_type({
3131 typename = "record",
3132 fields = {
3133 ["name"] = STRING,
3134 ["namewhat"] = STRING,
3135 ["source"] = STRING,
3136 ["short_src"] = STRING,
3137 ["linedefined"] = NUMBER,
3138 ["lastlinedefined"] = NUMBER,
3139 ["what"] = STRING,
3140 ["currentline"] = NUMBER,
3141 ["istailcall"] = BOOLEAN,
3142 ["nups"] = NUMBER,
3143 ["nparams"] = NUMBER,
3144 ["isvararg"] = BOOLEAN,
3145 ["func"] = ANY,
3146 ["activelines"] = a_type({ typename = "map", keys = NUMBER, values = BOOLEAN }),
3147 },
3148})
3149
3150local numeric_binop = {
3151 ["number"] = {
3152 ["number"] = NUMBER,
3153 },
3154}
3155
3156local relational_binop = {
3157 ["number"] = {
3158 ["number"] = BOOLEAN,
3159 },
3160 ["string"] = {
3161 ["string"] = BOOLEAN,
3162 },
3163 ["boolean"] = {
3164 ["boolean"] = BOOLEAN,
3165 },
3166}
3167
3168local equality_binop = {
3169 ["number"] = {
3170 ["number"] = BOOLEAN,
3171 ["nil"] = BOOLEAN,
3172 },
3173 ["string"] = {
3174 ["string"] = BOOLEAN,
3175 ["nil"] = BOOLEAN,
3176 },
3177 ["boolean"] = {
3178 ["boolean"] = BOOLEAN,
3179 ["nil"] = BOOLEAN,
3180 },
3181 ["record"] = {
3182 ["emptytable"] = BOOLEAN,
3183 ["arrayrecord"] = BOOLEAN,
3184 ["record"] = BOOLEAN,
3185 ["nil"] = BOOLEAN,
3186 },
3187 ["array"] = {
3188 ["emptytable"] = BOOLEAN,
3189 ["arrayrecord"] = BOOLEAN,
3190 ["array"] = BOOLEAN,
3191 ["nil"] = BOOLEAN,
3192 },
3193 ["arrayrecord"] = {
3194 ["emptytable"] = BOOLEAN,
3195 ["arrayrecord"] = BOOLEAN,
3196 ["record"] = BOOLEAN,
3197 ["array"] = BOOLEAN,
3198 ["nil"] = BOOLEAN,
3199 },
3200 ["map"] = {
3201 ["emptytable"] = BOOLEAN,
3202 ["map"] = BOOLEAN,
3203 ["nil"] = BOOLEAN,
3204 },
3205 ["thread"] = {
3206 ["thread"] = BOOLEAN,
3207 ["nil"] = BOOLEAN,
3208 },
3209}
3210
3211local unop_types = {
3212 ["#"] = {
3213 ["arrayrecord"] = NUMBER,
3214 ["string"] = NUMBER,
3215 ["array"] = NUMBER,
3216 ["map"] = NUMBER,
3217 ["emptytable"] = NUMBER,
3218 },
3219 ["-"] = {
3220 ["number"] = NUMBER,
3221 },
3222 ["not"] = {
3223 ["string"] = BOOLEAN,
3224 ["number"] = BOOLEAN,
3225 ["boolean"] = BOOLEAN,
3226 ["record"] = BOOLEAN,
3227 ["arrayrecord"] = BOOLEAN,
3228 ["array"] = BOOLEAN,
3229 ["map"] = BOOLEAN,
3230 ["emptytable"] = BOOLEAN,
3231 ["thread"] = BOOLEAN,
3232 },
3233}
3234
3235local binop_types = {
3236 ["+"] = numeric_binop,
3237 ["-"] = {
3238 ["number"] = {
3239 ["number"] = NUMBER,
3240 },
3241 },
3242 ["*"] = numeric_binop,
3243 ["%"] = numeric_binop,
3244 ["/"] = numeric_binop,
3245 ["^"] = numeric_binop,
3246 ["&"] = numeric_binop,
3247 ["|"] = numeric_binop,
3248 ["<<"] = numeric_binop,
3249 [">>"] = numeric_binop,
3250 ["=="] = equality_binop,
3251 ["~="] = equality_binop,
3252 ["<="] = relational_binop,
3253 [">="] = relational_binop,
3254 ["<"] = relational_binop,
3255 [">"] = relational_binop,
3256 ["or"] = {
3257 ["boolean"] = {
3258 ["boolean"] = BOOLEAN,
3259 ["function"] = FUNCTION,
3260 },
3261 ["number"] = {
3262 ["number"] = NUMBER,
3263 ["boolean"] = BOOLEAN,
3264 },
3265 ["string"] = {
3266 ["string"] = STRING,
3267 ["boolean"] = BOOLEAN,
3268 ["enum"] = STRING,
3269 },
3270 ["function"] = {
3271 ["function"] = FUNCTION,
3272 ["boolean"] = BOOLEAN,
3273 },
3274 ["array"] = {
3275 ["boolean"] = BOOLEAN,
3276 },
3277 ["record"] = {
3278 ["boolean"] = BOOLEAN,
3279 },
3280 ["arrayrecord"] = {
3281 ["boolean"] = BOOLEAN,
3282 },
3283 ["map"] = {
3284 ["boolean"] = BOOLEAN,
3285 },
3286 ["enum"] = {
3287 ["string"] = STRING,
3288 },
3289 ["thread"] = {
3290 ["boolean"] = BOOLEAN,
3291 },
3292 },
3293 [".."] = {
3294 ["string"] = {
3295 ["string"] = STRING,
3296 ["enum"] = STRING,
3297 ["number"] = STRING,
3298 },
3299 ["number"] = {
3300 ["number"] = STRING,
3301 ["string"] = STRING,
3302 ["enum"] = STRING,
3303 },
3304 ["enum"] = {
3305 ["number"] = STRING,
3306 ["string"] = STRING,
3307 ["enum"] = STRING,
3308 },
3309 },
3310}
3311
3312local show_type
3313
3314local function is_unknown(t)
3315 return t.typename == "unknown" or
3316 t.typename == "unknown_emptytable_value"
3317end
3318
3319local show_type
3320
3321local function show_type_base(t, seen)
3322
3323 if seen[t] then
3324 return "..."
3325 end
3326 seen[t] = true
3327
3328 local function show(t)
3329 return show_type(t, seen)
3330 end
3331
3332 if t.typename == "nominal" then
3333 if t.typevals then
3334 local out = { table.concat(t.names, "."), "<" }
3335 local vals = {}
3336 for _, v in ipairs(t.typevals) do
3337 table.insert(vals, show(v))
3338 end
3339 table.insert(out, table.concat(vals, ", "))
3340 table.insert(out, ">")
3341 return table.concat(out)
3342 else
3343 return table.concat(t.names, ".")
3344 end
3345 elseif t.typename == "tuple" then
3346 local out = {}
3347 for _, v in ipairs(t) do
3348 table.insert(out, show(v))
3349 end
3350 return "(" .. table.concat(out, ", ") .. ")"
3351 elseif t.typename == "poly" then
3352 local out = {}
3353 for _, v in ipairs(t.types) do
3354 table.insert(out, show(v))
3355 end
3356 return table.concat(out, " or ")
3357 elseif t.typename == "union" then
3358 local out = {}
3359 for _, v in ipairs(t.types) do
3360 table.insert(out, show(v))
3361 end
3362 return table.concat(out, " | ")
3363 elseif t.typename == "emptytable" then
3364 return "{}"
3365 elseif t.typename == "map" then
3366 return "{" .. show(t.keys) .. " : " .. show(t.values) .. "}"
3367 elseif t.typename == "array" then
3368 return "{" .. show(t.elements) .. "}"
3369 elseif t.typename == "enum" then
3370 return t.names and table.concat(t.names, ".") or "enum"
3371 elseif is_record_type(t) then
3372 local out = {}
3373 for _, k in ipairs(t.field_order) do
3374 local v = t.fields[k]
3375 table.insert(out, k .. ": " .. show(v))
3376 end
3377 return "{" .. table.concat(out, ", ") .. "}"
3378 elseif t.typename == "function" then
3379 local out = {}
3380 table.insert(out, "function(")
3381 local args = {}
3382 if t.is_method then
3383 table.insert(args, "self")
3384 end
3385 for i, v in ipairs(t.args) do
3386 if not t.is_method or i > 1 then
3387 table.insert(args, show(v))
3388 end
3389 end
3390 table.insert(out, table.concat(args, ","))
3391 table.insert(out, ")")
3392 if #t.rets > 0 then
3393 table.insert(out, ":")
3394 local rets = {}
3395 for _, v in ipairs(t.rets) do
3396 table.insert(rets, show(v))
3397 end
3398 table.insert(out, table.concat(rets, ","))
3399 end
3400 return table.concat(out)
3401 elseif t.typename == "number" or
3402 t.typename == "boolean" or
3403 t.typename == "thread" then
3404 return t.typename
3405 elseif t.typename == "string" then
3406 return t.typename ..
3407 (t.tk and " " .. t.tk or "")
3408 elseif t.typename == "typevar" then
3409 return t.typevar
3410 elseif t.typename == "typearg" then
3411 return t.typearg
3412 elseif is_unknown(t) then
3413 return "<unknown type>"
3414 elseif t.typename == "invalid" then
3415 return "<invalid type>"
3416 elseif t.typename == "any" then
3417 return "<any type>"
3418 elseif t.typename == "nil" then
3419 return "nil"
3420 elseif t.typename == "typetype" then
3421 return "type " .. show(t.def)
3422 elseif t.typename == "bad_nominal" then
3423 return table.concat(t.names, ".") .. " (an unknown type)"
3424 else
3425 return inspect(t)
3426 end
3427end
3428
3429show_type = function(t, seen)
3430 local ret = show_type_base(t, seen or {})
3431 if t.inferred_at then
3432 ret = ret .. " (inferred at " .. t.inferred_at_file .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ": )"
3433 end
3434 return ret
3435end
3436
3437local Error = {}
3438
3439
3440
3441
3442
3443
3444local Result = {}
3445
3446
3447
3448
3449
3450
3451
3452
3453local function search_for(module_name, suffix, path, tried)
3454 for entry in path:gmatch("[^;]+") do
3455 local slash_name = module_name:gsub("%.", "/")
3456 local filename = entry:gsub("?", slash_name)
3457 local tl_filename = filename:gsub("%.lua$", suffix)
3458 local fd = io.open(tl_filename, "r")
3459 if fd then
3460 return tl_filename, fd, tried
3461 end
3462 table.insert(tried, "no file '" .. tl_filename .. "'")
3463 end
3464 return nil, nil, tried
3465end
3466
3467function tl.search_module(module_name, search_dtl)
3468 local found
3469 local tried = {}
3470 local path = os.getenv("TL_PATH") or package.path
3471 if search_dtl then
3472 local found, fd, tried = search_for(module_name, ".d.tl", path, tried)
3473 if found then
3474 return found, fd
3475 end
3476 end
3477 local found, fd, tried = search_for(module_name, ".tl", path, tried)
3478 if found then
3479 return found, fd
3480 end
3481 local found, fd, tried = search_for(module_name, ".lua", path, tried)
3482 if found then
3483 return found, fd
3484 end
3485 return nil, nil, tried
3486end
3487
3488local Variable = {}
3489
3490
3491
3492
3493
3494
3495
3496local function fill_field_order(t)
3497 if t.typename == "record" then
3498 t.field_order = {}
3499 for k, v in pairs(t.fields) do
3500 table.insert(t.field_order, k)
3501 end
3502 table.sort(t.field_order)
3503 end
3504end
3505
3506local function require_module(module_name, lax, env, result)
3507 local modules = env.modules
3508
3509 if modules[module_name] then
3510 return modules[module_name], true
3511 end
3512 modules[module_name] = UNKNOWN
3513
3514 local found, fd, tried = tl.search_module(module_name, true)
3515 if found and (lax or found:match("tl$")) then
3516 fd:close()
3517 local _result, err = tl.process(found, env, result)
3518 assert(_result, err)
3519
3520 if not _result.type then
3521 _result.type = BOOLEAN
3522 end
3523
3524 modules[module_name] = _result.type
3525
3526 return _result.type, true
3527 end
3528
3529 return UNKNOWN, found ~= nil
3530end
3531
3532local standard_library = {
3533 ["..."] = a_type({ typename = "tuple", STRING, STRING, STRING, STRING, STRING }),
3534 ["@return"] = a_type({ typename = "tuple", ANY }),
3535 ["any"] = a_type({ typename = "typetype", def = ANY }),
3536 ["arg"] = ARRAY_OF_STRING,
3537 ["assert"] = a_type({
3538 typename = "poly",
3539 types = {
3540 a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ALPHA }, rets = { ALPHA } }),
3541 a_type({ typename = "function", typeargs = { ARG_ALPHA, ARG_BETA }, args = { ALPHA, BETA }, rets = { ALPHA } }),
3542 },
3543 }),
3544 ["collectgarbage"] = a_type({ typename = "function", args = { STRING }, rets = { a_type({ typename = "union", types = { BOOLEAN, NUMBER } }), NUMBER, NUMBER } }),
3545 ["dofile"] = a_type({ typename = "function", args = { OPT_STRING }, rets = { VARARG_ANY } }),
3546 ["error"] = a_type({ typename = "function", args = { STRING, NUMBER }, rets = {} }),
3547 ["getmetatable"] = a_type({ typename = "function", args = { ANY }, rets = { NOMINAL_METATABLE } }),
3548 ["ipairs"] = a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA }, rets = {
3549 a_type({ typename = "function", args = {}, rets = { NUMBER, ALPHA } }),
3550 }, }),
3551 ["load"] = a_type({
3552 typename = "poly",
3553 types = {
3554 a_type({ typename = "function", args = { STRING }, rets = { FUNCTION, STRING } }),
3555 a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION, STRING } }),
3556 a_type({ typename = "function", args = { STRING, STRING, STRING }, rets = { FUNCTION, STRING } }),
3557 a_type({ typename = "function", args = { STRING, STRING, STRING, TABLE }, rets = { FUNCTION, STRING } }),
3558 },
3559 }),
3560 ["loadfile"] = a_type({
3561 typename = "poly",
3562 types = {
3563 a_type({ typename = "function", args = {}, rets = { FUNCTION, ANY } }),
3564 a_type({ typename = "function", args = { STRING }, rets = { FUNCTION, ANY } }),
3565 a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION, ANY } }),
3566 a_type({ typename = "function", args = { STRING, STRING, TABLE }, rets = { FUNCTION, ANY } }),
3567 },
3568 }),
3569 ["next"] = a_type({
3570 typename = "poly",
3571 types = {
3572 a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA }, rets = { ALPHA, BETA } }),
3573 a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA, ALPHA }, rets = { ALPHA, BETA } }),
3574 a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA }, rets = { NUMBER, ALPHA } }),
3575 a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA, ALPHA }, rets = { NUMBER, ALPHA } }),
3576 },
3577 }),
3578 ["pairs"] = a_type({ typename = "function", typeargs = { ARG_ALPHA, ARG_BETA }, args = { a_type({ typename = "map", keys = ALPHA, values = BETA }) }, rets = {
3579 a_type({ typename = "function", args = {}, rets = { ALPHA, BETA } }),
3580 }, }),
3581 ["pcall"] = a_type({ typename = "function", args = { FUNCTION, VARARG_ANY }, rets = { BOOLEAN, ANY } }),
3582 ["xpcall"] = a_type({ typename = "function", args = { FUNCTION, FUNCTION, VARARG_ANY }, rets = { BOOLEAN, ANY } }),
3583 ["print"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = {} }),
3584 ["rawequal"] = a_type({ typename = "function", args = { ANY, ANY }, rets = { BOOLEAN } }),
3585 ["rawget"] = a_type({ typename = "function", args = { TABLE, ANY }, rets = { ANY } }),
3586 ["rawlen"] = a_type({
3587 typename = "poly",
3588 types = {
3589 a_type({ typename = "function", args = { TABLE }, rets = { NUMBER } }),
3590 a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }),
3591 },
3592 }),
3593 ["rawset"] = a_type({
3594 typename = "poly",
3595 types = {
3596 a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA, ALPHA, BETA }, rets = {} }),
3597 a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA, NUMBER, ALPHA }, rets = {} }),
3598 a_type({ typename = "function", args = { TABLE, ANY, ANY }, rets = {} }),
3599 },
3600 }),
3601 ["require"] = a_type({ typename = "function", args = { STRING }, rets = {} }),
3602 ["select"] = a_type({
3603 typename = "poly",
3604 types = {
3605 a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { NUMBER, VARARG_ALPHA }, rets = { ALPHA } }),
3606 a_type({ typename = "function", args = { NUMBER, VARARG_ANY }, rets = { ANY } }),
3607 a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { NUMBER } }),
3608 },
3609 }),
3610 ["setmetatable"] = a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ALPHA, NOMINAL_METATABLE }, rets = { ALPHA } }),
3611 ["tonumber"] = a_type({ typename = "function", args = { ANY, NUMBER }, rets = { NUMBER } }),
3612 ["tostring"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }),
3613 ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }),
3614 ["FILE"] = a_type({
3615 typename = "typetype",
3616 def = a_type({
3617 typename = "record",
3618 fields = {
3619 ["close"] = a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING } }),
3620 ["flush"] = a_type({ typename = "function", args = { NOMINAL_FILE }, rets = {} }),
3621 ["lines"] = a_type({ typename = "function", args = { NOMINAL_FILE, a_type({ typename = "union", types = { STRING, NUMBER }, is_va = true }) }, rets = {
3622 a_type({ typename = "function", args = {}, rets = { VARARG_STRING } }),
3623 }, }),
3624 ["read"] = a_type({
3625 typename = "poly",
3626 types = {
3627 a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { STRING, STRING } }),
3628 a_type({ typename = "function", args = { NOMINAL_FILE, NUMBER }, rets = { STRING, STRING } }),
3629 },
3630 }),
3631 ["seek"] = a_type({
3632 typename = "poly",
3633 types = {
3634 a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NUMBER, STRING } }),
3635 a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { NUMBER, STRING } }),
3636 a_type({ typename = "function", args = { NOMINAL_FILE, STRING, NUMBER }, rets = { NUMBER, STRING } }),
3637 },
3638 }),
3639 ["setvbuf"] = a_type({ typename = "function", args = { NOMINAL_FILE, STRING, OPT_NUMBER }, rets = {} }),
3640 ["write"] = a_type({ typename = "function", args = { NOMINAL_FILE, VARARG_STRING }, rets = { NOMINAL_FILE, STRING } }),
3641
3642 },
3643 }),
3644 }),
3645 ["METATABLE"] = a_type({
3646 typename = "typetype",
3647 def = a_type({
3648 typename = "record",
3649 fields = {
3650 ["__call"] = FUNCTION,
3651 ["__gc"] = a_type({ typename = "function", args = { ANY }, rets = {} }),
3652 ["__index"] = ANY,
3653 ["__len"] = a_type({ typename = "function", args = { ANY }, rets = { NUMBER } }),
3654 ["__mode"] = a_type({ typename = "enum", enumset = { ["k"] = true, ["v"] = true, ["kv"] = true } }),
3655 ["__newindex"] = ANY,
3656 ["__pairs"] = a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { a_type({ typename = "map", keys = ALPHA, values = BETA }) }, rets = {
3657 a_type({ typename = "function", args = {}, rets = { ALPHA, BETA } }),
3658 }, }),
3659 ["__tostring"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }),
3660 ["__name"] = STRING,
3661
3662
3663 ["__add"] = FUNCTION,
3664 ["__sub"] = FUNCTION,
3665 ["__mul"] = FUNCTION,
3666 ["__div"] = FUNCTION,
3667 ["__idiv"] = FUNCTION,
3668 ["__mod"] = FUNCTION,
3669 ["__pow"] = FUNCTION,
3670 ["__unm"] = FUNCTION,
3671 ["__band"] = FUNCTION,
3672 ["__bor"] = FUNCTION,
3673 ["__bxor"] = FUNCTION,
3674 ["__bnot"] = FUNCTION,
3675 ["__shl"] = FUNCTION,
3676 ["__shr"] = FUNCTION,
3677 ["__concat"] = FUNCTION,
3678 ["__eq"] = FUNCTION,
3679 ["__lt"] = FUNCTION,
3680 ["__le"] = FUNCTION,
3681 },
3682 }),
3683 }),
3684 ["coroutine"] = a_type({
3685 typename = "record",
3686 fields = {
3687 ["create"] = a_type({ typename = "function", args = { FUNCTION }, rets = { THREAD } }),
3688 ["close"] = a_type({ typename = "function", args = { THREAD }, rets = { BOOLEAN, STRING } }),
3689 ["isyieldable"] = a_type({ typename = "function", args = {}, rets = { BOOLEAN } }),
3690 ["resume"] = a_type({ typename = "function", args = { THREAD, VARARG_ANY }, rets = { BOOLEAN, VARARG_ANY } }),
3691 ["running"] = a_type({ typename = "function", args = {}, rets = { THREAD, BOOLEAN } }),
3692 ["status"] = a_type({ typename = "function", args = { THREAD }, rets = { STRING } }),
3693 ["wrap"] = a_type({ typename = "function", args = { FUNCTION }, rets = { FUNCTION } }),
3694 ["yield"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = { VARARG_ANY } }),
3695 },
3696 }),
3697 ["debug"] = a_type({
3698 typename = "record",
3699 fields = {
3700 ["traceback"] = a_type({
3701 typename = "poly",
3702 types = {
3703 a_type({ typename = "function", args = { THREAD, STRING, NUMBER }, rets = { STRING } }),
3704 a_type({ typename = "function", args = { STRING, NUMBER }, rets = { STRING } }),
3705 },
3706 }),
3707 ["getinfo"] = a_type({
3708 typename = "poly",
3709 types = {
3710 a_type({ typename = "function", args = { ANY }, rets = { DEBUG_GETINFO_TABLE } }),
3711 a_type({ typename = "function", args = { ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }),
3712 a_type({ typename = "function", args = { ANY, ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }),
3713 },
3714 }),
3715 },
3716 }),
3717 ["io"] = a_type({
3718 typename = "record",
3719 fields = {
3720 ["close"] = a_type({
3721 typename = "poly",
3722 types = {
3723 a_type({ typename = "function", args = {}, rets = { BOOLEAN, STRING } }),
3724 a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING } }),
3725 },
3726 }),
3727 ["flush"] = a_type({ typename = "function", args = {}, rets = {} }),
3728 ["input"] = a_type({
3729 typename = "poly",
3730 types = {
3731 a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }),
3732 a_type({ typename = "function", args = { STRING }, rets = { NOMINAL_FILE } }),
3733 a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NOMINAL_FILE } }),
3734 },
3735 }),
3736 ["lines"] = a_type({ typename = "function", args = { OPT_STRING, a_type({ typename = "union", types = { STRING, NUMBER }, is_va = true }) }, rets = {
3737 a_type({ typename = "function", args = {}, rets = { VARARG_STRING } }),
3738 }, }),
3739 ["open"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { NOMINAL_FILE, STRING } }),
3740 ["output"] = a_type({
3741 typename = "poly",
3742 types = {
3743 a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }),
3744 a_type({ typename = "function", args = { STRING }, rets = { NOMINAL_FILE } }),
3745 a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NOMINAL_FILE } }),
3746 },
3747 }),
3748 ["popen"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { NOMINAL_FILE, STRING } }),
3749 ["read"] = a_type({
3750 typename = "poly",
3751 types = {
3752 a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { STRING, STRING } }),
3753 a_type({ typename = "function", args = { NOMINAL_FILE, NUMBER }, rets = { STRING, STRING } }),
3754 },
3755 }),
3756 ["stderr"] = NOMINAL_FILE,
3757 ["stdin"] = NOMINAL_FILE,
3758 ["stdout"] = NOMINAL_FILE,
3759 ["tmpfile"] = a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }),
3760 ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }),
3761 ["write"] = a_type({ typename = "function", args = { VARARG_STRING }, rets = { NOMINAL_FILE, STRING } }),
3762 },
3763 }),
3764 ["math"] = a_type({
3765 typename = "record",
3766 fields = {
3767 ["abs"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3768 ["acos"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3769 ["asin"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3770 ["atan"] = a_type({
3771 typename = "poly",
3772 a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3773 a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }),
3774 }),
3775 ["atan2"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }),
3776 ["ceil"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3777 ["cos"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3778 ["cosh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3779 ["deg"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3780 ["exp"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3781 ["floor"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3782 ["fmod"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }),
3783 ["frexp"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER, NUMBER } }),
3784 ["huge"] = NUMBER,
3785 ["ldexp"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }),
3786 ["log"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3787 ["log10"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3788 ["max"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { NUMBER } }),
3789 ["maxinteger"] = NUMBER,
3790 ["min"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { NUMBER } }),
3791 ["mininteger"] = NUMBER,
3792 ["modf"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER, NUMBER } }),
3793 ["pi"] = NUMBER,
3794 ["pow"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }),
3795 ["rad"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3796 ["random"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }),
3797 ["randomseed"] = a_type({ typename = "function", args = { NUMBER }, rets = {} }),
3798 ["sin"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3799 ["sinh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3800 ["sqrt"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3801 ["tan"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3802 ["tanh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }),
3803 ["tointeger"] = a_type({ typename = "function", args = { ANY }, rets = { NUMBER } }),
3804 ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }),
3805 ["ult"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { BOOLEAN } }),
3806 },
3807 }),
3808 ["os"] = a_type({
3809 typename = "record",
3810 fields = {
3811 ["clock"] = a_type({ typename = "function", args = {}, rets = { NUMBER } }),
3812 ["date"] = a_type({
3813 typename = "poly",
3814 types = {
3815 a_type({ typename = "function", args = {}, rets = { STRING } }),
3816 a_type({ typename = "function", args = { STRING, OPT_STRING }, rets = { a_type({ typename = "union", types = { STRING, OS_DATE_TABLE } }) } }),
3817 },
3818 }),
3819 ["difftime"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }),
3820 ["execute"] = a_type({ typename = "function", args = { STRING }, rets = { BOOLEAN, STRING, NUMBER } }),
3821 ["exit"] = a_type({
3822 typename = "poly",
3823 types = {
3824 a_type({ typename = "function", args = { NUMBER, BOOLEAN }, rets = {} }),
3825 a_type({ typename = "function", args = { BOOLEAN, BOOLEAN }, rets = {} }),
3826 },
3827 }),
3828 ["getenv"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }),
3829 ["remove"] = a_type({ typename = "function", args = { STRING }, rets = { BOOLEAN, STRING } }),
3830 ["rename"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { BOOLEAN, STRING } }),
3831 ["setlocale"] = a_type({ typename = "function", args = { STRING, OPT_STRING }, rets = { STRING } }),
3832 ["time"] = a_type({ typename = "function", args = {}, rets = { NUMBER } }),
3833 ["tmpname"] = a_type({ typename = "function", args = {}, rets = { STRING } }),
3834 },
3835 }),
3836 ["package"] = a_type({
3837 typename = "record",
3838 fields = {
3839 ["config"] = STRING,
3840 ["cpath"] = STRING,
3841 ["loaded"] = a_type({
3842 typename = "map",
3843 keys = STRING,
3844 values = ANY,
3845 }),
3846 ["loaders"] = a_type({
3847 typename = "array",
3848 elements = a_type({ typename = "function", args = { STRING }, rets = { ANY } }),
3849 }),
3850 ["loadlib"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION } }),
3851 ["path"] = STRING,
3852 ["preload"] = TABLE,
3853 ["searchers"] = a_type({
3854 typename = "array",
3855 elements = a_type({ typename = "function", args = { STRING }, rets = { ANY } }),
3856 }),
3857 ["searchpath"] = a_type({ typename = "function", args = { STRING, STRING, OPT_STRING, OPT_STRING }, rets = { STRING, STRING } }),
3858 },
3859 }),
3860 ["string"] = a_type({
3861 typename = "record",
3862 fields = {
3863 ["byte"] = a_type({
3864 typename = "poly",
3865 types = {
3866 a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }),
3867 a_type({ typename = "function", args = { STRING, NUMBER }, rets = { NUMBER } }),
3868 a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { VARARG_NUMBER } }),
3869 },
3870 }),
3871 ["char"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { STRING } }),
3872 ["dump"] = a_type({
3873 typename = "poly",
3874 types = {
3875 a_type({ typename = "function", args = { FUNCTION }, rets = { STRING } }),
3876 a_type({ typename = "function", args = { FUNCTION, BOOLEAN }, rets = { STRING } }),
3877 },
3878 }),
3879 ["find"] = a_type({
3880 typename = "poly",
3881 types = {
3882 a_type({ typename = "function", args = { STRING, STRING }, rets = { NUMBER, NUMBER, VARARG_STRING } }),
3883 a_type({ typename = "function", args = { STRING, STRING, NUMBER }, rets = { NUMBER, NUMBER, VARARG_STRING } }),
3884 a_type({ typename = "function", args = { STRING, STRING, NUMBER, BOOLEAN }, rets = { NUMBER, NUMBER, VARARG_STRING } }),
3885
3886 },
3887 }),
3888 ["format"] = a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { STRING } }),
3889 ["gmatch"] = a_type({ typename = "function", args = { STRING, STRING }, rets = {
3890 a_type({ typename = "function", args = {}, rets = { STRING } }),
3891 }, }),
3892 ["gsub"] = a_type({
3893 typename = "poly",
3894 types = {
3895 a_type({ typename = "function", args = { STRING, STRING, STRING, NUMBER }, rets = { STRING, NUMBER } }),
3896 a_type({ typename = "function", args = { STRING, STRING, a_type({ typename = "map", keys = STRING, values = STRING }), NUMBER }, rets = { STRING, NUMBER } }),
3897 a_type({ typename = "function", args = { STRING, STRING, a_type({ typename = "function", args = { VARARG_STRING }, rets = { STRING } }) }, rets = { STRING, NUMBER } }),
3898
3899 },
3900 }),
3901 ["len"] = a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }),
3902 ["lower"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }),
3903 ["match"] = a_type({ typename = "function", args = { STRING, STRING, NUMBER }, rets = { VARARG_STRING } }),
3904 ["pack"] = a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { STRING } }),
3905 ["packsize"] = a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }),
3906 ["rep"] = a_type({ typename = "function", args = { STRING, NUMBER }, rets = { STRING } }),
3907 ["reverse"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }),
3908 ["sub"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { STRING } }),
3909 ["unpack"] = a_type({ typename = "function", args = { STRING, STRING, OPT_NUMBER }, rets = { VARARG_ANY } }),
3910 ["upper"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }),
3911 },
3912 }),
3913 ["table"] = a_type({
3914 typename = "record",
3915 fields = {
3916 ["concat"] = a_type({ typename = "function", args = { ARRAY_OF_STRING, OPT_STRING, OPT_NUMBER, OPT_NUMBER }, rets = { STRING } }),
3917 ["insert"] = a_type({
3918 typename = "poly",
3919 types = {
3920 a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, ALPHA }, rets = {} }),
3921 a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, ALPHA }, rets = {} }),
3922 },
3923 }),
3924 ["move"] = a_type({
3925 typename = "poly",
3926 types = {
3927 a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, NUMBER, NUMBER }, rets = { ARRAY_OF_ALPHA } }),
3928 a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, NUMBER, NUMBER, ARRAY_OF_ALPHA }, rets = { ARRAY_OF_ALPHA } }),
3929 },
3930 }),
3931 ["pack"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = { TABLE } }),
3932 ["remove"] = a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, OPT_NUMBER }, rets = { ALPHA } }),
3933 ["sort"] = a_type({
3934 typename = "poly",
3935 types = {
3936 a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA }, rets = {} }),
3937 a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, a_type({ typename = "function", args = { ALPHA, ALPHA }, rets = { BOOLEAN } }) }, rets = {} }),
3938 },
3939 }),
3940 ["unpack"] = a_type({
3941 typename = "function",
3942 needs_compat53 = true,
3943 typeargs = { ARG_ALPHA },
3944 args = { ARRAY_OF_ALPHA, NUMBER, NUMBER },
3945 rets = { VARARG_ALPHA },
3946 }),
3947 },
3948 }),
3949 ["utf8"] = a_type({
3950 typename = "record",
3951 fields = {
3952 ["char"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { STRING } }),
3953 ["charpattern"] = STRING,
3954 ["codepoint"] = a_type({ typename = "function", args = { STRING, OPT_NUMBER, OPT_NUMBER }, rets = { VARARG_NUMBER } }),
3955 ["codes"] = a_type({ typename = "function", args = { STRING }, rets = {
3956 a_type({ typename = "function", args = {}, rets = { NUMBER, STRING } }),
3957 }, }),
3958 ["len"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { NUMBER } }),
3959 ["offset"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { NUMBER } }),
3960 },
3961 }),
3962}
3963
3964for _, t in pairs(standard_library) do
3965 fill_field_order(t)
3966 if t.typename == "typetype" then
3967 fill_field_order(t.def)
3968 end
3969end
3970fill_field_order(OS_DATE_TABLE)
3971fill_field_order(DEBUG_GETINFO_TABLE)
3972
3973NOMINAL_FILE.found = standard_library["FILE"]
3974NOMINAL_METATABLE.found = standard_library["METATABLE"]
3975
3976local compat53_code_cache = {}
3977
3978local function add_compat53_entries(program, used_set)
3979 if not next(used_set) then
3980 return
3981 end
3982
3983 local used_list = {}
3984 for name, _ in pairs(used_set) do
3985 table.insert(used_list, name)
3986 end
3987 table.sort(used_list)
3988
3989 local compat53_loaded = false
3990
3991 local n = 1
3992 local function load_code(name, text)
3993 local code = compat53_code_cache[name]
3994 if not code then
3995 local tokens = tl.lex(text)
3996 local _
3997 _, code = tl.parse_program(tokens, {}, "@internal")
3998 tl.type_check(code, { lax = false, skip_compat53 = true })
3999 code = code[1]
4000 compat53_code_cache[name] = code
4001 end
4002 table.insert(program, n, code)
4003 n = n + 1
4004 end
4005
4006 for i, name in ipairs(used_list) do
4007 local mod, fn = name:match("([^.]*)%.(.*)")
4008 local errs = {}
4009 local text
4010 local code = compat53_code_cache[name]
4011 if not code then
4012
4013 if name == "table.unpack" then
4014 load_code(name, "local _tl_table_unpack = unpack or table.unpack")
4015 else
4016 if not compat53_loaded then
4017 load_code("compat53", "local _tl_compat53 = ((tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3) and require('compat53.module')")
4018 compat53_loaded = true
4019 end
4020 load_code(name, (("local $NAME = _tl_compat53 and _tl_compat53.$NAME or $NAME"):gsub("$NAME", name)))
4021 end
4022 end
4023 end
4024 program.y = 1
4025end
4026
4027local function get_stdlib_compat53(lax)
4028 if lax then
4029 return {
4030 ["utf8"] = true,
4031 }
4032 else
4033 return {
4034 ["io"] = true,
4035 ["math"] = true,
4036 ["string"] = true,
4037 ["table"] = true,
4038 ["utf8"] = true,
4039 ["coroutine"] = true,
4040 ["os"] = true,
4041 ["package"] = true,
4042 ["debug"] = true,
4043 ["load"] = true,
4044 ["loadfile"] = true,
4045 ["assert"] = true,
4046 ["pairs"] = true,
4047 ["ipairs"] = true,
4048 ["pcall"] = true,
4049 ["xpcall"] = true,
4050 ["rawlen"] = true,
4051 }
4052 end
4053end
4054
4055local function init_globals(lax)
4056 local globals = {}
4057 local stdlib_compat53 = get_stdlib_compat53(lax)
4058
4059 for name, typ in pairs(standard_library) do
4060 globals[name] = { t = typ, needs_compat53 = stdlib_compat53[name], is_const = true }
4061 end
4062
4063
4064
4065
4066 globals["@is_va"] = { t = VARARG_ANY }
4067
4068 return globals
4069end
4070
4071function tl.init_env(lax, skip_compat53)
4072 local env = {
4073 modules = {},
4074 globals = init_globals(lax),
4075 skip_compat53 = skip_compat53,
4076 }
4077
4078
4079 for name, var in pairs(standard_library) do
4080 if var.typename == "record" then
4081 env.modules[name] = var
4082 end
4083 end
4084
4085 return env
4086end
4087
4088function tl.type_check(ast, opts)
4089 opts = opts or {}
4090 opts.env = opts.env or tl.init_env(opts.lax, opts.skip_compat53)
4091 local lax = opts.lax
4092 local filename = opts.filename
4093 local result = opts.result or {
4094 syntax_errors = {},
4095 type_errors = {},
4096 unknowns = {},
4097 }
4098
4099 local stdlib_compat53 = get_stdlib_compat53(lax)
4100
4101 local st = { opts.env.globals }
4102
4103 local all_needs_compat53 = {}
4104
4105 local errors = result.type_errors or {}
4106 local unknowns = result.unknowns or {}
4107 local module_type
4108
4109 local function find_var(name)
4110 if name == "_G" then
4111
4112 local globals = {}
4113 for k, v in pairs(st[1]) do
4114 if k:sub(1, 1) ~= "@" then
4115 globals[k] = v.t
4116 end
4117 end
4118 local field_order = {}
4119 for k, _ in pairs(globals) do
4120 table.insert(field_order, k)
4121 end
4122 return a_type({
4123 typename = "record",
4124 field_order = field_order,
4125 fields = globals,
4126 }), false
4127 end
4128 for i = #st, 1, -1 do
4129 local scope = st[i]
4130 if scope[name] then
4131 if i == 1 and scope[name].needs_compat53 then
4132 all_needs_compat53[name] = true
4133 end
4134 local typ = scope[name].t
4135
4136 return typ, scope[name].is_const
4137 end
4138 end
4139 end
4140
4141 local function resolve_typevars(t, seen)
4142 seen = seen or {}
4143 if seen[t] then
4144 return seen[t]
4145 end
4146
4147 local orig_t = t
4148 local clear_tk = false
4149 if t.typename == "typevar" then
4150 local tv = find_var(t.typevar)
4151 if tv then
4152 t = tv
4153 clear_tk = true
4154 else
4155 t = UNKNOWN
4156 end
4157 end
4158
4159 local copy = {}
4160 seen[orig_t] = copy
4161
4162 for k, v in pairs(t) do
4163 local cp = copy
4164 if type(v) == "table" then
4165 cp[k] = resolve_typevars(v, seen)
4166 else
4167 cp[k] = v
4168 end
4169 end
4170
4171 if clear_tk then
4172 copy.tk = nil
4173 end
4174
4175 return copy
4176 end
4177
4178 local function find_type(names, accept_typearg)
4179 local typ = find_var(names[1])
4180 if not typ then
4181 return nil
4182 end
4183 for i = 2, #names do
4184 local nested = typ.fields or (typ.def and typ.def.fields)
4185 if nested then
4186 typ = nested[names[i]]
4187 if typ == nil then
4188 return nil
4189 end
4190 else
4191 break
4192 end
4193 end
4194 if typ then
4195 if accept_typearg and typ.typename == "typearg" then
4196 return typ
4197 end
4198 if is_type(typ) then
4199 return typ
4200 end
4201 end
4202 return nil
4203 end
4204
4205 local function infer_var(emptytable, t, node)
4206 local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration")
4207 local nst = is_global and 1 or #st
4208 for i = nst, 1, -1 do
4209 local scope = st[i]
4210 if scope[emptytable.assigned_to] then
4211 scope[emptytable.assigned_to] = {
4212 t = t,
4213 is_const = false,
4214 }
4215 t.inferred_at = node
4216 t.inferred_at_file = filename
4217 end
4218 end
4219 end
4220
4221 local function find_global(name)
4222 local scope = st[1]
4223 if scope[name] then
4224 return scope[name].t, scope[name].is_const
4225 end
4226 end
4227
4228 local function resolve_tuple(t)
4229 if t.typename == "tuple" then
4230 t = t[1]
4231 end
4232 if t == nil then
4233 return NIL
4234 end
4235 return t
4236 end
4237
4238 local function error_in_type(where, msg, ...)
4239 local n = select("#", ...)
4240 if n > 0 then
4241 local showt = {}
4242 for i = 1, n do
4243 local t = select(i, ...)
4244 if t.typename == "invalid" then
4245 return nil
4246 end
4247 showt[i] = show_type(t)
4248 end
4249 msg = msg:format(_tl_table_unpack(showt))
4250 end
4251
4252 return {
4253 y = where.y,
4254 x = where.x,
4255 msg = msg,
4256 filename = where.filename or filename,
4257 }
4258 end
4259
4260 local function type_error(t, msg, ...)
4261 local e = error_in_type(t, msg, ...)
4262 if e then
4263 table.insert(errors, e)
4264 return true
4265 else
4266 return false
4267 end
4268 end
4269
4270 local function node_error(node, msg, ...)
4271 local ok = type_error(node, msg, ...)
4272 node.type = INVALID
4273 return node.type
4274 end
4275
4276 local function terr(t, s, ...)
4277 return { error_in_type(t, s, ...) }
4278 end
4279
4280 local function add_unknown(node, name)
4281 table.insert(unknowns, { y = node.y, x = node.x, msg = name, filename = filename })
4282 end
4283
4284 local function add_var(node, var, valtype, is_const, is_narrowing)
4285 if lax and node and is_unknown(valtype) and (var ~= "self" and var ~= "...") then
4286 add_unknown(node, var)
4287 end
4288 if st[#st][var] and is_narrowing then
4289 if not st[#st][var].is_narrowed then
4290 st[#st][var].narrowed_from = st[#st][var].t
4291 end
4292 st[#st][var].is_narrowed = true
4293 st[#st][var].t = valtype
4294 else
4295 st[#st][var] = { t = valtype, is_const = is_const, is_narrowed = is_narrowing }
4296 end
4297 end
4298
4299 local CompareTypes = {}
4300
4301 local function compare_typevars(t1, t2, comp)
4302 local tv1 = find_var(t1.typevar)
4303 local tv2 = find_var(t2.typevar)
4304 if t1.typevar == t2.typevar then
4305 local has_t1 = not not tv1
4306 local has_t2 = not not tv2
4307 if has_t1 == has_t2 then
4308 return true
4309 end
4310 end
4311 local function cmp(k, v, a, b)
4312 if find_var(k) then
4313 return comp(a, b)
4314 else
4315 add_var(nil, k, resolve_typevars(v))
4316 return true
4317 end
4318 end
4319 if t2.typename == "typevar" then
4320 return cmp(t2.typevar, t1, t1, tv2)
4321 else
4322 return cmp(t1.typevar, t2, tv1, t2)
4323 end
4324 end
4325
4326 local function add_errs_prefixing(src, dst, prefix, node)
4327 if not src then
4328 return
4329 end
4330 for i, err in ipairs(src) do
4331 err.msg = prefix .. err.msg
4332
4333
4334 if node and node.y and (
4335 (err.filename ~= filename) or
4336 (not err.y) or
4337 (node.y > err.y or (node.y == err.y and node.x > err.x))) then
4338
4339 err.y = node.y
4340 err.x = node.x
4341 err.filename = filename
4342 end
4343
4344 table.insert(dst, err)
4345 end
4346 end
4347
4348 local is_a
4349
4350 local TypeGetter = {}
4351
4352 local function match_record_fields(t1, t2, cmp)
4353 cmp = cmp or is_a
4354 local fielderrs = {}
4355 for _, k in ipairs(t1.field_order) do
4356 local f = t1.fields[k]
4357 local t2k = t2(k)
4358 if t2k == nil then
4359 if not lax then
4360 table.insert(fielderrs, error_in_type(f, "unknown field " .. k))
4361 end
4362 else
4363 local match, errs = is_a(f, t2k)
4364 add_errs_prefixing(errs, fielderrs, "record field doesn't match: " .. k .. ": ")
4365 end
4366 end
4367 if #fielderrs > 0 then
4368 return false, fielderrs
4369 end
4370 return true
4371 end
4372
4373 local function match_fields_to_record(t1, t2, cmp)
4374 return match_record_fields(t1, function(k) return t2.fields[k] end, cmp)
4375 end
4376
4377 local function match_fields_to_map(t1, t2)
4378 if not match_record_fields(t1, function(_) return t2.values end) then
4379 return false, { error_in_type(t1, "not all fields have type %s", t2.values) }
4380 end
4381 return true
4382 end
4383
4384 local function arg_check(cmp, a, b, at, n, errs)
4385 local matches, match_errs = cmp(a, b)
4386 if not matches then
4387 add_errs_prefixing(match_errs, errs, "argument " .. n .. ": ", at)
4388 return false
4389 end
4390 return true
4391 end
4392
4393 local same_type
4394
4395 local function has_all_types_of(t1s, t2s)
4396 for _, t1 in ipairs(t1s) do
4397 local found = false
4398 for _, t2 in ipairs(t2s) do
4399 if is_a(t2, t1) then
4400 found = true
4401 break
4402 end
4403 end
4404 if not found then
4405 return false
4406 end
4407 end
4408 return true
4409 end
4410
4411 local function any_errors(all_errs)
4412 if #all_errs == 0 then
4413 return true
4414 else
4415 return false, all_errs
4416 end
4417 end
4418
4419 local function are_same_nominals(t1, t2)
4420 local same_names
4421 if t1.found and t2.found then
4422 same_names = t1.found.typeid == t2.found.typeid
4423 else
4424 local ft1 = t1.found or find_type(t1.names)
4425 local ft2 = t2.found or find_type(t2.names)
4426 if ft1 and ft2 then
4427 same_names = ft1.typeid == ft2.typeid
4428 else
4429 if not ft1 then
4430 type_error(t1, "unknown type %s", t1)
4431 end
4432 if not ft2 then
4433 type_error(t2, "unknown type %s", t2)
4434 end
4435 return false, {}
4436 end
4437 end
4438
4439 if same_names then
4440 if t1.typevals == nil and t2.typevals == nil then
4441 return true
4442 elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then
4443 local all_errs = {}
4444 for i = 1, #t1.typevals do
4445 local ok, errs = same_type(t2.typevals[i], t1.typevals[i])
4446 add_errs_prefixing(errs, all_errs, "type parameter <" .. show_type(t1.typevals[i]) .. ">: ", t1)
4447 end
4448 if #all_errs == 0 then
4449 return true
4450 else
4451 return false, all_errs
4452 end
4453 end
4454 else
4455 return false, terr(t1, "%s is not a %s", t1, t2)
4456 end
4457 end
4458
4459 same_type = function(t1, t2)
4460 assert(type(t1) == "table")
4461 assert(type(t2) == "table")
4462
4463 if t1.typename == "typevar" or t2.typename == "typevar" then
4464 return compare_typevars(t1, t2, same_type)
4465 end
4466
4467 if t1.typename ~= t2.typename then
4468 return false, terr(t1, "got %s, expected %s", t1, t2)
4469 end
4470 if t1.typename == "array" then
4471 return same_type(t1.elements, t2.elements)
4472 elseif t1.typename == "map" then
4473 local all_errs = {}
4474 local k_ok, k_errs = same_type(t1.keys, t2.keys)
4475 if not k_ok then
4476 add_errs_prefixing(k_errs, all_errs, "keys", t1)
4477 end
4478 local v_ok, v_errs = same_type(t1.values, t2.values)
4479 if not v_ok then
4480 add_errs_prefixing(v_errs, all_errs, "values", t1)
4481 end
4482 return any_errors(all_errs)
4483 elseif t1.typename == "union" then
4484 if has_all_types_of(t1.types, t2.types) and
4485 has_all_types_of(t2.types, t1.types) then
4486 return true
4487 else
4488 return false, terr(t1, "got %s, expected %s", t1, t2)
4489 end
4490 elseif t1.typename == "nominal" then
4491 return are_same_nominals(t1, t2)
4492 elseif t1.typename == "record" then
4493 return match_fields_to_record(t1, t2, same_type)
4494 elseif t1.typename == "function" then
4495 if #t1.args ~= #t2.args then
4496 return false, terr(t1, "different number of input arguments: got " .. #t1.args .. ", expected " .. #t2.args)
4497 end
4498 if #t1.rets ~= #t2.rets then
4499 return false, terr(t1, "different number of return values: got " .. #t1.args .. ", expected " .. #t2.args)
4500 end
4501 local all_errs = {}
4502 for i = 1, #t1.args do
4503 arg_check(same_type, t1.args[i], t2.args[i], t1, i, all_errs)
4504 end
4505 for i = 1, #t1.rets do
4506 local ok, errs = same_type(t1.rets[i], t2.rets[i])
4507 add_errs_prefixing(errs, all_errs, "return " .. i, t1)
4508 end
4509 return any_errors(all_errs)
4510 elseif t1.typename == "arrayrecord" then
4511 local ok, errs = same_type(t1.elements, t2.elements)
4512 if not ok then
4513 return ok, errs
4514 end
4515 return match_fields_to_record(t1, t2, same_type)
4516 end
4517 return true
4518 end
4519
4520 local function a_union(types)
4521 local ts = {}
4522 local stack = {}
4523 local i = 1
4524 while types[i] or stack[1] do
4525 local t
4526 if stack[1] then
4527 t = table.remove(stack)
4528 else
4529 t = types[i]
4530 i = i + 1
4531 end
4532 if t.typename == "union" then
4533 for _, s in ipairs(t.types) do
4534 table.insert(stack, s)
4535 end
4536 else
4537 table.insert(ts, t)
4538 end
4539 end
4540 return a_type({
4541 typename = "union",
4542 types = ts,
4543 })
4544 end
4545
4546 local function is_vararg(t)
4547 return t.args and #t.args > 0 and t.args[#t.args].is_va
4548 end
4549
4550 local function combine_errs(...)
4551 local errs
4552 for i = 1, select("#", ...) do
4553 local e = select(i, ...)
4554 if e then
4555 errs = errs or {}
4556 for _, err in ipairs(e) do
4557 table.insert(errs, err)
4558 end
4559 end
4560 end
4561 if not errs then
4562 return true
4563 else
4564 return false, errs
4565 end
4566 end
4567
4568 local resolve_unary = nil
4569
4570 local function is_known_table_type(t)
4571 return (t.typename == "array" or t.typename == "map" or t.typename == "record" or t.typename == "arrayrecord")
4572 end
4573
4574 is_a = function(t1, t2, for_equality)
4575 assert(type(t1) == "table")
4576 assert(type(t2) == "table")
4577
4578 if lax and (is_unknown(t1) or is_unknown(t2)) then
4579 return true
4580 end
4581
4582 if t1.typename == "nil" then
4583 return true
4584 end
4585
4586 if t2.typename ~= "tuple" then
4587 t1 = resolve_tuple(t1)
4588 end
4589 if t2.typename == "tuple" and t1.typename ~= "tuple" then
4590 t1 = a_type({
4591 typename = "tuple",
4592 [1] = t1,
4593 })
4594 end
4595
4596 if t1.typename == "typevar" or t2.typename == "typevar" then
4597 return compare_typevars(t1, t2, is_a)
4598 end
4599
4600 if t2.typename == "any" then
4601 return true
4602 elseif t2.typename == "poly" then
4603 for _, t in ipairs(t2.types) do
4604 if is_a(t1, t, for_equality) then
4605 return true
4606 end
4607 end
4608 return false, terr(t1, "cannot match against any alternatives of the polymorphic type")
4609 elseif t1.typename == "union" and t2.typename == "union" then
4610 if has_all_types_of(t1.types, t2.types) then
4611 return true
4612 else
4613 return false, terr(t1, "got %s, expected %s", t1, t2)
4614 end
4615 elseif t2.typename == "union" then
4616 for _, t in ipairs(t2.types) do
4617 if is_a(t1, t, for_equality) then
4618 return true
4619 end
4620 end
4621 elseif t1.typename == "poly" then
4622 for _, t in ipairs(t1.types) do
4623 if is_a(t, t2, for_equality) then
4624 return true
4625 end
4626 end
4627 return false, terr(t1, "cannot match against any alternatives of the polymorphic type")
4628 elseif t1.typename == "nominal" and t2.typename == "nominal" and #t2.names == 1 and t2.names[1] == "any" then
4629 return true
4630 elseif t1.typename == "nominal" and t2.typename == "nominal" then
4631 return are_same_nominals(t1, t2)
4632 elseif t1.typename == "enum" and t2.typename == "string" then
4633 local ok
4634 if for_equality then
4635 ok = t2.tk and t1.enumset[unquote(t2.tk)]
4636 else
4637 ok = true
4638 end
4639 if ok then
4640 return true
4641 else
4642 return false, terr(t1, "enum is incompatible with %s", t2)
4643 end
4644 elseif t1.typename == "string" and t2.typename == "enum" then
4645 local ok = t1.tk and t2.enumset[unquote(t1.tk)]
4646 if ok then
4647 return true
4648 else
4649 if t1.tk then
4650 return false, terr(t1, "%s is not a member of %s", t1, t2)
4651 else
4652 return false, terr(t1, "string is not a %s", t2)
4653 end
4654 end
4655 elseif t1.typename == "nominal" or t2.typename == "nominal" then
4656 local t1u = resolve_unary(t1)
4657 local t2u = resolve_unary(t2)
4658 local ok, errs = is_a(t1u, t2u, for_equality)
4659 if errs and #errs == 1 then
4660 if errs[1].msg:match("^got ") then
4661
4662
4663 errs = terr(t1, "got %s, expected %s", t1, t2)
4664 end
4665 end
4666 return ok, errs
4667 elseif t1.typename == "emptytable" and is_known_table_type(t2) then
4668 return true
4669 elseif t2.typename == "array" then
4670 if is_array_type(t1) then
4671 if is_a(t1.elements, t2.elements) then
4672 return true
4673 end
4674 elseif t1.typename == "map" then
4675 local _, errs_keys = is_a(t1.keys, NUMBER)
4676 local _, errs_values = is_a(t1.values, t2.elements)
4677 return combine_errs(errs_keys, errs_values)
4678 end
4679 elseif t2.typename == "record" then
4680 if is_record_type(t1) then
4681 return match_fields_to_record(t1, t2)
4682 elseif t1.typename == "typetype" and t1.def.typename == "record" then
4683 return is_a(t1.def, t2, for_equality)
4684 end
4685 elseif t2.typename == "arrayrecord" then
4686 if t1.typename == "array" then
4687 return is_a(t1.elements, t2.elements)
4688 elseif t1.typename == "record" then
4689 return match_fields_to_record(t1, t2)
4690 elseif t1.typename == "arrayrecord" then
4691 if not is_a(t1.elements, t2.elements) then
4692 return false, terr(t1, "array parts have incompatible element types")
4693 end
4694 return match_fields_to_record(t1, t2)
4695 end
4696 elseif t2.typename == "map" then
4697 if t1.typename == "map" then
4698 local _, errs_keys = is_a(t1.keys, t2.keys)
4699 local _, errs_values = is_a(t2.values, t1.values)
4700 if t2.values.typename == "any" then
4701 errs_values = {}
4702 end
4703 return combine_errs(errs_keys, errs_values)
4704 elseif t1.typename == "array" then
4705 local _, errs_keys = is_a(NUMBER, t2.keys)
4706 local _, errs_values = is_a(t1.elements, t2.values)
4707 return combine_errs(errs_keys, errs_values)
4708 elseif is_record_type(t1) then
4709 if not is_a(t2.keys, STRING) then
4710 return false, terr(t1, "can't match a record to a map with non-string keys")
4711 end
4712 if t2.keys.typename == "enum" then
4713 for _, k in ipairs(t1.field_order) do
4714 if not t2.keys.enumset[k] then
4715 return false, terr(t1, "key is not an enum value: " .. k)
4716 end
4717 end
4718 end
4719 return match_fields_to_map(t1, t2)
4720 end
4721 elseif t1.typename == "function" and t2.typename == "function" then
4722 local all_errs = {}
4723 if (not is_vararg(t2)) and #t1.args > #t2.args then
4724 t1.args.typename = "tuple"
4725 t2.args.typename = "tuple"
4726 table.insert(all_errs, error_in_type(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args))
4727 else
4728 for i = (t1.is_method and 2 or 1), #t1.args do
4729 arg_check(is_a, t1.args[i], t2.args[i] or ANY, nil, i, all_errs)
4730 end
4731 end
4732 local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets[#t2.rets].is_va
4733 if #t1.rets < #t2.rets and not diff_by_va then
4734 t1.rets.typename = "tuple"
4735 t2.rets.typename = "tuple"
4736 table.insert(all_errs, error_in_type(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets))
4737 else
4738 local nrets = #t2.rets
4739 if diff_by_va then
4740 nrets = nrets - 1
4741 end
4742 for i = 1, nrets do
4743 local ok, errs = is_a(t1.rets[i], t2.rets[i])
4744 add_errs_prefixing(errs, all_errs, "return " .. i .. ": ")
4745 end
4746 end
4747 if #all_errs == 0 then
4748 return true
4749 else
4750 return false, all_errs
4751 end
4752 elseif lax and ((not for_equality) and t2.typename == "boolean") then
4753
4754 return true
4755 elseif t1.typename == t2.typename then
4756 return true
4757 end
4758
4759 return false, terr(t1, "got %s, expected %s", t1, t2)
4760 end
4761
4762 local function assert_is_a(node, t1, t2, context, name)
4763 t1 = resolve_tuple(t1)
4764 t2 = resolve_tuple(t2)
4765 if lax and (is_unknown(t1) or is_unknown(t2)) then
4766 return
4767 end
4768
4769 if t2.typename == "unknown_emptytable_value" then
4770 if same_type(t2.emptytable_type.keys, NUMBER) then
4771 infer_var(t2.emptytable_type, a_type({ typename = "array", elements = t1 }), node)
4772 else
4773 infer_var(t2.emptytable_type, a_type({ typename = "map", keys = t2.emptytable_type.keys, values = t1 }), node)
4774 end
4775 return
4776 elseif t2.typename == "emptytable" then
4777 if is_known_table_type(t1) then
4778 infer_var(t2, t1, node)
4779 elseif t1.typename ~= "emptytable" then
4780 node_error(node, "in " .. context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1)
4781 end
4782 return
4783 end
4784
4785 local match, match_errs = is_a(t1, t2)
4786 add_errs_prefixing(match_errs, errors, "in " .. context .. ": " .. (name and (name .. ": ") or ""), node)
4787 end
4788
4789 local function close_types(vars)
4790 for name, var in pairs(vars) do
4791 if var.t.typename == "typetype" then
4792 var.t.closed = true
4793 end
4794 end
4795 end
4796
4797 local function begin_scope()
4798 table.insert(st, {})
4799 end
4800
4801 local function end_scope()
4802 local unresolved = st[#st]["@unresolved"]
4803 if unresolved then
4804 local upper = st[#st - 1]["@unresolved"]
4805 if upper then
4806 for name, nodes in pairs(unresolved.t.labels) do
4807 for _, node in ipairs(nodes) do
4808 upper.t.labels[name] = upper.t.labels[name] or {}
4809 table.insert(upper.t.labels[name], node)
4810 end
4811 end
4812 for name, types in pairs(unresolved.t.nominals) do
4813 for _, typ in ipairs(types) do
4814 upper.t.nominals[name] = upper.t.nominals[name] or {}
4815 table.insert(upper.t.nominals[name], typ)
4816 end
4817 end
4818 else
4819 st[#st - 1]["@unresolved"] = unresolved
4820 end
4821 end
4822 close_types(st[#st])
4823 table.remove(st)
4824 end
4825
4826 local type_check_function_call
4827 do
4828 local function try_match_func_args(node, f, args, is_method, argdelta)
4829 local ok = true
4830 local errs = {}
4831
4832 if is_method then
4833 argdelta = -1
4834 elseif not argdelta then
4835 argdelta = 0
4836 end
4837
4838 if f.is_method and not is_method and not (args[1] and is_a(args[1], f.args[1])) then
4839 table.insert(errs, { y = node.y, x = node.x, msg = "invoked method as a regular function: use ':' instead of '.'", filename = filename })
4840 return nil, errs
4841 end
4842
4843 local va = is_vararg(f)
4844 local nargs = va and
4845 math.max(#args, #f.args) or
4846 math.min(#args, #f.args)
4847
4848 for a = 1, nargs do
4849 local arg = args[a]
4850 local farg = f.args[a] or (va and f.args[#f.args])
4851 if arg == nil then
4852 if farg.is_va then
4853 break
4854 end
4855 else
4856 local at = node.e2 and node.e2[a] or node
4857 if not arg_check(is_a, arg, farg, at, (a + argdelta), errs) then
4858 ok = false
4859 break
4860 end
4861 end
4862 end
4863 if ok == true then
4864 f.rets.typename = "tuple"
4865
4866
4867 for a = 1, #args do
4868 local arg = args[a]
4869 local farg = f.args[a] or (va and f.args[#f.args])
4870 if arg.typename == "emptytable" then
4871 infer_var(arg, resolve_typevars(farg), node.e2[a])
4872 end
4873 end
4874
4875 return resolve_typevars(f.rets)
4876 end
4877 return nil, errs
4878 end
4879
4880 local function revert_typeargs(func)
4881 if func.typeargs then
4882 for _, arg in ipairs(func.typeargs) do
4883 if st[#st][arg.typearg] then
4884 st[#st][arg.typearg] = nil
4885 end
4886 end
4887 end
4888 end
4889
4890 local function remove_sorted_duplicates(t)
4891 local prev = nil
4892 for i = #t, 1, -1 do
4893 if t[i] == prev then
4894 table.remove(t, i)
4895 else
4896 prev = t[i]
4897 end
4898 end
4899 end
4900
4901 local function check_call(node, func, args, is_method, argdelta)
4902 assert(type(func) == "table")
4903 assert(type(args) == "table")
4904
4905 if lax and is_unknown(func) then
4906 func = a_type({ typename = "function", args = { VARARG_UNKNOWN }, rets = { VARARG_UNKNOWN } })
4907 end
4908
4909 func = resolve_unary(func)
4910
4911 args = args or {}
4912 local poly = func.typename == "poly" and func or { types = { func } }
4913 local first_errs
4914 local expects = {}
4915
4916 local tried = {}
4917 for i, f in ipairs(poly.types) do
4918 if not tried[i] then
4919 if f.typename ~= "function" then
4920 if lax and is_unknown(f) then
4921 return UNKNOWN
4922 end
4923 return node_error(node, "not a function: %s", f)
4924 end
4925 table.insert(expects, tostring(#f.args or 0))
4926 local va = is_vararg(f)
4927 if #args == (#f.args or 0) or (va and #args > #f.args) then
4928 tried[i] = true
4929 local matched, errs = try_match_func_args(node, f, args, is_method, argdelta)
4930 if matched then
4931 return matched
4932 else
4933 revert_typeargs(f)
4934 end
4935 first_errs = first_errs or errs
4936 end
4937 end
4938 end
4939
4940 for i, f in ipairs(poly.types) do
4941 if not tried[i] then
4942 tried[i] = true
4943 if #args < (#f.args or 0) then
4944 tried[i] = true
4945 local matched, errs = try_match_func_args(node, f, args, is_method, argdelta)
4946 if matched then
4947 return matched
4948 else
4949 revert_typeargs(f)
4950 end
4951 first_errs = first_errs or errs
4952 end
4953 end
4954 end
4955
4956 for i, f in ipairs(poly.types) do
4957 if not tried[i] then
4958 if is_vararg(f) and #args > (#f.args or 0) then
4959 tried[i] = true
4960 local matched, errs = try_match_func_args(node, f, args, is_method, argdelta)
4961 if matched then
4962 return matched
4963 else
4964 revert_typeargs(f)
4965 end
4966 first_errs = first_errs or errs
4967 end
4968 end
4969 end
4970
4971 if not first_errs then
4972 table.sort(expects)
4973 remove_sorted_duplicates(expects)
4974 node_error(node, "wrong number of arguments (given " .. #args .. ", expects " .. table.concat(expects, " or ") .. ")")
4975 else
4976 for _, err in ipairs(first_errs) do
4977 table.insert(errors, err)
4978 end
4979 end
4980
4981 poly.types[1].rets.typename = "tuple"
4982 return resolve_typevars(poly.types[1].rets)
4983 end
4984
4985 type_check_function_call = function(node, func, args, is_method, argdelta)
4986 begin_scope()
4987 local ret = check_call(node, func, args, is_method, argdelta)
4988 end_scope()
4989 return ret
4990 end
4991 end
4992
4993 local unknown_dots = {}
4994
4995 local function add_unknown_dot(node, name)
4996 if not unknown_dots[name] then
4997 unknown_dots[name] = true
4998 add_unknown(node, name)
4999 end
5000 end
5001
5002 local function get_self_type(t)
5003 if t.typename == "typetype" then
5004 return t.def
5005 else
5006 return t
5007 end
5008 end
5009
5010 local function match_record_key(node, tbl, key, orig_tbl)
5011 assert(type(tbl) == "table")
5012 assert(type(key) == "table")
5013
5014 tbl = resolve_unary(tbl)
5015 local type_description = tbl.typename
5016 if tbl.typename == "string" or tbl.typename == "enum" then
5017 tbl = find_var("string")
5018 end
5019
5020 if lax and (is_unknown(tbl) or tbl.typename == "typevar") then
5021 if node.e1.kind == "variable" and node.op.op ~= "@funcall" then
5022 add_unknown_dot(node, node.e1.tk .. "." .. key.tk)
5023 end
5024 return UNKNOWN
5025 end
5026
5027 tbl = get_self_type(tbl)
5028
5029 if tbl.typename == "emptytable" then
5030 elseif is_record_type(tbl) then
5031 assert(tbl.fields, "record has no fields!?")
5032
5033 if key.kind == "string" or key.kind == "identifier" then
5034 if tbl.fields[key.tk] then
5035 return tbl.fields[key.tk]
5036 end
5037 end
5038 else
5039 if is_unknown(tbl) then
5040 if not lax then
5041 node_error(node, "cannot index a value of unknown type")
5042 end
5043 else
5044 node_error(node, "cannot index something that is not a record: %s", tbl)
5045 end
5046 return INVALID
5047 end
5048
5049 if lax then
5050 if node.e1.kind == "variable" and node.op.op ~= "@funcall" then
5051 add_unknown_dot(node, node.e1.tk .. "." .. key.tk)
5052 end
5053 return UNKNOWN
5054 end
5055
5056 local description
5057 if node.e1.kind == "variable" then
5058 description = type_description .. " '" .. node.e1.tk .. "' of type " .. show_type(resolve_tuple(orig_tbl))
5059 else
5060 description = "type " .. show_type(resolve_tuple(orig_tbl))
5061 end
5062
5063 return node_error(key, "invalid key '" .. key.tk .. "' in " .. description)
5064 end
5065
5066 local function widen_in_scope(scope, var)
5067 if scope[var].is_narrowed then
5068 if scope[var].narrowed_from then
5069 scope[var].t = scope[var].narrowed_from
5070 scope[var].narrowed_from = nil
5071 scope[var].is_narrowed = false
5072 else
5073 scope[var] = nil
5074 end
5075 return true
5076 end
5077 return false
5078 end
5079
5080 local function widen_back_var(var)
5081 local widened = false
5082 for i = #st, 1, -1 do
5083 if st[i][var] then
5084 if widen_in_scope(st[i], var) then
5085 widened = true
5086 else
5087 break
5088 end
5089 end
5090 end
5091 return widened
5092 end
5093
5094 local function widen_all_unions()
5095 for i = #st, 1, -1 do
5096 for var, _ in pairs(st[i]) do
5097 widen_in_scope(st[i], var)
5098 end
5099 end
5100 end
5101
5102 local function add_global(node, var, valtype, is_const)
5103 if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then
5104 add_unknown(node, var)
5105 end
5106 st[1][var] = { t = valtype, is_const = is_const }
5107 end
5108
5109 local check_typevars
5110
5111 local function check_all_typevars(node, ts)
5112 if ts ~= nil then
5113 for _, arg in ipairs(ts) do
5114 check_typevars(node, arg)
5115 end
5116 end
5117 end
5118
5119 check_typevars = function(node, t)
5120 if t == nil then
5121 return
5122 end
5123 if t.typename == "typevar" then
5124 if not find_var(t.typevar) then
5125 node_error(node, "unknown type variable " .. t.typevar)
5126 end
5127 return
5128 end
5129 check_typevars(node, t.elements)
5130 check_typevars(node, t.keys)
5131 check_typevars(node, t.values)
5132 check_all_typevars(node, t.typeargs)
5133 check_all_typevars(node, t.args)
5134 check_all_typevars(node, t.rets)
5135 end
5136
5137 local function get_rets(rets)
5138 if lax and (#rets == 0) then
5139 return { a_type({ typename = "unknown", is_va = true }) }
5140 end
5141 return rets
5142 end
5143
5144 local function begin_function_scope(node, recurse)
5145 begin_scope()
5146 local args = {}
5147 if node.typeargs then
5148 for i, arg in ipairs(node.typeargs) do
5149 add_var(nil, arg.typearg, arg)
5150 end
5151 end
5152 local is_va = false
5153 for i, arg in ipairs(node.args) do
5154 local t = arg.decltype
5155 if not t then
5156 t = a_type({ typename = "unknown" })
5157 end
5158 if arg.tk == "..." then
5159 is_va = true
5160 t.is_va = true
5161 if i ~= #node.args then
5162 node_error(node, "'...' can only be last argument")
5163 end
5164 end
5165 check_typevars(arg, t)
5166 table.insert(args, t)
5167 add_var(arg, arg.tk, t)
5168 end
5169
5170 add_var(nil, "@is_va", is_va and VARARG_ANY or NIL)
5171
5172 add_var(nil, "@return", node.rets or a_type({ typename = "tuple" }))
5173 if recurse then
5174 add_var(nil, node.name.tk, a_type({
5175 typename = "function",
5176 args = args,
5177 rets = get_rets(node.rets),
5178 }))
5179 end
5180 end
5181
5182 local function fail_unresolved()
5183 local unresolved = st[#st]["@unresolved"]
5184 if unresolved then
5185 st[#st]["@unresolved"] = nil
5186 for name, nodes in pairs(unresolved.t.labels) do
5187 for _, node in ipairs(nodes) do
5188 node_error(node, "no visible label '" .. name .. "' for goto")
5189 end
5190 end
5191 for name, types in pairs(unresolved.t.nominals) do
5192 for _, typ in ipairs(types) do
5193 assert(typ.x)
5194 assert(typ.y)
5195 type_error(typ, "unknown type %s", typ)
5196 end
5197 end
5198 end
5199 end
5200
5201 local function end_function_scope()
5202 fail_unresolved()
5203 end_scope()
5204 end
5205
5206 local function match_typevals(t, def)
5207 if t.typevals and def.typeargs then
5208 if #t.typevals ~= #def.typeargs then
5209 type_error(t, "mismatch in number of type arguments")
5210 return nil
5211 end
5212
5213 begin_scope()
5214 for i, tt in ipairs(t.typevals) do
5215 add_var(nil, def.typeargs[i].typearg, tt)
5216 end
5217 local ret = resolve_typevars(def)
5218 end_scope()
5219 return ret
5220 elseif t.typevals then
5221 type_error(t, "spurious type arguments")
5222 return nil
5223 elseif def.typeargs then
5224 type_error(t, "missing type arguments in %s", def)
5225 return nil
5226 else
5227 return def
5228 end
5229 end
5230
5231 local function resolve_nominal(t)
5232 if t.resolved then
5233 return t.resolved
5234 end
5235
5236 local resolved
5237
5238 local typetype = t.found or find_type(t.names)
5239 if not typetype then
5240 type_error(t, "unknown type %s", t)
5241 elseif is_type(typetype) then
5242 resolved = match_typevals(t, typetype.def)
5243 else
5244 type_error(t, table.concat(t.names, ".") .. " is not a type")
5245 end
5246
5247 if not resolved then
5248 resolved = a_type({ typename = "bad_nominal", names = t.names })
5249 end
5250
5251 t.found = typetype
5252 t.resolved = resolved
5253 return resolved
5254 end
5255
5256 resolve_unary = function(t)
5257 t = resolve_tuple(t)
5258 if t.typename == "nominal" then
5259 return resolve_nominal(t)
5260 end
5261 return t
5262 end
5263
5264 local function flatten_list(list)
5265 local exps = {}
5266 for i = 1, #list - 1 do
5267 table.insert(exps, resolve_unary(list[i]))
5268 end
5269 if #list > 0 then
5270 local last = list[#list]
5271 if last.typename == "tuple" then
5272 for _, val in ipairs(last) do
5273 table.insert(exps, val)
5274 end
5275 else
5276 table.insert(exps, last)
5277 end
5278 end
5279 return exps
5280 end
5281
5282 local function get_assignment_values(vals, wanted)
5283 local ret = {}
5284 if vals == nil then
5285 return ret
5286 end
5287
5288 for i = 1, #vals - 1 do
5289 ret[i] = vals[i]
5290 end
5291 local last = vals[#vals]
5292
5293 if last.typename == "tuple" then
5294 for _, v in ipairs(last) do
5295 table.insert(ret, v)
5296 end
5297
5298 elseif last.is_va and #ret < wanted then
5299 while #ret < wanted do
5300 table.insert(ret, last)
5301 end
5302
5303 else
5304 table.insert(ret, last)
5305 end
5306 return ret
5307 end
5308
5309 local function match_all_record_field_names(node, a, field_names, errmsg)
5310 local t
5311 for _, k in ipairs(field_names) do
5312 local f = a.fields[k]
5313 if not t then
5314 t = f
5315 else
5316 if not same_type(f, t) then
5317 t = nil
5318 break
5319 end
5320 end
5321 end
5322 if t then
5323 return t
5324 else
5325 return node_error(node, errmsg)
5326 end
5327 end
5328
5329 local function type_check_index(node, idxnode, a, b)
5330 local orig_a = a
5331 local orig_b = b
5332 a = resolve_unary(a)
5333 b = resolve_unary(b)
5334
5335 if is_array_type(a) and is_a(b, NUMBER) then
5336 return a.elements
5337 elseif a.typename == "emptytable" then
5338 if a.keys == nil then
5339 a.keys = b
5340 a.keys_inferred_at = node
5341 a.keys_inferred_at_file = filename
5342 else
5343 if not is_a(b, a.keys) then
5344 local inferred = " (type of keys inferred at " .. a.keys_inferred_at_file .. ":" .. a.keys_inferred_at.y .. ":" .. a.keys_inferred_at.x .. ": )"
5345 return node_error(idxnode, "inconsistent index type: %s, expected %s" .. inferred, b, a.keys)
5346 end
5347 end
5348 return a_type({ y = node.y, x = node.x, typename = "unknown_emptytable_value", emptytable_type = a })
5349 elseif a.typename == "map" then
5350 if is_a(b, a.keys) then
5351 return a.values
5352 else
5353 return node_error(idxnode, "wrong index type: %s, expected %s", orig_b, a.keys)
5354 end
5355 elseif node.e2.kind == "string" or node.e2.kind == "enum_item" then
5356 return match_record_key(node, a, { y = node.e2.y, x = node.e2.x, kind = "string", tk = assert(node.e2.conststr) }, orig_a)
5357 elseif is_record_type(a) and b.typename == "enum" then
5358 local field_names = {}
5359 for k, _ in pairs(b.enumset) do
5360 table.insert(field_names, k)
5361 end
5362 table.sort(field_names)
5363 for _, k in ipairs(field_names) do
5364 if not a.fields[k] then
5365 return node_error(idxnode, "enum value '" .. k .. "' is not a field in %s", a)
5366 end
5367 end
5368 return match_all_record_field_names(idxnode, a, field_names,
5369"cannot index, not all enum values map to record fields of the same type")
5370 elseif lax and is_unknown(a) then
5371 return UNKNOWN
5372 else
5373 if is_a(b, STRING) then
5374 return node_error(idxnode, "cannot index object of type %s with a string, consider using an enum", orig_a)
5375 end
5376 return node_error(idxnode, "cannot index object of type %s with %s", orig_a, orig_b)
5377 end
5378 end
5379
5380 local function expand_type(where, old, new)
5381 if not old then
5382 return new
5383 else
5384 if not is_a(new, old) then
5385 if old.typename == "map" and is_record_type(new) then
5386 if old.keys.typename == "string" then
5387 for _, ftype in pairs(new.fields) do
5388 old.values = expand_type(where, old.values, ftype)
5389 end
5390 else
5391 node_error(where, "cannot determine table literal type")
5392 end
5393 elseif is_record_type(old) and is_record_type(new) then
5394 old.typename = "map"
5395 old.keys = STRING
5396 for _, ftype in pairs(old.fields) do
5397 if not old.values then
5398 old.values = ftype
5399 else
5400 old.values = expand_type(where, old.values, ftype)
5401 end
5402 end
5403 for _, ftype in pairs(new.fields) do
5404 if not old.values then
5405 new.values = ftype
5406 else
5407 new.values = expand_type(where, old.values, ftype)
5408 end
5409 end
5410 old.fields = nil
5411 old.field_order = nil
5412 elseif old.typename == "union" then
5413 new.tk = nil
5414 table.insert(old.types, new)
5415 else
5416 old.tk = nil
5417 new.tk = nil
5418 return a_union({ old, new })
5419 end
5420 end
5421 end
5422 return old
5423 end
5424
5425 local function find_in_scope(exp)
5426 if exp.kind == "variable" then
5427 local t = find_var(exp.tk)
5428 if t.def then
5429 if not t.def.closed and not t.closed then
5430 return t.def
5431 end
5432 end
5433 if not t.closed then
5434 return t
5435 end
5436 elseif exp.kind == "op" and exp.op.op == "." then
5437 local t = find_in_scope(exp.e1)
5438 if not t then
5439 return nil
5440 end
5441 while exp.e2.kind == "op" and exp.e2.op.op == "." do
5442 t = t.fields[exp.e2.e1.tk]
5443 if not t then
5444 return nil
5445 end
5446 exp = exp.e2
5447 end
5448 t = t.fields[exp.e2.tk]
5449 return t
5450 end
5451 end
5452
5453 local facts_and
5454 local facts_or
5455 local facts_not
5456 do
5457 local function join_facts(fss)
5458 local vars = {}
5459
5460 for _, fs in ipairs(fss) do
5461 for _, f in ipairs(fs) do
5462 if not vars[f.var] then
5463 vars[f.var] = {}
5464 end
5465 table.insert(vars[f.var], f)
5466 end
5467 end
5468 return vars
5469 end
5470
5471 local function intersect(xs, ys, same)
5472 local rs = {}
5473 for i = #xs, 1, -1 do
5474 local x = xs[i]
5475 for _, y in ipairs(ys) do
5476 if same(x, y) then
5477 table.insert(rs, x)
5478 break
5479 end
5480 end
5481 end
5482 return rs
5483 end
5484
5485 local function same_type_for_intersect(t, u)
5486 return (same_type(t, u))
5487 end
5488
5489 local function intersect_facts(fs, errnode)
5490 local all_is = true
5491 local types = {}
5492 for i, f in ipairs(fs) do
5493 if f.fact ~= "is" then
5494 all_is = false
5495 break
5496 end
5497 if f.typ.typename == "union" then
5498 if i == 1 then
5499 types = f.typ.types
5500 else
5501 types = intersect(types, f.typ.types, same_type_for_intersect)
5502 end
5503 else
5504 if i == 1 then
5505 types = { f.typ }
5506 else
5507 types = intersect(types, { f.typ }, same_type_for_intersect)
5508 end
5509 end
5510 end
5511
5512 if #types == 0 then
5513 node_error(errnode, "branch is always false")
5514 return false
5515 end
5516
5517 if all_is then
5518 if #types == 1 then
5519 return true, types[1]
5520 else
5521 return true, a_union(types)
5522 end
5523 else
5524 return false
5525 end
5526 end
5527
5528 local function sum_facts(fs)
5529 local all_is = true
5530 local types = {}
5531 for _, f in ipairs(fs) do
5532 if f.fact ~= "is" then
5533 all_is = false
5534 break
5535 end
5536 table.insert(types, f.typ)
5537 end
5538
5539 if all_is then
5540 if #types == 1 then
5541 return true, types[1]
5542 else
5543 return true, a_union(types)
5544 end
5545 else
5546 return false
5547 end
5548 end
5549
5550 local function subtract_types(u1, u2, errt)
5551 local types = {}
5552 for _, rt in ipairs(u1.types or { u1 }) do
5553 local not_present = true
5554 for _, ft in ipairs(u2.types or { u2 }) do
5555 if same_type(rt, ft) then
5556 not_present = false
5557 break
5558 end
5559 end
5560 if not_present then
5561 table.insert(types, rt)
5562 end
5563 end
5564
5565 if #types == 0 then
5566 type_error(errt, "branch is always false")
5567 return INVALID
5568 end
5569
5570 if #types == 1 then
5571 return types[1]
5572 else
5573 return a_union(types)
5574 end
5575 end
5576
5577 facts_and = function(f1, f2, errnode)
5578 if not f1 then
5579 return f2
5580 end
5581 if not f2 then
5582 return f1
5583 end
5584
5585 local out = {}
5586 for v, fs in pairs(join_facts({ f1, f2 })) do
5587 local ok, u = intersect_facts(fs, errnode)
5588
5589 if ok then
5590 table.insert(out, { fact = "is", var = v, typ = u })
5591 else
5592
5593 for _, f in ipairs(fs) do
5594 table.insert(out, f)
5595 end
5596 end
5597 end
5598 return out
5599 end
5600
5601 facts_or = function(f1, f2)
5602 if not f1 or not f2 then
5603 return nil
5604 end
5605
5606 local out = {}
5607 for v, fs in pairs(join_facts({ f1, f2 })) do
5608 local ok, u = sum_facts(fs)
5609 if ok then
5610 table.insert(out, { fact = "is", var = v, typ = u })
5611 else
5612
5613 for _, f in ipairs(fs) do
5614 table.insert(out, f)
5615 end
5616 end
5617 end
5618 return out
5619 end
5620
5621 facts_not = function(f1)
5622 if not f1 then
5623 return nil
5624 end
5625
5626 local out = {}
5627 for v, fs in pairs(join_facts({ f1 })) do
5628 local realtype = find_var(v)
5629 if realtype then
5630 local ok, u = sum_facts(fs)
5631 if ok then
5632 local not_typ = subtract_types(realtype, u, fs[1].typ)
5633 table.insert(out, { fact = "is", var = v, typ = not_typ })
5634 end
5635 end
5636 end
5637 return out
5638 end
5639 end
5640
5641 local function apply_facts(where, facts)
5642 if not facts then
5643 return
5644 end
5645 for _, f in ipairs(facts) do
5646 if f.fact == "is" then
5647 local t = resolve_typevars(f.typ)
5648 t.inferred_at = where
5649 t.inferred_at_file = filename
5650 add_var(nil, f.var, t, nil, true)
5651 end
5652 end
5653 end
5654
5655 local function dismiss_unresolved(name)
5656 local unresolved = st[#st]["@unresolved"]
5657 if unresolved then
5658 if unresolved.t.nominals[name] then
5659 for _, t in ipairs(unresolved.t.nominals[name]) do
5660 resolve_nominal(t)
5661 end
5662 end
5663 unresolved.t.nominals[name] = nil
5664 end
5665 end
5666
5667 local function type_check_funcall(node, a, b, argdelta)
5668 argdelta = argdelta or 0
5669 if node.e1.tk == "rawget" then
5670 if #b == 2 then
5671 local b1 = resolve_unary(b[1])
5672 local b2 = resolve_unary(b[2])
5673 local knode = node.e2[2]
5674 if is_record_type(b1) and knode.conststr then
5675 return match_record_key(node, b1, { y = knode.y, x = knode.x, kind = "string", tk = assert(knode.conststr) }, b1)
5676 else
5677 return type_check_index(node, knode, b1, b2)
5678 end
5679 else
5680 node_error(node, "rawget expects two arguments")
5681 end
5682 elseif node.e1.tk == "print_type" then
5683 print(show_type(b))
5684 return BOOLEAN
5685 elseif node.e1.tk == "require" then
5686 if #b == 1 then
5687 if node.e2[1].kind == "string" then
5688 local module_name = assert(node.e2[1].conststr)
5689 local t, found = require_module(module_name, lax, opts.env, result)
5690 if not found then
5691 node_error(node, "module not found: '" .. module_name .. "'")
5692 elseif not lax and is_unknown(t) then
5693 node_error(node, "no type information for required module: '" .. module_name .. "'")
5694 end
5695 return t
5696 else
5697 node_error(node, "don't know how to resolve a dynamic require")
5698 end
5699 else
5700 node_error(node, "require expects one literal argument")
5701 end
5702 elseif node.e1.tk == "pcall" then
5703 local ftype = table.remove(b, 1)
5704 local fe2 = {}
5705 for i = 2, #node.e2 do
5706 table.insert(fe2, node.e2[i])
5707 end
5708 local fnode = {
5709 y = node.y,
5710 x = node.x,
5711 typename = "op",
5712 op = { op = "@funcall" },
5713 e1 = node.e2[1],
5714 e2 = fe2,
5715 }
5716 local rets = type_check_funcall(fnode, ftype, b, argdelta + 1)
5717 if rets.typename ~= "tuple" then
5718 rets = a_type({ typename = "tuple", rets })
5719 end
5720 table.insert(rets, 1, BOOLEAN)
5721 return rets
5722 elseif node.e1.op and node.e1.op.op == ":" then
5723 local func = node.e1.type
5724 if func.typename == "function" or func.typename == "poly" then
5725 table.insert(b, 1, node.e1.e1.type)
5726 return type_check_function_call(node, func, b, true)
5727 else
5728 if lax and (is_unknown(func)) then
5729 if node.e1.e1.kind == "variable" then
5730 add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk)
5731 end
5732 return VARARG_UNKNOWN
5733 else
5734 return INVALID
5735 end
5736 end
5737 else
5738 return type_check_function_call(node, a, b, false, argdelta)
5739 end
5740 return UNKNOWN
5741 end
5742
5743 local visit_node = {}
5744
5745 visit_node.cbs = {
5746 ["statements"] = {
5747 before = function()
5748 begin_scope()
5749 end,
5750 after = function(node, children)
5751
5752 if #st == 2 then
5753 fail_unresolved()
5754 end
5755
5756 if not node.is_repeat then
5757 end_scope()
5758 end
5759
5760 node.type = NONE
5761 end,
5762 },
5763 ["local_type"] = {
5764 before = function(node)
5765 add_var(node.var, node.var.tk, node.value.newtype, node.var.is_const)
5766 end,
5767 after = function(node, children)
5768 dismiss_unresolved(node.var.tk)
5769 node.type = NONE
5770 end,
5771 },
5772 ["global_type"] = {
5773 before = function(node)
5774 add_global(node.var, node.var.tk, node.value.newtype, node.var.is_const)
5775 end,
5776 after = function(node, children)
5777 local existing, existing_is_const = find_global(node.var.tk)
5778 local var = node.var
5779 if existing then
5780 if existing_is_const == true and not var.is_const then
5781 node_error(var, "global was previously declared as <const>: " .. var.tk)
5782 end
5783 if existing_is_const == false and var.is_const then
5784 node_error(var, "global was previously declared as not <const>: " .. var.tk)
5785 end
5786 if not same_type(existing, node.value.newtype) then
5787 node_error(var, "cannot redeclare global with a different type: previous type of " .. var.tk .. " is %s", existing)
5788 end
5789 end
5790 dismiss_unresolved(var.tk)
5791 node.type = NONE
5792 end,
5793 },
5794 ["local_declaration"] = {
5795 after = function(node, children)
5796 local vals = get_assignment_values(children[2], #node.vars)
5797 for i, var in ipairs(node.vars) do
5798 local decltype = node.decltype and node.decltype[i]
5799 local infertype = vals and vals[i]
5800 if lax and infertype and infertype.typename == "nil" then
5801 infertype = nil
5802 end
5803 if decltype and infertype then
5804 assert_is_a(node.vars[i], infertype, decltype, "local declaration", var.tk)
5805 end
5806 local t = decltype or infertype
5807 if t == nil then
5808 t = a_type({ typename = "unknown" })
5809 if not lax then
5810 if node.exps then
5811 node_error(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. var.tk .. "'")
5812 else
5813 node_error(node.vars[i], "variable '" .. var.tk .. "' has no type or initial value")
5814 end
5815 end
5816 elseif t.typename == "emptytable" then
5817 t.declared_at = node
5818 t.assigned_to = var.tk
5819 end
5820 assert(var)
5821 add_var(var, var.tk, t, var.is_const)
5822
5823 dismiss_unresolved(var.tk)
5824 end
5825 node.type = NONE
5826 end,
5827 },
5828 ["global_declaration"] = {
5829 after = function(node, children)
5830 local vals = get_assignment_values(children[2], #node.vars)
5831 for i, var in ipairs(node.vars) do
5832 local decltype = node.decltype and node.decltype[i]
5833 local infertype = vals and vals[i]
5834 if lax and infertype and infertype.typename == "nil" then
5835 infertype = nil
5836 end
5837 if decltype and infertype then
5838 assert_is_a(node.vars[i], infertype, decltype, "global declaration", var.tk)
5839 end
5840 local t = decltype or infertype
5841 local existing, existing_is_const = find_global(var.tk)
5842 if existing then
5843 if infertype and existing_is_const then
5844 node_error(var, "cannot reassign to <const> global: " .. var.tk)
5845 end
5846 if existing_is_const == true and not var.is_const then
5847 node_error(var, "global was previously declared as <const>: " .. var.tk)
5848 end
5849 if existing_is_const == false and var.is_const then
5850 node_error(var, "global was previously declared as not <const>: " .. var.tk)
5851 end
5852 if not same_type(existing, t) then
5853 node_error(var, "cannot redeclare global with a different type: previous type of " .. var.tk .. " is %s", existing)
5854 end
5855 else
5856 if t == nil then
5857 t = a_type({ typename = "unknown" })
5858 elseif t.typename == "emptytable" then
5859 t.declared_at = node
5860 t.assigned_to = var.tk
5861 end
5862 add_global(var, var.tk, t, var.is_const)
5863
5864 dismiss_unresolved(var.tk)
5865 end
5866 end
5867 node.type = NONE
5868 end,
5869 },
5870 ["assignment"] = {
5871 after = function(node, children)
5872 local vals = get_assignment_values(children[2], #children[1])
5873 local exps = flatten_list(vals)
5874 for i, vartype in ipairs(children[1]) do
5875 local varnode = node.vars[i]
5876 if varnode.is_const then
5877 node_error(varnode, "cannot assign to <const> variable")
5878 end
5879 if varnode.kind == "variable" then
5880 if widen_back_var(varnode.tk) then
5881 vartype = find_var(varnode.tk)
5882 end
5883 end
5884 if vartype then
5885 local val = exps[i]
5886 if resolve_unary(vartype).typename == "typetype" then
5887 node_error(varnode, "cannot reassign a type")
5888 elseif val then
5889 assert_is_a(varnode, val, vartype, "assignment")
5890 if varnode.kind == "variable" and vartype.typename == "union" then
5891
5892 add_var(varnode, varnode.tk, val, false, true)
5893 end
5894 else
5895 node_error(varnode, "variable is not being assigned a value")
5896 end
5897 else
5898 node_error(varnode, "unknown variable")
5899 end
5900 end
5901 node.type = NONE
5902 end,
5903 },
5904 ["do"] = {
5905 after = function(node, children)
5906 node.type = NONE
5907 end,
5908 },
5909 ["if"] = {
5910 before_statements = function(node)
5911 begin_scope()
5912 apply_facts(node.exp, node.exp.facts)
5913 end,
5914 after = function(node, children)
5915 end_scope()
5916 node.type = NONE
5917 end,
5918 },
5919 ["elseif"] = {
5920 before = function(node)
5921 end_scope()
5922 begin_scope()
5923 end,
5924 before_statements = function(node)
5925 local f = facts_not(node.parent_if.exp.facts)
5926 for e = 1, node.elseif_n - 1 do
5927 f = facts_and(f, facts_not(node.parent_if.elseifs[e].exp.facts), node)
5928 end
5929 f = facts_and(f, node.exp.facts, node)
5930 apply_facts(node.exp, f)
5931 end,
5932 after = function(node, children)
5933 node.type = NONE
5934 end,
5935 },
5936 ["else"] = {
5937 before = function(node)
5938 end_scope()
5939 begin_scope()
5940 local f = facts_not(node.parent_if.exp.facts)
5941 for _, elseifnode in ipairs(node.parent_if.elseifs) do
5942 f = facts_and(f, facts_not(elseifnode.exp.facts), node)
5943 end
5944 apply_facts(node, f)
5945 end,
5946 after = function(node, children)
5947 node.type = NONE
5948 end,
5949 },
5950 ["while"] = {
5951 before = function()
5952
5953 widen_all_unions()
5954 end,
5955 before_statements = function(node)
5956 begin_scope()
5957 apply_facts(node.exp, node.exp.facts)
5958 end,
5959 after = function(node, children)
5960 end_scope()
5961 node.type = NONE
5962 end,
5963 },
5964 ["label"] = {
5965 before = function(node)
5966
5967 widen_all_unions()
5968 local label_id = "::" .. node.label .. "::"
5969 if st[#st][label_id] then
5970 node_error(node, "label '" .. node.label .. "' already defined at " .. filename)
5971 end
5972 local unresolved = st[#st]["@unresolved"]
5973 if unresolved then
5974 unresolved.t.labels[node.label] = nil
5975 end
5976 node.type = a_type({ y = node.y, x = node.x, typename = "none" })
5977 add_var(node, label_id, node.type)
5978 end,
5979 },
5980 ["goto"] = {
5981 after = function(node, children)
5982 if not find_var("::" .. node.label .. "::") then
5983 local unresolved = st[#st]["@unresolved"] and st[#st]["@unresolved"].t
5984 if not unresolved then
5985 unresolved = { typename = "unresolved", labels = {}, nominals = {} }
5986 add_var(node, "@unresolved", unresolved)
5987 end
5988 unresolved.labels[node.label] = unresolved.labels[node.label] or {}
5989 table.insert(unresolved.labels[node.label], node)
5990 end
5991 node.type = NONE
5992 end,
5993 },
5994 ["repeat"] = {
5995 before = function()
5996
5997 widen_all_unions()
5998 end,
5999 after = function(node, children)
6000
6001 end_scope()
6002 node.type = NONE
6003 end,
6004 },
6005 ["forin"] = {
6006 before = function()
6007 begin_scope()
6008 end,
6009 before_statements = function(node)
6010 local exp1 = node.exps[1]
6011 local exp1type = resolve_tuple(exp1.type)
6012 if exp1type.typename == "function" then
6013
6014 if exp1.op and exp1.op.op == "@funcall" then
6015 local t = resolve_unary(exp1.e2.type)
6016 if exp1.e1.tk == "pairs" and not (t.typename == "map" or t.typename == "record") then
6017 if not (lax and is_unknown(t)) then
6018 node_error(exp1, "attempting pairs loop on something that's not a map or record: %s", exp1.e2.type)
6019 end
6020 elseif exp1.e1.tk == "ipairs" and not is_array_type(t) then
6021 if not (lax and (is_unknown(t) or t.typename == "emptytable")) then
6022 node_error(exp1, "attempting ipairs loop on something that's not an array: %s", exp1.e2.type)
6023 end
6024 end
6025 end
6026 local last
6027 for i, v in ipairs(node.vars) do
6028 local r = exp1type.rets[i]
6029 if not r then
6030 if last and last.is_va then
6031 r = last
6032 else
6033 r = UNKNOWN
6034 end
6035 end
6036 add_var(v, v.tk, r)
6037 last = r
6038 end
6039 else
6040 if not (lax and is_unknown(exp1type)) then
6041 node_error(exp1, "expression in for loop does not return an iterator")
6042 end
6043 end
6044 end,
6045 after = function(node, children)
6046 end_scope()
6047 node.type = NONE
6048 end,
6049 },
6050 ["fornum"] = {
6051 before = function(node)
6052 begin_scope()
6053 add_var(nil, node.var.tk, NUMBER)
6054 end,
6055 after = function(node, children)
6056 end_scope()
6057 node.type = NONE
6058 end,
6059 },
6060 ["return"] = {
6061 after = function(node, children)
6062 local rets = assert(find_var("@return"))
6063 local nrets = #rets
6064 local vatype
6065 if nrets > 0 then
6066 vatype = rets[nrets].is_va and rets[nrets]
6067 end
6068
6069 if #children[1] > nrets and (not lax) and not vatype then
6070 rets.typename = "tuple"
6071 children[1].typename = "tuple"
6072 node_error(node, "excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1])
6073 end
6074
6075 for i = 1, #children[1] do
6076 local expected = rets[i] or vatype
6077 if expected then
6078 expected = resolve_unary(expected)
6079 local where = (node.exps[i] and node.exps[i].x) and
6080 node.exps[i] or
6081 node.exps
6082 assert(where and where.x)
6083 assert_is_a(where, children[1][i], expected, "return value")
6084 end
6085 end
6086
6087
6088 if #st == 2 then
6089 module_type = resolve_unary(children[1])
6090 end
6091
6092 node.type = NONE
6093 end,
6094 },
6095 ["variables"] = {
6096 after = function(node, children)
6097 node.type = children
6098
6099
6100 local n = #children
6101 if n > 0 and children[n].typename == "tuple" then
6102 local tuple = children[n]
6103 for i, c in ipairs(tuple) do
6104 children[n + i - 1] = c
6105 end
6106 end
6107
6108 node.type.typename = "tuple"
6109 end,
6110 },
6111 ["table_literal"] = {
6112 after = function(node, children)
6113 node.type = a_type({
6114 y = node.y,
6115 x = node.x,
6116 typename = "emptytable",
6117 })
6118 local is_record = false
6119 local is_array = false
6120 local is_map = false
6121 for i, child in ipairs(children) do
6122 assert(child.typename == "table_item")
6123 if child.kname then
6124 is_record = true
6125 if not node.type.fields then
6126 node.type.fields = {}
6127 node.type.field_order = {}
6128 end
6129 node.type.fields[child.kname] = child.vtype
6130 table.insert(node.type.field_order, child.kname)
6131 elseif child.ktype.typename == "number" then
6132 is_array = true
6133 if i == #children and node[i].key_parsed == "implicit" and child.vtype.typename == "tuple" then
6134
6135 for _, c in ipairs(child.vtype) do
6136 node.type.elements = expand_type(node, node.type.elements, c)
6137 end
6138 else
6139 node.type.elements = expand_type(node, node.type.elements, child.vtype)
6140 end
6141 if not node.type.elements then
6142 node_error(node, "cannot determine type of array elements")
6143 is_array = false
6144 end
6145 else
6146 is_map = true
6147 node.type.keys = expand_type(node, node.type.keys, child.ktype)
6148 node.type.values = expand_type(node, node.type.values, child.vtype)
6149 end
6150 end
6151 if is_array and is_map then
6152 node_error(node, "cannot determine type of table literal")
6153 elseif is_record and is_array then
6154 node.type.typename = "arrayrecord"
6155 elseif is_record and is_map then
6156 if node.type.keys.typename == "string" then
6157 node.type.typename = "map"
6158 for _, ftype in pairs(node.type.fields) do
6159 node.type.values = expand_type(node, node.type.values, ftype)
6160 end
6161 node.type.fields = nil
6162 node.type.field_order = nil
6163 else
6164 node_error(node, "cannot determine type of table literal")
6165 end
6166 elseif is_array then
6167 node.type.typename = "array"
6168 elseif is_record then
6169 node.type.typename = "record"
6170 elseif is_map then
6171 node.type.typename = "map"
6172 end
6173 end,
6174 },
6175 ["table_item"] = {
6176 after = function(node, children)
6177 local kname = node.key.conststr
6178 local ktype = children[1]
6179 local vtype = children[2]
6180 if node.decltype then
6181 vtype = node.decltype
6182 assert_is_a(node.value, children[2], node.decltype, "table item")
6183 end
6184 node.type = a_type({
6185 y = node.y,
6186 x = node.x,
6187 typename = "table_item",
6188 kname = kname,
6189 ktype = ktype,
6190 vtype = vtype,
6191 })
6192 end,
6193 },
6194 ["local_function"] = {
6195 before = function(node)
6196 begin_function_scope(node, true)
6197 end,
6198 after = function(node, children)
6199 end_function_scope()
6200 local rets = get_rets(children[3])
6201
6202 add_var(nil, node.name.tk, a_type({
6203 typename = "function",
6204 args = children[2],
6205 rets = rets,
6206 }))
6207 node.type = NONE
6208 end,
6209 },
6210 ["global_function"] = {
6211 before = function(node)
6212 begin_function_scope(node, true)
6213 end,
6214 after = function(node, children)
6215 end_function_scope()
6216 add_global(nil, node.name.tk, a_type({
6217 typename = "function",
6218 args = children[2],
6219 rets = get_rets(children[3]),
6220 }))
6221 node.type = NONE
6222 end,
6223 },
6224 ["record_function"] = {
6225 before = function(node)
6226 begin_function_scope(node)
6227 end,
6228 before_statements = function(node, children)
6229 if node.is_method then
6230 local rtype = get_self_type(children[1])
6231 children[3][1] = rtype
6232 add_var(nil, "self", rtype)
6233 end
6234
6235 local rtype = resolve_unary(get_self_type(children[1]))
6236 if rtype.typename == "emptytable" then
6237 rtype.typename = "record"
6238 end
6239 if is_record_type(rtype) then
6240 local fn_type = a_type({
6241 y = node.y,
6242 x = node.x,
6243 typename = "function",
6244 is_method = node.is_method,
6245 args = children[3],
6246 rets = get_rets(children[4]),
6247 })
6248
6249 local ok = false
6250 if lax then
6251 ok = true
6252 elseif rtype.fields and rtype.fields[node.name.tk] and is_a(fn_type, rtype.fields[node.name.tk]) then
6253 ok = true
6254 elseif find_in_scope(node.fn_owner) == rtype then
6255 ok = true
6256 end
6257
6258 if ok then
6259 rtype.fields = rtype.fields or {}
6260 rtype.field_order = rtype.field_order or {}
6261 rtype.fields[node.name.tk] = fn_type
6262 table.insert(rtype.field_order, node.name.tk)
6263 else
6264 local name = tl.pretty_print_ast(node.fn_owner, { preserve_indent = true, preserve_newlines = false })
6265 node_error(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. name .. "' was originally declared")
6266 end
6267 else
6268 if (not lax) or (rtype.typename ~= "unknown") then
6269 node_error(node, "not a module: %s", rtype)
6270 end
6271 end
6272 end,
6273 after = function(node, children)
6274 end_function_scope()
6275
6276 node.type = NONE
6277 end,
6278 },
6279 ["function"] = {
6280 before = function(node)
6281 begin_function_scope(node)
6282 end,
6283 after = function(node, children)
6284 end_function_scope()
6285
6286
6287 node.type = a_type({
6288 y = node.y,
6289 x = node.x,
6290 typename = "function",
6291 args = children[1],
6292 rets = children[2],
6293 })
6294 end,
6295 },
6296 ["cast"] = {
6297 after = function(node, children)
6298 node.type = node.casttype
6299 end,
6300 },
6301 ["paren"] = {
6302 after = function(node, children)
6303 node.type = resolve_unary(children[1])
6304 end,
6305 },
6306 ["op"] = {
6307 before = function(node)
6308 begin_scope()
6309 end,
6310 before_e2 = function(node)
6311 if node.op.op == "and" then
6312 apply_facts(node, node.e1.facts)
6313 elseif node.op.op == "or" then
6314 apply_facts(node, facts_not(node.e1.facts))
6315 end
6316 end,
6317 after = function(node, children)
6318 end_scope()
6319
6320 local a = children[1]
6321 local b = children[3]
6322
6323 local orig_a = a
6324 local orig_b = b
6325 local ua = a and resolve_unary(a)
6326 local ub = b and resolve_unary(b)
6327 if node.op.op == "@funcall" then
6328 node.type = type_check_funcall(node, a, b)
6329 elseif node.op.op == "@index" then
6330 node.type = type_check_index(node, node.e2, a, b)
6331 elseif node.op.op == "as" then
6332 node.type = b
6333 elseif node.op.op == "is" then
6334 if node.e1.kind == "variable" then
6335 node.facts = { { fact = "is", var = node.e1.tk, typ = b } }
6336 else
6337 node_error(node, "can only use 'is' on variables")
6338 end
6339 node.type = BOOLEAN
6340 elseif node.op.op == "." then
6341 a = ua
6342 if a.typename == "map" then
6343 if is_a(a.keys, STRING) or is_a(a.keys, ANY) then
6344 node.type = a.values
6345 else
6346 node_error(node, "cannot use . index, expects keys of type %s", a.keys)
6347 end
6348 else
6349 node.type = match_record_key(node, a, { y = node.e2.y, x = node.e2.x, kind = "string", tk = node.e2.tk }, orig_a)
6350 if node.type.needs_compat53 and not opts.skip_compat53 then
6351 local key = node.e1.tk .. "." .. node.e2.tk
6352 node.kind = "variable"
6353 node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk
6354 all_needs_compat53[key] = true
6355 end
6356 end
6357 elseif node.op.op == ":" then
6358 node.type = match_record_key(node, node.e1.type, node.e2, orig_a)
6359 elseif node.op.op == "not" then
6360 node.facts = facts_not(node.e1.facts)
6361 node.type = BOOLEAN
6362 elseif node.op.op == "and" then
6363 node.facts = facts_and(node.e1.facts, node.e2.facts, node)
6364 node.type = resolve_tuple(b)
6365 elseif node.op.op == "or" and b.typename == "emptytable" then
6366 node.facts = nil
6367 node.type = resolve_tuple(a)
6368 elseif node.op.op == "or" and same_type(ua, ub) then
6369 node.facts = facts_or(node.e1.facts, node.e2.facts)
6370 node.type = resolve_tuple(a)
6371 elseif node.op.op == "or" and b.typename == "nil" then
6372 node.facts = nil
6373 node.type = resolve_tuple(a)
6374 elseif node.op.op == "or" and
6375 ((ua.typename == "enum" and ub.typename == "string" and is_a(ub, ua)) or
6376 (ua.typename == "string" and ub.typename == "enum" and is_a(ua, ub))) then
6377 node.facts = nil
6378 node.type = (ua.typename == "enum" and ua or ub)
6379 elseif node.op.op == "or" and
6380 (a.typename == "nominal" or a.typename == "map") and
6381 is_record_type(b) and
6382 is_a(b, a) then
6383 node.facts = nil
6384 node.type = resolve_tuple(a)
6385 elseif node.op.op == "==" or node.op.op == "~=" then
6386 if is_a(a, b, true) or is_a(b, a, true) then
6387 node.type = BOOLEAN
6388 else
6389 if lax and (is_unknown(a) or is_unknown(b)) then
6390 node.type = UNKNOWN
6391 else
6392 node_error(node, "types are not comparable for equality: %s and %s", a, b)
6393 end
6394 end
6395 elseif node.op.arity == 1 and unop_types[node.op.op] then
6396 a = ua
6397 local types_op = unop_types[node.op.op]
6398 node.type = types_op[a.typename]
6399 if not node.type then
6400 if lax and is_unknown(a) then
6401 node.type = UNKNOWN
6402 else
6403 node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", orig_a)
6404 end
6405 end
6406 elseif node.op.arity == 2 and binop_types[node.op.op] then
6407 if node.op.op == "or" then
6408 node.facts = facts_or(node.e1.facts, node.e2.facts)
6409 end
6410
6411 a = ua
6412 b = ub
6413 local types_op = binop_types[node.op.op]
6414 node.type = types_op[a.typename] and types_op[a.typename][b.typename]
6415 if not node.type then
6416 if lax and (is_unknown(a) or is_unknown(b)) then
6417 node.type = UNKNOWN
6418 else
6419 node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", orig_a, orig_b)
6420 end
6421 end
6422 else
6423 error("unknown node op " .. node.op.op)
6424 end
6425 end,
6426 },
6427 ["variable"] = {
6428 after = function(node, children)
6429 if node.tk == "..." then
6430 local va_sentinel = find_var("@is_va")
6431 if not va_sentinel or va_sentinel.typename == "nil" then
6432 node.type = UNKNOWN
6433 node_error(node, "cannot use '...' outside a vararg function")
6434 end
6435 end
6436
6437 node.type, node.is_const = find_var(node.tk)
6438 if node.type == nil then
6439 node.type = a_type({ typename = "unknown" })
6440 if lax then
6441 add_unknown(node, node.tk)
6442 else
6443 node_error(node, "unknown variable: " .. node.tk)
6444 end
6445 end
6446 end,
6447 },
6448 ["identifier"] = {
6449 after = function(node, children)
6450 node.type = NONE
6451 end,
6452 },
6453 ["newtype"] = {
6454 after = function(node, children)
6455 node.type = node.newtype
6456 end,
6457 },
6458 }
6459
6460 visit_node.cbs["break"] = visit_node.cbs["do"]
6461
6462 visit_node.cbs["values"] = visit_node.cbs["variables"]
6463 visit_node.cbs["expression_list"] = visit_node.cbs["variables"]
6464 visit_node.cbs["argument_list"] = visit_node.cbs["variables"]
6465 visit_node.cbs["argument"] = visit_node.cbs["variable"]
6466
6467 visit_node.cbs["string"] = {
6468 after = function(node, children)
6469 node.type = a_type({
6470 y = node.y,
6471 x = node.x,
6472 typename = node.kind,
6473 tk = node.tk,
6474 })
6475 return node.type
6476 end,
6477 }
6478 visit_node.cbs["number"] = visit_node.cbs["string"]
6479 visit_node.cbs["nil"] = visit_node.cbs["string"]
6480 visit_node.cbs["boolean"] = visit_node.cbs["string"]
6481 visit_node.cbs["..."] = visit_node.cbs["variable"]
6482
6483 visit_node.after = {
6484 after = function(node, children)
6485 assert(type(node.type) == "table", node.kind .. " did not produce a type")
6486 assert(type(node.type.typename) == "string", node.kind .. " type does not have a typename")
6487 return node.type
6488 end,
6489 }
6490
6491 local visit_type = {
6492 cbs = {
6493 ["string"] = {
6494 after = function(typ, children)
6495 return typ
6496 end,
6497 },
6498 ["function"] = {
6499 before = function(typ, children)
6500 begin_scope()
6501 end,
6502 after = function(typ, children)
6503 end_scope()
6504 return typ
6505 end,
6506 },
6507 ["record"] = {
6508 before = function(typ, children)
6509 begin_scope()
6510 for name, typ in pairs(typ.fields) do
6511 if typ.typename == "typetype" then
6512 typ.typename = "nestedtype"
6513 add_var(nil, name, typ)
6514 end
6515 end
6516 end,
6517 after = function(typ, children)
6518 end_scope()
6519 for name, typ in pairs(typ.fields) do
6520 if typ.typename == "nestedtype" then
6521 typ.typename = "typetype"
6522 end
6523 end
6524 return typ
6525 end,
6526 },
6527 ["typearg"] = {
6528 after = function(typ, children)
6529 add_var(nil, typ.typearg, a_type({
6530 y = typ.y,
6531 x = typ.x,
6532 typename = "typearg",
6533 typearg = typ.typearg,
6534 }))
6535 return typ
6536 end,
6537 },
6538 ["nominal"] = {
6539 after = function(typ, children)
6540 local t = find_type(typ.names, true)
6541 if t then
6542 if t.typename == "typearg" then
6543
6544 typ.names = nil
6545 typ.typename = "typevar"
6546 typ.typevar = t.typearg
6547 else
6548 typ.found = t
6549 end
6550 else
6551 local name = typ.names[1]
6552 local unresolved = find_var("@unresolved")
6553 if not unresolved then
6554 unresolved = { typename = "unresolved", labels = {}, nominals = {} }
6555 add_var(nil, "@unresolved", unresolved)
6556 end
6557 unresolved.nominals[name] = unresolved.nominals[name] or {}
6558 table.insert(unresolved.nominals[name], typ)
6559 end
6560 return typ
6561 end,
6562 },
6563 ["union"] = {
6564 after = function(typ, children)
6565
6566
6567 local n_table_types = 0
6568 local n_function_types = 0
6569 local n_string_enum = 0
6570 for _, t in ipairs(typ.types) do
6571 t = resolve_unary(t)
6572 if table_types[t.typename] then
6573 n_table_types = n_table_types + 1
6574 if n_table_types > 1 then
6575 type_error(typ, "cannot discriminate a union between multiple table types: %s", typ)
6576 break
6577 end
6578 elseif t.typename == "function" then
6579 n_function_types = n_function_types + 1
6580 if n_function_types > 1 then
6581 type_error(typ, "cannot discriminate a union between multiple function types: %s", typ)
6582 break
6583 end
6584 elseif t.typename == "string" or t.typename == "enum" then
6585 n_string_enum = n_string_enum + 1
6586 if n_string_enum > 1 then
6587 type_error(typ, "cannot discriminate a union between multiple string/enum types: %s", typ)
6588 break
6589 end
6590 end
6591 end
6592 return typ
6593 end,
6594 },
6595 },
6596 after = {
6597 after = function(typ, children, ret)
6598 assert(type(ret) == "table", typ.typename .. " did not produce a type")
6599 assert(type(ret.typename) == "string", "type node does not have a typename")
6600 return ret
6601 end,
6602 },
6603 }
6604
6605 visit_type.cbs["typetype"] = visit_type.cbs["string"]
6606 visit_type.cbs["nestedtype"] = visit_type.cbs["string"]
6607 visit_type.cbs["typevar"] = visit_type.cbs["string"]
6608 visit_type.cbs["array"] = visit_type.cbs["string"]
6609 visit_type.cbs["map"] = visit_type.cbs["string"]
6610 visit_type.cbs["arrayrecord"] = visit_type.cbs["string"]
6611 visit_type.cbs["enum"] = visit_type.cbs["string"]
6612 visit_type.cbs["boolean"] = visit_type.cbs["string"]
6613 visit_type.cbs["nil"] = visit_type.cbs["string"]
6614 visit_type.cbs["number"] = visit_type.cbs["string"]
6615 visit_type.cbs["thread"] = visit_type.cbs["string"]
6616 visit_type.cbs["bad_nominal"] = visit_type.cbs["string"]
6617 visit_type.cbs["emptytable"] = visit_type.cbs["string"]
6618 visit_type.cbs["table_item"] = visit_type.cbs["string"]
6619 visit_type.cbs["unknown_emptytable_value"] = visit_type.cbs["string"]
6620 visit_type.cbs["tuple"] = visit_type.cbs["string"]
6621 visit_type.cbs["poly"] = visit_type.cbs["string"]
6622 visit_type.cbs["any"] = visit_type.cbs["string"]
6623 visit_type.cbs["unknown"] = visit_type.cbs["string"]
6624 visit_type.cbs["invalid"] = visit_type.cbs["string"]
6625 visit_type.cbs["unresolved"] = visit_type.cbs["string"]
6626 visit_type.cbs["none"] = visit_type.cbs["string"]
6627
6628 recurse_node(ast, visit_node, visit_type)
6629
6630 close_types(st[1])
6631
6632 local redundant = {}
6633 local lastx, lasty = 0, 0
6634 table.sort(errors, function(a, b)
6635 return ((a.filename and b.filename) and a.filename < b.filename) or
6636 (a.filename == b.filename and ((a.y < b.y) or (a.y == b.y and a.x < b.x)))
6637 end)
6638 for i, err in ipairs(errors) do
6639 if err.x == lastx and err.y == lasty then
6640 table.insert(redundant, i)
6641 end
6642 lastx, lasty = err.x, err.y
6643 end
6644 for i = #redundant, 1, -1 do
6645 table.remove(errors, redundant[i])
6646 end
6647
6648 if not opts.skip_compat53 then
6649 add_compat53_entries(ast, all_needs_compat53)
6650 end
6651
6652 return errors, unknowns, module_type
6653end
6654
6655function tl.process(filename, env, result, preload_modules)
6656 local fd, err = io.open(filename, "r")
6657 if not fd then
6658 return nil, "could not open " .. filename .. ": " .. err
6659 end
6660
6661 local input, err = fd:read("*a")
6662 fd:close()
6663 if not input then
6664 return nil, "could not read " .. filename .. ": " .. err
6665 end
6666
6667 local basename, extension = filename:match("(.*)%.([a-z]+)$")
6668 extension = extension and extension:lower()
6669
6670 local is_lua
6671 if extension == "tl" then
6672 is_lua = false
6673 elseif extension == "lua" then
6674 is_lua = true
6675 else
6676 is_lua = input:match("^#![^\n]*lua[^\n]*\n")
6677 end
6678
6679 result, err = tl.process_string(input, is_lua, env, result, preload_modules, filename)
6680
6681 if err then
6682 return nil, err
6683 end
6684
6685 return result
6686end
6687
6688function tl.process_string(input, is_lua, env, result, preload_modules,
6689filename)
6690
6691 env = env or tl.init_env(is_lua)
6692 result = result or {
6693 syntax_errors = {},
6694 type_errors = {},
6695 unknowns = {},
6696 }
6697 preload_modules = preload_modules or {}
6698 filename = filename or ""
6699
6700 local tokens, errs = tl.lex(input)
6701 if errs then
6702 for i, err in ipairs(errs) do
6703 table.insert(result.syntax_errors, {
6704 y = err.y,
6705 x = err.x,
6706 msg = "invalid token '" .. err.tk .. "'",
6707 filename = filename,
6708 })
6709 end
6710 end
6711
6712 local i, program = tl.parse_program(tokens, result.syntax_errors, filename)
6713 if #result.syntax_errors > 0 then
6714 return result
6715 end
6716
6717
6718 for _, name in ipairs(preload_modules) do
6719 local module_type = require_module(name, is_lua, env, result)
6720
6721 if module_type == UNKNOWN then
6722 return nil, string.format("Error: could not preload module '%s'", name)
6723 end
6724 end
6725
6726 local error, unknown
6727 local opts = {
6728 lax = is_lua,
6729 filename = filename,
6730 env = env,
6731 result = result,
6732 skip_compat53 = env.skip_compat53,
6733 }
6734 error, unknown, result.type = tl.type_check(program, opts)
6735
6736 result.ast = program
6737 result.env = env
6738
6739 return result
6740end
6741
6742function tl.gen(input, env)
6743 env = env or tl.init_env()
6744 local result, err = tl.process_string(input, false, env)
6745
6746 if err then
6747 return nil, nil
6748 end
6749
6750 if not result.ast then
6751 return nil, result
6752 end
6753
6754 return tl.pretty_print_ast(result.ast), result
6755end
6756
6757local function tl_package_loader(module_name)
6758 local found_filename, fd, tried = tl.search_module(module_name, false)
6759 if found_filename then
6760 local input = fd:read("*a")
6761 fd:close()
6762 local errs = {}
6763 local _, program = tl.parse_program(tl.lex(input), errs, module_name)
6764 if #errs > 0 then
6765 error(module_name .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg)
6766 end
6767 local code = tl.pretty_print_ast(program, true)
6768 local chunk, err = load(code, module_name, "t")
6769 if chunk then
6770 return function()
6771 local ret = chunk()
6772 package.loaded[module_name] = ret
6773 return ret
6774 end
6775 else
6776 error("Internal Compiler Error: Teal generator produced invalid Lua. Please report a bug at https://github.com/teal-language/tl")
6777 end
6778 end
6779 return table.concat(tried, "\n\t")
6780end
6781
6782function tl.loader()
6783 if package.searchers then
6784 table.insert(package.searchers, 2, tl_package_loader)
6785 else
6786 table.insert(package.loaders, 2, tl_package_loader)
6787 end
6788end
6789
6790function tl.load(input, chunkname, mode, env)
6791 local tokens = tl.lex(input)
6792 local errs = {}
6793 local i, program = tl.parse_program(tokens, errs, chunkname)
6794 if #errs > 0 then
6795 return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg
6796 end
6797 local code = tl.pretty_print_ast(program, true)
6798 return load(code, chunkname, mode, env)
6799end
6800
6801return tl