|
| 1 | +import numpy as np |
| 2 | +from pathlib import Path |
| 3 | +from mpi4py import MPI |
| 4 | + |
| 5 | +from pySDC.helpers.stats_helper import get_sorted |
| 6 | + |
| 7 | +from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI |
| 8 | +from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order |
| 9 | +from pySDC.implementations.problem_classes.NonlinearSchroedinger_MPIFFT import nonlinearschroedinger_imex |
| 10 | +from pySDC.implementations.transfer_classes.TransferMesh_MPIFFT import fft_to_fft |
| 11 | +from pySDC.projects.Resilience.hook import LogData, hook_collection |
| 12 | + |
| 13 | + |
| 14 | +def run_Schroedinger( |
| 15 | + custom_description=None, |
| 16 | + num_procs=1, |
| 17 | + Tend=1.0, |
| 18 | + hook_class=LogData, |
| 19 | + fault_stuff=None, |
| 20 | + custom_controller_params=None, |
| 21 | + custom_problem_params=None, |
| 22 | + use_MPI=False, |
| 23 | + space_comm=None, |
| 24 | + **kwargs, |
| 25 | +): |
| 26 | + """ |
| 27 | + Run a Schroedinger problem with default parameters. |
| 28 | +
|
| 29 | + Args: |
| 30 | + custom_description (dict): Overwrite presets |
| 31 | + num_procs (int): Number of steps for MSSDC |
| 32 | + Tend (float): Time to integrate to |
| 33 | + hook_class (pySDC.Hook): A hook to store data |
| 34 | + fault_stuff (dict): A dictionary with information on how to add faults |
| 35 | + custom_controller_params (dict): Overwrite presets |
| 36 | + custom_problem_params (dict): Overwrite presets |
| 37 | + use_MPI (bool): Whether or not to use MPI |
| 38 | +
|
| 39 | + Returns: |
| 40 | + dict: The stats object |
| 41 | + controller: The controller |
| 42 | + Tend: The time that was supposed to be integrated to |
| 43 | + """ |
| 44 | + |
| 45 | + space_comm = MPI.COMM_WORLD if space_comm is None else space_comm |
| 46 | + rank = space_comm.Get_rank() |
| 47 | + |
| 48 | + # initialize level parameters |
| 49 | + level_params = dict() |
| 50 | + level_params['restol'] = 1e-08 |
| 51 | + level_params['dt'] = 1e-01 / 2 |
| 52 | + level_params['nsweeps'] = 1 |
| 53 | + |
| 54 | + # initialize sweeper parameters |
| 55 | + sweeper_params = dict() |
| 56 | + sweeper_params['quad_type'] = 'RADAU-RIGHT' |
| 57 | + sweeper_params['num_nodes'] = 3 |
| 58 | + sweeper_params['QI'] = 'IE' |
| 59 | + sweeper_params['initial_guess'] = 'spread' |
| 60 | + |
| 61 | + # initialize problem parameters |
| 62 | + problem_params = dict() |
| 63 | + problem_params['nvars'] = (128, 128) |
| 64 | + problem_params['spectral'] = False |
| 65 | + problem_params['comm'] = space_comm |
| 66 | + |
| 67 | + if custom_problem_params is not None: |
| 68 | + problem_params = {**problem_params, **custom_problem_params} |
| 69 | + |
| 70 | + # initialize step parameters |
| 71 | + step_params = dict() |
| 72 | + step_params['maxiter'] = 50 |
| 73 | + |
| 74 | + # initialize controller parameters |
| 75 | + controller_params = dict() |
| 76 | + controller_params['logger_level'] = 30 if rank == 0 else 99 |
| 77 | + controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class]) |
| 78 | + controller_params['mssdc_jac'] = False |
| 79 | + |
| 80 | + # fill description dictionary for easy step instantiation |
| 81 | + if custom_controller_params is not None: |
| 82 | + controller_params = {**controller_params, **custom_controller_params} |
| 83 | + |
| 84 | + description = dict() |
| 85 | + description['problem_params'] = problem_params |
| 86 | + description['problem_class'] = nonlinearschroedinger_imex |
| 87 | + description['sweeper_class'] = imex_1st_order |
| 88 | + description['sweeper_params'] = sweeper_params |
| 89 | + description['level_params'] = level_params |
| 90 | + description['step_params'] = step_params |
| 91 | + |
| 92 | + if custom_description is not None: |
| 93 | + for k in custom_description.keys(): |
| 94 | + if type(custom_description[k]) == dict: |
| 95 | + description[k] = {**description.get(k, {}), **custom_description.get(k, {})} |
| 96 | + else: |
| 97 | + description[k] = custom_description[k] |
| 98 | + |
| 99 | + # set time parameters |
| 100 | + t0 = 0.0 |
| 101 | + |
| 102 | + # instantiate controller |
| 103 | + assert use_MPI == False, "MPI version in time not implemented" |
| 104 | + controller = controller_nonMPI(num_procs=num_procs, controller_params=controller_params, description=description) |
| 105 | + |
| 106 | + # get initial values on finest level |
| 107 | + P = controller.MS[0].levels[0].prob |
| 108 | + uinit = P.u_exact(t0) |
| 109 | + |
| 110 | + # insert faults |
| 111 | + if fault_stuff is not None: |
| 112 | + from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults |
| 113 | + |
| 114 | + nvars = [me / 2 for me in problem_params['nvars']] |
| 115 | + nvars[0] += 1 |
| 116 | + |
| 117 | + rnd_args = {'iteration': 5, 'problem_pos': nvars, 'min_node': 1} |
| 118 | + args = {'time': 0.3, 'target': 0} |
| 119 | + prepare_controller_for_faults(controller, fault_stuff, rnd_args, args) |
| 120 | + |
| 121 | + # call main function to get things done... |
| 122 | + uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) |
| 123 | + |
| 124 | + return stats, controller, Tend |
| 125 | + |
| 126 | + |
| 127 | +def plot_solution(stats): # pragma: no cover |
| 128 | + import matplotlib.pyplot as plt |
| 129 | + |
| 130 | + u = get_sorted(stats, type='u') |
| 131 | + fig, axs = plt.subplots(1, 2, figsize=(12, 5)) |
| 132 | + axs[0].imshow(np.abs(u[0][1])) |
| 133 | + axs[0].set_title(f't={u[0][0]:.2f}') |
| 134 | + for i in range(len(u)): |
| 135 | + axs[1].cla() |
| 136 | + axs[1].imshow(np.abs(u[i][1])) |
| 137 | + axs[1].set_title(f't={u[i][0]:.2f}') |
| 138 | + plt.pause(1e-1) |
| 139 | + fig.tight_layout() |
| 140 | + plt.show() |
| 141 | + |
| 142 | + |
| 143 | +def main(): |
| 144 | + stats, _, _ = run_Schroedinger(space_comm=MPI.COMM_WORLD) |
| 145 | + plot_solution(stats) |
| 146 | + |
| 147 | + |
| 148 | +if __name__ == "__main__": |
| 149 | + main() |
0 commit comments