Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ Builtin actions are all higher-order functions so they can easily have options o
| `toggle_block()` | | ✅ | | | | | | | |
| if/else <-> ternery | | ✅ | | | ✅ | | | | |
| if block/postfix | | ✅ | | | | | | | |
| comprehension -> for loop | | | | | ✅ | | | | |
| `toggle_hash_style()` | | ✅ | | | | | | | |
| `conceal_string()` | | | ✅ | | | | | | ✅ |

Expand Down
167 changes: 162 additions & 5 deletions lua/ts-node-action/filetypes/python.lua
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ local function node_trim_whitespace(node)
)
end

-- @param node TSNode
-- @return table
function node_multiline_text(node)
local buf = vim.api.nvim_get_current_buf()
local text = vim.treesitter.query.get_node_text(node, buf)
if text:match("\n") then
return vim.tbl_map(vim.trim, vim.split(text, "\n"))
else
return {text}
end
end

-- When inlined, these nodes must be parenthesized to avoid changing the
-- meaning of the code and to avoid syntax errors.
-- eg: x = lambda y: y + 1 if y else 0
Expand Down Expand Up @@ -307,7 +319,7 @@ local function destructure_conditional_expression(node)
end

--- @param stmt table
--- @return string, table, TSNode
--- @return string, table
--- @return nil
local function expand_cond_expr(stmt, padding_override)
local parent = stmt.node:parent()
Expand Down Expand Up @@ -384,7 +396,7 @@ end

--- @param stmt table { node, condition, consequence, alternative, comments }
--- @param padding_override table
--- @return string, table, TSNode
--- @return string, table
--- @return nil
local function inline_if(stmt, padding_override)

Expand Down Expand Up @@ -427,7 +439,7 @@ end

--- @param stmt table { node, condition, consequence, alternative, comments }
--- @param padding_override table
--- @return string, table, TSNode
--- @return string, table
--- @return nil
local function inline_ifelse(stmt, padding_override)

Expand Down Expand Up @@ -467,7 +479,7 @@ local function inline_if_statement(padding_override)
padding_override = padding_override or padding

--- @param if_statement TSNode
--- @return string, table, TSNode
--- @return string, table
local function action(if_statement)
local stmt = destructure_if_statement(if_statement)
-- we can't inline multiple statements within a block
Expand Down Expand Up @@ -503,7 +515,7 @@ local function expand_conditional_expression(padding_override)
padding_override = padding_override or padding

--- @param conditional_expression TSNode
--- @return string, table, TSNode
--- @return string, table
local function action(conditional_expression)
local stmt = destructure_conditional_expression(conditional_expression)
if #stmt.comments > 0 then
Expand All @@ -515,6 +527,150 @@ local function expand_conditional_expression(padding_override)
return { action, name = "Expand Conditional" }
end

local function insert_multiline_text(lines, replacement, indent, prepend, append)
local line_cnt = #lines
for i, line in ipairs(lines) do
if i == 1 then
line = prepend .. line
elseif i == 2 and line_cnt > 2 then
indent = indent .. " "
end
if i == line_cnt then
line = line .. append
if line_cnt > 2 then
indent = indent:sub(1, -5)
end
end
table.insert(replacement, indent .. line)
end
return indent
end

local function get_comprehension_config(type)
if type == "list_comprehension" then
return "[]", ".append"
elseif type == "set_comprehension" then
return "set()", ".add"
elseif type == "dictionary_comprehension" then
return "{}", ""
end
end

local function destructure_comprehension(comprehension)
local body
local clauses_and_comments = {}
local body_comments = {}

for child in comprehension:iter_children() do
if child:named() then
local child_type = child:type()
if not body then
if child_type == "comment" then
table.insert(body_comments, child)
else
body = child
end
else
table.insert(clauses_and_comments, child)
end
end
end

return {
node = comprehension,
body = body,
clauses_and_comments = clauses_and_comments,
body_comments = body_comments,
}
end

-- @param stmt tsnode
-- @return string, table
-- @return nil
local function expand_comprehension(for_in_clause)

local comprehension = for_in_clause:parent()

local comp_type = comprehension:type()
local comp_init, comp_method = get_comprehension_config(comp_type)
if not comp_init then
return
end

local stmt = destructure_comprehension(comprehension)
if not stmt then
return
end

local identifiers = {}
local parent = stmt.node:parent()
if parent:type() == "assignment" then
while parent:type() == "assignment" do
table.insert(identifiers, 1, helpers.node_text(parent:named_child(0)))
parent = parent:parent()
end
elseif parent:type() == "return_statement" then
table.insert(identifiers, 1, "result")
else
return
end

local replacement = {
table.concat(identifiers, " = ") .. " = " .. comp_init,
}

local _, start_col = parent:start()
local start_indent = string.rep(" ", start_col)
local indent = start_indent

for _, clause_or_comment in ipairs(stmt.clauses_and_comments) do
if clause_or_comment:type() == "comment" then
table.insert(replacement, indent .. helpers.node_text(clause_or_comment))
else
local lines = node_multiline_text(clause_or_comment)
insert_multiline_text(lines, replacement, indent, "", ":")
indent = indent .. " "
end
end

if stmt.body_comments then
for _, comment in ipairs(stmt.body_comments) do
table.insert(replacement, indent .. helpers.node_text(comment))
end
end

if comp_type == "dictionary_comprehension" then
local keys = node_multiline_text(stmt.body:named_child(0))
local values = node_multiline_text(stmt.body:named_child(1))
for _, identifier in ipairs(identifiers) do
local prepend = identifier .. "["
local append = "] = "
insert_multiline_text(keys, replacement, indent, prepend, append)
prepend = replacement[#replacement]
replacement[#replacement] = nil
append = ""
insert_multiline_text(values, replacement, indent, prepend, append)
end
else
local values = node_multiline_text(stmt.body)
for _, identifier in ipairs(identifiers) do
local prepend = identifier .. comp_method .. "("
local append = ")"
insert_multiline_text(values, replacement, indent, prepend, append)
end
end

if parent:type() == "return_statement" then
table.insert(replacement, start_indent .. "return " .. identifiers[1])
end

return replacement, {
cursor = {},
format = true,
target = parent,
}
end

return {
["dictionary"] = actions.toggle_multiline(padding),
["set"] = actions.toggle_multiline(padding),
Expand All @@ -532,4 +688,5 @@ return {
["integer"] = actions.toggle_int_readability(),
["conditional_expression"] = { expand_conditional_expression(padding), },
["if_statement"] = { inline_if_statement(padding), },
["for_in_clause"] = { { expand_comprehension, name = "Comprehension -> For Loop" } },
}
Loading