]> git.walde.dev - beanbeanbean/commitdiff
Update recurring to better handle multi-currency
authorDustin Walde <redacted>
Sat, 11 Nov 2023 03:05:52 +0000 (19:05 -0800)
committerDustin Walde <redacted>
Sat, 11 Nov 2023 03:05:52 +0000 (19:05 -0800)
src/beanbeanbean/recurring.py

index 590ca4ba4a6cb9ce3b1d2355ca2ea925ca4842c7..676109a8d194da0b7f9b10d145a291d0c6d6b8c0 100644 (file)
@@ -4,8 +4,10 @@ from decimal import (
     localcontext,
 )
 
-from beancount.core import amount
+from beancount.core import amount, position
+from beancount.core.convert import get_weight
 from beancount.core.data import (
+    Amount,
     Entries,
     new_metadata,
     Posting,
@@ -15,6 +17,8 @@ from beancount.loader import LoadError
 from dateutil import rrule
 import recurrent
 
+from beanbeanbean.balance import balance
+
 from .utils import flag
 
 from typing import List, Optional, Union
@@ -51,8 +55,17 @@ def is_recurring_transaction(entry) -> bool:
 
 def handle_recurring_transaction(txn: Transaction) -> Union[List[Transaction], LoadError]:
     post_vals = []
+    balances = {}
     for post in txn.postings:
-        post_vals.append([post.units, post.cost])
+        weight_amt = get_weight(post)
+        balances[weight_amt.currency] = amount.add(
+            balances.get(weight_amt.currency,
+                         Amount(Decimal(0.0), weight_amt.currency)),
+            weight_amt)
+        post_vals.append([post.units, weight_amt])
+
+    balanced = True
+    tolerances = txn.meta['__tolerances__']
 
     match_key = None
     amortize = False
@@ -69,6 +82,12 @@ def handle_recurring_transaction(txn: Transaction) -> Union[List[Transaction], L
                 match_key = rkey
                 phrase = txn.meta[rkey]
                 break
+    else:
+        # only force each amortized transaction to balance if the source did
+        for amt in balances.values():
+            if amount.abs(amt).number > tolerances[amt.currency]:
+                balanced = False
+                break
 
     if phrase is None:
         return LoadError(
@@ -84,31 +103,71 @@ def handle_recurring_transaction(txn: Transaction) -> Union[List[Transaction], L
     del txn_dict["date"]
     del txn_dict["postings"]
 
-    tolerances = txn.meta['__tolerances__']
     entries = []
 
     for i, dt in enumerate(dates):
         new_txn = Transaction(date=dt.date(), postings=[], **txn_dict)
         if amortize:
+            new_postings = []
+            balances = {}
             for j, posting in enumerate(txn.postings):
                 post_dict = posting._asdict()
-                for amt, key in ((post_vals[j][0], 'units'), (post_vals[j][1], 'cost')):
-                    if post_dict[key] is None:
-                        continue
-
-                    with localcontext() as ctx:
-                        ctx.prec = len(str(tolerances[amt.currency])) - 2
-                        amortized_amount = amount.div(amt, Decimal(len(dates)-i))
-                        post_dict[key] = amortized_amount
-                    post_vals[j][0] = amount.sub(amt, amortized_amount)
-                new_txn.postings.append(Posting(**post_dict))
+                remaining = post_vals[j]
+
+                amortized_cost = None
+                with localcontext() as ctx:
+                    post_tolerance = tolerances[remaining[0].currency]
+                    if post_tolerance > Decimal(0.0):
+                        ctx.prec = len(str(post_tolerance)) - 2
+                    amortized_post = amount.div(remaining[0], Decimal(len(dates)-i))
+                    if remaining[0].currency != remaining[1].currency:
+                        cost_tolerance = tolerances[remaining[1].currency]
+                        if cost_tolerance > Decimal(0.0):
+                            ctx.prec = len(str(cost_tolerance)) - 2
+                        amortized_cost = amount.div(remaining[1], Decimal(len(dates)-i))
+
+                weight_amt = amortized_post
+                if amortized_cost is not None:
+                    weight_amt = amortized_cost
+                if weight_amt.currency not in balances:
+                    balances[weight_amt.currency] = weight_amt
+                else:
+                    balances[weight_amt.currency] = amount.add(balances[weight_amt.currency], weight_amt)
+
+                post_dict['units'] = amortized_post
+                post_dict['cost'] = amortized_cost
+                post_dict['price'] = None
+                remaining[0] = amount.sub(remaining[0], amortized_post)
+                if amortized_cost is None:
+                    remaining[1] = remaining[0]
+                else:
+                    remaining[1] = amount.sub(remaining[1], amortized_cost)
+                new_postings.append(post_dict)
+
+            if balanced:
+                # verify it all still balances
+                for remainder in balances.values():
+                    if amount.abs(remainder).number >= tolerances[remainder.currency]:
+                        for post in new_postings:
+                            cost = post['cost']
+                            if cost is not None and cost.currency == remainder.currency:
+                                post['cost'] = amount.sub(cost, remainder)
+                                break
+                            elif cost is None and post['units'].currency == remainder.currency:
+                                post['units'] = amount.sub(post['units'], remainder)
+                                break
+
+            for pd in new_postings:
+                if pd['cost'] is not None:
+                    pd['cost'] = position.Cost(pd['cost'].number, pd['cost'].currency, txn.date, 'amortized')
+                new_txn.postings.append(Posting(**pd))
         else:
             for posting in txn.postings:
                 new_txn.postings.append(posting)
         entries.append(new_txn)
 
     # replace metadata with computed values to show it was processed
-    # this dict is shared among all entries..
+    # this dict is a shared reference among all entries..
     txn.meta['a͏mortize'] = amortize
     txn.meta['p͏hrase'] = txn.meta[match_key]
     del txn.meta[match_key]
@@ -133,6 +192,8 @@ def recurring(entries: Entries, options_map, config_string=""):
                 else:
                     errors.append(res)
             except Exception as e:
+                import traceback
+                print(traceback.format_exc())
                 errors.append(LoadError(
                     source=new_metadata(entry.meta["filename"], entry.meta["lineno"]),
                     message="Failed to handle recurring transaction: {}".format(e),