aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHisham Muhammad <hisham@gobolinux.org>2016-10-06 11:53:43 -0300
committerGitHub <noreply@github.com>2016-10-06 11:53:43 -0300
commit2c95d93a7d02b5b3278de1841d5717a56e6ccdfe (patch)
tree49c3e583567c40a5173a603dff941dcb6e415992
parent5f3b2b79bbba5de0a571188c08b42a213aac772f (diff)
parenteb7427676a2f09556f9d5f19a1a6392fd945e8bc (diff)
downloadluarocks-2c95d93a7d02b5b3278de1841d5717a56e6ccdfe.tar.gz
luarocks-2c95d93a7d02b5b3278de1841d5717a56e6ccdfe.tar.bz2
luarocks-2c95d93a7d02b5b3278de1841d5717a56e6ccdfe.zip
Merge pull request #624 from mpeterv/coroutineless-sortedpairs
Don't use coroutines in util.sortedpairs
-rw-r--r--spec/util_spec.lua56
-rw-r--r--src/luarocks/util.lua79
2 files changed, 99 insertions, 36 deletions
diff --git a/spec/util_spec.lua b/spec/util_spec.lua
index e6776e4b..2779b1ce 100644
--- a/spec/util_spec.lua
+++ b/spec/util_spec.lua
@@ -116,3 +116,59 @@ describe("Basic tests #blackbox #b_util", function()
116 end) 116 end)
117 end) 117 end)
118end) 118end)
119
120test_env.unload_luarocks()
121local util = require("luarocks.util")
122
123describe("Luarocks util test #whitebox #w_util", function()
124 describe("util.sortedpairs", function()
125 local function collect(iter, state, var)
126 local collected = {}
127
128 while true do
129 local returns = {iter(state, var)}
130
131 if returns[1] == nil then
132 return collected
133 else
134 table.insert(collected, returns)
135 var = returns[1]
136 end
137 end
138 end
139
140 it("default sort", function()
141 assert.are.same({}, collect(util.sortedpairs({})))
142 assert.are.same({
143 {1, "v1"},
144 {2, "v2"},
145 {3, "v3"},
146 {"bar", "v5"},
147 {"foo", "v4"}
148 }, collect(util.sortedpairs({"v1", "v2", "v3", foo = "v4", bar = "v5"})))
149 end)
150
151 it("sort by function", function()
152 local function compare(a, b) return a > b end
153 assert.are.same({}, collect(util.sortedpairs({}, compare)))
154 assert.are.same({
155 {3, "v3"},
156 {2, "v2"},
157 {1, "v1"}
158 }, collect(util.sortedpairs({"v1", "v2", "v3"}, compare)))
159 end)
160
161 it("sort by priority table", function()
162 assert.are.same({}, collect(util.sortedpairs({}, {"k1", "k2"})))
163 assert.are.same({
164 {"k3", "v3"},
165 {"k2", "v2", {"sub order"}},
166 {"k1", "v1"},
167 {"k4", "v4"},
168 {"k5", "v5"},
169 }, collect(util.sortedpairs({
170 k1 = "v1", k2 = "v2", k3 = "v3", k4 = "v4", k5 = "v5"
171 }, {"k3", {"k2", {"sub order"}}, "k1"})))
172 end)
173 end)
174end)
diff --git a/src/luarocks/util.lua b/src/luarocks/util.lua
index 532bea8b..c9fb7d63 100644
--- a/src/luarocks/util.lua
+++ b/src/luarocks/util.lua
@@ -357,52 +357,59 @@ local function default_sort(a, b)
357 end 357 end
358end 358end
359 359
360-- The iterator function used internally by util.sortedpairs. 360--- A table iterator generator that returns elements sorted by key,
361-- to be used in "for" loops.
361-- @param tbl table: The table to be iterated. 362-- @param tbl table: The table to be iterated.
362-- @param sort_function function or nil: An optional comparison function 363-- @param sort_function function or table or nil: An optional comparison function
363-- to be used by table.sort when sorting keys. 364-- to be used by table.sort when sorting keys, or an array listing an explicit order
364-- @see sortedpairs 365-- for keys. If a value itself is an array, it is taken so that the first element
365local function sortedpairs_iterator(tbl, sort_function) 366-- is a string representing the field name, and the second element is a priority table
366 local ks = util.keys(tbl) 367-- for that key, which is returned by the iterator as the third value after the key
367 if not sort_function or type(sort_function) == "function" then 368-- and the value.
368 table.sort(ks, sort_function or default_sort) 369-- @return function: the iterator function.
369 for _, k in ipairs(ks) do 370function util.sortedpairs(tbl, sort_function)
370 coroutine.yield(k, tbl[k]) 371 sort_function = sort_function or default_sort
371 end 372 local keys = util.keys(tbl)
373 local sub_orders = {}
374
375 if type(sort_function) == "function" then
376 table.sort(keys, sort_function)
372 else 377 else
373 local order = sort_function 378 local order = sort_function
374 local done = {} 379 local ordered_keys = {}
375 for _, k in ipairs(order) do 380 local all_keys = keys
376 local sub_order 381 keys = {}
377 if type(k) == "table" then 382
378 sub_order = k[2] 383 for _, order_entry in ipairs(order) do
379 k = k[1] 384 local key, sub_order
385 if type(order_entry) == "table" then
386 key = order_entry[1]
387 sub_order = order_entry[2]
388 else
389 key = order_entry
380 end 390 end
381 if tbl[k] then 391
382 done[k] = true 392 if tbl[key] then
383 coroutine.yield(k, tbl[k], sub_order) 393 ordered_keys[key] = true
394 sub_orders[key] = sub_order
395 table.insert(keys, key)
384 end 396 end
385 end 397 end
386 table.sort(ks, default_sort) 398
387 for _, k in ipairs(ks) do 399 table.sort(all_keys, default_sort)
388 if not done[k] then 400 for _, key in ipairs(all_keys) do
389 coroutine.yield(k, tbl[k]) 401 if not ordered_keys[key] then
402 table.insert(keys, key)
390 end 403 end
391 end 404 end
392 end 405 end
393end
394 406
395--- A table iterator generator that returns elements sorted by key, 407 local i = 1
396-- to be used in "for" loops. 408 return function()
397-- @param tbl table: The table to be iterated. 409 local key = keys[i]
398-- @param sort_function function or table or nil: An optional comparison function 410 i = i + 1
399-- to be used by table.sort when sorting keys, or an array listing an explicit order 411 return key, tbl[key], sub_orders[key]
400-- for keys. If a value itself is an array, it is taken so that the first element 412 end
401-- is a string representing the field name, and the second element is a priority table
402-- for that key.
403-- @return function: the iterator function.
404function util.sortedpairs(tbl, sort_function)
405 return coroutine.wrap(function() sortedpairs_iterator(tbl, sort_function) end)
406end 413end
407 414
408function util.lua_versions() 415function util.lua_versions()