--- /dev/null
+from beancount.core import amount
+from beancount.core.data import (
+ Amount,
+ Entries,
+ Transaction,
+)
+from .utils import flag, make_error
+
+from typing import Dict, List, Set, Tuple, Union
+
+__plugins__ = ('balance',)
+
+BALANCE_KEY = 'balance'
+
+
+def is_balance_entry(entry) -> bool:
+ if type(entry) is not Transaction:
+ return False
+
+ if BALANCE_KEY in entry.meta:
+ return True
+ for posting in entry.postings:
+ if BALANCE_KEY in posting.meta:
+ return True
+
+ return False
+
+
+# default to link or tag?
+def balance(entries: Entries, options_map, config_string=""):
+ del options_map, config_string # unused
+ errors = []
+
+ # pass 1, collect link/account pairs
+ link_accounts = {}
+ tag_accounts = {}
+ for i, entry in enumerate(entries):
+ if not is_balance_entry(entry):
+ continue
+
+ default_match = None
+ if len(entry.links) == 1:
+ for link in entry.links:
+ default_match = (link_accounts, link)
+ elif len(entry.tags) == 1:
+ for tag in entry.tags:
+ default_match = (tag_accounts, tag)
+
+ if default_match is not None and BALANCE_KEY in entry.meta:
+ account = entry.meta[BALANCE_KEY]
+ found = False
+ for posting in entry.postings:
+ if posting.account == account:
+ found = True
+ break
+ if not found:
+ entries[i] = flag(entry)
+ errors.append(make_error(entry, "TXN balancing is missing account"))
+ else:
+ if default_match[1] not in default_match[0]:
+ default_match[0][default_match[1]] = set()
+ default_match[0][default_match[1]].add(account)
+
+ for posting in entry.postings:
+ if BALANCE_KEY in posting.meta:
+ bal_val = posting.meta[BALANCE_KEY]
+ match = default_match
+ if bal_val is not None:
+ if bal_val in entry.tags:
+ match = (tag_accounts, bal_val)
+ elif bal_val in entry.links:
+ match = (link_accounts, bal_val)
+ if match is None:
+ entries[i] = flag(entry)
+ errors.append(make_error(entry, "Posting balancing match is missing or ambiguous"))
+ continue
+ if match[1] not in match[0]:
+ match[0][match[1]] = set()
+ match[0][match[1]].add(posting.account)
+
+ # pass 2, match and balance transactions
+ counts = {}
+ for i, entry in enumerate(entries):
+ if type(entry) is not Transaction:
+ continue
+ _handle_counts_from_set(counts, entry, i, entry.tags, tag_accounts, "tag")
+ _handle_counts_from_set(counts, entry, i, entry.links, link_accounts, "link")
+
+ # flag any unbalanced transactions
+ for id, vals in counts.items():
+ total, entry_idxs = vals
+ if total.number != 0.0:
+ errors.append(make_error(entries[entry_idxs[0]], "Unbalanced check for " + " ".join(id)))
+ for idx in entry_idxs:
+ entries[idx] = flag(entries[idx])
+
+ return entries, errors
+
+
+def _handle_counts_from_set(counts: Dict[Tuple[str,str,str],List[Union[Amount,List[int]]]],
+ txn: Transaction,
+ txn_idx: int,
+ labels: Set[str],
+ account_map: Dict[str, List[str]],
+ id: str) -> None:
+ for label in labels:
+ match_accounts = account_map.get(label)
+ if match_accounts is None:
+ continue
+ for posting in txn.postings:
+ if posting.account in match_accounts:
+ match_id = (id, label, posting.account)
+ if match_id in counts:
+ counts[match_id][0] = amount.add(counts[match_id][0], posting.units)
+ counts[match_id][1].append(txn_idx)
+ else:
+ counts[match_id] = [posting.units, [txn_idx]]
+