]> git.walde.dev - beanbeanbean/commitdiff
Add balance plugin
authorDustin Walde <redacted>
Thu, 9 Nov 2023 01:23:48 +0000 (17:23 -0800)
committerDustin Walde <redacted>
Thu, 9 Nov 2023 01:23:48 +0000 (17:23 -0800)
- Mark an account to balance based off of tag or link name

src/beanbeanbean/balance.py [new file with mode: 0644]
src/beanbeanbean/utils.py

diff --git a/src/beanbeanbean/balance.py b/src/beanbeanbean/balance.py
new file mode 100644 (file)
index 0000000..2eb5dff
--- /dev/null
@@ -0,0 +1,118 @@
+from beancount.core import amount
+from beancount.core.data import (
+    Amount,
+    Entries,
+    Transaction,
+)
+from .utils import flag, make_error
+
+from typing import Dict, List, Set, Tuple, Union
+
+__plugins__ = ('balance',)
+
+BALANCE_KEY = 'balance'
+
+
+def is_balance_entry(entry) -> bool:
+    if type(entry) is not Transaction:
+        return False
+
+    if BALANCE_KEY in entry.meta:
+        return True
+    for posting in entry.postings:
+        if BALANCE_KEY in posting.meta:
+            return True
+
+    return False
+
+
+# default to link or tag?
+def balance(entries: Entries, options_map, config_string=""):
+    del options_map, config_string # unused
+    errors = []
+
+    # pass 1, collect link/account pairs
+    link_accounts = {}
+    tag_accounts = {}
+    for i, entry in enumerate(entries):
+        if not is_balance_entry(entry):
+            continue
+
+        default_match = None
+        if len(entry.links) == 1:
+            for link in entry.links:
+                default_match = (link_accounts, link)
+        elif len(entry.tags) == 1:
+            for tag in entry.tags:
+                default_match = (tag_accounts, tag)
+        
+        if default_match is not None and BALANCE_KEY in entry.meta:
+            account = entry.meta[BALANCE_KEY]
+            found = False
+            for posting in entry.postings:
+                if posting.account == account:
+                    found = True
+                    break
+            if not found:
+                entries[i] = flag(entry)
+                errors.append(make_error(entry, "TXN balancing is missing account"))
+            else:
+                if default_match[1] not in default_match[0]:
+                    default_match[0][default_match[1]] = set()
+                default_match[0][default_match[1]].add(account)
+
+        for posting in entry.postings:
+            if BALANCE_KEY in posting.meta:
+                bal_val = posting.meta[BALANCE_KEY]
+                match = default_match
+                if bal_val is not None:
+                    if bal_val in entry.tags:
+                        match = (tag_accounts, bal_val)
+                    elif bal_val in entry.links:
+                        match = (link_accounts, bal_val)
+                if match is None:
+                    entries[i] = flag(entry)
+                    errors.append(make_error(entry, "Posting balancing match is missing or ambiguous"))
+                    continue
+                if match[1] not in match[0]:
+                    match[0][match[1]] = set()
+                match[0][match[1]].add(posting.account)
+
+    # pass 2, match and balance transactions
+    counts = {}
+    for i, entry in enumerate(entries):
+        if type(entry) is not Transaction:
+            continue
+        _handle_counts_from_set(counts, entry, i, entry.tags, tag_accounts, "tag")
+        _handle_counts_from_set(counts, entry, i, entry.links, link_accounts, "link")
+
+    # flag any unbalanced transactions
+    for id, vals in counts.items():
+        total, entry_idxs = vals
+        if total.number != 0.0:
+            errors.append(make_error(entries[entry_idxs[0]], "Unbalanced check for " + " ".join(id)))
+            for idx in entry_idxs:
+                entries[idx] = flag(entries[idx])
+
+    return entries, errors
+
+
+def _handle_counts_from_set(counts: Dict[Tuple[str,str,str],List[Union[Amount,List[int]]]],
+                            txn: Transaction,
+                            txn_idx: int,
+                            labels: Set[str],
+                            account_map: Dict[str, List[str]],
+                            id: str) -> None:
+    for label in labels:
+        match_accounts = account_map.get(label)
+        if match_accounts is None:
+            continue
+        for posting in txn.postings:
+            if posting.account in match_accounts:
+                match_id = (id, label, posting.account)
+                if match_id in counts:
+                    counts[match_id][0] = amount.add(counts[match_id][0], posting.units)
+                    counts[match_id][1].append(txn_idx)
+                else:
+                    counts[match_id] = [posting.units, [txn_idx]]
+
index 442f43232f90759a3a4291a5eed5a5dd52ac6f36..9262d2c7b753eef5fc3c83dfc4572b283c799f29 100644 (file)
@@ -1,4 +1,5 @@
-from beancount.core.data import Transaction
+from beancount.core.data import new_metadata, Transaction
+from beancount.loader import LoadError
 
 
 def flag(entry):
@@ -6,3 +7,10 @@ def flag(entry):
     tdict["flag"] = "!"
     return Transaction(**tdict)
 
+
+def make_error(entry, message) -> LoadError:
+    return LoadError(
+        entry=entry,
+        message=message,
+        source=new_metadata(entry.meta["filename"], entry.meta["lineno"]))
+