Skip to content

Commit d5ff028

Browse files
authored
Fix: Enable JOPT to support open-system optimization with TRACEDIFF fidelity (qutip#49)
1 parent ab2f132 commit d5ff028

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

src/qutip_qoc/_jopt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def _infid(self, params):
150150
if self._fid_type == "TRACEDIFF":
151151
diff = X - self._target
152152
# to prevent if/else in qobj.dag() and qobj.tr()
153-
diff_dag = Qobj(diff.data.adjoint(), dims=diff.dims)
154-
g = 1 / 2 * (diff_dag * diff).data.trace()
153+
diff_dag = diff.dag() # direct access to JAX array, no fallback!
154+
g = 1 / 2 * jnp.trace(diff_dag.data._jxa @ diff.data._jxa)
155155
infid = jnp.real(self._norm_fac * g)
156156
else:
157157
g = self._norm_fac * self._target.overlap(X)
@@ -160,4 +160,4 @@ def _infid(self, params):
160160
elif self._fid_type == "SU": # f_SU (incl global phase)
161161
infid = 1 - jnp.real(g)
162162

163-
return infid
163+
return infid

src/qutip_qoc/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
try:
1414
import jax
1515
import jaxlib
16-
_jitfun_type = jaxlib.xla_extension.PjitFunction
16+
_jitfun_type = type(jax.jit(lambda x: x))
1717
except ImportError:
1818
_jitfun_type = None
1919

tests/test_jopt_open_system_bug.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
import qutip as qt
3+
from qutip_qoc import Objective, optimize_pulses
4+
5+
from jax import jit, numpy
6+
7+
def test_open_system_jopt_runs_without_error():
8+
Hd = qt.Qobj(np.diag([1, 2]))
9+
c_ops = [np.sqrt(0.1) * qt.sigmam()]
10+
Hc = qt.sigmax()
11+
12+
Ld = qt.liouvillian(H=Hd, c_ops=c_ops)
13+
Lc = qt.liouvillian(Hc)
14+
15+
initial_state = qt.fock_dm(2, 0)
16+
target_state = qt.fock_dm(2, 1)
17+
18+
times = np.linspace(0, 2 * np.pi, 250)
19+
20+
@jit
21+
def sin_x(t, c, **kwargs):
22+
return c[0] * numpy.sin(c[1] * t)
23+
L = [Ld, [Lc, sin_x]]
24+
25+
guess_params = [1, 0.5]
26+
27+
res_jopt = optimize_pulses(
28+
objectives = Objective(initial_state, L, target_state),
29+
control_parameters = {
30+
"ctrl_x": {"guess": guess_params, "bounds": [(-1, 1), (0, 2 * np.pi)]}
31+
},
32+
tlist = times,
33+
algorithm_kwargs = {
34+
"alg": "JOPT",
35+
"fid_err_targ": 0.001,
36+
},
37+
)
38+
39+
assert res_jopt.infidelity < 0.25, f"Fidelity error too high: {res_jopt.infidelity}"

0 commit comments

Comments
 (0)