From 0e2acedd531a3619e46e7d954ed0209bae99d1e3 Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Mon, 27 Feb 2023 01:15:09 -0600 Subject: [PATCH 1/5] [python] cycle style of 'from .. import ..' statements --- lua/ts-node-action/filetypes/python.lua | 32 ++- .../filetypes/python/cycle_import_from.lua | 136 ++++++++++ .../python/import_from_statement_spec.lua | 244 ++++++++++++++++++ 3 files changed, 410 insertions(+), 2 deletions(-) create mode 100644 lua/ts-node-action/filetypes/python/cycle_import_from.lua create mode 100644 spec/filetypes/python/import_from_statement_spec.lua diff --git a/lua/ts-node-action/filetypes/python.lua b/lua/ts-node-action/filetypes/python.lua index fbb01c8..e5d42d0 100644 --- a/lua/ts-node-action/filetypes/python.lua +++ b/lua/ts-node-action/filetypes/python.lua @@ -1,5 +1,6 @@ -local helpers = require("ts-node-action.helpers") -local actions = require("ts-node-action.actions") +local helpers = require("ts-node-action.helpers") +local actions = require("ts-node-action.actions") +local import_from = require("ts-node-action.filetypes.python.cycle_import_from") -- Special cases: -- Because "is" and "not" are valid by themselves, they are seen as separate @@ -515,6 +516,32 @@ local function expand_conditional_expression(padding_override) return { action, name = "Expand Conditional" } end +local import_formats = { "inline", "block", "single", "block" } + +local function cycle_import_from_statement(user_import_formats) + user_import_formats = user_import_formats or import_formats + + local function action(import_from_statement) + + local stmt = import_from.destructure(import_from_statement) + if #stmt.comments > 0 then + return + end + + local new_format + for i, f in ipairs(user_import_formats) do + if f == stmt.format then + new_format = user_import_formats[i + 1] or user_import_formats[1] + break + end + end + + return import_from.cycle(stmt, new_format) + end + + return { action, name = "Cycle Import" } +end + return { ["dictionary"] = actions.toggle_multiline(padding), ["set"] = actions.toggle_multiline(padding), @@ -532,4 +559,5 @@ return { ["integer"] = actions.toggle_int_readability(), ["conditional_expression"] = { expand_conditional_expression(padding), }, ["if_statement"] = { inline_if_statement(padding), }, + ["import_from_statement"] = { cycle_import_from_statement(import_formats), }, } diff --git a/lua/ts-node-action/filetypes/python/cycle_import_from.lua b/lua/ts-node-action/filetypes/python/cycle_import_from.lua new file mode 100644 index 0000000..9688d8a --- /dev/null +++ b/lua/ts-node-action/filetypes/python/cycle_import_from.lua @@ -0,0 +1,136 @@ +local helpers = require("ts-node-action.helpers") + +--- @param import_from_statement TSNode +--- @return table +local function destructure_import_from_statement(import_from_statement) + local module = import_from_statement:named_child(0) + local names = {} + local sibling = module:next_named_sibling() + local comments = {} + + while sibling do + if sibling:type() == "comment" then + table.insert(comments, sibling) + else + table.insert(names, helpers.node_text(sibling)) + end + sibling = sibling:next_named_sibling() + end + + local format + if helpers.node_is_multiline(import_from_statement) then + format = "block" + elseif #names > 1 then + format = "inline" + else + format = "single" + end + + return { + node = import_from_statement, + module = helpers.node_text(module), + names = names, + comments = comments, + format = format, + } +end + +-- Find direct sibling import_from_statements that import from the same module +-- as the origin statement. +-- @param origin_stmt table +local function find_sibling_imports(origin_stmt) + local stmts = {} + local names = {} + + local prev_sibling = origin_stmt.node:prev_named_sibling() + while prev_sibling do + if (prev_sibling:type() == "import_from_statement" and + helpers.node_text(prev_sibling:named_child(0)) == origin_stmt.module) + then + local stmt = destructure_import_from_statement(prev_sibling) + table.insert(stmts, 1, stmt) + prev_sibling = prev_sibling:prev_named_sibling() + else + prev_sibling = nil + end + end + + for _, stmt in ipairs(stmts) do + for _, name in ipairs(stmt.names) do table.insert(names, name) end + end + + table.insert(stmts, origin_stmt) + for _, name in ipairs(origin_stmt.names) do table.insert(names, name) end + + local next_sibling = origin_stmt.node:next_named_sibling() + while next_sibling do + if (next_sibling:type() == "import_from_statement" and + helpers.node_text(next_sibling:named_child(0)) == origin_stmt.module) + then + local stmt = destructure_import_from_statement(next_sibling) + table.insert(stmts, stmt) + for _, name in ipairs(stmt.names) do table.insert(names, name) end + next_sibling = next_sibling:next_named_sibling() + else + next_sibling = nil + end + end + + return stmts, names +end + +-- Create a stub node to represent the replacement target. +-- This is necessary because the replacement spans multiple +-- top-level nodes and so we can't target the first and last +-- nodes directly. +local function create_stub_target_node(stmts) + local start_row, start_col = stmts[1].node:start() + local _, _, last_row, last_col = stmts[#stmts].node:range() + local node = {} + function node.range(self) + return start_row, start_col, last_row, last_col + end + return node +end + +-- @param origin_stmt table +-- @param new_format string +-- @return table, table +local function cycle(origin_stmt, new_format) + local replacement = {} + local stmts, names = find_sibling_imports(origin_stmt) + + if new_format == "single" then + for _, name in ipairs(names) do + table.insert( + replacement, + "from " .. origin_stmt.module .. " import " .. name .. "" + ) + end + elseif new_format == "inline" then + table.insert( + replacement, + "from " .. origin_stmt.module .. + " import " .. table.concat(names, ", ") + ) + elseif new_format == "block" then + table.insert( + replacement, + "from " .. origin_stmt.module .. " import (" + ) + for _, name in ipairs(names) do + table.insert(replacement, " " .. name .. ",") + end + table.insert(replacement, ")") + end + + return replacement, { + target = create_stub_target_node(stmts), + cursor = {row = 0, col = 0}, + } +end + +return { + destructure = destructure_import_from_statement, + cycle = cycle, +} diff --git a/spec/filetypes/python/import_from_statement_spec.lua b/spec/filetypes/python/import_from_statement_spec.lua new file mode 100644 index 0000000..5190b17 --- /dev/null +++ b/spec/filetypes/python/import_from_statement_spec.lua @@ -0,0 +1,244 @@ +dofile("./spec/spec_helper.lua") + +local Helper = SpecHelper.new("python", { shiftwidth = 4 }) + +describe("import_from_statement", function() + + + it("cycles from single to block", function() + assert.are.same( + { + [[from foo import (]], + [[ bar,]], + [[ baz,]], + [[ qux,]], + [[)]], + }, + Helper:call({ + [[from foo import bar]], + [[from foo import baz]], + [[from foo import qux]], + }) + ) + end) + + it("cycles from inline to block", function() + assert.are.same( + { + [[from foo import (]], + [[ bar,]], + [[ baz,]], + [[ qux,]], + [[)]], + }, + Helper:call({ + [[from foo import bar, baz, qux]], + }) + ) + end) + + it("cycles from block to single", function() + assert.are.same( + { + [[from foo import bar]], + [[from foo import baz]], + [[from foo import qux]], + }, + Helper:call({ + [[from foo import (]], + [[ bar,]], + [[ baz,]], + [[ qux,]], + [[)]], + }) + ) + end) + + it("cycles from block to single (continuation makes it a block)", function() + assert.are.same( + { + [[from foo import bar]], + [[from foo import baz]], + [[from foo import qux]], + }, + Helper:call({ + [[from foo import bar, \]], + [[ baz, qux]], + }) + ) + end) + + it("cycles from inline to block with mixed siblings", function() + assert.are.same( + { + [[from foo import (]], + [[ qux,]], + [[ bar,]], + [[ baz,]], + [[)]], + }, + Helper:call({ + [[from foo import qux]], + [[from foo import bar, baz]], + }, {2, 1}) + ) + end) + + it("cycles from block to single with mixed siblings", function() + assert.are.same( + { + [[from foo import a]], + [[from foo import b]], + [[from foo import c]], + [[from foo import bar]], + [[from foo import baz]], + [[from foo import qux]], + [[from foo import d]], + [[from foo import e]], + }, + Helper:call({ + [[from foo import a, b]], + [[from foo import c]], + [[from foo import (]], + [[ bar,]], + [[ baz,]], + [[ qux,]], + [[)]], + [[from foo import d, e]], + }, {3, 1}) + ) + end) + + it("cycles from inline to block only close siblings", function() + assert.are.same( + { + [[from abc import a, b, c]], + [[from foo import (]], + [[ bar,]], + [[ baz,]], + [[ qux,]], + [[ bee,]], + [[ boo,]], + [[ hah,]], + [[)]], + [[from xyz import x, y, z]], + }, + Helper:call({ + [[from abc import a, b, c]], + [[from foo import bar, baz, qux]], + [[from foo import bee, boo]], + [[from foo import hah]], + [[from xyz import x, y, z]], + }, {3, 1}) + ) + end) + + it("cycles with relative imports", function() + assert.are.same( + { + [[from .foo import (]], + [[ bar,]], + [[ baz,]], + [[ qux,]], + [[)]], + }, + Helper:call({ + [[from .foo import bar, baz, qux]], + }) + ) + end) + + it("cycles with relative imports", function() + assert.are.same( + { + [[from .foo import bar]], + [[from .foo import baz]], + [[from .foo import qux]], + }, + Helper:call({ + [[from .foo import (]], + [[ bar,]], + [[ baz,]], + [[ qux,]], + [[)]], + }) + ) + end) + + it("cycles with deep relative imports", function() + assert.are.same( + { + [[from .foo.bar.baz import (]], + [[ qux,]], + [[ bee,]], + [[ boo,]], + [[)]], + }, + Helper:call({ + [[from .foo.bar.baz import qux, bee, boo]], + }) + ) + end) + + it("cycles with multi-level relative imports", function() + assert.are.same( + { + [[from ...foo import (]], + [[ bar,]], + [[ baz,]], + [[ qux,]], + [[)]], + }, + Helper:call({ + [[from ...foo import bar, baz, qux]], + }) + ) + end) + + it("cycles with import aliases", function() + assert.are.same( + { + [[from foo import (]], + [[ bar as b,]], + [[ baz as z,]], + [[ qux as q,]], + [[)]], + }, + Helper:call({ + [[from foo import bar as b, baz as z, qux as q]], + }) + ) + end) + + it("doesn't cycle with embedded comments", function() + local text = { + [[from foo import (]], + [[ bar, # comment]], + [[ baz, # comment]], + [[ qux, # comment]], + [[)]], + } + assert.are.same(text, Helper:call(text)) + end) + + it("cycles with sibling comments", function() + assert.are.same( + { + [[from foo import bar]], + [[# comment]], + [[from foo import (]], + [[ baz,]], + [[)]], + [[# comment]], + [[from foo import qux]], + }, + Helper:call({ + [[from foo import bar]], + [[# comment]], + [[from foo import baz]], + [[# comment]], + [[from foo import qux]], + }, {3, 1}) + ) + end) +end) From afc7604d6f1ee2c64222c6ff68070ebfca748b80 Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Tue, 28 Feb 2023 10:02:21 -0600 Subject: [PATCH 2/5] [python] refactor to include import --- README.md | 1 + lua/ts-node-action/filetypes/python.lua | 64 ++- .../filetypes/python/cycle_import.lua | 382 ++++++++++++++++++ .../filetypes/python/cycle_import_from.lua | 136 ------- lua/ts-node-action/init.lua | 2 + .../python/import_from_statement_spec.lua | 134 +++++- .../python/import_statement_spec.lua | 171 ++++++++ 7 files changed, 694 insertions(+), 196 deletions(-) create mode 100644 lua/ts-node-action/filetypes/python/cycle_import.lua delete mode 100644 lua/ts-node-action/filetypes/python/cycle_import_from.lua create mode 100644 spec/filetypes/python/import_statement_spec.lua diff --git a/README.md b/README.md index 3be5900..063ad07 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,7 @@ Builtin actions are all higher-order functions so they can easily have options o | if block/postfix | | ✅ | | | | | | | | | `toggle_hash_style()` | | ✅ | | | | | | | | | `conceal_string()` | | | ✅ | | | | | | ✅ | +| `cycle_import()` | | | | | ✅ | | | | | ## Testing To run the test suite, clone the repo and run `./run_spec`. It should pull all dependencies into `spec/support/` on diff --git a/lua/ts-node-action/filetypes/python.lua b/lua/ts-node-action/filetypes/python.lua index e5d42d0..df55b2d 100644 --- a/lua/ts-node-action/filetypes/python.lua +++ b/lua/ts-node-action/filetypes/python.lua @@ -1,6 +1,6 @@ -local helpers = require("ts-node-action.helpers") -local actions = require("ts-node-action.actions") -local import_from = require("ts-node-action.filetypes.python.cycle_import_from") +local helpers = require("ts-node-action.helpers") +local actions = require("ts-node-action.actions") +local cycle_import = require("ts-node-action.filetypes.python.cycle_import") -- Special cases: -- Because "is" and "not" are valid by themselves, they are seen as separate @@ -308,7 +308,7 @@ local function destructure_conditional_expression(node) end --- @param stmt table ---- @return string, table, TSNode +--- @return string, table --- @return nil local function expand_cond_expr(stmt, padding_override) local parent = stmt.node:parent() @@ -385,7 +385,7 @@ end --- @param stmt table { node, condition, consequence, alternative, comments } --- @param padding_override table ---- @return string, table, TSNode +--- @return string, table --- @return nil local function inline_if(stmt, padding_override) @@ -428,7 +428,7 @@ end --- @param stmt table { node, condition, consequence, alternative, comments } --- @param padding_override table ---- @return string, table, TSNode +--- @return string, table --- @return nil local function inline_ifelse(stmt, padding_override) @@ -463,12 +463,12 @@ local function inline_ifelse(stmt, padding_override) end --- @param padding_override table ---- @return function +--- @return table local function inline_if_statement(padding_override) - padding_override = padding_override or padding + padding_override = vim.tbl_deep_extend( + 'force', padding, padding_override or {} + ) - --- @param if_statement TSNode - --- @return string, table, TSNode local function action(if_statement) local stmt = destructure_if_statement(if_statement) -- we can't inline multiple statements within a block @@ -499,12 +499,12 @@ local function inline_if_statement(padding_override) end --- @param padding_override table ---- @return function +--- @return table local function expand_conditional_expression(padding_override) - padding_override = padding_override or padding + padding_override = vim.tbl_deep_extend( + 'force', padding, padding_override or {} + ) - --- @param conditional_expression TSNode - --- @return string, table, TSNode local function action(conditional_expression) local stmt = destructure_conditional_expression(conditional_expression) if #stmt.comments > 0 then @@ -516,31 +516,16 @@ local function expand_conditional_expression(padding_override) return { action, name = "Expand Conditional" } end -local import_formats = { "inline", "block", "single", "block" } - -local function cycle_import_from_statement(user_import_formats) - user_import_formats = user_import_formats or import_formats - - local function action(import_from_statement) - - local stmt = import_from.destructure(import_from_statement) - if #stmt.comments > 0 then - return - end - - local new_format - for i, f in ipairs(user_import_formats) do - if f == stmt.format then - new_format = user_import_formats[i + 1] or user_import_formats[1] - break - end - end - - return import_from.cycle(stmt, new_format) - end +-- see python/cycle_import.lua for more config options +local cycle_import_from_config = { + ---@type string[] list of formats to cycle through; uses the provided order + formats = { "single", "inline", "expand" }, +} - return { action, name = "Cycle Import" } -end +local cycle_import_config = { + ---@type string[] list of formats to cycle through; uses the provided order + formats = { "single", "inline" }, +} return { ["dictionary"] = actions.toggle_multiline(padding), @@ -559,5 +544,6 @@ return { ["integer"] = actions.toggle_int_readability(), ["conditional_expression"] = { expand_conditional_expression(padding), }, ["if_statement"] = { inline_if_statement(padding), }, - ["import_from_statement"] = { cycle_import_from_statement(import_formats), }, + ["import_from_statement"] = { cycle_import(cycle_import_from_config), }, + ["import_statement"] = { cycle_import(cycle_import_config), }, } diff --git a/lua/ts-node-action/filetypes/python/cycle_import.lua b/lua/ts-node-action/filetypes/python/cycle_import.lua new file mode 100644 index 0000000..4cfd467 --- /dev/null +++ b/lua/ts-node-action/filetypes/python/cycle_import.lua @@ -0,0 +1,382 @@ +local helpers = require("ts-node-action.helpers") + +local ERROR_NS = "TS:NodeAction:Python:CycleImport - " + +---@param tables table[] +---@param key string +---@return table +local function collect_values_for_key(tables, key) + local values = {} + for _, tbl in ipairs(tables) do + for _, value in ipairs(tbl[key]) do + table.insert(values, value) + end + end + return values +end + +----@param import_from_statement TSNode +----@return table +local function destructure_import_from_statement(import_from_statement) + local module = import_from_statement:named_child(0) + local names = {} + local comments = {} + + local first_sibling_row + local sibling = module:next_named_sibling() + while sibling do + + if sibling:type() == "comment" then + table.insert(comments, sibling) + else + if not first_sibling_row then + first_sibling_row = sibling:start() + end + table.insert(names, helpers.node_text(sibling)) + end + + sibling = sibling:next_named_sibling() + end + + local format = "single" + if #names > 1 then + format = first_sibling_row == module:start() and "inline" or "expand" + elseif helpers.node_is_multiline(import_from_statement) then + format = "expand" + end + + return { + type = "import_from_statement", + node = import_from_statement, + modules = { helpers.node_text(module) }, + names = names, + comments = comments, + format = format + } +end + +---@param import_statement TSNode +---@return table +local function destructure_import_statement(import_statement) + local modules = {} + local comments = {} + + for child in import_statement:iter_children() do + if child:named() then + if child:type() == "comment" then + table.insert(comments, child) + else + table.insert(modules, helpers.node_text(child)) + end + end + end + + return { + type = "import_statement", + node = import_statement, + modules = modules, + names = modules, + comments = comments, + format = #modules > 1 and "inline" or "single", + } +end + +---@param node TSNode +---@param func fun(node: TSNode): boolean +---@return TSNode[] in reverse order +local function get_prev_siblings_while(node, func) + local nodes = {} + local prev_sibling = node:prev_named_sibling() + while prev_sibling and func(prev_sibling) do + table.insert(nodes, prev_sibling) + prev_sibling = prev_sibling:prev_named_sibling() + end + return nodes +end + +---@param node TSNode +---@param func fun(node: TSNode): boolean +---@return TSNode[] +local function get_next_siblings_while(node, func) + local nodes = {} + local next_sibling = node:next_named_sibling() + while next_sibling and func(next_sibling) do + table.insert(nodes, next_sibling) + next_sibling = next_sibling:next_named_sibling() + end + return nodes +end + +-- Collect qualifying siblings adjacent to origin_stmt and destructure them. +---@param origin_stmt table +---@param prev_siblings TSNode[] in reverse order +---@param next_siblings TSNode[] +---@param destructure fun(node: TSNode): table +---@param of_any_format boolean +---@return table[] +local function assemble_sibling_stmts( + origin_stmt, prev_siblings, next_siblings, destructure, of_any_format) + local stmts = {} + + for _, node in ipairs(prev_siblings) do + local sibling_stmt = destructure(node) + if of_any_format or sibling_stmt.format == origin_stmt.format then + table.insert(stmts, 1, sibling_stmt) + else + break + end + end + + table.insert(stmts, origin_stmt) + + for _, node in ipairs(next_siblings) do + local sibling_stmt = destructure(node) + if of_any_format or sibling_stmt.format == origin_stmt.format then + table.insert(stmts, sibling_stmt) + else + break + end + end + + return stmts +end + +local cycler_types = { + import_statement = { + allowed_formats = { "single", "inline", }, + destructure = destructure_import_statement, + make_sibling_validator = function(origin_stmt) + return function(sibling) + return sibling:type() == origin_stmt.type + end + end, + cycle = { + single = function(stmts, names, indent, config) + local replacement = {} + for i, name in ipairs(names) do + table.insert(replacement, (i ~= 1 and indent or "") .. "import " .. name) + end + return replacement + end, + inline = function(stmts, names, indent, config) + local replacement = {} + local prepend = "import " + local line = indent .. prepend .. table.concat(names, ", ") + + if #line > config.line_length then + line = indent .. prepend + for _, name in ipairs(names) do + if #line + #name >= config.line_length then + table.insert(replacement, line:sub(1, -3)) + line = indent .. prepend .. name .. ", " + else + line = line .. name .. ", " + end + end + line = line:sub(1, -3) + end + table.insert(replacement, line) + + return replacement + end, + }, + }, + import_from_statement = { + allowed_formats = { "single", "inline", "expand", }, + destructure = destructure_import_from_statement, + make_sibling_validator = function(origin_stmt) + local module = origin_stmt.modules[1] + return function(sibling) + return sibling:type() == origin_stmt.type and + helpers.node_text(sibling:named_child(0)) == module + end + end, + cycle = { + single = function(stmts, names, indent, config) + local replacement = {} + for i, name in ipairs(names) do + table.insert( + replacement, + (i == 1 and "" or indent) .. + "from " .. stmts[1].modules[1] .. " import " .. name .. "" + ) + end + return replacement + end, + inline = function(stmts, names, indent, config) + local replacement = {} + local prepend = "from " .. stmts[1].modules[1] .. " import " + local line = indent .. prepend .. table.concat(names, ", ") + local line_length = config.line_length + local use_parens = config.inline_use_parens + local eol_length = use_parens and 1 or 2 + + if #line > line_length then + line = indent .. (use_parens and prepend .. "(" or prepend) + + for _, name in ipairs(names) do + if #line + #name + eol_length > line_length then + line = use_parens and line:sub(1, -2) or line .. "\\" + table.insert(replacement, line) + line = indent .. " " .. name .. ", " + else + line = line .. name .. ", " + end + end + + line = line:sub(1, -3) .. (use_parens and ")" or "") + end + table.insert(replacement, line) + return replacement + end, + expand = function(stmts, names, indent, config) + local replacement = {} + local use_parens = config.expand_use_parens + local first_eol = use_parens and "(" or "\\" + local body_eol = use_parens and "" or " \\" + table.insert( + replacement, + "from " .. stmts[1].modules[1] .. " import " .. first_eol + ) + for i, name in ipairs(names) do + local line + if i == #names then + line = indent .. " " .. name .. (use_parens and "," or "") + else + line = indent .. " " .. name .. "," .. body_eol + end + table.insert(replacement, line) + end + if use_parens then + table.insert(replacement, indent .. ")") + end + return replacement + end, + } + }, +} + +-- Create a fake node to represent the replacement target. This is necessary +-- when the replacement spans multiple nodes without a suitable parent to serve +-- as a the target (eg, a top-level node's parent is the root). +-- +-- Should be indistiguishable from a TSNode, other than type(target) == "table", +-- but range() is only what's necessary by init.lua:replace_node(). +-- +---@param first_node TSNode +---@param last_node TSNode +---@return table +local function make_target_node(first_node, last_node) + -- TSNode's are userdata, which can't be cloned/altered, so this proxy's calls + -- to it and overrides the position methods. + local target = {} + for k, _ in pairs(getmetatable(first_node)) do + target[k] = function(_, ...) + return first_node[k](first_node, ...) + end + end + local start_pos = { first_node:start() } + local end_pos = { last_node:end_() } + function target:start() return unpack(start_pos) end + function target:end_() return unpack(end_pos) end + function target:range() + return start_pos[1], start_pos[2], end_pos[1], end_pos[2] + end + + return target +end + +---@param formats table +---@param format string +---@return string|nil +local function find_next_format(formats, format) + for i, f in ipairs(formats) do + if f == format then + return formats[i + 1] or formats[1] + end + end +end + +---@param node TSNode +---@param config table +---@return table|nil, table|nil +local function cycle(node, config) + + local cycler = cycler_types[node:type()] + + local stmt = cycler.destructure(node) + if #stmt.comments > 0 then + return + end + + local format = find_next_format(config.formats, stmt.format) + if not format then + return + end + + if not vim.tbl_contains(cycler.allowed_formats, format) then + print(ERROR_NS .. "Format '" .. format .. "' not supported") + return + end + + local is_valid_sibling = cycler.make_sibling_validator(stmt) + local stmts = assemble_sibling_stmts( + stmt, + get_prev_siblings_while(stmt.node, is_valid_sibling), + get_next_siblings_while(stmt.node, is_valid_sibling), + cycler.destructure, + config.siblings_of_any_format + ) + local names = collect_values_for_key(stmts, "names") + + local start = {node:start()} + local indent = string.rep(" ", start[2]) + + local replacement = cycler.cycle[format](stmts, names, indent, config) + + return replacement, { + target = make_target_node( + stmts[1].node, + stmts[#stmts].node + ), + cursor = {row = 0, col = 0}, + format = true, + } +end + +local default_config = { + ---@type table[] formats to cycle through, in the order provided + formats = {}, + ---@type number maximum line length for inline imports + line_length = 80, + ---@type boolean include siblings when format differs + siblings_of_any_format = true, + ---@type boolean use parens for inline imports (otherwise use \) + inline_use_parens = true, + ---@type boolean use parens for expanded imports (otherwise use \) + expand_use_parens = true, +} + +---@param config table +---@return table|nil +return function(config) + config = vim.tbl_deep_extend('force', default_config, config or {}) + + vim.validate{ + formats={ config.formats, "table" }, + line_length={ config.line_length, "number" }, + siblings_of_any_format={ config.siblings_of_any_format, "boolean" }, + inline_use_parens={ config.inline_use_parens, "boolean" }, + expand_use_parens={ config.expand_use_parens, "boolean" }, + } + + if #config.formats == 0 then + print(ERROR_NS .. "Empty config.formats, no formats to cycle") + end + + local function action(node) + return cycle(node, config) + end + + return { action, name = "Cycle Import" } +end diff --git a/lua/ts-node-action/filetypes/python/cycle_import_from.lua b/lua/ts-node-action/filetypes/python/cycle_import_from.lua deleted file mode 100644 index 9688d8a..0000000 --- a/lua/ts-node-action/filetypes/python/cycle_import_from.lua +++ /dev/null @@ -1,136 +0,0 @@ -local helpers = require("ts-node-action.helpers") - ---- @param import_from_statement TSNode ---- @return table -local function destructure_import_from_statement(import_from_statement) - local module = import_from_statement:named_child(0) - local names = {} - local sibling = module:next_named_sibling() - local comments = {} - - while sibling do - if sibling:type() == "comment" then - table.insert(comments, sibling) - else - table.insert(names, helpers.node_text(sibling)) - end - sibling = sibling:next_named_sibling() - end - - local format - if helpers.node_is_multiline(import_from_statement) then - format = "block" - elseif #names > 1 then - format = "inline" - else - format = "single" - end - - return { - node = import_from_statement, - module = helpers.node_text(module), - names = names, - comments = comments, - format = format, - } -end - --- Find direct sibling import_from_statements that import from the same module --- as the origin statement. --- @param origin_stmt table -local function find_sibling_imports(origin_stmt) - local stmts = {} - local names = {} - - local prev_sibling = origin_stmt.node:prev_named_sibling() - while prev_sibling do - if (prev_sibling:type() == "import_from_statement" and - helpers.node_text(prev_sibling:named_child(0)) == origin_stmt.module) - then - local stmt = destructure_import_from_statement(prev_sibling) - table.insert(stmts, 1, stmt) - prev_sibling = prev_sibling:prev_named_sibling() - else - prev_sibling = nil - end - end - - for _, stmt in ipairs(stmts) do - for _, name in ipairs(stmt.names) do table.insert(names, name) end - end - - table.insert(stmts, origin_stmt) - for _, name in ipairs(origin_stmt.names) do table.insert(names, name) end - - local next_sibling = origin_stmt.node:next_named_sibling() - while next_sibling do - if (next_sibling:type() == "import_from_statement" and - helpers.node_text(next_sibling:named_child(0)) == origin_stmt.module) - then - local stmt = destructure_import_from_statement(next_sibling) - table.insert(stmts, stmt) - for _, name in ipairs(stmt.names) do table.insert(names, name) end - next_sibling = next_sibling:next_named_sibling() - else - next_sibling = nil - end - end - - return stmts, names -end - --- Create a stub node to represent the replacement target. --- This is necessary because the replacement spans multiple --- top-level nodes and so we can't target the first and last --- nodes directly. -local function create_stub_target_node(stmts) - local start_row, start_col = stmts[1].node:start() - local _, _, last_row, last_col = stmts[#stmts].node:range() - local node = {} - function node.range(self) - return start_row, start_col, last_row, last_col - end - return node -end - --- @param origin_stmt table --- @param new_format string --- @return table, table -local function cycle(origin_stmt, new_format) - local replacement = {} - local stmts, names = find_sibling_imports(origin_stmt) - - if new_format == "single" then - for _, name in ipairs(names) do - table.insert( - replacement, - "from " .. origin_stmt.module .. " import " .. name .. "" - ) - end - elseif new_format == "inline" then - table.insert( - replacement, - "from " .. origin_stmt.module .. - " import " .. table.concat(names, ", ") - ) - elseif new_format == "block" then - table.insert( - replacement, - "from " .. origin_stmt.module .. " import (" - ) - for _, name in ipairs(names) do - table.insert(replacement, " " .. name .. ",") - end - table.insert(replacement, ")") - end - - return replacement, { - target = create_stub_target_node(stmts), - cursor = {row = 0, col = 0}, - } -end - -return { - destructure = destructure_import_from_statement, - cycle = cycle, -} diff --git a/lua/ts-node-action/init.lua b/lua/ts-node-action/init.lua index b769f0a..7d1aeb8 100644 --- a/lua/ts-node-action/init.lua +++ b/lua/ts-node-action/init.lua @@ -1,3 +1,5 @@ +---@alias TSNode userdata + local M = {} --- @private diff --git a/spec/filetypes/python/import_from_statement_spec.lua b/spec/filetypes/python/import_from_statement_spec.lua index 5190b17..27a4902 100644 --- a/spec/filetypes/python/import_from_statement_spec.lua +++ b/spec/filetypes/python/import_from_statement_spec.lua @@ -5,14 +5,10 @@ local Helper = SpecHelper.new("python", { shiftwidth = 4 }) describe("import_from_statement", function() - it("cycles from single to block", function() + it("cycles from single to inline", function() assert.are.same( { - [[from foo import (]], - [[ bar,]], - [[ baz,]], - [[ qux,]], - [[)]], + [[from foo import bar, baz, qux]], }, Helper:call({ [[from foo import bar]], @@ -22,7 +18,7 @@ describe("import_from_statement", function() ) end) - it("cycles from inline to block", function() + it("cycles from inline to expand", function() assert.are.same( { [[from foo import (]], @@ -37,7 +33,7 @@ describe("import_from_statement", function() ) end) - it("cycles from block to single", function() + it("cycles from expand to single", function() assert.are.same( { [[from foo import bar]], @@ -54,21 +50,38 @@ describe("import_from_statement", function() ) end) - it("cycles from block to single (continuation makes it a block)", function() + it("cycles from inline to expand (inline detect w continuation)", function() assert.are.same( { - [[from foo import bar]], - [[from foo import baz]], - [[from foo import qux]], + [[from foo import (]], + [[ bar,]], + [[ baz,]], + [[)]], }, Helper:call({ [[from foo import bar, \]], - [[ baz, qux]], + [[ baz]], + }) + ) + end) + + it("cycles from inline to expand (inline detect w parens)", function() + assert.are.same( + { + [[from foo import (]], + [[ bar,]], + [[ baz,]], + [[)]], + }, + Helper:call({ + [[from foo import (bar,]], + [[ baz)]], }) ) end) - it("cycles from inline to block with mixed siblings", function() + + it("cycles from inline to expand with mixed siblings", function() assert.are.same( { [[from foo import (]], @@ -84,7 +97,7 @@ describe("import_from_statement", function() ) end) - it("cycles from block to single with mixed siblings", function() + it("cycles from expand to single with mixed siblings", function() assert.are.same( { [[from foo import a]], @@ -109,7 +122,7 @@ describe("import_from_statement", function() ) end) - it("cycles from inline to block only close siblings", function() + it("cycles from inline to expand only close siblings", function() assert.are.same( { [[from abc import a, b, c]], @@ -224,21 +237,100 @@ describe("import_from_statement", function() it("cycles with sibling comments", function() assert.are.same( { - [[from foo import bar]], + [[from abc import abc]], [[# comment]], [[from foo import (]], + [[ bar,]], [[ baz,]], + [[ qux,]], [[)]], [[# comment]], - [[from foo import qux]], + [[from xyz import x, y, z]], }, Helper:call({ - [[from foo import bar]], + [[from abc import abc]], [[# comment]], - [[from foo import baz]], + [[from foo import bar, baz, qux]], [[# comment]], - [[from foo import qux]], + [[from xyz import x, y, z]], }, {3, 1}) ) end) + + it("cycles to inline (multiline due to config.line_length = 80)", function() + assert.are.same( + { + [[from json import (loads, dumps, JSONDecodeError as foo, detect_encoding,]], + [[ loads as decode, dumps as encode)]], + }, + Helper:call({ + [[from json import loads]], + [[from json import dumps]], + [[from json import JSONDecodeError as foo]], + [[from json import detect_encoding]], + [[from json import loads as decode]], + [[from json import dumps as encode]], + }) + ) + end) + + it("cycles to expand from multiline inline", function() + assert.are.same( + { + [[from json import (]], + [[ loads,]], + [[ dumps,]], + [[ JSONDecodeError as foo,]], + [[ detect_encoding,]], + [[ loads as decode,]], + [[ dumps as encode,]], + [[)]], + }, + Helper:call({ + [[from json import (loads, dumps, JSONDecodeError as foo, detect_encoding,]], + [[ loads as decode, dumps as encode)]], + }) + ) + end) + + it("cycles to inline (multiline) from indented expand", function() + assert.are.same( + { + [[def foo():]], + [[ from json import (loads, dumps, JSONDecodeError as foo, detect_encoding,]], + [[ loads as decode, dumps as encode)]], + }, + Helper:call({ + [[def foo():]], + [[ from json import loads]], + [[ from json import dumps]], + [[ from json import JSONDecodeError as foo]], + [[ from json import detect_encoding]], + [[ from json import loads as decode]], + [[ from json import dumps as encode]], + }, {2, 5}) + ) + end) + + it("cycles to expand from multiline inline while indented", function() + assert.are.same( + { + [[def foo():]], + [[ from json import (]], + [[ loads,]], + [[ dumps,]], + [[ JSONDecodeError as foo,]], + [[ detect_encoding,]], + [[ loads as decode,]], + [[ dumps as encode,]], + [[ )]], + }, + Helper:call({ + [[def foo():]], + [[ from json import (loads, dumps, JSONDecodeError as foo, detect_encoding,]], + [[ loads as decode, dumps as encode)]], + }, {2, 5}) + ) + end) + end) diff --git a/spec/filetypes/python/import_statement_spec.lua b/spec/filetypes/python/import_statement_spec.lua new file mode 100644 index 0000000..a2f7c62 --- /dev/null +++ b/spec/filetypes/python/import_statement_spec.lua @@ -0,0 +1,171 @@ +dofile("./spec/spec_helper.lua") + +local Helper = SpecHelper.new("python", { shiftwidth = 4 }) + +describe("import_statement", function() + + + it("cycles from single to inline", function() + assert.are.same( + { + [[import bar, baz, qux]], + }, + Helper:call({ + [[import bar]], + [[import baz]], + [[import qux]], + }) + ) + end) + + it("cycles from inline to single", function() + assert.are.same( + { + [[import bar]], + [[import baz]], + [[import qux]], + }, + Helper:call({ + [[import bar, baz, qux]], + }) + ) + end) + + it("cycles from inline to single (inline detected w continuation)", function() + assert.are.same( + { + [[import bar]], + [[import baz]], + }, + Helper:call({ + [[import bar, \]], + [[ baz]], + }) + ) + end) + + it("cycles from inline to single with mixed siblings", function() + assert.are.same( + { + [[import qux]], + [[import bar]], + [[import baz]], + }, + Helper:call({ + [[import qux]], + [[import bar, baz]], + }, {2, 1}) + ) + end) + + it("cycles from single to inline only close siblings", function() + assert.are.same( + { + [[from abc import a, b, c]], + [[import bar, bee, hah]], + [[from xyz import x, y, z]], + }, + Helper:call({ + [[from abc import a, b, c]], + [[import bar]], + [[import bee]], + [[import hah]], + [[from xyz import x, y, z]], + }, {3, 1}) + ) + end) + + it("cycles with deep relative imports", function() + assert.are.same( + { + [[import foo.bar.baz.qux]], + [[import fish.sandwich]], + [[import boo.ghosts]], + }, + Helper:call({ + [[import foo.bar.baz.qux, fish.sandwich, boo.ghosts]], + }) + ) + end) + + it("cycles with import aliases", function() + assert.are.same( + { + [[import foo.bar as b]], + [[import baz as z]], + [[import qux as q]], + }, + Helper:call({ + [[import foo.bar as b, baz as z, qux as q]], + }) + ) + end) + + it("doesn't cycle with comments", function() + local text = { + [[import bar # comment]], + [[import baz # comment]], + [[import qux # comment]], + } + assert.are.same(text, Helper:call(text)) + end) + + it("cycles with sibling comments", function() + assert.are.same( + { + [[from abc import abc]], + [[# comment]], + [[import bar]], + [[import baz]], + [[import qux]], + [[# comment]], + [[from xyz import x, y, z]], + }, + Helper:call({ + [[from abc import abc]], + [[# comment]], + [[import bar, baz, qux]], + [[# comment]], + [[from xyz import x, y, z]], + }, {3, 1}) + ) + end) + + it("cycles to multiline inline (it exceeded config.line_length)", function() + assert.are.same( + { + [[import this.will.be.long, once.its.inlined, it.will.be.too.long, bar, baz, qux]], + [[import abc, xyz, to.fit.on.one.line]], + }, + Helper:call({ + [[import this.will.be.long]], + [[import once.its.inlined]], + [[import it.will.be.too.long]], + [[import bar, baz, qux, abc, xyz]], + [[import to.fit.on.one.line]], + }) + ) + end) + + it("cycles to multiline inline while indented", function() + assert.are.same( + { + [[def foo():]], + [[ import this.will.be.long, once.its.inlined, it.will.be.too.long, bar, baz]], + [[ import qux, abc, xyz, to.fit.on.one.line]], + }, + Helper:call({ + [[def foo():]], + [[ import this.will.be.long]], + [[ import once.its.inlined]], + [[ import it.will.be.too.long]], + [[ import bar, baz, qux, abc, xyz]], + [[ import to.fit.on.one.line]], + }, {2, 5}) + ) + end) + + + + +end) From f85ef32471c07699f4f2d209be6f9e61bd2894ca Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Mon, 6 Mar 2023 00:09:43 -0600 Subject: [PATCH 3/5] [python] fully spec per-type config, so it's clear which options apply --- lua/ts-node-action/filetypes/python.lua | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lua/ts-node-action/filetypes/python.lua b/lua/ts-node-action/filetypes/python.lua index df55b2d..37eed85 100644 --- a/lua/ts-node-action/filetypes/python.lua +++ b/lua/ts-node-action/filetypes/python.lua @@ -520,11 +520,23 @@ end local cycle_import_from_config = { ---@type string[] list of formats to cycle through; uses the provided order formats = { "single", "inline", "expand" }, + ---@type number maximum line length for inline imports + line_length = 80, + ---@type boolean include siblings when format differs + siblings_of_any_format = true, + ---@type boolean use parens for inline imports (otherwise use \) + inline_use_parens = true, + ---@type boolean use parens for expanded imports (otherwise use \) + expand_use_parens = true, } local cycle_import_config = { ---@type string[] list of formats to cycle through; uses the provided order formats = { "single", "inline" }, + ---@type number maximum line length for inline imports + line_length = 80, + ---@type boolean include siblings when format differs + siblings_of_any_format = true, } return { From 5bd64a7df75028e7b2d909e7536f6a270c851034 Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Mon, 6 Mar 2023 13:32:46 -0600 Subject: [PATCH 4/5] [python] bugfix: more robust format detection for import_from --- .../filetypes/python/cycle_import.lua | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/lua/ts-node-action/filetypes/python/cycle_import.lua b/lua/ts-node-action/filetypes/python/cycle_import.lua index 4cfd467..320c963 100644 --- a/lua/ts-node-action/filetypes/python/cycle_import.lua +++ b/lua/ts-node-action/filetypes/python/cycle_import.lua @@ -22,16 +22,20 @@ local function destructure_import_from_statement(import_from_statement) local names = {} local comments = {} - local first_sibling_row + local prev_sibling_row = module:start() + local siblings_share_line = false local sibling = module:next_named_sibling() while sibling do if sibling:type() == "comment" then table.insert(comments, sibling) else - if not first_sibling_row then - first_sibling_row = sibling:start() + local sibling_start_row = sibling:start() + if sibling_start_row == prev_sibling_row then + siblings_share_line = true end + prev_sibling_row = sibling_start_row + table.insert(names, helpers.node_text(sibling)) end @@ -39,10 +43,22 @@ local function destructure_import_from_statement(import_from_statement) end local format = "single" - if #names > 1 then - format = first_sibling_row == module:start() and "inline" or "expand" - elseif helpers.node_is_multiline(import_from_statement) then - format = "expand" + if #names > 1 or helpers.node_is_multiline(import_from_statement) then + if siblings_share_line then + format = "inline" + else + format = "expand" + local num_children = import_from_statement:child_count() + local last_child = import_from_statement:child(num_children - 1) + -- multiline inline ends with a paren sharing the same line + -- but if not using parens, then it's ambiguous + if not last_child:named() and helpers.node_text(last_child) == ")" then + local last_named_child = last_child:prev_named_sibling() + if last_child:start() == last_named_child:start() then + format = "inline" + end + end + end end return { @@ -212,7 +228,12 @@ local cycler_types = { local eol_length = use_parens and 1 or 2 if #line > line_length then + line = indent .. (use_parens and prepend .. "(" or prepend) + if #line + #names[1] > line_length then + table.insert(replacement, use_parens and line or (line .. "\\")) + line = indent .. " " + end for _, name in ipairs(names) do if #line + #name + eol_length > line_length then From e96d5c9b867831e63c63a18d8f35443bd88a4c39 Mon Sep 17 00:00:00 2001 From: Alex Walker Date: Sat, 11 Mar 2023 17:51:40 -0600 Subject: [PATCH 5/5] [python] use iters from new node_utils; add more specs; style tweaks --- lua/ts-node-action/filetypes/python.lua | 26 ++-- .../filetypes/python/conditional.lua | 75 +++++---- .../filetypes/python/cycle_import.lua | 134 +++++----------- .../filetypes/python/helpers.lua | 36 ----- .../filetypes/python/node_utils.lua | 144 ++++++++++++++++++ .../python/import_from_statement_spec.lua | 13 ++ .../python/import_statement_spec.lua | 14 +- 7 files changed, 252 insertions(+), 190 deletions(-) delete mode 100644 lua/ts-node-action/filetypes/python/helpers.lua create mode 100644 lua/ts-node-action/filetypes/python/node_utils.lua diff --git a/lua/ts-node-action/filetypes/python.lua b/lua/ts-node-action/filetypes/python.lua index b06f733..ed4329b 100644 --- a/lua/ts-node-action/filetypes/python.lua +++ b/lua/ts-node-action/filetypes/python.lua @@ -1,6 +1,6 @@ local actions = require("ts-node-action.actions") local helpers = require("ts-node-action.helpers") -local pyhelpers = require("ts-node-action.filetypes.python.helpers") +local nu = require("ts-node-action.filetypes.python.node_utils") local conditional = require("ts-node-action.filetypes.python.conditional") local cycle_import = require("ts-node-action.filetypes.python.cycle_import") @@ -60,9 +60,9 @@ local boolean_override = { ["False"] = "True", } ---- @param padding_override table ---- @param uncollapsible_override table ---- @return table +---@param padding_override table +---@param uncollapsible_override table +---@return table local function inline_if_statement(padding_override, uncollapsible_override) padding_override = vim.tbl_deep_extend( 'force', padding, padding_override or {} @@ -70,7 +70,7 @@ local function inline_if_statement(padding_override, uncollapsible_override) uncollapsible_override = vim.tbl_deep_extend( 'force', uncollapsible, uncollapsible_override or {} ) - local collapse = pyhelpers.collapse_child_nodes( + local collapse = nu.collapse_func( padding_override, uncollapsible_override ) @@ -105,23 +105,27 @@ local function inline_if_statement(padding_override, uncollapsible_override) return { action, name = "Inline Conditional" } end ---- @param padding_override table ---- @param uncollapsible_override table ---- @return table -local function expand_conditional_expression(padding_override, uncollapsible_override) +---@param padding_override table +---@param uncollapsible_override table +---@return table +local function expand_conditional_expression( + padding_override, uncollapsible_override +) padding_override = vim.tbl_deep_extend( 'force', padding, padding_override or {} ) uncollapsible_override = vim.tbl_deep_extend( 'force', uncollapsible, uncollapsible_override or {} ) - local collapse = pyhelpers.collapse_child_nodes( + local collapse = nu.collapse_func( padding_override, uncollapsible_override ) local function action(conditional_expression) - local stmt = conditional.destructure_conditional_expression(conditional_expression) + local stmt = conditional.destructure_conditional_expression( + conditional_expression + ) if #stmt.comments > 0 then return end return conditional.expand_cond_expr(stmt, collapse) diff --git a/lua/ts-node-action/filetypes/python/conditional.lua b/lua/ts-node-action/filetypes/python/conditional.lua index 2c65f04..92bd789 100644 --- a/lua/ts-node-action/filetypes/python/conditional.lua +++ b/lua/ts-node-action/filetypes/python/conditional.lua @@ -1,5 +1,5 @@ local helpers = require("ts-node-action.helpers") -local pyhelpers = require("ts-node-action.filetypes.python.helpers") +local nu = require("ts-node-action.filetypes.python.node_utils") local M = {} @@ -137,13 +137,11 @@ end ---@param comments table ---@return nil (mutates comments) local function deep_collect_comments(node, comments) - for child in node:iter_children() do - if child:named() then - if child:type() == "comment" then - table.insert(comments, child) - else - deep_collect_comments(child, comments) - end + for child in nu.iter_named_children(node) do + if child:type() == "comment" then + table.insert(comments, child) + else + deep_collect_comments(child, comments) end end end @@ -153,14 +151,12 @@ end ---@param comments table ---@return nil (mutates children and comments) local function collect_named_children(parent, children, comments) - for child in parent:iter_children() do - if child:named() then - if child:type() == "comment" then - table.insert(comments, child) - else - table.insert(children, child) - deep_collect_comments(child, comments) - end + for child in nu.iter_named_children(parent) do + if child:type() == "comment" then + table.insert(comments, child) + else + table.insert(children, child) + deep_collect_comments(child, comments) end end end @@ -173,22 +169,19 @@ M.destructure_if_statement = function(if_statement) local alternative = {} local comments = {} - for child in if_statement:iter_children() do - if child:named() then - local child_type = child:type() - - if child_type == "comment" then - table.insert(comments, child) - elseif child_type == "block" then - collect_named_children(child, consequence, comments) - elseif child_type == "else_clause" then - local block = {} - collect_named_children(child, block, comments) - collect_named_children(block[1], alternative, comments) - else - condition = child - end - + for child in nu.iter_named_children(if_statement) do + local child_type = child:type() + + if child_type == "comment" then + table.insert(comments, child) + elseif child_type == "block" then + collect_named_children(child, consequence, comments) + elseif child_type == "else_clause" then + local block = {} + collect_named_children(child, block, comments) + collect_named_children(block[1], alternative, comments) + else + condition = child end end @@ -282,7 +275,7 @@ M.expand_cond_expr = function(stmt, collapse) local callback = nil if row_parent then table.insert(replacement, 1, "") - callback = function() pyhelpers.node_trim_whitespace(parent) end + callback = function() nu.trim_whitespace(parent) end end return replacement, { @@ -313,11 +306,11 @@ M.inline_if = function(stmt, collapse) return replacement, { cursor = {} } end ---- @param cons_type string ---- @param alt_type string ---- @param cons_lhs string ---- @param alt_lhs string ---- @return boolean +---@param cons_type string +---@param alt_type string +---@param cons_lhs string +---@param alt_lhs string +---@return boolean local function body_types_are_inlineable(cons_type, alt_type, cons_lhs, alt_lhs) -- strict match if cons_type == "assignment" or alt_type == "assignment" then @@ -335,9 +328,9 @@ local function body_types_are_inlineable(cons_type, alt_type, cons_lhs, alt_lhs) mixable_match_body_types[alt_type] end ---- @param stmt table { node, condition, consequence, alternative, comments } ---- @param collapse function ---- @return string|nil, table|nil +---@param stmt table { node, condition, consequence, alternative, comments } +---@param collapse function +---@return string|nil, table|nil M.inline_ifelse = function(stmt, collapse) local cons_lhs, cons_rhs, cons_type, cons_child = node_text_lhs_rhs( diff --git a/lua/ts-node-action/filetypes/python/cycle_import.lua b/lua/ts-node-action/filetypes/python/cycle_import.lua index 320c963..f00c89b 100644 --- a/lua/ts-node-action/filetypes/python/cycle_import.lua +++ b/lua/ts-node-action/filetypes/python/cycle_import.lua @@ -1,32 +1,21 @@ local helpers = require("ts-node-action.helpers") +local nu = require("ts-node-action.filetypes.python.node_utils") -local ERROR_NS = "TS:NodeAction:Python:CycleImport - " - ----@param tables table[] ----@param key string ----@return table -local function collect_values_for_key(tables, key) - local values = {} - for _, tbl in ipairs(tables) do - for _, value in ipairs(tbl[key]) do - table.insert(values, value) - end - end - return values +local function print_error(...) + print("TS:NodeAction:Python:CycleImport - ", ...) end -----@param import_from_statement TSNode -----@return table +---@param import_from_statement TSNode +---@return table local function destructure_import_from_statement(import_from_statement) local module = import_from_statement:named_child(0) local names = {} local comments = {} - local prev_sibling_row = module:start() + local prev_sibling_row = module:start() local siblings_share_line = false - local sibling = module:next_named_sibling() - while sibling do + for sibling in nu.iter_next_named_sibling(module) do if sibling:type() == "comment" then table.insert(comments, sibling) else @@ -38,8 +27,6 @@ local function destructure_import_from_statement(import_from_statement) table.insert(names, helpers.node_text(sibling)) end - - sibling = sibling:next_named_sibling() end local format = "single" @@ -77,13 +64,11 @@ local function destructure_import_statement(import_statement) local modules = {} local comments = {} - for child in import_statement:iter_children() do - if child:named() then - if child:type() == "comment" then - table.insert(comments, child) - else - table.insert(modules, helpers.node_text(child)) - end + for node in nu.iter_named_children(import_statement) do + if node:type() == "comment" then + table.insert(comments, node) + else + table.insert(modules, helpers.node_text(node)) end end @@ -97,40 +82,15 @@ local function destructure_import_statement(import_statement) } end ----@param node TSNode ----@param func fun(node: TSNode): boolean ----@return TSNode[] in reverse order -local function get_prev_siblings_while(node, func) - local nodes = {} - local prev_sibling = node:prev_named_sibling() - while prev_sibling and func(prev_sibling) do - table.insert(nodes, prev_sibling) - prev_sibling = prev_sibling:prev_named_sibling() - end - return nodes -end - ----@param node TSNode ----@param func fun(node: TSNode): boolean ----@return TSNode[] -local function get_next_siblings_while(node, func) - local nodes = {} - local next_sibling = node:next_named_sibling() - while next_sibling and func(next_sibling) do - table.insert(nodes, next_sibling) - next_sibling = next_sibling:next_named_sibling() - end - return nodes -end - -- Collect qualifying siblings adjacent to origin_stmt and destructure them. +-- ---@param origin_stmt table ---@param prev_siblings TSNode[] in reverse order ---@param next_siblings TSNode[] ---@param destructure fun(node: TSNode): table ---@param of_any_format boolean ---@return table[] -local function assemble_sibling_stmts( +local function assemble_stmts( origin_stmt, prev_siblings, next_siblings, destructure, of_any_format) local stmts = {} @@ -167,14 +127,17 @@ local cycler_types = { end end, cycle = { - single = function(stmts, names, indent, config) + single = function(_, names, indent, _) local replacement = {} for i, name in ipairs(names) do - table.insert(replacement, (i ~= 1 and indent or "") .. "import " .. name) + table.insert( + replacement, + (i ~= 1 and indent or "") .. "import " .. name + ) end return replacement end, - inline = function(stmts, names, indent, config) + inline = function(_, names, indent, config) local replacement = {} local prepend = "import " local line = indent .. prepend .. table.concat(names, ", ") @@ -208,7 +171,7 @@ local cycler_types = { end end, cycle = { - single = function(stmts, names, indent, config) + single = function(stmts, names, indent, _) local replacement = {} for i, name in ipairs(names) do table.insert( @@ -277,36 +240,6 @@ local cycler_types = { }, } --- Create a fake node to represent the replacement target. This is necessary --- when the replacement spans multiple nodes without a suitable parent to serve --- as a the target (eg, a top-level node's parent is the root). --- --- Should be indistiguishable from a TSNode, other than type(target) == "table", --- but range() is only what's necessary by init.lua:replace_node(). --- ----@param first_node TSNode ----@param last_node TSNode ----@return table -local function make_target_node(first_node, last_node) - -- TSNode's are userdata, which can't be cloned/altered, so this proxy's calls - -- to it and overrides the position methods. - local target = {} - for k, _ in pairs(getmetatable(first_node)) do - target[k] = function(_, ...) - return first_node[k](first_node, ...) - end - end - local start_pos = { first_node:start() } - local end_pos = { last_node:end_() } - function target:start() return unpack(start_pos) end - function target:end_() return unpack(end_pos) end - function target:range() - return start_pos[1], start_pos[2], end_pos[1], end_pos[2] - end - - return target -end - ---@param formats table ---@param format string ---@return string|nil @@ -336,30 +269,33 @@ local function cycle(node, config) end if not vim.tbl_contains(cycler.allowed_formats, format) then - print(ERROR_NS .. "Format '" .. format .. "' not supported") + print_error("Format '" .. format .. "' not supported") return end local is_valid_sibling = cycler.make_sibling_validator(stmt) - local stmts = assemble_sibling_stmts( + local stmts = assemble_stmts( stmt, - get_prev_siblings_while(stmt.node, is_valid_sibling), - get_next_siblings_while(stmt.node, is_valid_sibling), + nu.takewhile(is_valid_sibling, nu.iter_prev_named_sibling(stmt.node)), + nu.takewhile(is_valid_sibling, nu.iter_next_named_sibling(stmt.node)), cycler.destructure, config.siblings_of_any_format ) - local names = collect_values_for_key(stmts, "names") + local names = vim.tbl_flatten( + vim.tbl_map(function(a_stmt) return a_stmt["names"] end, stmts) + ) local start = {node:start()} local indent = string.rep(" ", start[2]) local replacement = cycler.cycle[format](stmts, names, indent, config) - + local target = nu.make_target( + stmts[1].node, + { stmts[1].node:start() }, + { stmts[#stmts].node:end_() } + ) return replacement, { - target = make_target_node( - stmts[1].node, - stmts[#stmts].node - ), + target = target, cursor = {row = 0, col = 0}, format = true, } @@ -392,7 +328,7 @@ return function(config) } if #config.formats == 0 then - print(ERROR_NS .. "Empty config.formats, no formats to cycle") + print_error("Empty config.formats, no formats to cycle") end local function action(node) diff --git a/lua/ts-node-action/filetypes/python/helpers.lua b/lua/ts-node-action/filetypes/python/helpers.lua deleted file mode 100644 index fefeed2..0000000 --- a/lua/ts-node-action/filetypes/python/helpers.lua +++ /dev/null @@ -1,36 +0,0 @@ -local actions = require("ts-node-action.actions") -local helpers = require("ts-node-action.helpers") - -local M = {} - ----@param node TSNode -M.node_trim_whitespace = function(node) - local start_row, _, end_row, _ = node:range() - vim.cmd("silent! keeppatterns " .. (start_row + 1) .. "," .. (end_row + 1) .. "s/\\s\\+$//g") -end - --- Recreating actions.toggle_multiline.collapse_child_nodes() here because --- it is not exported. --- ----@param padding table ----@param uncollapsible table ----@return function -M.collapse_child_nodes = function(padding, uncollapsible) - - ---@param node TSNode - ---@return string - local function action(node) - if not helpers.node_is_multiline(node) then - return helpers.node_text(node) - end - - local tbl = actions.toggle_multiline(padding, uncollapsible) - local replacement = tbl[1][1](node) - - return replacement - end - - return action -end - -return M diff --git a/lua/ts-node-action/filetypes/python/node_utils.lua b/lua/ts-node-action/filetypes/python/node_utils.lua new file mode 100644 index 0000000..983087e --- /dev/null +++ b/lua/ts-node-action/filetypes/python/node_utils.lua @@ -0,0 +1,144 @@ +local actions = require("ts-node-action.actions") +local helpers = require("ts-node-action.helpers") + +-- WARN: Functions defined here should be treated as private/internal. +-- This is like an incubator and all are subject to change. + +-- NOTE: All functions are for TSNode, so rather than prefixing every function +-- name with "node_", the module is named "node_utils". + +local M = {} + +M.lines = function(node) + local lines = helpers.node_text(node) + if type(lines) == "string" then + return { lines } + end + return lines +end + +---@param node TSNode +M.trim_whitespace = function(node) + local start_row, _, end_row, _ = node:range() + vim.cmd("silent! keeppatterns " .. (start_row + 1) .. "," .. (end_row + 1) .. "s/\\s\\+$//g") +end + +-- Recreating actions.toggle_multiline.collapse_child_nodes() here because +-- it is not exported. +-- +---@param padding table +---@param uncollapsible table +---@return function @A function that takes a TSNode and returns a string +M.collapse_func = function(padding, uncollapsible) + local collapse = actions.toggle_multiline(padding, uncollapsible)[1][1] + + return function(node) + if not helpers.node_is_multiline(node) then + return helpers.node_text(node) + end + return collapse(node) + end +end + +-- Like vim.tbl_filter, but for TSNodes. +-- +---@param accept fun(node: TSNode): boolean @returns true for a valid node +---@param iter fun(): TSNode|nil @returns the next node +---@return TSNode[] +M.filter = function(accept, iter) + local nodes = {} + local node = iter() + while node and accept(node) do + table.insert(nodes, node) + node = iter() + end + return nodes +end + +-- Like filter, but stops at the first falsey value. +-- +---@param accept fun(node: TSNode): boolean @returns true for a valid node +---@param iter fun(): TSNode|nil @returns the next node +---@return TSNode[] +M.takewhile = function(accept, iter) + local nodes = {} + local node = iter() + while node and accept(node) do + table.insert(nodes, node) + node = iter() + end + return nodes +end + +M.iter_named_children = function(node) + local iter = node:iter_children() + return function() + local child = iter() + while child and not child:named() do + child = iter() + end + return child + end +end +M.iter_prev_named_sibling = function(node) + local sibling = node:prev_named_sibling() + return function() + if sibling then + local curr_sibling = sibling + sibling = sibling:prev_named_sibling() + return curr_sibling + end + end +end +M.iter_next_named_sibling = function(node) + local sibling = node:next_named_sibling() + return function() + if sibling then + local curr_sibling = sibling + sibling = sibling:next_named_sibling() + return curr_sibling + end + end +end +M.iter_parent = function(node) + local parent = node:parent() + return function() + if parent then + local curr_parent = parent + parent = parent:parent() + return curr_parent + end + end +end + + +-- Create a fake node to represent the replacement target. This is necessary +-- when the replacement spans multiple nodes without a suitable parent to serve +-- as a the target (eg, a top-level node's parent is the root and we are acting +-- on multiple children). +-- +-- This is indistiguishable from a TSNode, other than type(target) == "table". +-- +---@param node TSNode +---@param start_pos table +---@param end_pos table +---@return table +M.make_target = function(node, start_pos, end_pos) + -- TSNode's are userdata, which can't be cloned/altered, so this proxy's calls + -- to it and overrides the position methods. + local target = {} + for k, _ in pairs(getmetatable(node)) do + target[k] = function(_, ...) + return node[k](node, ...) + end + end + function target:start() return unpack(start_pos) end + function target:end_() return unpack(end_pos) end + function target:range() + return start_pos[1], start_pos[2], end_pos[1], end_pos[2] + end + + return target +end + +return M diff --git a/spec/filetypes/python/import_from_statement_spec.lua b/spec/filetypes/python/import_from_statement_spec.lua index 27a4902..53ad166 100644 --- a/spec/filetypes/python/import_from_statement_spec.lua +++ b/spec/filetypes/python/import_from_statement_spec.lua @@ -80,6 +80,19 @@ describe("import_from_statement", function() ) end) + it("cycles from inline to expand (single, slightly ambiguous)", function() + assert.are.same( + { + [[from foo import (]], + [[ bar,]], + [[)]], + }, + Helper:call({ + [[from foo import (]], + [[ bar)]], + }) + ) + end) it("cycles from inline to expand with mixed siblings", function() assert.are.same( diff --git a/spec/filetypes/python/import_statement_spec.lua b/spec/filetypes/python/import_statement_spec.lua index a2f7c62..19268a3 100644 --- a/spec/filetypes/python/import_statement_spec.lua +++ b/spec/filetypes/python/import_statement_spec.lua @@ -4,6 +4,17 @@ local Helper = SpecHelper.new("python", { shiftwidth = 4 }) describe("import_statement", function() + it("doesn't cycle with 1 import (same for both)", function() + assert.are.same( + { + [[import bar]], + }, + Helper:call({ + [[import bar]], + }) + ) + end) + it("cycles from single to inline", function() assert.are.same( @@ -165,7 +176,4 @@ describe("import_statement", function() ) end) - - - end)