Skip to content

Commit ef6695e

Browse files
authored
Merge pull request #1048 from jdebacker/dask_graph
Merging
2 parents d476a06 + 0273913 commit ef6695e

File tree

6 files changed

+124
-81
lines changed

6 files changed

+124
-81
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88

9+
## [0.14.7] - 2025-08-21 17:00:00
10+
11+
### Added
12+
13+
- Refactor calls to dask in `SS.py` and `TPI.py`. See PR [#1048](https://github.com/PSLmodels/OG-Core/pull/1048)
14+
915
## [0.14.6] - 2025-08-15 14:00:00
1016

1117
### Added
@@ -403,6 +409,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
403409
- Any earlier versions of OG-USA can be found in the [`OG-Core`](https://github.com/PSLmodels/OG-Core) repository [release history](https://github.com/PSLmodels/OG-Core/releases) from [v.0.6.4](https://github.com/PSLmodels/OG-Core/releases/tag/v0.6.4) (Jul. 20, 2021) or earlier.
404410

405411

412+
[0.14.7]: https://github.com/PSLmodels/OG-Core/compare/v0.14.6...v0.14.7
406413
[0.14.6]: https://github.com/PSLmodels/OG-Core/compare/v0.14.5...v0.14.6
407414
[0.14.5]: https://github.com/PSLmodels/OG-Core/compare/v0.14.4...v0.14.5
408415
[0.14.4]: https://github.com/PSLmodels/OG-Core/compare/v0.14.3...v0.14.4

environment.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ dependencies:
1313
- dask>=2.30.0
1414
- dask-core>=2.30.0
1515
- distributed>=2.30.1
16-
- paramtools>=0.15.0
16+
- paramtools>=0.20.0
1717
- sphinx>=3.5.4
18-
- marshmallow<4.0.0
1918
- sphinx-argparse
2019
- sphinxcontrib-bibtex>=2.0.0
2120
- sphinx-math-dollar

ogcore/SS.py

Lines changed: 86 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,60 @@ def euler_equation_solver(guesses, *args):
169169
return errors
170170

171171

172+
def solve_for_j(
173+
guesses,
174+
r_p,
175+
w,
176+
p_tilde,
177+
bq_j,
178+
rm_j,
179+
tr_j,
180+
ubi_j,
181+
factor,
182+
j,
183+
p_future,
184+
):
185+
"""
186+
Solves the household's optimization problem for a given type j.
187+
188+
Args:
189+
guesses (Numpy array): initial guesses for b and n, length 2S
190+
r_p (scalar): return on household investment portfolio
191+
w (scalar): real wage rate
192+
p_tilde (scalar): composite good price
193+
bq_j (Numpy array): bequest amounts by age, length S
194+
rm_j (Numpy array): remittance amounts by age, length S
195+
tr_j (Numpy array): government transfer amount by age, length S
196+
ubi_j (vector): universal basic income (UBI) payment, length S
197+
factor (scalar): scaling factor converting model units to dollars
198+
j (int): household type index
199+
p_future (OG-Core Specifications object): future model parameters
200+
201+
Returns:
202+
root (OptimizeResult): the optimization result
203+
"""
204+
# scattered_p is either the original object (serial case)
205+
# or a Future pointing to it (distributed case)
206+
return opt.root(
207+
euler_equation_solver,
208+
guesses * 0.9,
209+
args=(
210+
r_p,
211+
w,
212+
p_tilde,
213+
bq_j,
214+
rm_j,
215+
tr_j,
216+
ubi_j,
217+
factor,
218+
j,
219+
p_future,
220+
),
221+
method=p_future.FOC_root_method,
222+
tol=MINIMIZER_TOL,
223+
)
224+
225+
172226
def inner_loop(outer_loop_vars, p, client):
173227
"""
174228
This function solves for the inner loop of the SS. That is, given
@@ -235,50 +289,19 @@ def inner_loop(outer_loop_vars, p, client):
235289
tr = household.get_tr(TR, None, p, "SS")
236290
ubi = p.ubi_nom_array[-1, :, :] / factor
237291

238-
scattered_p = client.scatter(p, broadcast=True) if client else p
239-
240-
lazy_values = []
241-
for j in range(p.J):
242-
guesses = np.append(bssmat[:, j], nssmat[:, j])
292+
results = []
293+
# from dask.base import dask_sizeof
243294

244-
# Create a delayed function that will access the scattered_p
245-
@delayed
246-
def solve_for_j(
247-
guesses,
248-
r_p,
249-
w,
250-
p_tilde,
251-
bq_j,
252-
rm_j,
253-
tr_j,
254-
ubi_j,
255-
factor,
256-
j,
257-
scattered_p_future,
258-
):
259-
# This function will be executed on workers with access to scattered_p
260-
return opt.root(
261-
euler_equation_solver,
262-
guesses * 0.9,
263-
args=(
264-
r_p,
265-
w,
266-
p_tilde,
267-
bq_j,
268-
rm_j,
269-
tr_j,
270-
ubi_j,
271-
factor,
272-
j,
273-
scattered_p_future,
274-
),
275-
method=p.FOC_root_method,
276-
tol=MINIMIZER_TOL,
277-
)
278-
279-
# Add the delayed computation to our list
280-
lazy_values.append(
281-
solve_for_j(
295+
if client:
296+
# Scatter p only once and only if client not equal None
297+
scattered_p_future = client.scatter(p, broadcast=True)
298+
299+
# Launch in parallel with submit (or map)
300+
futures = []
301+
for j in range(p.J):
302+
guesses = np.append(bssmat[:, j], nssmat[:, j])
303+
f = client.submit(
304+
solve_for_j,
282305
guesses,
283306
r_p,
284307
w,
@@ -289,21 +312,30 @@ def solve_for_j(
289312
ubi[:, j],
290313
factor,
291314
j,
292-
scattered_p,
315+
scattered_p_future,
293316
)
294-
)
317+
futures.append(f)
295318

296-
if client:
297-
# Compute all the values
298-
futures = client.compute(lazy_values)
299-
# Later, gather the results when needed
300319
results = client.gather(futures)
320+
301321
else:
302-
results = compute(
303-
*lazy_values,
304-
scheduler=dask.multiprocessing.get,
305-
num_workers=p.num_workers,
306-
)
322+
# Serial fallback (no dask client)
323+
for j in range(p.J):
324+
guesses = np.append(bssmat[:, j], nssmat[:, j])
325+
res = solve_for_j(
326+
guesses,
327+
r_p,
328+
w,
329+
p_tilde,
330+
bq[:, j],
331+
rm[:, j],
332+
tr[:, j],
333+
ubi[:, j],
334+
factor,
335+
j,
336+
p, # pass the raw object directly
337+
)
338+
results.append(res)
307339

308340
for j, result in enumerate(results):
309341
euler_errors[:, j] = result.fun
@@ -1221,7 +1253,6 @@ def run_SS(p, client=None):
12211253
results
12221254
12231255
"""
1224-
12251256
# Create list of deviation factors for initial guesses of r and TR
12261257
dev_factor_list = [
12271258
[1.00, 1.0],

ogcore/TPI.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,10 @@ def run_TPI(p, client=None):
756756
euler_errors = np.zeros((p.T, 2 * p.S, p.J))
757757
TPIdist_vec = np.zeros(p.maxiter)
758758

759+
# scatter parameters to workers
760+
if client:
761+
scattered_p_future = client.scatter(p, broadcast=True)
762+
759763
# TPI loop
760764
while (TPIiter < p.maxiter) and (TPIdist >= p.mindist_TPI):
761765
outer_loop_vars = (r_p, r, w, p_m, BQ, RM, TR, theta)
@@ -766,38 +770,40 @@ def run_TPI(p, client=None):
766770
).sum(axis=2)
767771
p_tilde = aggr.get_ptilde(p_i[:, :], p.tau_c[:, :], p.alpha_c, "TPI")
768772

769-
# scatter parameters to workers
770-
scattered_p = client.scatter(p, broadcast=True) if client else p
771-
773+
# Initialize Euler errors
772774
euler_errors = np.zeros((p.T, 2 * p.S, p.J))
773-
lazy_values = []
774-
for j in range(p.J):
775-
guesses = (guesses_b[:, :, j], guesses_n[:, :, j])
776-
777-
# Add the delayed computation to our list
778-
lazy_values.append(
779-
delayed(inner_loop)(
775+
# Solve for household decisions
776+
results = []
777+
if client:
778+
futures = []
779+
for j in range(p.J):
780+
guesses = (guesses_b[:, :, j], guesses_n[:, :, j])
781+
f = client.submit(
782+
inner_loop,
780783
guesses,
781784
outer_loop_vars,
782785
initial_values,
783786
ubi,
784787
j,
785788
ind,
786-
scattered_p,
789+
scattered_p_future,
787790
)
788-
)
789-
if client:
790-
# Compute all the values
791-
futures = client.compute(lazy_values)
792-
# Later, gather the results when needed
791+
futures.append(f)
793792
results = client.gather(futures)
794793
else:
795-
results = compute(
796-
*lazy_values,
797-
scheduler=dask.multiprocessing.get,
798-
num_workers=p.num_workers,
799-
)
800-
794+
# Serial fallback (no dask client)
795+
for j in range(p.J):
796+
guesses = (guesses_b[:, :, j], guesses_n[:, :, j])
797+
res = inner_loop(
798+
guesses,
799+
outer_loop_vars,
800+
initial_values,
801+
ubi,
802+
j,
803+
ind,
804+
p, # pass the raw object directly
805+
)
806+
results.append(res)
801807
for j, result in enumerate(results):
802808
euler_errors[:, :, j], b_mat[:, :, j], n_mat[:, :, j] = result
803809

ogcore/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
from ogcore.txfunc import *
2121
from ogcore.utils import *
2222

23-
__version__ = "0.14.6"
23+
__version__ = "0.14.7"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="ogcore",
8-
version="0.14.6",
8+
version="0.14.7",
99
author="Jason DeBacker and Richard W. Evans",
1010
license="CC0 1.0 Universal (CC0 1.0) Public Domain Dedication",
1111
description="A general equilibrium overlapping generations model for fiscal policy analysis",

0 commit comments

Comments
 (0)