]> git.walde.dev - beanbeanbean/commitdiff
Add first try at link balancing main
authorDustin Walde <redacted>
Tue, 14 Nov 2023 23:05:25 +0000 (15:05 -0800)
committerDustin Walde <redacted>
Tue, 14 Nov 2023 23:05:25 +0000 (15:05 -0800)
- Code is messy
- Match by links, tags, file, or query

src/beanbeanbean/balance.py

index 078efeb48a059bcf0c67255488486c87125a88fa..380abb98b73d9571d1cf55b8d34603d1fa9a2788 100644 (file)
@@ -1,19 +1,68 @@
 from beancount.core import amount
+from beancount.core.convert import get_weight
 from beancount.core.data import (
     Amount,
+    Balance,
     Custom,
+    Decimal,
     Entries,
+    Pad,
     Transaction,
 )
+from beancount.query import query
 from .utils import flag, make_error
 
-from typing import Dict, List, Set, Tuple, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
 
 __plugins__ = ('balance',)
 
 BALANCE_KEY = 'balance'
 
 
+class BalanceMatch:
+    entry: Balance
+    amount: Amount
+    query: bool
+    links: Set
+    tags: Set
+    file: Optional[str]
+
+    def __init__(self, entry: Balance, errors) -> None:
+        assert 'match' in entry.meta or 'query' in entry.meta
+        self.entry = entry
+        self.query = 'query' in entry.meta
+        self.links = set()
+        self.tags = set()
+        self.file = None
+        self.amount = Amount(currency=entry.amount.currency, number=Decimal(0.0))
+        if 'match' in entry.meta:
+            parts = entry.meta["match"].lower().split(",")
+            for part in parts:
+                if part == "file":
+                    self.file = entry.meta["filename"]
+                elif part.startswith("^"):
+                    self.links.add(part[1:])
+                elif part.startswith("#"):
+                    self.tags.add(part[1:])
+                else:
+                    errors.append(make_error(entry, f"Invalid match item: {part}"))
+
+    def matches(self, entry: Transaction) -> bool:
+        if self.file is not None and entry.meta["filename"] == self.file:
+            return True
+        if len(self.links.intersection(entry.links)) > 0:
+            return True
+        if len(self.tags.intersection(entry.tags)) > 0:
+            return True
+        return False
+
+
+class AccountBalances:
+    def __init__(self, currency: str) -> None:
+        self.total = Amount(currency=currency, number=Decimal(0.0))
+        self.matches: List[BalanceMatch] = []
+
+
 def is_balance_entry(entry, auto_tags: Set[str]) -> bool:
     if type(entry) is not Transaction:
         return False
@@ -30,14 +79,23 @@ def is_balance_entry(entry, auto_tags: Set[str]) -> bool:
     return False
 
 
+def is_match_balance(entry: Balance) -> bool:
+    return "match" in entry.meta
+
+
+def is_query_balance(entry: Balance) -> bool:
+    return "query" in entry.meta
+
+
 # default to link or tag?
 def balance(entries: Entries, options_map, config_string=""):
-    del options_map, config_string # unused
+    del config_string # unused
     errors = []
 
     # pass 0, collect configuration entries
     auto_tags: Set[str] = set()
     auto_accounts: Set[str] = set()
+    balance_accounts: Dict[str,Dict[str,AccountBalances]] = {}
     for entry in entries:
         if type(entry) == Custom:
             if entry.type == "b3.balance_tags":
@@ -48,11 +106,80 @@ def balance(entries: Entries, options_map, config_string=""):
                 for account in entry.values:
                     assert account.dtype is str
                     auto_accounts.add(account.value)
+        elif type(entry) == Balance:
+            b3_match = None
+            if is_match_balance(entry):
+                b3_match = BalanceMatch(entry, errors)
+            elif is_query_balance(entry):
+                query_str = f"SELECT sum(position) WHERE account ~ '{entry.account}' AND {entry.meta['query']}"
+                res = query.run_query(entries, options_map, query_str, numberify=True)
+                if len(res[1]) == 0: # maybe ok if balance should be zero?
+                    errors.append(make_error(entry, "Balance query failed to match"))
+                else:
+                    if entry.amount.number != res[1][0][0]:
+                        errors.append(make_error(entry, "Balance query is unbalanced"))
+                    entry.meta["actual"] = Amount(currency=entry.amount.currency, number=res[1][0][0])
+                b3_match = BalanceMatch(entry, errors)
+
+            if b3_match is not None:
+                if entry.account not in balance_accounts:
+                    balance_accounts[entry.account] = {}
+                if entry.amount.currency not in balance_accounts[entry.account]:
+                    balance_accounts[entry.account][entry.amount.currency] = \
+                        AccountBalances(entry.amount.currency)
+                balance_accounts[entry.account][entry.amount.currency].matches.append(b3_match)
 
     # pass 1, collect link/account pairs
     link_accounts = {}
     tag_accounts = {}
+    pad_accounts = set()
     for i, entry in enumerate(entries):
+        if type(entry) is Balance:
+            if is_match_balance(entry) or is_query_balance(entry):
+                match_balance = balance_accounts[entry.account][entry.amount.currency]
+                full_balance = match_balance.total
+                balance_match = None
+                for match in match_balance.matches:
+                    if entry == match.entry:
+                        balance_match = match
+                        break
+                assert balance_match is not None
+                link_balance = balance_match.amount
+
+                bal_dict = entry._asdict()
+                entry.meta["amount"] = entry.amount
+                if "match" in entry.meta:
+                    entry.meta["actual"] = link_balance
+                link_balance = entry.meta["actual"]
+                bal_dict["amount"] = amount.add(
+                        full_balance,
+                        amount.sub(entry.amount, link_balance))
+                entries[i] = Balance(**bal_dict)
+                match_balance.matches.remove(balance_match)
+                if len(match_balance.matches) == 0:
+                    del balance_accounts[entry.account][entry.amount.currency]
+                    if len(balance_accounts[entry.account]) == 0:
+                        del balance_accounts[entry.account]
+                elif entry.account in pad_accounts:
+                    # keep track of padding for remaining balances
+                    pad_amount = amount.sub(entry.amount, full_balance)
+                    match_balance.total = entry.amount
+                    for match in match_balance.matches:
+                        match.amount = amount.add(match.amount, pad_amount)
+            elif entry.account in pad_accounts and entry.account in balance_accounts:
+                if entry.amount.currency in balance_accounts[entry.account]:
+                    balance_accounts[entry.account][entry.amount.currency].total = entry.amount
+        elif type(entry) is Pad and entry.account in balance_accounts:
+            pad_accounts.add(entry.account)
+        elif type(entry) is Transaction and len(balance_accounts) > 0:
+            for post in entry.postings:
+                post_amount = get_weight(post)
+                if post.account in balance_accounts and post_amount.currency in balance_accounts[post.account]:
+                    balances = balance_accounts[post.account][post_amount.currency]
+                    balances.total = amount.add(balances.total, post_amount)
+                    for match in balances.matches: # oops matches/matches() is confusing
+                        if match.matches(entry):
+                            match.amount = amount.add(match.amount, post_amount)
         if not is_balance_entry(entry, auto_tags):
             continue