Skip to content

Commit 8e018a8

Browse files
committed
update calvo for faster descent
1 parent 0ddd8b4 commit 8e018a8

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

lectures/calvo_gradient.md

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ kernelspec:
1111
name: python3
1212
---
1313

14-
1514
# A Model of Calvo
1615

1716
This lecture describes a linear-quadratic versions of a model that Guillermo Calvo {cite}`Calvo1978` used to illustrate the **time inconsistency** of optimal government
@@ -318,8 +317,7 @@ We use the following imports in this lecture
318317
from quantecon import LQ
319318
import numpy as np
320319
import jax.numpy as jnp
321-
from jax import jit
322-
import jax
320+
from jax import jit, grad
323321
import optax
324322
import statsmodels.api as sm
325323
import matplotlib.pyplot as plt
@@ -449,7 +447,7 @@ Now we compute the value of $V$ under this setup, and compare it against those o
449447
450448
```{code-cell} ipython3
451449
# Assume β=0.85, c=2, T=40.
452-
T=40
450+
T = 40
453451
clq = ChangLQ(β=0.85, c=2, T=T)
454452
```
455453
@@ -459,30 +457,40 @@ def compute_θ(μ, α=1):
459457
λ = α / (1 + α)
460458
T = len(μ) - 1
461459
μbar = μ[-1]
462-
θ = jnp.zeros(len(μ))
463-
464-
for t in range(T):
465-
temp = sum(λ**j * μ[t + j] for j in range(T - t))
466-
θ = θ.at[t].set((1 - λ) * temp + λ**(T - t) * μbar)
467-
468-
θ = θ.at[-1].set(μbar)
460+
461+
# Create an array of powers for λ
462+
λ_powers = λ ** jnp.arange(T + 1)
463+
464+
# Compute the weighted sums for all t
465+
weighted_sums = jnp.array(
466+
[jnp.sum(λ_powers[:T-t] * μ[t:T]) for t in range(T)])
467+
468+
# Compute θ values except for the last element
469+
θ = (1 - λ) * weighted_sums + λ**(T - jnp.arange(T)) * μbar
470+
471+
# Set the last element
472+
θ = jnp.append(θ, μbar)
473+
469474
return θ
470-
475+
471476
@jit
472477
def compute_V(μ, β, c, α=1, u0=1, u1=0.5, u2=3):
473478
θ = compute_θ(μ, α)
474479
475480
h0 = u0
476481
h1 = -u1 * α
477482
h2 = -0.5 * u2 * α**2
478-
483+
479484
T = len(μ) - 1
480-
V = 0
485+
t = np.arange(T)
486+
487+
# Compute sum except for the last element
488+
V_sum = np.sum(β**t * (h0 + h1 * θ[:T] + h2 * θ[:T]**2 - 0.5 * c * μ[:T]**2))
481489
482-
for t in range(T):
483-
V += β**t * (h0 + h1 * θ[t] + h2 * θ[t]**2 - 0.5 * c * μ[t]**2)
490+
# Compute the final term
491+
V_final = (β**T / (1 - β)) * (h0 + h1 * μ[-1] + h2 * μ[-1]**2 - 0.5 * c * μ[-1]**2)
484492
485-
V += (β**T / (1 - β)) * (h0 + h1 * μ[-1] + h2 * μ[-1]**2 - 0.5 * c * μ[-1]**2)
493+
V = V_sum + V_final
486494
487495
return V
488496
```
@@ -501,7 +509,7 @@ We will use the [`optax.adam`](https://optax.readthedocs.io/en/latest/api/optimi
501509
```{code-cell} ipython3
502510
def adam_optimizer(grad_func, init_params,
503511
lr=0.1,
504-
max_iter=1_000,
512+
max_iter=10_000,
505513
error_tol=1e-7):
506514
507515
# Set initial parameters and optimizer
@@ -531,7 +539,7 @@ def adam_optimizer(grad_func, init_params,
531539
return params
532540
```
533541
534-
Here we use automatic differentiation functionality in JAX with `jax.grad`.
542+
Here we use automatic differentiation functionality in JAX with `grad`.
535543
536544
```{code-cell} ipython3
537545
:tags: [scroll-output]
@@ -540,7 +548,7 @@ Here we use automatic differentiation functionality in JAX with `jax.grad`.
540548
μ_init = jnp.zeros(T)
541549
542550
# Maximization instead of minimization
543-
grad_V = jit(jax.grad(
551+
grad_V = jit(grad(
544552
lambda μ: -compute_V(μ, β=0.85, c=2)))
545553
```
546554
@@ -691,7 +699,7 @@ In this case, we restrict $\mu_t = \bar \mu \text{ for } \forall t$
691699
μ_init = jnp.zeros(1)
692700
693701
# Maximization instead of minimization
694-
grad_V = jit(jax.grad(
702+
grad_V = jit(grad(
695703
lambda μ: -compute_V(μ, β=0.85, c=2)))
696704
697705
# Optimize μ

0 commit comments

Comments
 (0)