diff --git a/changelog.md b/changelog.md
index 3b093e74..ea5d094e 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,6 +1,11 @@
Upcoming (TBD)
==============
+Features
+---------
+* Allow styling prompts with HTML-like tags.
+
+
Documentation
---------
* Give example for ANSI prompt colors in `~/.myclirc`.
diff --git a/mycli/main.py b/mycli/main.py
index 1c0b5e4a..01ffa72a 100755
--- a/mycli/main.py
+++ b/mycli/main.py
@@ -70,7 +70,7 @@
from mycli.main_modes.execute import main_execute_from_cli
from mycli.main_modes.list_dsn import main_list_dsn
from mycli.main_modes.list_ssh_config import main_list_ssh_config
-from mycli.main_modes.repl import get_prompt, main_repl, set_all_external_titles
+from mycli.main_modes.repl import main_repl, render_prompt_string, set_all_external_titles
from mycli.packages import special
from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme
from mycli.packages.filepaths import dir_path_exists, guess_socket_location
@@ -268,8 +268,8 @@ def __init__(
self.min_completion_trigger = c["main"].as_int("min_completion_trigger")
# a hack, pending a better way to handle settings and state
repl_package.MIN_COMPLETION_TRIGGER = self.min_completion_trigger
- self.last_prompt_message = ANSI('')
- self.last_custom_toolbar_message = ANSI('')
+ self.last_prompt_message = to_formatted_text('')
+ self.last_custom_toolbar_message = to_formatted_text('')
# Register custom special commands.
self.register_special_commands()
@@ -907,8 +907,9 @@ def get_output_margin(self, status: str | None = None) -> int:
render_counter = self.prompt_session.app.render_counter
else:
render_counter = 0
- # todo: this jump back to get_prompt() in repl.py is a sign that separation is incomplete
- self.prompt_lines = get_prompt(self, self.prompt_format, render_counter).count('\n') + 1
+ # todo: this jump back to render_prompt_string() in repl.py is a sign that separation is incomplete
+ prompt_string = render_prompt_string(self, self.prompt_format, render_counter)
+ self.prompt_lines = to_plain_text(prompt_string).count('\n') + 1
margin = self.get_reserved_space() + self.prompt_lines
if special.is_timing_enabled():
margin += 1
diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py
index da8f148a..10b8df9c 100644
--- a/mycli/main_modes/repl.py
+++ b/mycli/main_modes/repl.py
@@ -4,6 +4,7 @@
from datetime import datetime
import functools
from functools import partial
+import html
from importlib import resources
import os
import random
@@ -13,6 +14,7 @@
import time
import traceback
from typing import TYPE_CHECKING, Any, Generator
+from xml.parsers.expat import ExpatError
import click
import prompt_toolkit
@@ -23,6 +25,10 @@
from prompt_toolkit.filters import Condition, has_focus, is_done
from prompt_toolkit.formatted_text import (
ANSI,
+ HTML,
+ FormattedText,
+ to_formatted_text,
+ to_plain_text,
)
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor
@@ -162,7 +168,13 @@ def set_external_terminal_tab_title(mycli: 'MyCli') -> None:
return
if not sys.stderr.isatty():
return
- title = sanitize_terminal_title(get_prompt(mycli, mycli.terminal_tab_title_format, mycli.prompt_session.app.render_counter))
+ title = sanitize_terminal_title(
+ render_prompt_string(
+ mycli,
+ mycli.terminal_tab_title_format,
+ mycli.prompt_session.app.render_counter,
+ )
+ )
print(f'\x1b]1;{title}\a', file=sys.stderr, end='')
sys.stderr.flush()
@@ -174,7 +186,13 @@ def set_external_terminal_window_title(mycli: 'MyCli') -> None:
return
if not sys.stderr.isatty():
return
- title = sanitize_terminal_title(get_prompt(mycli, mycli.terminal_window_title_format, mycli.prompt_session.app.render_counter))
+ title = sanitize_terminal_title(
+ render_prompt_string(
+ mycli,
+ mycli.terminal_window_title_format,
+ mycli.prompt_session.app.render_counter,
+ )
+ )
print(f'\x1b]2;{title}\a', file=sys.stderr, end='')
sys.stderr.flush()
@@ -186,7 +204,13 @@ def set_external_multiplex_window_title(mycli: 'MyCli') -> None:
return
if not mycli.prompt_session:
return
- title = sanitize_terminal_title(get_prompt(mycli, mycli.multiplex_window_title_format, mycli.prompt_session.app.render_counter))
+ title = sanitize_terminal_title(
+ render_prompt_string(
+ mycli,
+ mycli.multiplex_window_title_format,
+ mycli.prompt_session.app.render_counter,
+ )
+ )
try:
subprocess.run(
['tmux', 'rename-window', title],
@@ -208,7 +232,13 @@ def set_external_multiplex_pane_title(mycli: 'MyCli') -> None:
return
if not sys.stderr.isatty():
return
- title = sanitize_terminal_title(get_prompt(mycli, mycli.multiplex_pane_title_format, mycli.prompt_session.app.render_counter))
+ title = sanitize_terminal_title(
+ render_prompt_string(
+ mycli,
+ mycli.multiplex_pane_title_format,
+ mycli.prompt_session.app.render_counter,
+ )
+ )
print(f'\x1b]2;{title}\x1b\\', file=sys.stderr, end='')
sys.stderr.flush()
@@ -216,25 +246,33 @@ def set_external_multiplex_pane_title(mycli: 'MyCli') -> None:
def get_custom_toolbar(
mycli: 'MyCli',
toolbar_format: str,
-) -> ANSI:
+) -> FormattedText:
if not mycli.prompt_session:
- return ANSI('')
+ return to_formatted_text('')
if not mycli.prompt_session.app:
- return ANSI('')
+ return to_formatted_text('')
if mycli.prompt_session.app.current_buffer.text:
return mycli.last_custom_toolbar_message
- toolbar = get_prompt(mycli, toolbar_format, mycli.prompt_session.app.render_counter)
- toolbar = toolbar.replace('\\x1b', '\x1b')
- mycli.last_custom_toolbar_message = ANSI(toolbar)
+ mycli.last_custom_toolbar_message = render_prompt_string(
+ mycli,
+ toolbar_format,
+ mycli.prompt_session.app.render_counter,
+ )
return mycli.last_custom_toolbar_message
+def maybe_html_escape(string: str, is_html: bool) -> str:
+ if is_html:
+ return html.escape(string, quote=False)
+ return string
+
+
@functools.lru_cache(maxsize=256)
-def get_prompt(
+def render_prompt_string(
mycli: 'MyCli',
string: str,
_render_counter: int,
-) -> str:
+) -> FormattedText:
sqlexecute = mycli.sqlexecute
assert sqlexecute is not None
if mycli.login_path and mycli.login_path_as_host:
@@ -247,79 +285,106 @@ def get_prompt(
if re.match(r'^[\d\.]+$', short_prompt_host):
short_prompt_host = prompt_host
now = datetime.now()
- backslash_placeholder = '\ufffc_backslash'
- string = string.replace('\\\\', backslash_placeholder)
- string = string.replace('\\u', sqlexecute.user or '(none)')
- string = string.replace('\\h', prompt_host or '(none)')
- string = string.replace('\\H', short_prompt_host or '(none)')
- string = string.replace('\\d', sqlexecute.dbname or '(none)')
species_name = sqlexecute.server_info.species.name if sqlexecute.server_info and sqlexecute.server_info.species else 'MySQL'
- string = string.replace('\\t', species_name)
- string = string.replace('\\n', '\n')
- string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
- string = string.replace('\\m', now.strftime('%M'))
- string = string.replace('\\P', now.strftime('%p'))
- string = string.replace('\\R', now.strftime('%H'))
- string = string.replace('\\r', now.strftime('%I'))
- string = string.replace('\\s', now.strftime('%S'))
- string = string.replace('\\p', str(sqlexecute.port))
- string = string.replace('\\j', os.path.basename(sqlexecute.socket or '(none)'))
- string = string.replace('\\J', sqlexecute.socket or '(none)')
- string = string.replace('\\k', os.path.basename(sqlexecute.socket or str(sqlexecute.port)))
- string = string.replace('\\K', sqlexecute.socket or str(sqlexecute.port))
- string = string.replace('\\A', mycli.dsn_alias or '(none)')
- string = string.replace('\\_', ' ')
- string = string.replace(backslash_placeholder, '\\')
-
+ strings = string.split('\\\\')
+ is_html = strings[0].startswith('\\')
+ strings = [x.replace('\\u', maybe_html_escape(sqlexecute.user or '(none)', is_html)) for x in strings]
+ strings = [x.replace('\\h', maybe_html_escape(prompt_host or '(none)', is_html)) for x in strings]
+ strings = [x.replace('\\H', maybe_html_escape(short_prompt_host or '(none)', is_html)) for x in strings]
+ strings = [x.replace('\\d', maybe_html_escape(sqlexecute.dbname or '(none)', is_html)) for x in strings]
+ strings = [x.replace('\\t', maybe_html_escape(species_name, is_html)) for x in strings]
+ strings = [x.replace('\\n', '\n') for x in strings]
+ strings = [x.replace('\\D', maybe_html_escape(now.strftime('%a %b %d %H:%M:%S %Y'), is_html)) for x in strings]
+ strings = [x.replace('\\m', maybe_html_escape(now.strftime('%M'), is_html)) for x in strings]
+ strings = [x.replace('\\P', maybe_html_escape(now.strftime('%p'), is_html)) for x in strings]
+ strings = [x.replace('\\R', maybe_html_escape(now.strftime('%H'), is_html)) for x in strings]
+ strings = [x.replace('\\r', maybe_html_escape(now.strftime('%I'), is_html)) for x in strings]
+ strings = [x.replace('\\s', maybe_html_escape(now.strftime('%S'), is_html)) for x in strings]
+ strings = [x.replace('\\p', maybe_html_escape(str(sqlexecute.port), is_html)) for x in strings]
+ strings = [
+ x.replace('\\j', maybe_html_escape(os.path.basename(sqlexecute.socket or '(none)').replace('\\', '/'), is_html)) for x in strings
+ ]
+ strings = [x.replace('\\J', maybe_html_escape((sqlexecute.socket or '(none)').replace('\\', '/'), is_html)) for x in strings]
+ strings = [
+ x.replace('\\k', maybe_html_escape(os.path.basename(sqlexecute.socket or str(sqlexecute.port)).replace('\\', '/'), is_html))
+ for x in strings
+ ]
+ strings = [
+ x.replace('\\K', maybe_html_escape((sqlexecute.socket or str(sqlexecute.port)).replace('\\', '/'), is_html)) for x in strings
+ ]
+ strings = [x.replace('\\A', maybe_html_escape(mycli.dsn_alias or '(none)', is_html)) for x in strings]
+ strings = [x.replace('\\_', ' ') for x in strings]
+
+ checker_string = ' '.join(strings)
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
- if '\\y' in string:
+ if '\\y' in checker_string:
with sqlexecute.conn.cursor() as cur:
- string = string.replace('\\y', str(get_uptime(cur)) or '(none)')
- if '\\Y' in string:
+ strings = [x.replace('\\y', maybe_html_escape(str(get_uptime(cur)) or '(none)', is_html)) for x in strings]
+ if '\\Y' in checker_string:
with sqlexecute.conn.cursor() as cur:
- string = string.replace('\\Y', format_uptime(str(get_uptime(cur))) or '(none)')
+ strings = [x.replace('\\Y', maybe_html_escape(format_uptime(str(get_uptime(cur))) or '(none)', is_html)) for x in strings]
else:
- string = string.replace('\\y', '(none)')
- string = string.replace('\\Y', '(none)')
+ strings = [x.replace('\\y', '(none)') for x in strings]
+ strings = [x.replace('\\Y', '(none)') for x in strings]
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
- if '\\T' in string:
+ if '\\T' in checker_string:
with sqlexecute.conn.cursor() as cur:
- string = string.replace('\\T', get_ssl_version(cur) or '(none)')
+ strings = [x.replace('\\T', maybe_html_escape(get_ssl_version(cur) or '(none)', is_html)) for x in strings]
else:
- string = string.replace('\\T', '(none)')
+ strings = [x.replace('\\T', '(none)') for x in strings]
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
- if '\\w' in string:
+ if '\\w' in checker_string:
with sqlexecute.conn.cursor() as cur:
- string = string.replace('\\w', str(get_warning_count(cur) or '(none)'))
+ strings = [x.replace('\\w', maybe_html_escape(str(get_warning_count(cur) or '(none)'), is_html)) for x in strings]
else:
- string = string.replace('\\w', '(none)')
+ strings = [x.replace('\\w', '(none)') for x in strings]
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
- if '\\W' in string:
+ if '\\W' in checker_string:
with sqlexecute.conn.cursor() as cur:
- string = string.replace('\\W', str(get_warning_count(cur) or ''))
+ strings = [x.replace('\\W', maybe_html_escape(str(get_warning_count(cur) or ''), is_html)) for x in strings]
else:
- string = string.replace('\\W', '')
+ strings = [x.replace('\\W', '') for x in strings]
- return string
+ if is_html:
+ strings[0] = strings[0].removeprefix('\\')
+ strings[-1] = strings[-1].removesuffix('\\')
+ elif '\\x1b' in checker_string:
+ strings = [x.replace('\\x1b', '\x1b') for x in strings]
+
+ strings = [re.sub(r'\\(.)', r'(unknown prompt format string: \\\1)', x) for x in strings]
+
+ string = '\\'.join(strings)
+
+ if is_html:
+ try:
+ formatted_string = to_formatted_text(HTML(string))
+ except (ExpatError, ValueError):
+ formatted_string = to_formatted_text(HTML('(cannot parse HTML prompt string)'))
+ else:
+ formatted_string = to_formatted_text(ANSI(string))
+
+ return formatted_string
def _get_prompt_message(
mycli: 'MyCli',
app: prompt_toolkit.application.application.Application,
-) -> ANSI:
+) -> FormattedText:
if app.current_buffer.text:
return mycli.last_prompt_message
- prompt = get_prompt(mycli, mycli.prompt_format, app.render_counter)
- if mycli.prompt_format == mycli.default_prompt and len(prompt) > mycli.max_len_prompt:
- prompt = get_prompt(mycli, mycli.default_prompt_splitln, app.render_counter)
- mycli.prompt_lines = prompt.count('\n') + 1
- prompt = prompt.replace('\\x1b', '\x1b')
+ prompt = render_prompt_string(mycli, mycli.prompt_format, app.render_counter)
+ prompt_plain = to_plain_text(prompt)
+ if mycli.prompt_format == mycli.default_prompt and len(prompt_plain) > mycli.max_len_prompt:
+ prompt = render_prompt_string(mycli, mycli.default_prompt_splitln, app.render_counter)
+ prompt_plain = to_plain_text(prompt)
+ mycli.prompt_lines = prompt_plain.count('\n') + 1
if not mycli.prompt_lines:
- mycli.prompt_lines = prompt.count('\n') + 1
- mycli.last_prompt_message = ANSI(prompt)
+ mycli.prompt_lines = prompt_plain.count('\n') + 1
+
+ mycli.last_prompt_message = prompt
return mycli.last_prompt_message
diff --git a/mycli/myclirc b/mycli/myclirc
index 61d027bb..3d9a156a 100644
--- a/mycli/myclirc
+++ b/mycli/myclirc
@@ -138,6 +138,12 @@ wider_completion_menu = False
# * \\ - a literal backslash
# * \x1b[...m - an ANSI escape sequence (can style with color or attributes)
# ANSI color example: prompt = '\x1b[31mroot\x1b[0m@localhost:\d> '
+# * \ - a leading sequence indicating that the rest of the prompt be styled like HTML.
+# See https://python-prompt-toolkit.readthedocs.io/en/stable/pages/printing_text.html#html .
+# Characters such as "&" or literal "<" and ">" must be HTML-escaped in this mode.
+# HTML styles cannot be combined with ANSI sequences. HTML mode takes precedence.
+# HTML color example: prompt = '\root@localhost:\d> '
+#
prompt = '\t \u@\h:\d> '
prompt_continuation = '->'
diff --git a/mycli/packages/string_utils.py b/mycli/packages/string_utils.py
index 89402ad5..56103330 100644
--- a/mycli/packages/string_utils.py
+++ b/mycli/packages/string_utils.py
@@ -1,10 +1,15 @@
import re
from cli_helpers.utils import strip_ansi
+from prompt_toolkit.formatted_text import (
+ FormattedText,
+ to_plain_text,
+)
-def sanitize_terminal_title(title: str) -> str:
- sanitized = strip_ansi(title)
+def sanitize_terminal_title(title: FormattedText) -> str:
+ sanitized = to_plain_text(title)
+ sanitized = strip_ansi(sanitized)
sanitized = sanitized.replace('\n', ' ')
sanitized = re.sub('[\x00-\x1f\x7f]', '', sanitized)
return sanitized
diff --git a/test/myclirc b/test/myclirc
index 15fb4547..f7d0ac1a 100644
--- a/test/myclirc
+++ b/test/myclirc
@@ -125,17 +125,23 @@ wider_completion_menu = False
# * \K - full connection socket path OR the port
# * \T - connection SSL/TLS version
# * \t - database vendor (Percona, MySQL, MariaDB, TiDB)
+# * \u - username
# * \w - number of warnings, or "(none)" (requires frequent trips to the server)
-# * \W - number of warnings, or the empty string (requires frequent trips to the server)
+# * \W - number of warnings, or the empty string (requires frequent trips to the server)
# * \y - uptime in seconds (requires frequent trips to the server)
# * \Y - uptime in words (requires frequent trips to the server)
-# * \u - username
# * \A - DSN alias
# * \n - a newline
# * \_ - a space
# * \\ - a literal backslash
# * \x1b[...m - an ANSI escape sequence (can style with color or attributes)
# ANSI color example: prompt = '\x1b[31mroot\x1b[0m@localhost:\d> '
+# * \ - a leading sequence indicating that the rest of the prompt be styled like HTML.
+# See https://python-prompt-toolkit.readthedocs.io/en/stable/pages/printing_text.html#html .
+# Characters such as "&" or literal "<" and ">" must be HTML-escaped in this mode
+# HTML styles cannot be combined with ANSI sequences. HTML mode takes precedence.
+# HTML color example: prompt = '\root@localhost:\d> '
+#
prompt = "\t \u@\h:\d> "
prompt_continuation = ->
diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py
index 295e6987..949a9cda 100644
--- a/test/pytests/test_main.py
+++ b/test/pytests/test_main.py
@@ -13,6 +13,11 @@
import click
from click.testing import CliRunner
+from prompt_toolkit.formatted_text import (
+ FormattedText,
+ to_formatted_text,
+ to_plain_text,
+)
import pymysql
from pymysql.err import OperationalError
import pytest
@@ -391,8 +396,9 @@ def test_prompt_no_host_only_socket(executor):
mycli.sqlexecute.user = DEFAULT_USER
mycli.sqlexecute.dbname = DEFAULT_DATABASE
mycli.sqlexecute.port = DEFAULT_PORT
- prompt = repl_mode.get_prompt(mycli, mycli.prompt_format, 0)
- assert prompt == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_DATABASE}> "
+ prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0)
+ prompt_plain = to_plain_text(prompt)
+ assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_DATABASE}> "
@dbtest
@@ -406,8 +412,9 @@ def test_prompt_socket_overrides_port(executor):
mycli.sqlexecute.user = DEFAULT_USER
mycli.sqlexecute.dbname = DEFAULT_DATABASE
mycli.sqlexecute.port = DEFAULT_PORT
- prompt = repl_mode.get_prompt(mycli, mycli.prompt_format, 0)
- assert prompt == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:mysqld.sock {DEFAULT_DATABASE}> "
+ prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0)
+ prompt_plain = to_plain_text(prompt)
+ assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:mysqld.sock {DEFAULT_DATABASE}> "
@dbtest
@@ -421,8 +428,9 @@ def test_prompt_socket_short_host(executor):
mycli.sqlexecute.user = DEFAULT_USER
mycli.sqlexecute.dbname = DEFAULT_DATABASE
mycli.sqlexecute.port = DEFAULT_PORT
- prompt = repl_mode.get_prompt(mycli, mycli.prompt_format, 0)
- assert prompt == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_PORT} {DEFAULT_DATABASE}> "
+ prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0)
+ prompt_plain = to_plain_text(prompt)
+ assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_PORT} {DEFAULT_DATABASE}> "
@dbtest
@@ -2252,11 +2260,11 @@ def test_get_output_margin_uses_prompt_session_render_counter(monkeypatch: pytes
SimpleNamespace(app=SimpleNamespace(render_counter=7)),
)
- def fake_get_prompt(mycli: Any, string: str, render_counter: int) -> str:
+ def fake_render_prompt_string(mycli: Any, string: str, render_counter: int) -> FormattedText:
render_counters.append(render_counter)
- return 'line1\nline2'
+ return to_formatted_text('line1\nline2')
- monkeypatch.setattr(main, 'get_prompt', fake_get_prompt)
+ monkeypatch.setattr(main, 'render_prompt_string', fake_render_prompt_string)
monkeypatch.setattr(main.special, 'is_timing_enabled', lambda: False)
assert main.MyCli.get_output_margin(cli, 'ok') == 5
assert render_counters == [7]
diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py
index 81f470a4..dd44fc2e 100644
--- a/test/pytests/test_main_modes_repl.py
+++ b/test/pytests/test_main_modes_repl.py
@@ -8,7 +8,7 @@
from types import SimpleNamespace
from typing import Any, Literal, cast
-from prompt_toolkit.formatted_text import to_plain_text
+from prompt_toolkit.formatted_text import to_formatted_text, to_plain_text
import pymysql
import pytest
@@ -335,8 +335,8 @@ def test_repl_show_startup_banner_and_prompt_helpers(monkeypatch: pytest.MonkeyP
monkeypatch.setattr(
repl_mode,
- 'get_prompt',
- lambda mycli, string, render_counter: '0123456' if string == cli.default_prompt else 'a\nb',
+ 'render_prompt_string',
+ lambda mycli, string, render_counter: to_formatted_text('0123456') if string == cli.default_prompt else 'a\nb',
)
cli.max_len_prompt = 5
prompt_text = to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=2))))
@@ -348,7 +348,7 @@ def test_repl_show_startup_banner_and_prompt_helpers(monkeypatch: pytest.MonkeyP
cli.prompt_format = 'custom'
cli.prompt_lines = 0
- monkeypatch.setattr(repl_mode, 'get_prompt', lambda mycli, string, render_counter: 'single')
+ monkeypatch.setattr(repl_mode, 'render_prompt_string', lambda mycli, string, render_counter: to_formatted_text('single'))
assert to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=4)))) == 'single'
assert cli.prompt_lines == 1
@@ -397,8 +397,9 @@ def cursor(self) -> PromptCursor:
cli.login_path = 'prod'
cli.login_path_as_host = True
cli.dsn_alias = 'dsn'
- prompt = repl_mode.get_prompt(cli, r'\h|\H|\A|\y|\Y|\T|\w|\W', 0)
- assert prompt == 'prod|prod|dsn|(none)|(none)|(none)|(none)|'
+ prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\A|\y|\Y|\T|\w|\W', 0)
+ prompt_plain = to_plain_text(prompt)
+ assert prompt_plain == 'prod|prod|dsn|(none)|(none)|(none)|(none)|'
sqlexecute.conn = PromptConnection()
cli.login_path_as_host = False
@@ -406,8 +407,9 @@ def cursor(self) -> PromptCursor:
monkeypatch.setattr(repl_mode, 'format_uptime', lambda uptime: f'uptime:{uptime}')
monkeypatch.setattr(repl_mode, 'get_ssl_version', lambda cur: 'TLSv1.3')
monkeypatch.setattr(repl_mode, 'get_warning_count', lambda cur: 7)
- prompt = repl_mode.get_prompt(cli, r'\H|\y|\Y|\T|\w|\W', 1)
- assert prompt == '127.0.0.1|123|uptime:123|TLSv1.3|7|7'
+ prompt = repl_mode.render_prompt_string(cli, r'\H|\y|\Y|\T|\w|\W', 1)
+ prompt_plain = to_plain_text(prompt)
+ assert prompt_plain == '127.0.0.1|123|uptime:123|TLSv1.3|7|7'
cli.prompt_session = None
assert to_plain_text(repl_mode.get_custom_toolbar(cli, 'fmt')) == ''
@@ -420,7 +422,7 @@ def cursor(self) -> PromptCursor:
assert repl_mode.get_custom_toolbar(cli, 'fmt') == cli.last_custom_toolbar_message
cli.prompt_session.app.current_buffer.text = ''
- monkeypatch.setattr(repl_mode, 'get_prompt', lambda mycli, string, render_counter: f'title:{string}')
+ monkeypatch.setattr(repl_mode, 'render_prompt_string', lambda mycli, string, render_counter: f'title:{string}')
assert 'title:fmt' in str(repl_mode.get_custom_toolbar(cli, 'fmt'))
cli.terminal_tab_title_format = 'tab'
@@ -481,14 +483,17 @@ def cursor(self) -> PromptCursor:
monkeypatch.setattr(repl_mode, 'get_ssl_version', lambda cur: 'TLSv1.3')
monkeypatch.setattr(repl_mode, 'get_warning_count', lambda cur: 7)
- prompt = repl_mode.get_prompt(cli, r'\h|\H|\y|\Y', 0)
- assert prompt == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|123|uptime:123'
+ prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\y|\Y', 0)
+ prompt_plain = to_plain_text(prompt)
+ assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|123|uptime:123'
- prompt = repl_mode.get_prompt(cli, r'\h|\H|\w|\W', 1)
- assert prompt == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|7|7'
+ prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\w|\W', 1)
+ prompt_plain = to_plain_text(prompt)
+ assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|7|7'
- prompt = repl_mode.get_prompt(cli, r'\h|\H|\T', 2)
- assert prompt == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|TLSv1.3'
+ prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\T', 2)
+ prompt_plain = to_plain_text(prompt)
+ assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|TLSv1.3'
monkeypatch.setattr(repl_mode.sys.stderr, 'isatty', lambda: True)
monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError('unexpected print')))
@@ -530,6 +535,57 @@ def cursor(self) -> PromptCursor:
repl_mode.set_external_multiplex_pane_title(cli)
+def test_maybe_html_escape() -> None:
+ assert repl_mode.maybe_html_escape('plain', False) == 'plain'
+ assert repl_mode.maybe_html_escape('a&b<1>', True) == 'a&b<1>'
+
+
+def test_render_prompt_string_html() -> None:
+ repl_mode.render_prompt_string.cache_clear()
+
+ cli = make_repl_cli(
+ SimpleNamespace(
+ user='ab',
+ host='db.example.com',
+ dbname='nameprod',
+ port=3306,
+ socket=None,
+ server_info=SimpleNamespace(species=SimpleNamespace(name='MySQL')),
+ conn=None,
+ )
+ )
+ cli.dsn_alias = 'aliasone'
+
+ html_prompt = repl_mode.render_prompt_string(cli, r'\\u@\d|\A\', 0)
+ assert to_plain_text(html_prompt) == 'ab@nameprod|aliasone'
+
+ bad_html_prompt = repl_mode.render_prompt_string(cli, r'\\u', 1)
+ assert to_plain_text(bad_html_prompt) == '(cannot parse HTML prompt string)'
+
+ ansi_prompt = repl_mode.render_prompt_string(cli, r'\x1b[31mred\x1b[0m', 2)
+ assert to_plain_text(ansi_prompt) == 'red'
+
+
+def test_render_prompt_string_ansi() -> None:
+ repl_mode.render_prompt_string.cache_clear()
+
+ cli = make_repl_cli(
+ SimpleNamespace(
+ user='ab',
+ host='db.example.com',
+ dbname='nameprod',
+ port=3306,
+ socket=None,
+ server_info=SimpleNamespace(species=SimpleNamespace(name='MySQL')),
+ conn=None,
+ )
+ )
+ cli.dsn_alias = 'aliasone'
+
+ ansi_prompt = repl_mode.render_prompt_string(cli, r'\x1b[31mred\x1b[0m', 2)
+ assert to_plain_text(ansi_prompt) == 'red'
+
+
def test_output_results_covers_watch_warning_timing_beep_and_interrupts(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeSQLExecute:
def run(self, text: str) -> list[SQLResult]: