Skip to content

Commit ba5c078

Browse files
committed
fix everything in the gradient computation!
1 parent 2230267 commit ba5c078

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

lectures/calvo_gradient.md

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ A
869869
870870
```{code-cell} ipython3
871871
μ_vec = μs.copy()
872-
μ_vec[-1] = μs[-1]/(1-λ)
872+
# μ_vec[-1] = μs[-1]/(1-λ)
873873
```
874874
875875
```{code-cell} ipython3
@@ -1090,26 +1090,64 @@ def compute_J(μ, β, c, α=1, u0=1, u1=0.5, u2=3):
10901090
10911091
return βθ_sum + βθ_square_sum - βμ_square_sum
10921092
1093-
def compute_μ(β, c, T, α=1, u0=1, u1=0.5, u2=3):
1093+
def compute_μ(μ, β, c, T, α=1, u0=1, u1=0.5, u2=3):
10941094
h0 = u0
10951095
h1 = -u1 * α
10961096
h2 = -0.5 * u2 * α**2
10971097
λ = α / (1 + α)
1098-
1098+
e = jnp.hstack([np.ones(T),
1099+
1/(1 - λ)])
10991100
A = jnp.eye(T+1) - λ*jnp.eye(T+1, k=1)
1100-
B = (1-λ) * jnp.linalg.inv(A)
1101+
B = (1-λ) * jnp.linalg.inv(A)
11011102
11021103
β_vec = jnp.hstack([β**jnp.arange(T),
11031104
(β**T/(1-β))])
11041105
1105-
A = 2 * h2 * (B.T @ jnp.diag(β_vec) @ B) - c * jnp.diag(β_vec)
1106-
b = - h1 * (B.T @ β_vec)
1107-
1106+
b = - h1 * (B.T @ β_vec) * e
1107+
E = jnp.diag(e)
1108+
M = E @ (2 * β_vec * h2 * B.T @ B)
1109+
G = c * jnp.diag(β_vec) @ jnp.linalg.inv(E)
1110+
A = M - G
11081111
return jnp.linalg.solve(A, b)
11091112
1110-
print('\n', compute_μ(β=0.85, c=2, T=39))
1113+
μ_vec_closed = compute_μ(jnp.ones(T), β=0.85, c=2, T=39)
1114+
e = jnp.hstack([np.ones(T-1),
1115+
1/(1 - λ)])
1116+
μ_closed = μ_vec_closed / e
1117+
print(f"closed formed μ = \n{μ_closed}")
1118+
```
1119+
1120+
```{code-cell} ipython3
1121+
print(f'deviation = {np.linalg.norm(μ_closed - clq.μ_series)}')
1122+
```
1123+
1124+
```{code-cell} ipython3
1125+
compute_V(μ_closed, β=0.85, c=2)
1126+
```
11111127
1112-
compute_V(compute_μ(β=0.85, c=2, T=39), β=0.85, c=2)
1128+
```{code-cell} ipython3
1129+
@jit
1130+
def compute_J(μ, β, c, α=1, u0=1, u1=0.5, u2=3):
1131+
T = len(μ) - 1
1132+
1133+
h0 = u0
1134+
h1 = -u1 * α
1135+
h2 = -0.5 * u2 * α**2
1136+
λ = α / (1 + α)
1137+
1138+
μ_vec = μ.at[-1].set(μ[-1]/(1-λ))
1139+
1140+
A = jnp.eye(T+1) - λ*jnp.eye(T+1, k=1)
1141+
B = (1-λ) * jnp.linalg.inv(A)
1142+
1143+
β_vec = jnp.hstack([β**jnp.arange(T),
1144+
(β**T/(1-β))])
1145+
1146+
θ = B @ μ_vec
1147+
βθ_sum = jnp.sum((β_vec * h1) * θ)
1148+
βθ_square_sum = β_vec * h2 * θ.T @ θ
1149+
βμ_square_sum = 0.5 * c * β_vec * μ.T @ μ
1150+
return βθ_sum + βθ_square_sum - βμ_square_sum
11131151
```
11141152
11151153
```{code-cell} ipython3
@@ -1130,6 +1168,10 @@ optimized_μ = adam_optimizer(grad_J, μ_init)
11301168
print(f"optimized μ = \n{optimized_μ}")
11311169
```
11321170
1171+
```{code-cell} ipython3
1172+
grad_J(optimized_μ)
1173+
```
1174+
11331175
```{code-cell} ipython3
11341176
print(f"original μ = \n{clq.μ_series}")
11351177
```

0 commit comments

Comments
 (0)