diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index b1a8ceef0fd2..b7eb17617a64 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -1,77 +1,175 @@ +""" +String manipulation nodes converted to ComfyUI v3 format. + +This module contains v3 conversions of all string manipulation nodes from nodes_string.py. +The v3 implementations provide type safety, better documentation, and cleaner APIs +while maintaining full backward compatibility with v1 through the automatic +compatibility layer. +""" + import re +from comfy_api.v3 import io + -from comfy.comfy_types.node_typing import IO +class StringConcatenate(io.ComfyNodeV3): + """Concatenates two strings with an optional delimiter between them.""" -class StringConcatenate(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "delimiter": (IO.STRING, {"multiline": False, "default": ""}) - } - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, delimiter, **kwargs): - return delimiter.join((string_a, string_b)), - -class StringSubstring(): + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="StringConcatenate", + display_name="String Concatenate", + category="utils/string", + description="Concatenates two strings together with an optional delimiter between them.", + inputs=[ + io.String.Input( + "string_a", + display_name="String A", + multiline=True, + tooltip="The first string to concatenate", + ), + io.String.Input( + "string_b", + display_name="String B", + multiline=True, + tooltip="The second string to concatenate", + ), + io.String.Input( + "delimiter", + display_name="Delimiter", + default="", + multiline=False, + tooltip="The delimiter to insert between the two strings (empty by default)", + ), + ], + outputs=[ + io.String.Output( + "concatenated", + display_name="Concatenated String", + tooltip="The result of concatenating string_a and string_b with the delimiter", + ), + ], + ) + + @classmethod + def execute(cls, string_a: str, string_b: str, delimiter: str) -> io.NodeOutput: + """Concatenates two strings with an optional delimiter.""" + result = delimiter.join((string_a, string_b)) + return io.NodeOutput(result) + + +class StringSubstring(io.ComfyNodeV3): + """Extracts a substring from a string using start and end indices.""" + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "start": (IO.INT, {}), - "end": (IO.INT, {}), - } - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, start, end, **kwargs): - return string[start:end], - -class StringLength(): + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="StringSubstring", + display_name="String Substring", + category="utils/string", + description="Extracts a portion of a string using Python slice notation [start:end].", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to extract a substring from", + ), + io.Int.Input( + "start", + display_name="Start Index", + tooltip="Starting position (inclusive). Negative values count from the end", + ), + io.Int.Input( + "end", + display_name="End Index", + tooltip="Ending position (exclusive). Negative values count from the end", + ), + ], + outputs=[ + io.String.Output( + "substring", + display_name="Substring", + tooltip="The extracted substring", + ), + ], + ) + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}) - } - } + def execute(cls, string: str, start: int, end: int) -> io.NodeOutput: + """Extracts substring using Python slice notation.""" + return io.NodeOutput(string[start:end]) - RETURN_TYPES = (IO.INT,) - RETURN_NAMES = ("length",) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, **kwargs): - length = len(string) +class StringLength(io.ComfyNodeV3): + """Returns the length of a string.""" - return length, + @classmethod + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="StringLength", + display_name="String Length", + category="utils/string", + description="Calculates the number of characters in a string.", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to measure", + ), + ], + outputs=[ + io.Int.Output( + "length", + display_name="Length", + tooltip="The number of characters in the string", + ), + ], + ) -class CaseConverter(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]}) - } - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, mode, **kwargs): + def execute(cls, string: str) -> io.NodeOutput: + """Returns the length of the input string.""" + return io.NodeOutput(len(string)) + + +class CaseConverter(io.ComfyNodeV3): + """Converts string case to uppercase, lowercase, capitalize, or title case.""" + + @classmethod + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="CaseConverter", + display_name="Case Converter", + category="utils/string", + description="Converts text to different case formats.", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to convert", + ), + io.Combo.Input( + "mode", + display_name="Mode", + options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"], + tooltip="The case conversion mode to apply", + ), + ], + outputs=[ + io.String.Output( + "converted", + display_name="Converted String", + tooltip="The string with the selected case conversion applied", + ), + ], + ) + + @classmethod + def execute(cls, string: str, mode: str) -> io.NodeOutput: + """Converts string to the selected case format.""" if mode == "UPPERCASE": result = string.upper() elif mode == "lowercase": @@ -83,24 +181,45 @@ def execute(self, string, mode, **kwargs): else: result = string - return result, + return io.NodeOutput(result) + +class StringTrim(io.ComfyNodeV3): + """Removes whitespace from the beginning, end, or both sides of a string.""" -class StringTrim(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]}) - } - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, mode, **kwargs): + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="StringTrim", + display_name="String Trim", + category="utils/string", + description="Removes leading and/or trailing whitespace from a string.", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to trim", + ), + io.Combo.Input( + "mode", + display_name="Mode", + options=["Both", "Left", "Right"], + tooltip="Which side(s) to trim whitespace from", + ), + ], + outputs=[ + io.String.Output( + "trimmed", + display_name="Trimmed String", + tooltip="The string with whitespace removed", + ), + ], + ) + + @classmethod + def execute(cls, string: str, mode: str) -> io.NodeOutput: + """Removes whitespace based on the selected mode.""" if mode == "Both": result = string.strip() elif mode == "Left": @@ -110,70 +229,157 @@ def execute(self, string, mode, **kwargs): else: result = string - return result, + return io.NodeOutput(result) + + +class StringReplace(io.ComfyNodeV3): + """Replaces all occurrences of a substring with another string.""" + + @classmethod + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="StringReplace", + display_name="String Replace", + category="utils/string", + description="Replaces all occurrences of a substring within a string.", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to search in", + ), + io.String.Input( + "find", + display_name="Find", + multiline=True, + tooltip="The substring to search for", + ), + io.String.Input( + "replace", + display_name="Replace", + multiline=True, + tooltip="The string to replace matches with", + ), + ], + outputs=[ + io.String.Output( + "result", + display_name="Result", + tooltip="The string with all replacements made", + ), + ], + ) -class StringReplace(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "find": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}) - } - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, find, replace, **kwargs): + def execute(cls, string: str, find: str, replace: str) -> io.NodeOutput: + """Replaces all occurrences of find with replace.""" result = string.replace(find, replace) - return result, + return io.NodeOutput(result) -class StringContains(): +class StringContains(io.ComfyNodeV3): + """Checks if a string contains a substring.""" + + @classmethod + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="StringContains", + display_name="String Contains", + category="utils/string", + description="Checks whether a string contains a specific substring.", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to search in", + ), + io.String.Input( + "substring", + display_name="Substring", + multiline=True, + tooltip="The substring to search for", + ), + io.Boolean.Input( + "case_sensitive", + display_name="Case Sensitive", + default=True, + tooltip="Whether the search should be case sensitive", + ), + ], + outputs=[ + io.Boolean.Output( + "contains", + display_name="Contains", + tooltip="True if the substring is found, False otherwise", + ), + ], + ) + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "substring": (IO.STRING, {"multiline": True}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } - - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("contains",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, substring, case_sensitive, **kwargs): + def execute( + cls, string: str, substring: str, case_sensitive: bool + ) -> io.NodeOutput: + """Checks if string contains substring with optional case sensitivity.""" if case_sensitive: contains = substring in string else: contains = substring.lower() in string.lower() - return contains, + return io.NodeOutput(contains) -class StringCompare(): +class StringCompare(io.ComfyNodeV3): + """Compares two strings with various comparison modes.""" + + @classmethod + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="StringCompare", + display_name="String Compare", + category="utils/string", + description="Compares two strings using different comparison modes.", + inputs=[ + io.String.Input( + "string_a", + display_name="String A", + multiline=True, + tooltip="The first string to compare", + ), + io.String.Input( + "string_b", + display_name="String B", + multiline=True, + tooltip="The second string to compare", + ), + io.Combo.Input( + "mode", + display_name="Mode", + options=["Starts With", "Ends With", "Equal"], + tooltip="The comparison mode to use", + ), + io.Boolean.Input( + "case_sensitive", + display_name="Case Sensitive", + default=True, + tooltip="Whether the comparison should be case sensitive", + ), + ], + outputs=[ + io.Boolean.Output( + "result", + display_name="Result", + tooltip="True if the comparison succeeds, False otherwise", + ), + ], + ) + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } - - RETURN_TYPES = (IO.BOOLEAN,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, mode, case_sensitive, **kwargs): + def execute( + cls, string_a: str, string_b: str, mode: str, case_sensitive: bool + ) -> io.NodeOutput: + """Compares two strings based on the selected mode and case sensitivity.""" if case_sensitive: a = string_a b = string_b @@ -182,31 +388,78 @@ def execute(self, string_a, string_b, mode, case_sensitive, **kwargs): b = string_b.lower() if mode == "Equal": - return a == b, + result = a == b elif mode == "Starts With": - return a.startswith(b), + result = a.startswith(b) elif mode == "Ends With": - return a.endswith(b), + result = a.endswith(b) + else: + result = False + + return io.NodeOutput(result) + + +class RegexMatch(io.ComfyNodeV3): + """Tests if a string matches a regular expression pattern.""" -class RegexMatch(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}) - } - } - - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("matches",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs): + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="RegexMatch", + display_name="Regex Match", + category="utils/string", + description="Tests whether a string matches a regular expression pattern.", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to test", + ), + io.String.Input( + "regex_pattern", + display_name="Regex Pattern", + multiline=True, + tooltip="The regular expression pattern to match against", + ), + io.Boolean.Input( + "case_insensitive", + display_name="Case Insensitive", + default=True, + tooltip="Whether to ignore case when matching", + ), + io.Boolean.Input( + "multiline", + display_name="Multiline", + default=False, + tooltip="Whether ^ and $ match line boundaries", + ), + io.Boolean.Input( + "dotall", + display_name="Dot All", + default=False, + tooltip="Whether . matches newline characters", + ), + ], + outputs=[ + io.Boolean.Output( + "matches", + display_name="Matches", + tooltip="True if the pattern matches, False otherwise", + ), + ], + ) + + @classmethod + def execute( + cls, + string: str, + regex_pattern: str, + case_insensitive: bool, + multiline: bool, + dotall: bool, + ) -> io.NodeOutput: + """Tests if string matches the regex pattern.""" flags = 0 if case_insensitive: @@ -219,33 +472,89 @@ def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, ** try: match = re.search(regex_pattern, string, flags) result = match is not None - except re.error: result = False - return result, + return io.NodeOutput(result) + +class RegexExtract(io.ComfyNodeV3): + """Extracts text from a string using regular expression patterns.""" -class RegexExtract(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}), - "group_index": (IO.INT, {"default": 1, "min": 0, "max": 100}) - } - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs): + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="RegexExtract", + display_name="Regex Extract", + category="utils/string", + description="Extracts text from a string using regular expression patterns and groups.", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to extract from", + ), + io.String.Input( + "regex_pattern", + display_name="Regex Pattern", + multiline=True, + tooltip="The regular expression pattern with optional groups", + ), + io.Combo.Input( + "mode", + display_name="Mode", + options=["First Match", "All Matches", "First Group", "All Groups"], + tooltip="What to extract from the matches", + ), + io.Boolean.Input( + "case_insensitive", + display_name="Case Insensitive", + default=True, + tooltip="Whether to ignore case when matching", + ), + io.Boolean.Input( + "multiline", + display_name="Multiline", + default=False, + tooltip="Whether ^ and $ match line boundaries", + ), + io.Boolean.Input( + "dotall", + display_name="Dot All", + default=False, + tooltip="Whether . matches newline characters", + ), + io.Int.Input( + "group_index", + display_name="Group Index", + default=1, + min=0, + max=100, + tooltip="Which capture group to extract (0 = entire match)", + ), + ], + outputs=[ + io.String.Output( + "extracted", + display_name="Extracted", + tooltip="The extracted text (multiple matches joined with newlines)", + ), + ], + ) + + @classmethod + def execute( + cls, + string: str, + regex_pattern: str, + mode: str, + case_insensitive: bool, + multiline: bool, + dotall: bool, + group_index: int, + ) -> io.NodeOutput: + """Extracts text based on regex pattern and mode.""" join_delimiter = "\n" flags = 0 @@ -294,32 +603,90 @@ def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dota except re.error: result = "" - return result, + return io.NodeOutput(result) + +class RegexReplace(io.ComfyNodeV3): + """Find and replace text using regex patterns.""" -class RegexReplace(): - DESCRIPTION = "Find and replace text using regex patterns." @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}), - }, - "optional": { - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}), - "count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}), - } - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs): + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="RegexReplace", + display_name="Regex Replace", + category="utils/string", + description="Find and replace text using regular expression patterns.", + inputs=[ + io.String.Input( + "string", + display_name="String", + multiline=True, + tooltip="The string to perform replacements on", + ), + io.String.Input( + "regex_pattern", + display_name="Regex Pattern", + multiline=True, + tooltip="The regular expression pattern to match", + ), + io.String.Input( + "replace", + display_name="Replace", + multiline=True, + tooltip="The replacement text (can use \\1, \\2 for capture groups)", + ), + io.Boolean.Input( + "case_insensitive", + display_name="Case Insensitive", + default=True, + optional=True, + tooltip="Whether to ignore case when matching", + ), + io.Boolean.Input( + "multiline", + display_name="Multiline", + default=False, + optional=True, + tooltip="Whether ^ and $ match line boundaries", + ), + io.Boolean.Input( + "dotall", + display_name="Dot All", + default=False, + optional=True, + tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines.", + ), + io.Int.Input( + "count", + display_name="Count", + default=0, + min=0, + max=100, + optional=True, + tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc.", + ), + ], + outputs=[ + io.String.Output( + "result", + display_name="Result", + tooltip="The string with replacements made", + ), + ], + ) + + @classmethod + def execute( + cls, + string: str, + regex_pattern: str, + replace: str, + case_insensitive: bool = True, + multiline: bool = False, + dotall: bool = False, + count: int = 0, + ) -> io.NodeOutput: + """Replaces text matching regex pattern.""" flags = 0 if case_insensitive: @@ -328,8 +695,10 @@ def execute(self, string, regex_pattern, replace, case_insensitive=True, multili flags |= re.MULTILINE if dotall: flags |= re.DOTALL + result = re.sub(regex_pattern, replace, string, count=count, flags=flags) - return result, + return io.NodeOutput(result) + NODE_CLASS_MAPPINGS = { "StringConcatenate": StringConcatenate, @@ -358,3 +727,4 @@ def execute(self, string, regex_pattern, replace, case_insensitive=True, multili "RegexExtract": "Regex Extract", "RegexReplace": "Regex Replace", } +