Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions pgtricks/pg_dump_splitsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,23 @@
MEMORY_UNITS = {"": 1, "k": KIBIBYTE, "m": MEBIBYTE, "g": GIBIBYTE}


def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]:
def try_float(s1: str, s2: str) -> tuple[float, float]:
"""Convert two strings to floats. Return original ones on conversion error."""
if not s1 or not s2 or s1[0] not in '0123456789.-' or s2[0] not in '0123456789.-':
# optimization
return s1, s2
try:
return float(s1), float(s2)
except ValueError:
return s1, s2
raise ValueError
return float(s1), float(s2)


def linecomp(l1: str, l2: str) -> int:
p1 = l1.split('\t', 1)
p2 = l2.split('\t', 1)
# TODO: unquote cast after support for Python 3.8 is dropped
v1, v2 = cast("tuple[float, float]", try_float(p1[0], p2[0]))
result = (v1 > v2) - (v1 < v2)
# modifying a line to see whether Darker works:
try:
v1, v2 = try_float(p1[0], p2[0])
result = (v1 > v2) - (v1 < v2)
except ValueError:
result = (p1[0] > p2[0]) - (p1[0] < p2[0])
if not result and len(p1) == len(p2) == 2:
return linecomp(p1[1], p2[1])
return result
Expand Down
44 changes: 24 additions & 20 deletions pgtricks/tests/test_pg_dump_splitsort.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import cmp_to_key
from textwrap import dedent

Expand Down Expand Up @@ -36,29 +37,32 @@ def test_sql_copy_regular_expression(test_input, expected):
@pytest.mark.parametrize(
's1, s2, expect',
[
('', '', ('', '')),
('foo', '', ('foo', '')),
('foo', 'bar', ('foo', 'bar')),
('0', '1', (0.0, 1.0)),
('0', 'one', ('0', 'one')),
('0.0', '0.0', (0.0, 0.0)),
('0.0', 'one point zero', ('0.0', 'one point zero')),
('0.', '1.', (0.0, 1.0)),
('0.', 'one', ('0.', 'one')),
('4.2', '0.42', (4.2, 0.42)),
('4.2', 'four point two', ('4.2', 'four point two')),
('-.42', '-0.042', (-0.42, -0.042)),
('-.42', 'minus something', ('-.42', 'minus something')),
(r'\N', r'\N', (r'\N', r'\N')),
('foo', r'\N', ('foo', r'\N')),
('-4.2', r'\N', ('-4.2', r'\N')),
("", "", ValueError),
("foo", "", ValueError),
("foo", "bar", ValueError),
("0", "1", (0.0, 1.0)),
("0", "one", ValueError),
("0.0", "0.0", (0.0, 0.0)),
("0.0", "one point zero", ValueError),
("0.", "1.", (0.0, 1.0)),
("0.", "one", ValueError),
("4.2", "0.42", (4.2, 0.42)),
("4.2", "four point two", ValueError),
("-.42", "-0.042", (-0.42, -0.042)),
("-.42", "minus something", ValueError),
(r"\N", r"\N", ValueError),
("foo", r"\N", ValueError),
("-4.2", r"\N", ValueError),
],
)
def test_try_float(s1, s2, expect):
result1, result2 = try_float(s1, s2)
assert type(result1) is type(expect[0])
assert type(result2) is type(expect[1])
assert (result1, result2) == expect
with pytest.raises(expect) if expect is ValueError else nullcontext():

result1, result2 = try_float(s1, s2)

assert type(result1) is type(expect[0])
assert type(result2) is type(expect[1])
assert (result1, result2) == expect


@pytest.mark.parametrize(
Expand Down