From: Dustin Walde Date: Thu, 9 Nov 2023 01:23:48 +0000 (-0800) Subject: Add balance plugin X-Git-Url: https://git.walde.dev/?a=commitdiff_plain;h=9eaddd45c201fbd5bea06bd503c207e66f3ba170;p=beanbeanbean Add balance plugin - Mark an account to balance based off of tag or link name --- diff --git a/src/beanbeanbean/balance.py b/src/beanbeanbean/balance.py new file mode 100644 index 0000000..2eb5dff --- /dev/null +++ b/src/beanbeanbean/balance.py @@ -0,0 +1,118 @@ +from beancount.core import amount +from beancount.core.data import ( + Amount, + Entries, + Transaction, +) +from .utils import flag, make_error + +from typing import Dict, List, Set, Tuple, Union + +__plugins__ = ('balance',) + +BALANCE_KEY = 'balance' + + +def is_balance_entry(entry) -> bool: + if type(entry) is not Transaction: + return False + + if BALANCE_KEY in entry.meta: + return True + for posting in entry.postings: + if BALANCE_KEY in posting.meta: + return True + + return False + + +# default to link or tag? +def balance(entries: Entries, options_map, config_string=""): + del options_map, config_string # unused + errors = [] + + # pass 1, collect link/account pairs + link_accounts = {} + tag_accounts = {} + for i, entry in enumerate(entries): + if not is_balance_entry(entry): + continue + + default_match = None + if len(entry.links) == 1: + for link in entry.links: + default_match = (link_accounts, link) + elif len(entry.tags) == 1: + for tag in entry.tags: + default_match = (tag_accounts, tag) + + if default_match is not None and BALANCE_KEY in entry.meta: + account = entry.meta[BALANCE_KEY] + found = False + for posting in entry.postings: + if posting.account == account: + found = True + break + if not found: + entries[i] = flag(entry) + errors.append(make_error(entry, "TXN balancing is missing account")) + else: + if default_match[1] not in default_match[0]: + default_match[0][default_match[1]] = set() + default_match[0][default_match[1]].add(account) + + for posting in entry.postings: + if BALANCE_KEY in posting.meta: + bal_val = posting.meta[BALANCE_KEY] + match = default_match + if bal_val is not None: + if bal_val in entry.tags: + match = (tag_accounts, bal_val) + elif bal_val in entry.links: + match = (link_accounts, bal_val) + if match is None: + entries[i] = flag(entry) + errors.append(make_error(entry, "Posting balancing match is missing or ambiguous")) + continue + if match[1] not in match[0]: + match[0][match[1]] = set() + match[0][match[1]].add(posting.account) + + # pass 2, match and balance transactions + counts = {} + for i, entry in enumerate(entries): + if type(entry) is not Transaction: + continue + _handle_counts_from_set(counts, entry, i, entry.tags, tag_accounts, "tag") + _handle_counts_from_set(counts, entry, i, entry.links, link_accounts, "link") + + # flag any unbalanced transactions + for id, vals in counts.items(): + total, entry_idxs = vals + if total.number != 0.0: + errors.append(make_error(entries[entry_idxs[0]], "Unbalanced check for " + " ".join(id))) + for idx in entry_idxs: + entries[idx] = flag(entries[idx]) + + return entries, errors + + +def _handle_counts_from_set(counts: Dict[Tuple[str,str,str],List[Union[Amount,List[int]]]], + txn: Transaction, + txn_idx: int, + labels: Set[str], + account_map: Dict[str, List[str]], + id: str) -> None: + for label in labels: + match_accounts = account_map.get(label) + if match_accounts is None: + continue + for posting in txn.postings: + if posting.account in match_accounts: + match_id = (id, label, posting.account) + if match_id in counts: + counts[match_id][0] = amount.add(counts[match_id][0], posting.units) + counts[match_id][1].append(txn_idx) + else: + counts[match_id] = [posting.units, [txn_idx]] + diff --git a/src/beanbeanbean/utils.py b/src/beanbeanbean/utils.py index 442f432..9262d2c 100644 --- a/src/beanbeanbean/utils.py +++ b/src/beanbeanbean/utils.py @@ -1,4 +1,5 @@ -from beancount.core.data import Transaction +from beancount.core.data import new_metadata, Transaction +from beancount.loader import LoadError def flag(entry): @@ -6,3 +7,10 @@ def flag(entry): tdict["flag"] = "!" return Transaction(**tdict) + +def make_error(entry, message) -> LoadError: + return LoadError( + entry=entry, + message=message, + source=new_metadata(entry.meta["filename"], entry.meta["lineno"])) +