1010from typing import Any
1111from typing import Dict
1212from datetime import datetime
13+ from decimal import Decimal
1314from amatino .denominated import Denominated
1415from amatino .denomination import Denomination
1516from amatino .decodable import Decodable
@@ -66,9 +67,9 @@ def __init__(
6667 generated_time : AmatinoTime ,
6768 global_unit_id : Optional [int ],
6869 custom_unit_id : Optional [int ],
69- assets : Optional [ List [TreeNode ] ],
70- liabilities : Optional [ List [TreeNode ] ],
71- equities : Optional [ List [TreeNode ] ],
70+ assets : List [TreeNode ],
71+ liabilities : List [TreeNode ],
72+ equities : List [TreeNode ],
7273 depth : int
7374 ) -> None :
7475
@@ -79,15 +80,12 @@ def __init__(
7980 assert isinstance (global_unit_id , int )
8081 if custom_unit_id is not None :
8182 assert isinstance (custom_unit_id , int )
82- if assets is not None :
83- assert isinstance (assets , list )
84- assert False not in [isinstance (a , TreeNode ) for a in assets ]
85- if liabilities is not None :
86- assert isinstance (liabilities , list )
87- assert False not in [isinstance (l , TreeNode ) for l in liabilities ]
88- if equities is not None :
89- assert isinstance (equities , list )
90- assert False not in [isinstance (e , TreeNode ) for e in equities ]
83+ assert isinstance (assets , list )
84+ assert False not in [isinstance (a , TreeNode ) for a in assets ]
85+ assert isinstance (liabilities , list )
86+ assert False not in [isinstance (l , TreeNode ) for l in liabilities ]
87+ assert isinstance (equities , list )
88+ assert False not in [isinstance (e , TreeNode ) for e in equities ]
9189
9290 self ._entity = entity
9391 self ._balance_time = balance_time
@@ -111,15 +109,21 @@ def __init__(
111109 liabilities = Immutable (lambda s : s ._liabilities )
112110 equities = Immutable (lambda s : s ._equities )
113111
114- has_assets = Immutable (
115- lambda s : s ._assets is not None and len (s ._assets ) > 0
116- )
117- has_liabilities = Immutable (
118- lambda s : s ._liabilities is not None and len (s ._liabilities ) > 0
119- )
120- has_equities = Immutable (
121- lambda s : s ._equities is not None and len (s ._equities ) > 0
122- )
112+ has_assets = Immutable (lambda s : len (s ._assets ) > 0 )
113+ has_liabilities = Immutable (lambda s : len (s ._liabilities ) > 0 )
114+ has_equities = Immutable (lambda s : len (s ._equities ) > 0 )
115+
116+ total_assets = Immutable (lambda s : s ._compute_total (s ._assets ))
117+ total_liabilities = Immutable (lambda s : s ._compute_total (s ._liabilities ))
118+ total_equity = Immutable (lambda s : s ._compute_total (s ._equities ))
119+
120+ def _compute_total (self , nodes : List [TreeNode ]) -> Decimal :
121+ """Return the total of all top level recursive balances"""
122+ if len (nodes ) < 1 :
123+ return Decimal (0 )
124+ total = sum ([n .recursive_balance for n in nodes ])
125+ assert isinstance (total , Decimal )
126+ return total
123127
124128 @classmethod
125129 def decode (
@@ -133,30 +137,22 @@ def decode(
133137
134138 try :
135139
136- assets = None
137- if data ['assets' ] is not None :
138- assets = TreeNode .decode_many (entity , data ['assets' ])
139-
140- liabilities = None
141- if data ['liabilities' ] is not None :
142- liabilities = TreeNode .decode_many (
143- entity ,
144- data ['liabilities' ]
145- )
146-
147- equities = None
148- if data ['equities' ] is not None :
149- equities = TreeNode .decode_many (entity , data ['equities' ])
140+ if data ['assets' ] is None :
141+ data ['assets' ] = list ()
142+ if data ['liabilities' ] is None :
143+ data ['liabilities' ] = list ()
144+ if data ['equities' ] is None :
145+ data ['equities' ] = list ()
150146
151147 position = cls (
152148 entity = entity ,
153149 balance_time = AmatinoTime .decode (data ['balance_time' ]),
154150 generated_time = AmatinoTime .decode (data ['generated_time' ]),
155151 global_unit_id = data ['global_unit_denomination' ],
156152 custom_unit_id = data ['custom_unit_denomination' ],
157- assets = assets ,
158- liabilities = liabilities ,
159- equities = equities ,
153+ assets = TreeNode . decode_many ( entity , data [ ' assets' ]) ,
154+ liabilities = TreeNode . decode_many ( entity , data [ ' liabilities' ]) ,
155+ equities = TreeNode . decode_many ( entity , data [ ' equities' ]) ,
160156 depth = data ['depth' ]
161157 )
162158
0 commit comments