Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/beanahead/expired.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import timedelta
from pathlib import Path
import re
import sys

from beancount.core.data import Transaction

Expand Down Expand Up @@ -129,7 +130,8 @@ def _update_txn(txn: Transaction, path: Path) -> Transaction | None:
f"\n0 Move transaction forwards to tomorrow ({TOMORROW})."
f"\n1 Move transaction forwards to another date."
f"\n2 Remove transaction from ledger {path.stem}."
f"\n3 Leave transaction as is."
f"\n3 Leave transaction as is.",
file=sys.stderr,
)
response: str = utils.get_input("Choose one of the above options, [0-3]:")
while not utils.response_is_valid_number(response, 3):
Expand Down Expand Up @@ -247,7 +249,8 @@ def admin_expired_txns(ledgers: list[str]):
paths_string = "\n".join([str(path) for path in paths])
print(
"There are no expired transactions on any of the following"
f" ledgers:\n{paths_string}"
f" ledgers:\n{paths_string}",
file=sys.stderr,
)
return

Expand All @@ -256,7 +259,8 @@ def admin_expired_txns(ledgers: list[str]):
if not updated_paths:
print(
"\nYou have not choosen to modify any expired transactions."
"\nNo ledger has been altered."
"\nNo ledger has been altered.",
file=sys.stderr,
)
return

Expand All @@ -266,4 +270,6 @@ def admin_expired_txns(ledgers: list[str]):
updated_contents[path] = content

overwrite_ledgers(updated_contents)
print(f"\nThe following ledgers have been updated:\n{paths_string}")
print(
f"\nThe following ledgers have been updated:\n{paths_string}", file=sys.stderr
)
72 changes: 72 additions & 0 deletions src/beanahead/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from beancount.core.data import Transaction, Entries
from beancount.ingest import extract

from . import utils
from . import reconcile
from .errors import BeanaheadWriteError


class ReconcileExpected:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could a simple test be added to verify the hook is working as intended?

"""
Hook class for smart_importer to reconcile expected entries on the fly.

You also need to use the adapted duplicate hook to avoid false positives
using the new style import invocation:

...
hools = [ReconcileExpected.adapted_find_duplicate_entries]
beancount.ingest.scripts_utils.ingest(CONFIG, hooks=hooks)
"""

def __init__(self, x_txns_file):
path = utils.get_verified_path(x_txns_file)
utils.set_root_accounts_context(x_txns_file)
_ = utils.get_verified_ledger_file_key(path) # just verify that a ledger
self.expected_txns_path = path
self.expected_txns: list[Transaction] = utils.get_unverified_txns(path)

def __call__(self, importer, file, imported_entries, existing_entries) -> Entries:
"""Apply the hook and modify the imported entries.

Args:
importer: The importer that this hooks is being applied to.
file: The file that is being imported.
imported_entries: The current list of imported entries.
existing_entries: The existing entries, as passed to the extract
function.

Returns:
The updated imported entries.
"""
new_txns, new_other = reconcile.separate_out_txns(imported_entries)
reconciled_x_txns = reconcile.reconcile_x_txns(self.expected_txns, new_txns)

updated_new_txns = reconcile.update_new_txns(new_txns, reconciled_x_txns)
updated_entries = updated_new_txns + new_other

# Update expected transation file
x_txns_to_remove = []
for x_txn, _ in reconciled_x_txns:
if x_txn in self.expected_txns:
x_txns_to_remove.append(x_txn)

prev_contents = utils.get_content(self.expected_txns_path)
try:
utils.remove_txns_from_ledger(self.expected_txns_path, x_txns_to_remove)
except Exception as err:
utils.write(self.expected_txns_path, prev_contents)
raise BeanaheadWriteError(
self.expected_txns_path, [self.expected_txns_path]
) from err

return updated_entries

@staticmethod
def adapted_find_duplicate_entries(new_entries_list, existing_entries):
keep = []
# filter out expected transactions from duplicate detection
for entry in existing_entries:
if isinstance(entry, Transaction) and utils.TAGS_X & entry.tags:
continue
keep.append(entry)
return extract.find_duplicate_entries(new_entries_list, keep)
22 changes: 16 additions & 6 deletions src/beanahead/reconcile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from decimal import Decimal
from pathlib import Path
import re
import sys

import beancount
from beancount import loader
Expand Down Expand Up @@ -147,7 +148,9 @@ def get_pattern(x_txn: Transaction) -> re.Pattern:
def get_payee_matches(txns: list[Transaction], x_txn: Transaction) -> list[Transaction]:
"""Return transactions matching an Expected Transaction's payee."""
pattern = get_pattern(x_txn)
return [txn for txn in txns if pattern.search(txn.payee) is not None]
return [
txn for txn in txns if (txn.payee and pattern.search(txn.payee) is not None)
]


def get_common_accounts(a: Transaction, b: Transaction) -> set[str]:
Expand Down Expand Up @@ -362,7 +365,8 @@ def confirm_single(
print(
f"{utils.SEPARATOR_LINE}Expected Transaction:\n"
f"{utils.compose_entries_content(x_txn)}\n"
f"Incoming Transaction:\n{utils.compose_entries_content(matches[0])}"
f"Incoming Transaction:\n{utils.compose_entries_content(matches[0])}",
file=sys.stderr,
)
response = utils.get_input(MSG_SINGLE_MATCH).lower()
while response not in ["n", "y"]:
Expand Down Expand Up @@ -392,10 +396,11 @@ def get_mult_match(
print(
f"{utils.SEPARATOR_LINE}Expected Transaction:\n"
f"{utils.compose_entries_content(x_txn)}\n\n"
f"Incoming Transactions:\n"
f"Incoming Transactions:\n",
file=sys.stderr,
)
for i, match in enumerate(matches):
print(f"{i}\n{utils.compose_entries_content(match)}")
print(f"{i}\n{utils.compose_entries_content(match)}", file=sys.stderr)

max_value = len(matches) - 1
options = f"[0-{max_value}]/n"
Expand Down Expand Up @@ -521,7 +526,12 @@ def update_new_txn(new_txn: Transaction, x_txn: Transaction) -> Transaction:
new_txn_posting = get_posting_to_account(new_txn, account)

# carry over any meta not otherwise defined on new_txn
updated_posting = new_txn_posting._replace(meta=new_txn_posting.meta.copy())
if new_txn_posting.meta:
updated_posting = new_txn_posting._replace(
meta=new_txn_posting.meta.copy()
)
else:
updated_posting = new_txn_posting._replace(meta={})
for k, v in posting.meta.items():
updated_posting.meta.setdefault(k, v)

Expand Down Expand Up @@ -729,4 +739,4 @@ def reconcile_new_txns(
for path, txns in x_txns_to_remove.items():
msg += f"\n{len(txns)} transactions have been removed from ledger {path}."

print(msg)
print(msg, file=sys.stderr)
14 changes: 10 additions & 4 deletions src/beanahead/rx_txns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
from pathlib import Path
import re
import sys

import pandas as pd
from beancount import loader
Expand Down Expand Up @@ -209,9 +210,9 @@ def get_definition_group(definition: Transaction) -> GrouperKey:
if account == bal_sheet_account:
continue
account_type = get_account_type(account)
if account_type == "Assets":
if account_type == utils.RootAccountsContext.get("name_assets", "Assets"):
other_sides.add("Assets")
elif account_type == "Income":
elif account_type == utils.RootAccountsContext.get("name_income", "Income"):
other_sides.add("Income")
else:
other_sides.add("Expenses")
Expand Down Expand Up @@ -651,12 +652,16 @@ def add_txns(self, end: str | pd.Timestamp = END_DFLT):

new_txns, new_defs = self._get_new_txns_data(end)
if not new_txns:
print(f"There are no new Regular Expected Transactions to add with {end=}.")
print(
f"There are no new Regular Expected Transactions to add with {end=}.",
file=sys.stderr,
)
return

ledger_txns = self.rx_txns + new_txns

# ensure all new content checks out before writting anything
utils.set_root_accounts_context(self.path_ledger_main)
content_ledger = compose_new_content("rx", ledger_txns)
content_defs = compose_new_content("rx_def", new_defs)

Expand All @@ -667,5 +672,6 @@ def add_txns(self, end: str | pd.Timestamp = END_DFLT):
print(
f"{len(new_txns)} transactions have been added to the ledger"
f" '{self.path_ledger.stem}'.\nDefinitions on '{self.path_defs.stem}' have"
f" been updated to reflect the most recent transactions."
f" been updated to reflect the most recent transactions.",
file=sys.stderr,
)
7 changes: 7 additions & 0 deletions src/beanahead/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

def make_file(args: argparse.Namespace):
"""Pass through command line args to make a new beanahead file."""
if args.main_ledger:
utils.set_root_accounts_context(args.main_ledger)
utils.create_beanahead_file(args.key, args.dirpath, args.filename)


Expand Down Expand Up @@ -89,6 +91,11 @@ def main():
choices=["x", "rx", "rx_def"],
metavar="key",
)
parser_make.add_argument(
*["-l", "--main-ledger"],
help="Path to the main ledger file to read its options.",
metavar="main_ledger",
)
parser_make.add_argument(
*["-d", "--dirpath"],
help=(
Expand Down
54 changes: 47 additions & 7 deletions src/beanahead/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
from pathlib import Path
import re
import sys

from beancount import loader
from beancount.core import data
Expand All @@ -29,6 +30,14 @@
TAG_RX = "rx_txn"
TAGS_X = set([TAG_X, TAG_RX])

NAME_OPTIONS = {
"name_assets": "Assets",
"name_liabilities": "Liabilities",
"name_income": "Income",
"name_expenses": "Expenses",
"name_equity": "Equity",
}

RX_META_DFLTS = {
"final": None,
"roll": True,
Expand Down Expand Up @@ -67,6 +76,26 @@

LEDGER_FILE_KEYS = ["x", "rx"]

RootAccountsContext = {} # global context


def set_root_accounts_context(path_ledger: str) -> dict[str]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a simple test to verify that all's working as expected when users are using non-default category names?

"""Set the root accounts context from the ledger file path.

Returns
-------
dict[str]
The name options set in the ledger.
"""
name_options: dict[str] = {}
options = get_options(path_ledger)
for opt, dflt in NAME_OPTIONS.items():
if options[opt] != dflt:
name_options[opt] = options[opt]
global RootAccountsContext
RootAccountsContext = name_options
return name_options


def validate_file_key(file_key: str):
"""Validate a file_key.
Expand Down Expand Up @@ -114,8 +143,15 @@ def compose_header_footer(file_key: str) -> tuple[str, str]:
"""
config = FILE_CONFIG[file_key]
plugin, tag, comment = config["plugin"], config["tag"], config["comment"]
extra_headers = ""
for k, v in RootAccountsContext.items():
extra_headers += f'option "{k}" "{v}"\n'

header = f"""option "title" "{config['title']}"\n"""
if extra_headers:
header += "\n"
header += extra_headers
header += "\n"
if plugin is not None:
header += f'plugin "{plugin}"\n'
header += f"pushtag #{tag}\n"
Expand Down Expand Up @@ -518,7 +554,7 @@ def reverse_automatic_balancing(txn: Transaction) -> Transaction:
"""
new_postings = []
for posting in txn.postings:
if AUTOMATIC_META in posting.meta:
if AUTOMATIC_META in (posting.meta or {}):
meta = {k: v for k, v in posting.meta.items() if k != AUTOMATIC_META}
posting = posting._replace(units=None, meta=meta)
new_postings.append(posting)
Expand All @@ -537,10 +573,7 @@ def is_assets_account(string: str) -> bool:
>>> is_assets_account("Assets:US:BofA:Checking")
True
"""
return is_account_type("Assets", string)


BAL_SHEET_ACCS = ["Assets", "Liabilities"]
return is_account_type(RootAccountsContext.get("name_assets", "Assets"), string)


def is_balance_sheet_account(string: str) -> bool:
Expand All @@ -566,7 +599,13 @@ def is_balance_sheet_account(string: str) -> bool:
>>> is_balance_sheet_account("Income:US:BayBook:Match401k")
False
"""
return any(is_account_type(acc_type, string) for acc_type in BAL_SHEET_ACCS)
return any(
is_account_type(acc_type, string)
for acc_type in [
RootAccountsContext.get("name_assets", "Assets"),
RootAccountsContext.get("name_liabilities", "Liabilities"),
]
)


def get_balance_sheet_accounts(txn: Transaction) -> list[str]:
Expand Down Expand Up @@ -845,7 +884,8 @@ def get_input(text: str) -> str:
-----
Function included to facilitate mocking user input when testing.
"""
return input(text)
print(text, file=sys.stderr, end=": ")
return input()


def response_is_valid_number(response: str, max_value: int) -> bool:
Expand Down
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def get_expected_output(string: str):
return textwrap.dedent(string)[1:]


def also_get_stdout(f: abc.Callable, *args, **kwargs) -> tuple[typing.Any, str]:
"""Return a function's return together with output to stdout."""
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
def also_get_stderr(f: abc.Callable, *args, **kwargs) -> tuple[typing.Any, str]:
"""Return a function's return together with output to stderr."""
stderr = io.StringIO()
with contextlib.redirect_stderr(stderr):
rtrn = f(*args, **kwargs)
return rtrn, stdout.getvalue()
return rtrn, stderr.getvalue()


@pytest.fixture
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(self, responses: abc.Generator[str]):
monkeypatch.setattr("beanahead.utils.get_input", self.input)

def input(self, string: str) -> str:
print(string)
print(string, file=sys.stderr)
return next(self.responses)

yield MockInput
Expand Down
Loading