Skip to content

Commit a6d863a

Browse files
authored
Merge pull request #276 from brownbaerchen/resilience_schroedinger
Resilience in Schrödinger equation
2 parents 8297fb2 + bfe353e commit a6d863a

File tree

13 files changed

+787
-169
lines changed

13 files changed

+787
-169
lines changed

pySDC/implementations/controller_classes/controller_nonMPI.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, num_procs, controller_params, description):
3939
for _ in range(num_procs - 1):
4040
self.MS.append(dill.copy(self.MS[0]))
4141
# if this fails (e.g. due to un-picklable data in the steps), initialize seperately
42-
except dill.PicklingError and TypeError:
42+
except (dill.PicklingError, TypeError):
4343
self.logger.warning('Need to initialize steps separately due to pickling error')
4444
for _ in range(num_procs - 1):
4545
self.MS.append(stepclass.step(description))

pySDC/implementations/problem_classes/NonlinearSchroedinger_MPIFFT.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def solve_system(self, rhs, factor, u0, t):
154154

155155
return me
156156

157-
def u_exact(self, t):
157+
def u_exact(self, t, **kwargs):
158158
"""
159159
Routine to compute the exact solution at time t, see (1.3) https://arxiv.org/pdf/nlin/0702010.pdf for details
160160

pySDC/projects/Resilience/Lorenz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def run_Lorenz(
6666
# initialize controller parameters
6767
controller_params = dict()
6868
controller_params['logger_level'] = 30
69-
controller_params['hook_class'] = hook_collection + hook_class if type(hook_class) == list else [hook_class]
69+
controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class])
7070
controller_params['mssdc_jac'] = False
7171

7272
if custom_controller_params is not None:

pySDC/projects/Resilience/ResilienceLorenz.py

Lines changed: 0 additions & 105 deletions
This file was deleted.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import numpy as np
2+
from pathlib import Path
3+
from mpi4py import MPI
4+
5+
from pySDC.helpers.stats_helper import get_sorted
6+
7+
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
8+
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
9+
from pySDC.implementations.problem_classes.NonlinearSchroedinger_MPIFFT import nonlinearschroedinger_imex
10+
from pySDC.implementations.transfer_classes.TransferMesh_MPIFFT import fft_to_fft
11+
from pySDC.projects.Resilience.hook import LogData, hook_collection
12+
13+
14+
def run_Schroedinger(
15+
custom_description=None,
16+
num_procs=1,
17+
Tend=1.0,
18+
hook_class=LogData,
19+
fault_stuff=None,
20+
custom_controller_params=None,
21+
custom_problem_params=None,
22+
use_MPI=False,
23+
space_comm=None,
24+
**kwargs,
25+
):
26+
"""
27+
Run a Schroedinger problem with default parameters.
28+
29+
Args:
30+
custom_description (dict): Overwrite presets
31+
num_procs (int): Number of steps for MSSDC
32+
Tend (float): Time to integrate to
33+
hook_class (pySDC.Hook): A hook to store data
34+
fault_stuff (dict): A dictionary with information on how to add faults
35+
custom_controller_params (dict): Overwrite presets
36+
custom_problem_params (dict): Overwrite presets
37+
use_MPI (bool): Whether or not to use MPI
38+
39+
Returns:
40+
dict: The stats object
41+
controller: The controller
42+
Tend: The time that was supposed to be integrated to
43+
"""
44+
45+
space_comm = MPI.COMM_WORLD if space_comm is None else space_comm
46+
rank = space_comm.Get_rank()
47+
48+
# initialize level parameters
49+
level_params = dict()
50+
level_params['restol'] = 1e-08
51+
level_params['dt'] = 1e-01 / 2
52+
level_params['nsweeps'] = 1
53+
54+
# initialize sweeper parameters
55+
sweeper_params = dict()
56+
sweeper_params['quad_type'] = 'RADAU-RIGHT'
57+
sweeper_params['num_nodes'] = 3
58+
sweeper_params['QI'] = 'IE'
59+
sweeper_params['initial_guess'] = 'spread'
60+
61+
# initialize problem parameters
62+
problem_params = dict()
63+
problem_params['nvars'] = (128, 128)
64+
problem_params['spectral'] = False
65+
problem_params['comm'] = space_comm
66+
67+
if custom_problem_params is not None:
68+
problem_params = {**problem_params, **custom_problem_params}
69+
70+
# initialize step parameters
71+
step_params = dict()
72+
step_params['maxiter'] = 50
73+
74+
# initialize controller parameters
75+
controller_params = dict()
76+
controller_params['logger_level'] = 30 if rank == 0 else 99
77+
controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class])
78+
controller_params['mssdc_jac'] = False
79+
80+
# fill description dictionary for easy step instantiation
81+
if custom_controller_params is not None:
82+
controller_params = {**controller_params, **custom_controller_params}
83+
84+
description = dict()
85+
description['problem_params'] = problem_params
86+
description['problem_class'] = nonlinearschroedinger_imex
87+
description['sweeper_class'] = imex_1st_order
88+
description['sweeper_params'] = sweeper_params
89+
description['level_params'] = level_params
90+
description['step_params'] = step_params
91+
92+
if custom_description is not None:
93+
for k in custom_description.keys():
94+
if type(custom_description[k]) == dict:
95+
description[k] = {**description.get(k, {}), **custom_description.get(k, {})}
96+
else:
97+
description[k] = custom_description[k]
98+
99+
# set time parameters
100+
t0 = 0.0
101+
102+
# instantiate controller
103+
assert use_MPI == False, "MPI version in time not implemented"
104+
controller = controller_nonMPI(num_procs=num_procs, controller_params=controller_params, description=description)
105+
106+
# get initial values on finest level
107+
P = controller.MS[0].levels[0].prob
108+
uinit = P.u_exact(t0)
109+
110+
# insert faults
111+
if fault_stuff is not None:
112+
from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults
113+
114+
nvars = [me / 2 for me in problem_params['nvars']]
115+
nvars[0] += 1
116+
117+
rnd_args = {'iteration': 5, 'problem_pos': nvars, 'min_node': 1}
118+
args = {'time': 0.3, 'target': 0}
119+
prepare_controller_for_faults(controller, fault_stuff, rnd_args, args)
120+
121+
# call main function to get things done...
122+
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
123+
124+
return stats, controller, Tend
125+
126+
127+
def plot_solution(stats): # pragma: no cover
128+
import matplotlib.pyplot as plt
129+
130+
u = get_sorted(stats, type='u')
131+
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
132+
axs[0].imshow(np.abs(u[0][1]))
133+
axs[0].set_title(f't={u[0][0]:.2f}')
134+
for i in range(len(u)):
135+
axs[1].cla()
136+
axs[1].imshow(np.abs(u[i][1]))
137+
axs[1].set_title(f't={u[i][0]:.2f}')
138+
plt.pause(1e-1)
139+
fig.tight_layout()
140+
plt.show()
141+
142+
143+
def main():
144+
stats, _, _ = run_Schroedinger(space_comm=MPI.COMM_WORLD)
145+
plot_solution(stats)
146+
147+
148+
if __name__ == "__main__":
149+
main()

pySDC/projects/Resilience/accuracy_check.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def multiple_runs(
113113
if dt_list is not None:
114114
pass
115115
elif Tend_fixed:
116-
dt_list = 0.1 * 10.0 ** -(np.arange(5) / 2)
116+
dt_list = 0.1 * 10.0 ** -(np.arange(3) / 2)
117117
else:
118118
dt_list = 0.01 * 10.0 ** -(np.arange(20) / 10.0)
119119

@@ -377,7 +377,7 @@ def check_order_with_adaptivity():
377377
we expect is 1 + 1/k.
378378
"""
379379
setup_mpl()
380-
ks = [4, 3, 2]
380+
ks = [3, 2]
381381
for serial in [True, False]:
382382
fig, ax = plt.subplots(1, 1, figsize=(3.5, 3))
383383
plot_all_errors(ax, ks, serial, Tend_fixed=5e-1, var='e_tol', dt_list=[1e-5, 1e-6, 1e-7], avoid_restarts=True)

0 commit comments

Comments
 (0)