Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pylib/cqlshlib/copyutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,7 +1753,8 @@ def format_value(self, val, cqltype):
encoding=self.encoding, colormap=NO_COLOR_MAP, date_time_format=self.date_time_format,
float_precision=cqltype.precision, nullval=self.nullval, quote=False,
decimal_sep=self.decimal_sep, thousands_sep=self.thousands_sep,
boolean_styles=self.boolean_styles)
boolean_styles=self.boolean_styles,
escape_control_chars=False)
return formatted

def close(self):
Expand Down
50 changes: 32 additions & 18 deletions pylib/cqlshlib/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def _turn_bits_red(match):

def format_by_type(val, cqltype, encoding, colormap=None, addcolor=False,
nullval=None, date_time_format=None, float_precision=None,
decimal_sep=None, thousands_sep=None, boolean_styles=None):
decimal_sep=None, thousands_sep=None, boolean_styles=None,
escape_control_chars=True):
if nullval is None:
nullval = default_null_placeholder
if val is None:
Expand All @@ -77,7 +78,7 @@ def format_by_type(val, cqltype, encoding, colormap=None, addcolor=False,
return format_value(val, cqltype=cqltype, encoding=encoding, colormap=colormap,
date_time_format=date_time_format, float_precision=float_precision,
nullval=nullval, decimal_sep=decimal_sep, thousands_sep=thousands_sep,
boolean_styles=boolean_styles)
boolean_styles=boolean_styles, escape_control_chars=escape_control_chars)


def color_text(bval, colormap, displaywidth=None):
Expand Down Expand Up @@ -477,11 +478,16 @@ def decode_zig_zag_64(n):


@formatter_for('str')
def format_value_text(val, encoding, colormap, quote=False, **_):
escapedval = val.replace('\\', '\\\\')
def format_value_text(val, encoding, colormap, quote=False, escape_control_chars=True, **_):
if escape_control_chars:
escapedval = val.replace('\\', '\\\\')
else:
escapedval = val

if quote:
escapedval = escapedval.replace("'", "''")
escapedval = UNICODE_CONTROLCHARS_RE.sub(_show_control_chars, escapedval)
if escape_control_chars:
escapedval = UNICODE_CONTROLCHARS_RE.sub(_show_control_chars, escapedval)
bval = escapedval
if quote:
bval = "'{}'".format(bval)
Expand All @@ -496,11 +502,13 @@ def format_value_text(val, encoding, colormap, quote=False, **_):

def format_simple_collection(val, cqltype, lbracket, rbracket, encoding,
colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles):
decimal_sep, thousands_sep, boolean_styles,
escape_control_chars=True):
subs = [format_value(sval, cqltype=stype, encoding=encoding, colormap=colormap,
date_time_format=date_time_format, float_precision=float_precision,
nullval=nullval, quote=True, decimal_sep=decimal_sep,
thousands_sep=thousands_sep, boolean_styles=boolean_styles)
thousands_sep=thousands_sep, boolean_styles=boolean_styles,
escape_control_chars=escape_control_chars)
for sval, stype in zip(val, cqltype.get_n_sub_types(len(val)))]
bval = lbracket + ', '.join(get_str(sval) for sval in subs) + rbracket
if colormap is NO_COLOR_MAP:
Expand All @@ -515,26 +523,29 @@ def format_simple_collection(val, cqltype, lbracket, rbracket, encoding,

@formatter_for('list')
def format_value_list(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_control_chars=True, **_):
return format_simple_collection(val, cqltype, '[', ']', encoding, colormap,
date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles)
decimal_sep, thousands_sep, boolean_styles,
escape_control_chars=escape_control_chars)


@formatter_for('tuple')
def format_value_tuple(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_control_chars=True, **_):
return format_simple_collection(val, cqltype, '(', ')', encoding, colormap,
date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles)
decimal_sep, thousands_sep, boolean_styles,
escape_control_chars=escape_control_chars)


@formatter_for('set')
def format_value_set(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_control_chars=True, **_):
return format_simple_collection(val, cqltype, '{', '}', encoding, colormap,
date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles)
decimal_sep, thousands_sep, boolean_styles,
escape_control_chars=escape_control_chars)


formatter_for('frozenset')(format_value_set)
Expand All @@ -544,12 +555,13 @@ def format_value_set(val, cqltype, encoding, colormap, date_time_format, float_p

@formatter_for('dict')
def format_value_map(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_control_chars=True, **_):
def subformat(v, t):
return format_value(v, cqltype=t, encoding=encoding, colormap=colormap,
date_time_format=date_time_format, float_precision=float_precision,
nullval=nullval, quote=True, decimal_sep=decimal_sep,
thousands_sep=thousands_sep, boolean_styles=boolean_styles)
thousands_sep=thousands_sep, boolean_styles=boolean_styles,
escape_control_chars=escape_control_chars)

subs = [(subformat(k, cqltype.sub_types[0]), subformat(v, cqltype.sub_types[1])) for (k, v) in sorted(val.items())]
bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}'
Expand All @@ -572,17 +584,19 @@ def subformat(v, t):


def format_value_utype(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_control_chars=True, **_):
def format_field_value(v, t):
if v is None:
return colorme(nullval, colormap, 'error')
return format_value(v, cqltype=t, encoding=encoding, colormap=colormap,
date_time_format=date_time_format, float_precision=float_precision,
nullval=nullval, quote=True, decimal_sep=decimal_sep,
thousands_sep=thousands_sep, boolean_styles=boolean_styles)
thousands_sep=thousands_sep, boolean_styles=boolean_styles,
escape_control_chars=escape_control_chars)

def format_field_name(name):
return format_value_text(name, encoding=encoding, colormap=colormap, quote=False)
return format_value_text(name, encoding=encoding, colormap=colormap, quote=False,
escape_control_chars=escape_control_chars)

subs = [(format_field_name(k), format_field_value(v, t)) for ((k, v), t) in zip(list(val._asdict().items()),
cqltype.sub_types)]
Expand Down
194 changes: 194 additions & 0 deletions pylib/cqlshlib/test/test_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from collections import OrderedDict

from cqlshlib.displaying import NO_COLOR_MAP
from cqlshlib.formatting import (
format_value_text,
format_value_list,
format_value_set,
format_value_tuple,
format_value_map,
format_value_utype,
CqlType
)


class _MockUDT:
""" Mimics the driver's UDT shape (exposes _asdict()) without the
identifier restrictions Python's namedtuple imposes on field names. """
def __init__(self, items):
self._items = items

def _asdict(self):
return OrderedDict(self._items)


class TestFormatting(unittest.TestCase):

def setUp(self):
self.fmt_kwargs = {
'encoding': 'utf-8',
'colormap': NO_COLOR_MAP,
'date_time_format': None,
'float_precision': 3,
'nullval': 'null',
'decimal_sep': '.',
'thousands_sep': ',',
'boolean_styles': None
}

def test_format_value_text_control_chars(self):
"""
Test that control chars AND literal backslashes are escaped for terminal
display (default), but BOTH are preserved verbatim when
escape_control_chars=False is passed (for CSV export).
"""
self.assertEqual(
format_value_text("Hello World", encoding='utf-8', colormap=NO_COLOR_MAP),
"Hello World"
)

test_string = "C:\\Users\\alice\nHello\x00"

terminal_output = format_value_text(test_string, encoding='utf-8', colormap=NO_COLOR_MAP)
self.assertEqual(terminal_output, "C:\\\\Users\\\\alice\\nHello\\x00")

csv_output = format_value_text(test_string, encoding='utf-8', colormap=NO_COLOR_MAP,
escape_control_chars=False)
self.assertEqual(csv_output, test_string)

self.assertIn('C:\\Users', csv_output)
self.assertNotIn('C:\\\\Users', csv_output)

def test_format_value_list_control_chars(self):
""" Test control character propagation in lists """
list_val = ["line1\nline2", "null\x00byte"]
cql_type = CqlType('list<text>')

terminal_output = format_value_list(list_val, cqltype=cql_type, **self.fmt_kwargs)
self.assertEqual(terminal_output, "['line1\\nline2', 'null\\x00byte']")

csv_output = format_value_list(list_val, cqltype=cql_type, escape_control_chars=False, **self.fmt_kwargs)
self.assertEqual(csv_output, "['line1\nline2', 'null\x00byte']")

def test_format_value_map_control_chars(self):
""" Test control character propagation in map keys and values """
map_val = {"key\n1": "val\x001"}
cql_type = CqlType('map<text, text>')

terminal_output = format_value_map(map_val, cqltype=cql_type, **self.fmt_kwargs)
self.assertEqual(terminal_output, "{'key\\n1': 'val\\x001'}")

csv_output = format_value_map(map_val, cqltype=cql_type, escape_control_chars=False, **self.fmt_kwargs)
self.assertEqual(csv_output, "{'key\n1': 'val\x001'}")

def test_udt_field_name_and_value_control_chars(self):
""" Test control character propagation in UDT field names and values """
# The driver exposes UDT instances via an _asdict() shape; namedtuple
# cannot be used here because UDT field names may contain characters
# (e.g. '\n') that are not valid Python identifiers.
udt_val = _MockUDT([('field_a\n', 'val\n1'), ('field_b', 'val\x002')])

cql_type = CqlType('text')
cql_type.sub_types = [CqlType('text'), CqlType('text')]

terminal_output = format_value_utype(udt_val, cqltype=cql_type, **self.fmt_kwargs)
self.assertEqual(terminal_output, "{field_a\\n: 'val\\n1', field_b: 'val\\x002'}")

csv_output = format_value_utype(udt_val, cqltype=cql_type, escape_control_chars=False, **self.fmt_kwargs)
self.assertEqual(csv_output, "{field_a\n: 'val\n1', field_b: 'val\x002'}")

def test_format_value_text_empty_string(self):
""" Empty strings pass through cleanly in both modes (no spurious
characters introduced by the regex sub or the escape pipeline). """
self.assertEqual(
format_value_text("", encoding='utf-8', colormap=NO_COLOR_MAP),
""
)
self.assertEqual(
format_value_text("", encoding='utf-8', colormap=NO_COLOR_MAP, escape_control_chars=False),
""
)

def test_format_value_text_latin1_and_del_control_chars(self):
""" UNICODE_CONTROLCHARS_RE matches [\\x00-\\x1f\\x7f-\\xa0]: in addition
to the common C0 controls, DEL (\\x7f), C1 controls (e.g. \\x80) and
NBSP (\\xa0) must also be escaped on terminals and preserved for CSV. """
test_string = "del\x7fmid\x80end\xa0nbsp"

terminal_output = format_value_text(test_string, encoding='utf-8', colormap=NO_COLOR_MAP)
self.assertEqual(terminal_output, "del\\x7fmid\\x80end\\xa0nbsp")

csv_output = format_value_text(test_string, encoding='utf-8', colormap=NO_COLOR_MAP,
escape_control_chars=False)
self.assertEqual(csv_output, test_string)

def test_format_value_text_consecutive_control_chars(self):
""" A run of adjacent control chars must be escaped/preserved
character-by-character, not collapsed. """
test_string = "a\n\n\x00\x00b"

terminal_output = format_value_text(test_string, encoding='utf-8', colormap=NO_COLOR_MAP)
self.assertEqual(terminal_output, "a\\n\\n\\x00\\x00b")

csv_output = format_value_text(test_string, encoding='utf-8', colormap=NO_COLOR_MAP,
escape_control_chars=False)
self.assertEqual(csv_output, test_string)

def test_format_value_tuple_control_chars(self):
""" format_value_tuple delegates to format_simple_collection; verify
the flag propagates to its element formatters. """
tuple_val = ("a\n", "b\x00")
cql_type = CqlType('tuple<text, text>')

terminal_output = format_value_tuple(tuple_val, cqltype=cql_type, **self.fmt_kwargs)
self.assertEqual(terminal_output, "('a\\n', 'b\\x00')")

csv_output = format_value_tuple(tuple_val, cqltype=cql_type, escape_control_chars=False,
**self.fmt_kwargs)
self.assertEqual(csv_output, "('a\n', 'b\x00')")

def test_format_value_set_control_chars(self):
""" format_value_set delegates to format_simple_collection. A list is
passed here because format_simple_collection just iterates val and
CPython set iteration order depends on PYTHONHASHSEED. """
set_val = ["a\n", "b\x00"]
cql_type = CqlType('set<text>')

terminal_output = format_value_set(set_val, cqltype=cql_type, **self.fmt_kwargs)
self.assertEqual(terminal_output, "{'a\\n', 'b\\x00'}")

csv_output = format_value_set(set_val, cqltype=cql_type, escape_control_chars=False,
**self.fmt_kwargs)
self.assertEqual(csv_output, "{'a\n', 'b\x00'}")

def test_nested_map_of_list_control_chars(self):
""" Two-level nesting (map<text, list<text>>): the flag must propagate
through the outer map's subformat() into the inner list's element
formatters as well. Guards against regressions where the flag is
forwarded at one level but dropped at the next. """
nested_val = {"key\n1": ["v\x001", "v\n2"]}
cql_type = CqlType('map<text, list<text>>')

terminal_output = format_value_map(nested_val, cqltype=cql_type, **self.fmt_kwargs)
self.assertEqual(terminal_output, "{'key\\n1': ['v\\x001', 'v\\n2']}")

csv_output = format_value_map(nested_val, cqltype=cql_type, escape_control_chars=False,
**self.fmt_kwargs)
self.assertEqual(csv_output, "{'key\n1': ['v\x001', 'v\n2']}")