From a3aae2e6ac5a1d2f0e43d3534c77e11707917b4f Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 25 Apr 2026 16:21:26 -0400 Subject: [PATCH] Allow styling prompts with HTML-like tags * use FormattedText for prompt strings rather than ANSI * do all interpretation and formatting within render_prompt_string() ( renamed from get_prompt() ) * make formatting more robust: don't depend on a special string to be substituted for backslashes * inline an error message if the user gives an unrecognized backslash prompt format string * replace backslashes with forward slashes in socket names, in case they would be further substituted as format strings Motivation: It is much easier to figure out how to apply colors and styles using this method. Bugs and limitations: Since the substitutions are still done serially on a string (or list of strings), in the rare case that say a DSN-name substitution value contained a backslash, it might again be substituted, or fail. A better way could be to use "re.split()" with a capture group, replace known prompt strings with callables, then have a substitution pass, then join at the end. But the edge case of a double-substitution is very unlikely, and the edge case of HTML characters is handled. --- changelog.md | 5 + mycli/main.py | 11 +- mycli/main_modes/repl.py | 183 ++++++++++++++++++--------- mycli/myclirc | 5 + mycli/packages/string_utils.py | 9 +- test/myclirc | 9 +- test/pytests/test_main.py | 26 ++-- test/pytests/test_main_modes_repl.py | 86 ++++++++++--- 8 files changed, 242 insertions(+), 92 deletions(-) diff --git a/changelog.md b/changelog.md index ae52dda6..89a66d76 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ Upcoming (TBD) ============== +Features +--------- +* Allow styling prompts with HTML-like tags. + + Internal --------- * Remove unused fixture data. diff --git a/mycli/main.py b/mycli/main.py index 1c0b5e4a..01ffa72a 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -70,7 +70,7 @@ from mycli.main_modes.execute import main_execute_from_cli from mycli.main_modes.list_dsn import main_list_dsn from mycli.main_modes.list_ssh_config import main_list_ssh_config -from mycli.main_modes.repl import get_prompt, main_repl, set_all_external_titles +from mycli.main_modes.repl import main_repl, render_prompt_string, set_all_external_titles from mycli.packages import special from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme from mycli.packages.filepaths import dir_path_exists, guess_socket_location @@ -268,8 +268,8 @@ def __init__( self.min_completion_trigger = c["main"].as_int("min_completion_trigger") # a hack, pending a better way to handle settings and state repl_package.MIN_COMPLETION_TRIGGER = self.min_completion_trigger - self.last_prompt_message = ANSI('') - self.last_custom_toolbar_message = ANSI('') + self.last_prompt_message = to_formatted_text('') + self.last_custom_toolbar_message = to_formatted_text('') # Register custom special commands. self.register_special_commands() @@ -907,8 +907,9 @@ def get_output_margin(self, status: str | None = None) -> int: render_counter = self.prompt_session.app.render_counter else: render_counter = 0 - # todo: this jump back to get_prompt() in repl.py is a sign that separation is incomplete - self.prompt_lines = get_prompt(self, self.prompt_format, render_counter).count('\n') + 1 + # todo: this jump back to render_prompt_string() in repl.py is a sign that separation is incomplete + prompt_string = render_prompt_string(self, self.prompt_format, render_counter) + self.prompt_lines = to_plain_text(prompt_string).count('\n') + 1 margin = self.get_reserved_space() + self.prompt_lines if special.is_timing_enabled(): margin += 1 diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py index da8f148a..10b8df9c 100644 --- a/mycli/main_modes/repl.py +++ b/mycli/main_modes/repl.py @@ -4,6 +4,7 @@ from datetime import datetime import functools from functools import partial +import html from importlib import resources import os import random @@ -13,6 +14,7 @@ import time import traceback from typing import TYPE_CHECKING, Any, Generator +from xml.parsers.expat import ExpatError import click import prompt_toolkit @@ -23,6 +25,10 @@ from prompt_toolkit.filters import Condition, has_focus, is_done from prompt_toolkit.formatted_text import ( ANSI, + HTML, + FormattedText, + to_formatted_text, + to_plain_text, ) from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor @@ -162,7 +168,13 @@ def set_external_terminal_tab_title(mycli: 'MyCli') -> None: return if not sys.stderr.isatty(): return - title = sanitize_terminal_title(get_prompt(mycli, mycli.terminal_tab_title_format, mycli.prompt_session.app.render_counter)) + title = sanitize_terminal_title( + render_prompt_string( + mycli, + mycli.terminal_tab_title_format, + mycli.prompt_session.app.render_counter, + ) + ) print(f'\x1b]1;{title}\a', file=sys.stderr, end='') sys.stderr.flush() @@ -174,7 +186,13 @@ def set_external_terminal_window_title(mycli: 'MyCli') -> None: return if not sys.stderr.isatty(): return - title = sanitize_terminal_title(get_prompt(mycli, mycli.terminal_window_title_format, mycli.prompt_session.app.render_counter)) + title = sanitize_terminal_title( + render_prompt_string( + mycli, + mycli.terminal_window_title_format, + mycli.prompt_session.app.render_counter, + ) + ) print(f'\x1b]2;{title}\a', file=sys.stderr, end='') sys.stderr.flush() @@ -186,7 +204,13 @@ def set_external_multiplex_window_title(mycli: 'MyCli') -> None: return if not mycli.prompt_session: return - title = sanitize_terminal_title(get_prompt(mycli, mycli.multiplex_window_title_format, mycli.prompt_session.app.render_counter)) + title = sanitize_terminal_title( + render_prompt_string( + mycli, + mycli.multiplex_window_title_format, + mycli.prompt_session.app.render_counter, + ) + ) try: subprocess.run( ['tmux', 'rename-window', title], @@ -208,7 +232,13 @@ def set_external_multiplex_pane_title(mycli: 'MyCli') -> None: return if not sys.stderr.isatty(): return - title = sanitize_terminal_title(get_prompt(mycli, mycli.multiplex_pane_title_format, mycli.prompt_session.app.render_counter)) + title = sanitize_terminal_title( + render_prompt_string( + mycli, + mycli.multiplex_pane_title_format, + mycli.prompt_session.app.render_counter, + ) + ) print(f'\x1b]2;{title}\x1b\\', file=sys.stderr, end='') sys.stderr.flush() @@ -216,25 +246,33 @@ def set_external_multiplex_pane_title(mycli: 'MyCli') -> None: def get_custom_toolbar( mycli: 'MyCli', toolbar_format: str, -) -> ANSI: +) -> FormattedText: if not mycli.prompt_session: - return ANSI('') + return to_formatted_text('') if not mycli.prompt_session.app: - return ANSI('') + return to_formatted_text('') if mycli.prompt_session.app.current_buffer.text: return mycli.last_custom_toolbar_message - toolbar = get_prompt(mycli, toolbar_format, mycli.prompt_session.app.render_counter) - toolbar = toolbar.replace('\\x1b', '\x1b') - mycli.last_custom_toolbar_message = ANSI(toolbar) + mycli.last_custom_toolbar_message = render_prompt_string( + mycli, + toolbar_format, + mycli.prompt_session.app.render_counter, + ) return mycli.last_custom_toolbar_message +def maybe_html_escape(string: str, is_html: bool) -> str: + if is_html: + return html.escape(string, quote=False) + return string + + @functools.lru_cache(maxsize=256) -def get_prompt( +def render_prompt_string( mycli: 'MyCli', string: str, _render_counter: int, -) -> str: +) -> FormattedText: sqlexecute = mycli.sqlexecute assert sqlexecute is not None if mycli.login_path and mycli.login_path_as_host: @@ -247,79 +285,106 @@ def get_prompt( if re.match(r'^[\d\.]+$', short_prompt_host): short_prompt_host = prompt_host now = datetime.now() - backslash_placeholder = '\ufffc_backslash' - string = string.replace('\\\\', backslash_placeholder) - string = string.replace('\\u', sqlexecute.user or '(none)') - string = string.replace('\\h', prompt_host or '(none)') - string = string.replace('\\H', short_prompt_host or '(none)') - string = string.replace('\\d', sqlexecute.dbname or '(none)') species_name = sqlexecute.server_info.species.name if sqlexecute.server_info and sqlexecute.server_info.species else 'MySQL' - string = string.replace('\\t', species_name) - string = string.replace('\\n', '\n') - string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) - string = string.replace('\\m', now.strftime('%M')) - string = string.replace('\\P', now.strftime('%p')) - string = string.replace('\\R', now.strftime('%H')) - string = string.replace('\\r', now.strftime('%I')) - string = string.replace('\\s', now.strftime('%S')) - string = string.replace('\\p', str(sqlexecute.port)) - string = string.replace('\\j', os.path.basename(sqlexecute.socket or '(none)')) - string = string.replace('\\J', sqlexecute.socket or '(none)') - string = string.replace('\\k', os.path.basename(sqlexecute.socket or str(sqlexecute.port))) - string = string.replace('\\K', sqlexecute.socket or str(sqlexecute.port)) - string = string.replace('\\A', mycli.dsn_alias or '(none)') - string = string.replace('\\_', ' ') - string = string.replace(backslash_placeholder, '\\') - + strings = string.split('\\\\') + is_html = strings[0].startswith('\\') + strings = [x.replace('\\u', maybe_html_escape(sqlexecute.user or '(none)', is_html)) for x in strings] + strings = [x.replace('\\h', maybe_html_escape(prompt_host or '(none)', is_html)) for x in strings] + strings = [x.replace('\\H', maybe_html_escape(short_prompt_host or '(none)', is_html)) for x in strings] + strings = [x.replace('\\d', maybe_html_escape(sqlexecute.dbname or '(none)', is_html)) for x in strings] + strings = [x.replace('\\t', maybe_html_escape(species_name, is_html)) for x in strings] + strings = [x.replace('\\n', '\n') for x in strings] + strings = [x.replace('\\D', maybe_html_escape(now.strftime('%a %b %d %H:%M:%S %Y'), is_html)) for x in strings] + strings = [x.replace('\\m', maybe_html_escape(now.strftime('%M'), is_html)) for x in strings] + strings = [x.replace('\\P', maybe_html_escape(now.strftime('%p'), is_html)) for x in strings] + strings = [x.replace('\\R', maybe_html_escape(now.strftime('%H'), is_html)) for x in strings] + strings = [x.replace('\\r', maybe_html_escape(now.strftime('%I'), is_html)) for x in strings] + strings = [x.replace('\\s', maybe_html_escape(now.strftime('%S'), is_html)) for x in strings] + strings = [x.replace('\\p', maybe_html_escape(str(sqlexecute.port), is_html)) for x in strings] + strings = [ + x.replace('\\j', maybe_html_escape(os.path.basename(sqlexecute.socket or '(none)').replace('\\', '/'), is_html)) for x in strings + ] + strings = [x.replace('\\J', maybe_html_escape((sqlexecute.socket or '(none)').replace('\\', '/'), is_html)) for x in strings] + strings = [ + x.replace('\\k', maybe_html_escape(os.path.basename(sqlexecute.socket or str(sqlexecute.port)).replace('\\', '/'), is_html)) + for x in strings + ] + strings = [ + x.replace('\\K', maybe_html_escape((sqlexecute.socket or str(sqlexecute.port)).replace('\\', '/'), is_html)) for x in strings + ] + strings = [x.replace('\\A', maybe_html_escape(mycli.dsn_alias or '(none)', is_html)) for x in strings] + strings = [x.replace('\\_', ' ') for x in strings] + + checker_string = ' '.join(strings) if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: - if '\\y' in string: + if '\\y' in checker_string: with sqlexecute.conn.cursor() as cur: - string = string.replace('\\y', str(get_uptime(cur)) or '(none)') - if '\\Y' in string: + strings = [x.replace('\\y', maybe_html_escape(str(get_uptime(cur)) or '(none)', is_html)) for x in strings] + if '\\Y' in checker_string: with sqlexecute.conn.cursor() as cur: - string = string.replace('\\Y', format_uptime(str(get_uptime(cur))) or '(none)') + strings = [x.replace('\\Y', maybe_html_escape(format_uptime(str(get_uptime(cur))) or '(none)', is_html)) for x in strings] else: - string = string.replace('\\y', '(none)') - string = string.replace('\\Y', '(none)') + strings = [x.replace('\\y', '(none)') for x in strings] + strings = [x.replace('\\Y', '(none)') for x in strings] if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: - if '\\T' in string: + if '\\T' in checker_string: with sqlexecute.conn.cursor() as cur: - string = string.replace('\\T', get_ssl_version(cur) or '(none)') + strings = [x.replace('\\T', maybe_html_escape(get_ssl_version(cur) or '(none)', is_html)) for x in strings] else: - string = string.replace('\\T', '(none)') + strings = [x.replace('\\T', '(none)') for x in strings] if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: - if '\\w' in string: + if '\\w' in checker_string: with sqlexecute.conn.cursor() as cur: - string = string.replace('\\w', str(get_warning_count(cur) or '(none)')) + strings = [x.replace('\\w', maybe_html_escape(str(get_warning_count(cur) or '(none)'), is_html)) for x in strings] else: - string = string.replace('\\w', '(none)') + strings = [x.replace('\\w', '(none)') for x in strings] if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: - if '\\W' in string: + if '\\W' in checker_string: with sqlexecute.conn.cursor() as cur: - string = string.replace('\\W', str(get_warning_count(cur) or '')) + strings = [x.replace('\\W', maybe_html_escape(str(get_warning_count(cur) or ''), is_html)) for x in strings] else: - string = string.replace('\\W', '') + strings = [x.replace('\\W', '') for x in strings] - return string + if is_html: + strings[0] = strings[0].removeprefix('\\') + strings[-1] = strings[-1].removesuffix('\\') + elif '\\x1b' in checker_string: + strings = [x.replace('\\x1b', '\x1b') for x in strings] + + strings = [re.sub(r'\\(.)', r'(unknown prompt format string: \\\1)', x) for x in strings] + + string = '\\'.join(strings) + + if is_html: + try: + formatted_string = to_formatted_text(HTML(string)) + except (ExpatError, ValueError): + formatted_string = to_formatted_text(HTML('(cannot parse HTML prompt string)')) + else: + formatted_string = to_formatted_text(ANSI(string)) + + return formatted_string def _get_prompt_message( mycli: 'MyCli', app: prompt_toolkit.application.application.Application, -) -> ANSI: +) -> FormattedText: if app.current_buffer.text: return mycli.last_prompt_message - prompt = get_prompt(mycli, mycli.prompt_format, app.render_counter) - if mycli.prompt_format == mycli.default_prompt and len(prompt) > mycli.max_len_prompt: - prompt = get_prompt(mycli, mycli.default_prompt_splitln, app.render_counter) - mycli.prompt_lines = prompt.count('\n') + 1 - prompt = prompt.replace('\\x1b', '\x1b') + prompt = render_prompt_string(mycli, mycli.prompt_format, app.render_counter) + prompt_plain = to_plain_text(prompt) + if mycli.prompt_format == mycli.default_prompt and len(prompt_plain) > mycli.max_len_prompt: + prompt = render_prompt_string(mycli, mycli.default_prompt_splitln, app.render_counter) + prompt_plain = to_plain_text(prompt) + mycli.prompt_lines = prompt_plain.count('\n') + 1 if not mycli.prompt_lines: - mycli.prompt_lines = prompt.count('\n') + 1 - mycli.last_prompt_message = ANSI(prompt) + mycli.prompt_lines = prompt_plain.count('\n') + 1 + + mycli.last_prompt_message = prompt return mycli.last_prompt_message diff --git a/mycli/myclirc b/mycli/myclirc index 3aa35189..b14bb824 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -137,6 +137,11 @@ wider_completion_menu = False # * \_ - a space # * \\ - a literal backslash # * \x1b[...m - an ANSI escape sequence (can style with color) +# * \ - a leading sequence indicating that the rest of the prompt be styled like HTML. +# See https://python-prompt-toolkit.readthedocs.io/en/stable/pages/printing_text.html#html . +# Characters such as "&" or literal "<" and ">" must be HTML-escaped. +# HTML styles cannot be combined with ANSI sequences. HTML takes precedence. +# HTML color example: prompt = '\root@localhost:\d> ' prompt = '\t \u@\h:\d> ' prompt_continuation = '->' diff --git a/mycli/packages/string_utils.py b/mycli/packages/string_utils.py index 89402ad5..56103330 100644 --- a/mycli/packages/string_utils.py +++ b/mycli/packages/string_utils.py @@ -1,10 +1,15 @@ import re from cli_helpers.utils import strip_ansi +from prompt_toolkit.formatted_text import ( + FormattedText, + to_plain_text, +) -def sanitize_terminal_title(title: str) -> str: - sanitized = strip_ansi(title) +def sanitize_terminal_title(title: FormattedText) -> str: + sanitized = to_plain_text(title) + sanitized = strip_ansi(sanitized) sanitized = sanitized.replace('\n', ' ') sanitized = re.sub('[\x00-\x1f\x7f]', '', sanitized) return sanitized diff --git a/test/myclirc b/test/myclirc index 811c51d2..f68f55be 100644 --- a/test/myclirc +++ b/test/myclirc @@ -125,16 +125,21 @@ wider_completion_menu = False # * \K - full connection socket path OR the port # * \T - connection SSL/TLS version # * \t - database vendor (Percona, MySQL, MariaDB, TiDB) +# * \u - username # * \w - number of warnings, or "(none)" (requires frequent trips to the server) -# * \W - number of warnings, or the empty string (requires frequent trips to the server) +# * \W - number of warnings, or the empty string (requires frequent trips to the server) # * \y - uptime in seconds (requires frequent trips to the server) # * \Y - uptime in words (requires frequent trips to the server) -# * \u - username # * \A - DSN alias # * \n - a newline # * \_ - a space # * \\ - a literal backslash # * \x1b[...m - an ANSI escape sequence (can style with color) +# * \ - a leading sequence indicating that the rest of the prompt be styled like HTML. +# See https://python-prompt-toolkit.readthedocs.io/en/stable/pages/printing_text.html#html . +# Characters such as "&" or literal "<" and ">" must be HTML-escaped. +# HTML styles cannot be combined with ANSI sequences. HTML takes precedence. +# HTML color example: prompt = '\root@localhost:\d> ' prompt = "\t \u@\h:\d> " prompt_continuation = -> diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 295e6987..949a9cda 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -13,6 +13,11 @@ import click from click.testing import CliRunner +from prompt_toolkit.formatted_text import ( + FormattedText, + to_formatted_text, + to_plain_text, +) import pymysql from pymysql.err import OperationalError import pytest @@ -391,8 +396,9 @@ def test_prompt_no_host_only_socket(executor): mycli.sqlexecute.user = DEFAULT_USER mycli.sqlexecute.dbname = DEFAULT_DATABASE mycli.sqlexecute.port = DEFAULT_PORT - prompt = repl_mode.get_prompt(mycli, mycli.prompt_format, 0) - assert prompt == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_DATABASE}> " + prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_DATABASE}> " @dbtest @@ -406,8 +412,9 @@ def test_prompt_socket_overrides_port(executor): mycli.sqlexecute.user = DEFAULT_USER mycli.sqlexecute.dbname = DEFAULT_DATABASE mycli.sqlexecute.port = DEFAULT_PORT - prompt = repl_mode.get_prompt(mycli, mycli.prompt_format, 0) - assert prompt == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:mysqld.sock {DEFAULT_DATABASE}> " + prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:mysqld.sock {DEFAULT_DATABASE}> " @dbtest @@ -421,8 +428,9 @@ def test_prompt_socket_short_host(executor): mycli.sqlexecute.user = DEFAULT_USER mycli.sqlexecute.dbname = DEFAULT_DATABASE mycli.sqlexecute.port = DEFAULT_PORT - prompt = repl_mode.get_prompt(mycli, mycli.prompt_format, 0) - assert prompt == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_PORT} {DEFAULT_DATABASE}> " + prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_PORT} {DEFAULT_DATABASE}> " @dbtest @@ -2252,11 +2260,11 @@ def test_get_output_margin_uses_prompt_session_render_counter(monkeypatch: pytes SimpleNamespace(app=SimpleNamespace(render_counter=7)), ) - def fake_get_prompt(mycli: Any, string: str, render_counter: int) -> str: + def fake_render_prompt_string(mycli: Any, string: str, render_counter: int) -> FormattedText: render_counters.append(render_counter) - return 'line1\nline2' + return to_formatted_text('line1\nline2') - monkeypatch.setattr(main, 'get_prompt', fake_get_prompt) + monkeypatch.setattr(main, 'render_prompt_string', fake_render_prompt_string) monkeypatch.setattr(main.special, 'is_timing_enabled', lambda: False) assert main.MyCli.get_output_margin(cli, 'ok') == 5 assert render_counters == [7] diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py index 81f470a4..dd44fc2e 100644 --- a/test/pytests/test_main_modes_repl.py +++ b/test/pytests/test_main_modes_repl.py @@ -8,7 +8,7 @@ from types import SimpleNamespace from typing import Any, Literal, cast -from prompt_toolkit.formatted_text import to_plain_text +from prompt_toolkit.formatted_text import to_formatted_text, to_plain_text import pymysql import pytest @@ -335,8 +335,8 @@ def test_repl_show_startup_banner_and_prompt_helpers(monkeypatch: pytest.MonkeyP monkeypatch.setattr( repl_mode, - 'get_prompt', - lambda mycli, string, render_counter: '0123456' if string == cli.default_prompt else 'a\nb', + 'render_prompt_string', + lambda mycli, string, render_counter: to_formatted_text('0123456') if string == cli.default_prompt else 'a\nb', ) cli.max_len_prompt = 5 prompt_text = to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=2)))) @@ -348,7 +348,7 @@ def test_repl_show_startup_banner_and_prompt_helpers(monkeypatch: pytest.MonkeyP cli.prompt_format = 'custom' cli.prompt_lines = 0 - monkeypatch.setattr(repl_mode, 'get_prompt', lambda mycli, string, render_counter: 'single') + monkeypatch.setattr(repl_mode, 'render_prompt_string', lambda mycli, string, render_counter: to_formatted_text('single')) assert to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=4)))) == 'single' assert cli.prompt_lines == 1 @@ -397,8 +397,9 @@ def cursor(self) -> PromptCursor: cli.login_path = 'prod' cli.login_path_as_host = True cli.dsn_alias = 'dsn' - prompt = repl_mode.get_prompt(cli, r'\h|\H|\A|\y|\Y|\T|\w|\W', 0) - assert prompt == 'prod|prod|dsn|(none)|(none)|(none)|(none)|' + prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\A|\y|\Y|\T|\w|\W', 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == 'prod|prod|dsn|(none)|(none)|(none)|(none)|' sqlexecute.conn = PromptConnection() cli.login_path_as_host = False @@ -406,8 +407,9 @@ def cursor(self) -> PromptCursor: monkeypatch.setattr(repl_mode, 'format_uptime', lambda uptime: f'uptime:{uptime}') monkeypatch.setattr(repl_mode, 'get_ssl_version', lambda cur: 'TLSv1.3') monkeypatch.setattr(repl_mode, 'get_warning_count', lambda cur: 7) - prompt = repl_mode.get_prompt(cli, r'\H|\y|\Y|\T|\w|\W', 1) - assert prompt == '127.0.0.1|123|uptime:123|TLSv1.3|7|7' + prompt = repl_mode.render_prompt_string(cli, r'\H|\y|\Y|\T|\w|\W', 1) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == '127.0.0.1|123|uptime:123|TLSv1.3|7|7' cli.prompt_session = None assert to_plain_text(repl_mode.get_custom_toolbar(cli, 'fmt')) == '' @@ -420,7 +422,7 @@ def cursor(self) -> PromptCursor: assert repl_mode.get_custom_toolbar(cli, 'fmt') == cli.last_custom_toolbar_message cli.prompt_session.app.current_buffer.text = '' - monkeypatch.setattr(repl_mode, 'get_prompt', lambda mycli, string, render_counter: f'title:{string}') + monkeypatch.setattr(repl_mode, 'render_prompt_string', lambda mycli, string, render_counter: f'title:{string}') assert 'title:fmt' in str(repl_mode.get_custom_toolbar(cli, 'fmt')) cli.terminal_tab_title_format = 'tab' @@ -481,14 +483,17 @@ def cursor(self) -> PromptCursor: monkeypatch.setattr(repl_mode, 'get_ssl_version', lambda cur: 'TLSv1.3') monkeypatch.setattr(repl_mode, 'get_warning_count', lambda cur: 7) - prompt = repl_mode.get_prompt(cli, r'\h|\H|\y|\Y', 0) - assert prompt == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|123|uptime:123' + prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\y|\Y', 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|123|uptime:123' - prompt = repl_mode.get_prompt(cli, r'\h|\H|\w|\W', 1) - assert prompt == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|7|7' + prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\w|\W', 1) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|7|7' - prompt = repl_mode.get_prompt(cli, r'\h|\H|\T', 2) - assert prompt == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|TLSv1.3' + prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\T', 2) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|TLSv1.3' monkeypatch.setattr(repl_mode.sys.stderr, 'isatty', lambda: True) monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError('unexpected print'))) @@ -530,6 +535,57 @@ def cursor(self) -> PromptCursor: repl_mode.set_external_multiplex_pane_title(cli) +def test_maybe_html_escape() -> None: + assert repl_mode.maybe_html_escape('plain', False) == 'plain' + assert repl_mode.maybe_html_escape('a&b<1>', True) == 'a&b<1>' + + +def test_render_prompt_string_html() -> None: + repl_mode.render_prompt_string.cache_clear() + + cli = make_repl_cli( + SimpleNamespace( + user='ab', + host='db.example.com', + dbname='nameprod', + port=3306, + socket=None, + server_info=SimpleNamespace(species=SimpleNamespace(name='MySQL')), + conn=None, + ) + ) + cli.dsn_alias = 'aliasone' + + html_prompt = repl_mode.render_prompt_string(cli, r'\\u@\d|\A\', 0) + assert to_plain_text(html_prompt) == 'ab@nameprod|aliasone' + + bad_html_prompt = repl_mode.render_prompt_string(cli, r'\\u', 1) + assert to_plain_text(bad_html_prompt) == '(cannot parse HTML prompt string)' + + ansi_prompt = repl_mode.render_prompt_string(cli, r'\x1b[31mred\x1b[0m', 2) + assert to_plain_text(ansi_prompt) == 'red' + + +def test_render_prompt_string_ansi() -> None: + repl_mode.render_prompt_string.cache_clear() + + cli = make_repl_cli( + SimpleNamespace( + user='ab', + host='db.example.com', + dbname='nameprod', + port=3306, + socket=None, + server_info=SimpleNamespace(species=SimpleNamespace(name='MySQL')), + conn=None, + ) + ) + cli.dsn_alias = 'aliasone' + + ansi_prompt = repl_mode.render_prompt_string(cli, r'\x1b[31mred\x1b[0m', 2) + assert to_plain_text(ansi_prompt) == 'red' + + def test_output_results_covers_watch_warning_timing_beep_and_interrupts(monkeypatch: pytest.MonkeyPatch) -> None: class FakeSQLExecute: def run(self, text: str) -> list[SQLResult]: