From 6649c1027e285bf33c63fe748465c773bcea3bc2 Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Sat, 25 Feb 2023 18:44:14 -0600 Subject: [PATCH 1/4] [python] expand comprehensions via for_in_clause --- lua/ts-node-action/filetypes/python.lua | 180 ++++++++++++++- spec/filetypes/python/for_in_clause_spec.lua | 227 +++++++++++++++++++ 2 files changed, 402 insertions(+), 5 deletions(-) create mode 100644 spec/filetypes/python/for_in_clause_spec.lua diff --git a/lua/ts-node-action/filetypes/python.lua b/lua/ts-node-action/filetypes/python.lua index fbb01c8..d56ccce 100644 --- a/lua/ts-node-action/filetypes/python.lua +++ b/lua/ts-node-action/filetypes/python.lua @@ -64,6 +64,18 @@ local function node_trim_whitespace(node) ) end +-- @param node TSNode +-- @return table +function node_multiline_text(node) + local buf = vim.api.nvim_get_current_buf() + local text = vim.treesitter.query.get_node_text(node, buf) + if text:match("\n") then + return vim.tbl_map(vim.trim, vim.split(text, "\n")) + else + return {text} + end +end + -- When inlined, these nodes must be parenthesized to avoid changing the -- meaning of the code and to avoid syntax errors. -- eg: x = lambda y: y + 1 if y else 0 @@ -307,7 +319,7 @@ local function destructure_conditional_expression(node) end --- @param stmt table ---- @return string, table, TSNode +--- @return string, table --- @return nil local function expand_cond_expr(stmt, padding_override) local parent = stmt.node:parent() @@ -384,7 +396,7 @@ end --- @param stmt table { node, condition, consequence, alternative, comments } --- @param padding_override table ---- @return string, table, TSNode +--- @return string, table --- @return nil local function inline_if(stmt, padding_override) @@ -427,7 +439,7 @@ end --- @param stmt table { node, condition, consequence, alternative, comments } --- @param padding_override table ---- @return string, table, TSNode +--- @return string, table --- @return nil local function inline_ifelse(stmt, padding_override) @@ -467,7 +479,7 @@ local function inline_if_statement(padding_override) padding_override = padding_override or padding --- @param if_statement TSNode - --- @return string, table, TSNode + --- @return string, table local function action(if_statement) local stmt = destructure_if_statement(if_statement) -- we can't inline multiple statements within a block @@ -503,7 +515,7 @@ local function expand_conditional_expression(padding_override) padding_override = padding_override or padding --- @param conditional_expression TSNode - --- @return string, table, TSNode + --- @return string, table local function action(conditional_expression) local stmt = destructure_conditional_expression(conditional_expression) if #stmt.comments > 0 then @@ -515,6 +527,163 @@ local function expand_conditional_expression(padding_override) return { action, name = "Expand Conditional" } end +local function insert_multiline_text(lines, replacement, indent, prepend, append) + local line_cnt = #lines + for i, line in ipairs(lines) do + if i == 1 then + line = prepend .. line + elseif i == 2 and line_cnt > 2 then + indent = indent .. " " + end + if i == line_cnt then + line = line .. append + if line_cnt > 2 then + indent = indent:sub(1, -5) + end + end + table.insert(replacement, indent .. line) + end + return indent +end + +local function get_comprehension_config(type) + if type == "list_comprehension" then + return "[]", ".append" + elseif type == "set_comprehension" then + return "set()", ".add" + elseif type == "dictionary_comprehension" then + return "{}", "" + end +end + +-- @param stmt tsnode +-- @param padding_override table +-- @return string, table +-- @return nil +local function expand_comprehension(stmt, padding_override) + + local identifiers = {} + local parent = stmt.node:parent() + if parent:type() == "assignment" then + while parent:type() == "assignment" do + table.insert(identifiers, 1, helpers.node_text(parent:named_child(0))) + parent = parent:parent() + end + elseif parent:type() == "return_statement" then + table.insert(identifiers, 1, "result") + else + return + end + + local comp_type = stmt.node:type() + local comp_init, comp_method = get_comprehension_config(comp_type) + local replacement = { + table.concat(identifiers, " = ") .. " = " .. comp_init, + } + + local _, start_col = parent:start() + local start_indent = string.rep(" ", start_col) + local indent = start_indent + + for _, clause_or_comment in ipairs(stmt.clauses_and_comments) do + if clause_or_comment:type() == "comment" then + table.insert(replacement, indent .. helpers.node_text(clause_or_comment)) + else + local lines = node_multiline_text(clause_or_comment) + insert_multiline_text(lines, replacement, indent, "", ":") + indent = indent .. " " + end + end + + if stmt.body_comments then + for _, comment in ipairs(stmt.body_comments) do + table.insert(replacement, indent .. helpers.node_text(comment)) + end + end + + if comp_type == "dictionary_comprehension" then + local keys = node_multiline_text(stmt.body:named_child(0)) + local values = node_multiline_text(stmt.body:named_child(1)) + for _, identifier in ipairs(identifiers) do + local prepend = identifier .. "[" + local append = "] = " + insert_multiline_text(keys, replacement, indent, prepend, append) + prepend = replacement[#replacement] + replacement[#replacement] = nil + append = "" + insert_multiline_text(values, replacement, indent, prepend, append) + end + else + local values = node_multiline_text(stmt.body) + for _, identifier in ipairs(identifiers) do + local prepend = identifier .. comp_method .. "(" + local append = ")" + insert_multiline_text(values, replacement, indent, prepend, append) + end + end + + if parent:type() == "return_statement" then + table.insert(replacement, start_indent .. "return " .. identifiers[1]) + end + + return replacement, { + cursor = {}, + format = true, + target = parent, + } +end + +local function destructure_comprehension(comprehension) + local body + local clauses_and_comments = {} + local body_comments = {} + + for child in comprehension:iter_children() do + if child:named() then + local child_type = child:type() + if not body then + if child_type == "comment" then + table.insert(body_comments, child) + else + body = child + end + else + table.insert(clauses_and_comments, child) + end + end + end + + return { + node = comprehension, + body = body, + clauses_and_comments = clauses_and_comments, + body_comments = body_comments, + } +end + +local function expand_for_in_clause(padding_override) + padding_override = padding_override or padding + + local function action(node) + local comprehension = node:parent() + + local type = comprehension:type() + if not get_comprehension_config(type) then + return + end + + local stmt = destructure_comprehension(comprehension) + if not stmt then + return + end + + return expand_comprehension(stmt, padding_override) + end + + return { action, name = "Expand Comprehension" } +end + + return { ["dictionary"] = actions.toggle_multiline(padding), ["set"] = actions.toggle_multiline(padding), @@ -532,4 +701,5 @@ return { ["integer"] = actions.toggle_int_readability(), ["conditional_expression"] = { expand_conditional_expression(padding), }, ["if_statement"] = { inline_if_statement(padding), }, + ["for_in_clause"] = { expand_for_in_clause(padding), }, } diff --git a/spec/filetypes/python/for_in_clause_spec.lua b/spec/filetypes/python/for_in_clause_spec.lua new file mode 100644 index 0000000..2e75c57 --- /dev/null +++ b/spec/filetypes/python/for_in_clause_spec.lua @@ -0,0 +1,227 @@ +dofile("./spec/spec_helper.lua") + +local Helper = SpecHelper:new("python") + +describe("for_in_clause", function() + + it("expands list assignment for", function() + assert.are.same( + { + "xs = []", + "for x in range(10):", + " xs.append(x)", + }, + Helper:call({"xs = [x for x in range(10)]"}, {1, 9}) + ) + end) + + it("expands list assignment for/if", function() + assert.are.same( + { + "xs = []", + "for x in range(10):", + " if -x and x - 3 == 0 and abs(x - 1) < 2:", + " xs.append(x)", + }, + Helper:call({ + "xs = [x for x in range(10) if -x and x - 3 == 0 and abs(x - 1) < 2]" + }, {1, 9}) + ) + end) + + it("expands set assignment for/if", function() + assert.are.same( + { + "xs = set()", + "for x in range(10):", + " if -x and x - 3 == 0 and abs(x - 1) < 2:", + " xs.add(x)", + }, + Helper:call({ + "xs = {x for x in range(10) if -x and x - 3 == 0 and abs(x - 1) < 2}" + }, {1, 9}) + ) + end) + + it("expands dict assignment for/if", function() + assert.are.same( + { + "xs = {}", + "for x in range(10):", + " if -x and abs(x - 1) < 2:", + " xs[x] = x + 1", + }, + Helper:call({ + "xs = {x: x + 1 for x in range(10) if -x and abs(x - 1) < 2}" + }, {1, 16}) + ) + end) + + it("expands list return for", function() + assert.are.same( + { + "result = []", + "for x in range(10):", + " result.append(x)", + "return result", + }, + Helper:call({ + "return [x for x in range(10)]" + }, {1, 11}) + ) + end) + + it("expands list return for/if", function() + assert.are.same( + { + "result = []", + "for x in range(10):", + " if -x and abs(x - 1) < 2:", + " result.append(x)", + "return result", + }, + Helper:call({ + "return [x for x in range(10) if -x and abs(x - 1) < 2]" + }, {1, 11}) + ) + end) + + it("expands set return for/if", function() + assert.are.same( + { + "result = set()", + "for x in range(10):", + " if -x and abs(x - 1) < 2:", + " result.add(x)", + "return result", + }, + Helper:call({ + "return {x for x in range(10) if -x and abs(x - 1) < 2}" + }, {1, 11}) + ) + end) + + it("expands dict return for/if", function() + assert.are.same( + { + "result = {}", + "for x in range(10):", + " if -x and abs(x - 1) < 2:", + " result[x] = x + 1", + "return result", + }, + Helper:call({ + "return {x: x + 1 for x in range(10) if -x and abs(x - 1) < 2}" + }, {1, 18}) + ) + end) + + it("doesn't expand generator assignment", function() + local text = { + "xs = (x for x in range(10))" + } + assert.are.same(text, Helper:call(text, {1, 9})) + end) + + it("doesn't expand generator return", function() + local text = { + "return (x for x in range(10))" + } + assert.are.same(text, Helper:call(text, {1, 11})) + end) + + it("expands a multiline dict assignment with comments", function() + assert.are.same( + { + "a = {}", + "# before for 1", + "for x in range(", + " # for inside arg", + " 1,", + " 5, # for inside arg 2", + " ):", + " # before if 1", + " if (x < 2 or", + " x - 3 == 0):", + " # after if 1", + " # before body", + " a[x] = foo(x) + 1", + }, + Helper:call({ + "a = { # before body", + " x: foo(x) + 1", + " # before for 1", + " for x in range(", + " # for inside arg", + " 1,", + " 5, # for inside arg 2", + " )", + " # before if 1", + " if (x < 2 or", + " x - 3 == 0)", + " # after if 1", + "}", + }, {4, 5}) + ) + end) + + it("expands absurd multiline set assignment with comments", function() + assert.are.same( + { + "a = b = c = set()", + "# before for 1", + "for x in range(", + " # for inside arg", + " 1,", + " 5, # for inside arg 2", + " ):", + " # before if 1", + " if x != y:", + " # before for 2", + " for z in {", + " # for inside arg", + " 1, 2, 3", + " # for inside arg 2", + " }:", + " # before if 2", + " if y != z:", + " # after if 2", + " # before body", + " a.add((x,", + " y, # y", + " z))", + " b.add((x,", + " y, # y", + " z))", + " c.add((x,", + " y, # y", + " z)) # after comprehension", + }, + Helper:call({ + "a = b = c = { # before body", + " (x,", + " y, # y", + " z)", + " # before for 1", + " for x in range(", + " # for inside arg", + " 1,", + " 5, # for inside arg 2", + " )", + " # before if 1", + " if x != y", + " # before for 2", + " for z in {", + " # for inside arg", + " 1, 2, 3", + " # for inside arg 2", + " }", + " # before if 2", + " if y != z", + " # after if 2", + " } # after comprehension", + }, {6, 14}) + ) + end) + +end) From 4f479dc437b59f66c0b9d8bae3ad9f4bcda86841 Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Sat, 25 Feb 2023 19:40:55 -0600 Subject: [PATCH 2/4] [python] update spec for new workflow --- spec/filetypes/python/for_in_clause_spec.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spec/filetypes/python/for_in_clause_spec.lua b/spec/filetypes/python/for_in_clause_spec.lua index 2e75c57..df2bb19 100644 --- a/spec/filetypes/python/for_in_clause_spec.lua +++ b/spec/filetypes/python/for_in_clause_spec.lua @@ -1,6 +1,6 @@ dofile("./spec/spec_helper.lua") -local Helper = SpecHelper:new("python") +local Helper = SpecHelper.new("python", { shiftwidth = 4 }) describe("for_in_clause", function() From 8af96493e5111f02024d802f337b8a1979df7a7b Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Sun, 26 Feb 2023 12:49:34 -0600 Subject: [PATCH 3/4] [python] update README; removed wrapper/action fn that received padding_override b/c padding isn't used --- README.md | 1 + lua/ts-node-action/filetypes/python.lua | 99 +++++++++++-------------- 2 files changed, 44 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 3be5900..e4be185 100644 --- a/README.md +++ b/README.md @@ -264,6 +264,7 @@ Builtin actions are all higher-order functions so they can easily have options o | `toggle_block()` | | ✅ | | | | | | | | | if/else <-> ternery | | ✅ | | | ✅ | | | | | | if block/postfix | | ✅ | | | | | | | | +| expand comprehension | | | | | ✅ | | | | | | `toggle_hash_style()` | | ✅ | | | | | | | | | `conceal_string()` | | | ✅ | | | | | | ✅ | diff --git a/lua/ts-node-action/filetypes/python.lua b/lua/ts-node-action/filetypes/python.lua index d56ccce..e0d5361 100644 --- a/lua/ts-node-action/filetypes/python.lua +++ b/lua/ts-node-action/filetypes/python.lua @@ -556,11 +556,51 @@ local function get_comprehension_config(type) end end +local function destructure_comprehension(comprehension) + local body + local clauses_and_comments = {} + local body_comments = {} + + for child in comprehension:iter_children() do + if child:named() then + local child_type = child:type() + if not body then + if child_type == "comment" then + table.insert(body_comments, child) + else + body = child + end + else + table.insert(clauses_and_comments, child) + end + end + end + + return { + node = comprehension, + body = body, + clauses_and_comments = clauses_and_comments, + body_comments = body_comments, + } +end + -- @param stmt tsnode --- @param padding_override table -- @return string, table -- @return nil -local function expand_comprehension(stmt, padding_override) +local function expand_comprehension(for_in_clause) + + local comprehension = for_in_clause:parent() + + local comp_type = comprehension:type() + local comp_init, comp_method = get_comprehension_config(comp_type) + if not comp_init then + return + end + + local stmt = destructure_comprehension(comprehension) + if not stmt then + return + end local identifiers = {} local parent = stmt.node:parent() @@ -575,8 +615,6 @@ local function expand_comprehension(stmt, padding_override) return end - local comp_type = stmt.node:type() - local comp_init, comp_method = get_comprehension_config(comp_type) local replacement = { table.concat(identifiers, " = ") .. " = " .. comp_init, } @@ -633,57 +671,6 @@ local function expand_comprehension(stmt, padding_override) } end -local function destructure_comprehension(comprehension) - local body - local clauses_and_comments = {} - local body_comments = {} - - for child in comprehension:iter_children() do - if child:named() then - local child_type = child:type() - if not body then - if child_type == "comment" then - table.insert(body_comments, child) - else - body = child - end - else - table.insert(clauses_and_comments, child) - end - end - end - - return { - node = comprehension, - body = body, - clauses_and_comments = clauses_and_comments, - body_comments = body_comments, - } -end - -local function expand_for_in_clause(padding_override) - padding_override = padding_override or padding - - local function action(node) - local comprehension = node:parent() - - local type = comprehension:type() - if not get_comprehension_config(type) then - return - end - - local stmt = destructure_comprehension(comprehension) - if not stmt then - return - end - - return expand_comprehension(stmt, padding_override) - end - - return { action, name = "Expand Comprehension" } -end - - return { ["dictionary"] = actions.toggle_multiline(padding), ["set"] = actions.toggle_multiline(padding), @@ -701,5 +688,5 @@ return { ["integer"] = actions.toggle_int_readability(), ["conditional_expression"] = { expand_conditional_expression(padding), }, ["if_statement"] = { inline_if_statement(padding), }, - ["for_in_clause"] = { expand_for_in_clause(padding), }, + ["for_in_clause"] = { { expand_comprehension, name = "Expand Comprehension" } }, } From 5f78b05a0182eab16df12fa655178ec6ba23a92a Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Sun, 26 Feb 2023 13:07:42 -0600 Subject: [PATCH 4/4] [python] tweak wording in README/action to be clearer --- README.md | 2 +- lua/ts-node-action/filetypes/python.lua | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e4be185..7430f3b 100644 --- a/README.md +++ b/README.md @@ -264,7 +264,7 @@ Builtin actions are all higher-order functions so they can easily have options o | `toggle_block()` | | ✅ | | | | | | | | | if/else <-> ternery | | ✅ | | | ✅ | | | | | | if block/postfix | | ✅ | | | | | | | | -| expand comprehension | | | | | ✅ | | | | | +| comprehension -> for loop | | | | | ✅ | | | | | | `toggle_hash_style()` | | ✅ | | | | | | | | | `conceal_string()` | | | ✅ | | | | | | ✅ | diff --git a/lua/ts-node-action/filetypes/python.lua b/lua/ts-node-action/filetypes/python.lua index e0d5361..4eab13f 100644 --- a/lua/ts-node-action/filetypes/python.lua +++ b/lua/ts-node-action/filetypes/python.lua @@ -688,5 +688,5 @@ return { ["integer"] = actions.toggle_int_readability(), ["conditional_expression"] = { expand_conditional_expression(padding), }, ["if_statement"] = { inline_if_statement(padding), }, - ["for_in_clause"] = { { expand_comprehension, name = "Expand Comprehension" } }, + ["for_in_clause"] = { { expand_comprehension, name = "Comprehension -> For Loop" } }, }