]> git.walde.dev - beanbeanbean/commitdiff
Add support for auto-balance options
authorDustin Walde <redacted>
Mon, 13 Nov 2023 01:16:41 +0000 (17:16 -0800)
committerDustin Walde <redacted>
Mon, 13 Nov 2023 01:16:41 +0000 (17:16 -0800)
Custom items b3.balance_tags, b3.balance_accounts

src/beanbeanbean/balance.py

index 2eb5dffc2bcb97234c80830b379d6dd077edb771..078efeb48a059bcf0c67255488486c87125a88fa 100644 (file)
@@ -1,6 +1,7 @@
 from beancount.core import amount
 from beancount.core.data import (
     Amount,
+    Custom,
     Entries,
     Transaction,
 )
@@ -13,7 +14,7 @@ __plugins__ = ('balance',)
 BALANCE_KEY = 'balance'
 
 
-def is_balance_entry(entry) -> bool:
+def is_balance_entry(entry, auto_tags: Set[str]) -> bool:
     if type(entry) is not Transaction:
         return False
 
@@ -23,6 +24,9 @@ def is_balance_entry(entry) -> bool:
         if BALANCE_KEY in posting.meta:
             return True
 
+    if len(auto_tags.intersection(entry.tags)) > 0:
+        return True
+
     return False
 
 
@@ -31,35 +35,49 @@ def balance(entries: Entries, options_map, config_string=""):
     del options_map, config_string # unused
     errors = []
 
+    # pass 0, collect configuration entries
+    auto_tags: Set[str] = set()
+    auto_accounts: Set[str] = set()
+    for entry in entries:
+        if type(entry) == Custom:
+            if entry.type == "b3.balance_tags":
+                for tag in entry.values:
+                    assert tag.dtype is str
+                    auto_tags.add(tag.value)
+            elif entry.type == "b3.balance_accounts":
+                for account in entry.values:
+                    assert account.dtype is str
+                    auto_accounts.add(account.value)
+
     # pass 1, collect link/account pairs
     link_accounts = {}
     tag_accounts = {}
     for i, entry in enumerate(entries):
-        if not is_balance_entry(entry):
+        if not is_balance_entry(entry, auto_tags):
             continue
 
+        entry_tags = entry.tags.difference(auto_tags)
+
         default_match = None
         if len(entry.links) == 1:
             for link in entry.links:
                 default_match = (link_accounts, link)
-        elif len(entry.tags) == 1:
-            for tag in entry.tags:
+        elif len(entry_tags) == 1:
+            for tag in entry_tags:
                 default_match = (tag_accounts, tag)
         
-        if default_match is not None and BALANCE_KEY in entry.meta:
-            account = entry.meta[BALANCE_KEY]
-            found = False
+        found = False
+        if default_match is not None and (
+                BALANCE_KEY in entry.meta or
+                len(auto_tags.intersection(entry.tags)) > 0):
+            account = entry.meta.get(BALANCE_KEY)
             for posting in entry.postings:
-                if posting.account == account:
+                if posting.account == str(account) \
+                        or _is_auto_account(posting.account, auto_accounts):
                     found = True
-                    break
-            if not found:
-                entries[i] = flag(entry)
-                errors.append(make_error(entry, "TXN balancing is missing account"))
-            else:
-                if default_match[1] not in default_match[0]:
-                    default_match[0][default_match[1]] = set()
-                default_match[0][default_match[1]].add(account)
+                    if default_match[1] not in default_match[0]:
+                        default_match[0][default_match[1]] = set()
+                    default_match[0][default_match[1]].add(posting.account)
 
         for posting in entry.postings:
             if BALANCE_KEY in posting.meta:
@@ -74,10 +92,15 @@ def balance(entries: Entries, options_map, config_string=""):
                     entries[i] = flag(entry)
                     errors.append(make_error(entry, "Posting balancing match is missing or ambiguous"))
                     continue
+                found = True
                 if match[1] not in match[0]:
                     match[0][match[1]] = set()
                 match[0][match[1]].add(posting.account)
 
+        if not found:
+            entries[i] = flag(entry)
+            errors.append(make_error(entry, "TXN balancing is missing account"))
+
     # pass 2, match and balance transactions
     counts = {}
     for i, entry in enumerate(entries):
@@ -110,9 +133,20 @@ def _handle_counts_from_set(counts: Dict[Tuple[str,str,str],List[Union[Amount,Li
         for posting in txn.postings:
             if posting.account in match_accounts:
                 match_id = (id, label, posting.account)
+                if id == "tag":
+                    posting.meta["balance"] = "#" + label
+                else:
+                    posting.meta["balance"] = "^" + label
                 if match_id in counts:
                     counts[match_id][0] = amount.add(counts[match_id][0], posting.units)
                     counts[match_id][1].append(txn_idx)
                 else:
                     counts[match_id] = [posting.units, [txn_idx]]
 
+
+def _is_auto_account(account: str, auto_accounts: Set[str]) -> bool:
+    for auto in auto_accounts:
+        if account.startswith(auto):
+            return True
+    return False
+