Skip to content

Commit 7ce080f

Browse files
authored
Generic DeepONet (#68)
* generic deeponet * example for generic deeponet * adapt tests to new interface
1 parent e227700 commit 7ce080f

File tree

5 files changed

+282
-39
lines changed

5 files changed

+282
-39
lines changed

examples/run_poisson_deeponet.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import argparse
2+
import logging
3+
4+
import torch
5+
from problems.poisson import Poisson
6+
7+
from pina import PINN, LabelTensor, Plotter
8+
from pina.model.deeponet import DeepONet, check_combos, spawn_combo_networks
9+
10+
logging.basicConfig(
11+
filename="poisson_deeponet.log", filemode="w", level=logging.INFO
12+
)
13+
14+
15+
class SinFeature(torch.nn.Module):
16+
"""
17+
Feature: sin(x)
18+
"""
19+
20+
def __init__(self, label):
21+
super().__init__()
22+
23+
if not isinstance(label, (tuple, list)):
24+
label = [label]
25+
self._label = label
26+
27+
def forward(self, x):
28+
"""
29+
Defines the computation performed at every call.
30+
31+
:param LabelTensor x: the input tensor.
32+
:return: the output computed by the model.
33+
:rtype: LabelTensor
34+
"""
35+
t = torch.sin(x.extract(self._label) * torch.pi)
36+
return LabelTensor(t, [f"sin({self._label})"])
37+
38+
39+
def prepare_deeponet_model(args, problem, extra_feature_combo_func=None):
40+
combos = tuple(map(lambda combo: combo.split("-"), args.combos.split(",")))
41+
check_combos(combos, problem.input_variables)
42+
43+
extra_feature = extra_feature_combo_func if args.extra else None
44+
networks = spawn_combo_networks(
45+
combos=combos,
46+
layers=list(map(int, args.layers.split(","))) if args.layers else [],
47+
output_dimension=args.hidden * len(problem.output_variables),
48+
func=torch.nn.Softplus,
49+
extra_feature=extra_feature,
50+
bias=not args.nobias,
51+
)
52+
53+
return DeepONet(
54+
networks,
55+
problem.output_variables,
56+
aggregator=args.aggregator,
57+
reduction=args.reduction,
58+
)
59+
60+
61+
if __name__ == "__main__":
62+
parser = argparse.ArgumentParser(description="Run PINA")
63+
parser.add_argument("-s", "--save", action="store_true")
64+
parser.add_argument("-l", "--load", action="store_true")
65+
parser.add_argument("id_run", help="Run ID", type=int)
66+
67+
parser.add_argument("--extra", help="Extra features", action="store_true")
68+
parser.add_argument("--nobias", action="store_true")
69+
parser.add_argument(
70+
"--combos",
71+
help="DeepONet internal network combinations",
72+
type=str,
73+
required=True,
74+
)
75+
parser.add_argument(
76+
"--aggregator", help="Aggregator for DeepONet", type=str, default="*"
77+
)
78+
parser.add_argument(
79+
"--reduction", help="Reduction for DeepONet", type=str, default="+"
80+
)
81+
parser.add_argument(
82+
"--hidden",
83+
help="Number of variables in the hidden DeepONet layer",
84+
type=int,
85+
required=True,
86+
)
87+
parser.add_argument(
88+
"--layers",
89+
help="Structure of the DeepONet partial layers",
90+
type=str,
91+
required=True,
92+
)
93+
cli_args = parser.parse_args()
94+
95+
poisson_problem = Poisson()
96+
97+
model = prepare_deeponet_model(
98+
cli_args,
99+
poisson_problem,
100+
extra_feature_combo_func=lambda combo: [SinFeature(combo)],
101+
)
102+
pinn = PINN(poisson_problem, model, lr=0.01, regularizer=1e-8)
103+
if cli_args.save:
104+
pinn.span_pts(
105+
20, "grid", locations=["gamma1", "gamma2", "gamma3", "gamma4"]
106+
)
107+
pinn.span_pts(20, "grid", locations=["D"])
108+
pinn.train(1.0e-10, 100)
109+
pinn.save_state(f"pina.poisson_{cli_args.id_run}")
110+
if cli_args.load:
111+
pinn.load_state(f"pina.poisson_{cli_args.id_run}")
112+
plotter = Plotter()
113+
plotter.plot(pinn)

pina/model/deeponet.py

Lines changed: 142 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,60 @@
11
"""Module for DeepONet model"""
2+
import logging
3+
from functools import partial, reduce
4+
25
import torch
36
import torch.nn as nn
47

58
from pina import LabelTensor
9+
from pina.model import FeedForward
10+
from pina.utils import is_function
11+
12+
13+
def check_combos(combos, variables):
14+
"""
15+
Check that the given combinations are subsets (overlapping
16+
is allowed) of the given set of variables.
17+
18+
:param iterable(iterable(str)) combos: Combinations of variables.
19+
:param iterable(str) variables: Variables.
20+
"""
21+
for combo in combos:
22+
for variable in combo:
23+
if variable not in variables:
24+
raise ValueError(
25+
f"Combinations should be (overlapping) subsets of input variables, {variable} is not an input variable"
26+
)
27+
28+
29+
def spawn_combo_networks(
30+
combos, layers, output_dimension, func, extra_feature, bias=True
31+
):
32+
"""
33+
Spawn internal networks for DeepONet based on the given combos.
34+
35+
:param iterable(iterable(str)) combos: Combinations of variables.
36+
:param iterable(int) layers: Size of hidden layers.
37+
:param int output_dimension: Size of the output layer of the networks.
38+
:param func: Nonlinearity.
39+
:param extra_feature: Extra feature to be considered by the networks.
40+
:param bool bias: Whether to consider bias or not.
41+
"""
42+
if is_function(extra_feature):
43+
extra_feature_func = lambda _: extra_feature
44+
else:
45+
extra_feature_func = extra_feature
46+
47+
return [
48+
FeedForward(
49+
layers=layers,
50+
input_variables=tuple(combo),
51+
output_variables=output_dimension,
52+
func=func,
53+
extra_features=extra_feature_func(combo),
54+
bias=bias,
55+
)
56+
for combo in combos
57+
]
658

759

860
class DeepONet(torch.nn.Module):
@@ -18,23 +70,27 @@ class DeepONet(torch.nn.Module):
1870
<https://doi.org/10.1038/s42256-021-00302-5>`_
1971
2072
"""
21-
def __init__(self, branch_net, trunk_net, output_variables):
73+
74+
def __init__(self, nets, output_variables, aggregator="*", reduction="+"):
2275
"""
23-
:param torch.nn.Module branch_net: the neural network to use as branch
24-
model. It has to take as input a :class:`LabelTensor`. The number
25-
of dimension of the output has to be the same of the `trunk_net`.
26-
:param torch.nn.Module trunk_net: the neural network to use as trunk
27-
model. It has to take as input a :class:`LabelTensor`. The number
28-
of dimension of the output has to be the same of the `branch_net`.
76+
:param iterable(torch.nn.Module) nets: Internal DeepONet networks
77+
(branch and trunk in the original DeepONet).
2978
:param list(str) output_variables: the list containing the labels
3079
corresponding to the components of the output computed by the
3180
model.
81+
:param string | callable aggregator: Aggregator to be used to aggregate
82+
partial results from the modules in `nets`. Partial results are
83+
aggregated component-wise. See :func:`_symbol_functions` for the
84+
available default aggregators.
85+
:param string | callable reduction: Reduction to be used to reduce
86+
the aggregated result of the modules in `nets` to the desired output
87+
dimension. See :func:`_symbol_functions` for the available default
88+
reductions.
3289
3390
:Example:
3491
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
3592
>>> trunk = FFN(input_variables=['b'], output_variables=20)
36-
>>> onet = DeepONet(trunk_net=trunk, branch_net=branch
37-
>>> output_variables=output_vars)
93+
>>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
3894
DeepONet(
3995
(trunk_net): FeedForward(
4096
(extra_features): Sequential()
@@ -63,22 +119,76 @@ def __init__(self, branch_net, trunk_net, output_variables):
63119
self.output_variables = output_variables
64120
self.output_dimension = len(output_variables)
65121

66-
trunk_out_dim = trunk_net.layers[-1].out_features
67-
branch_out_dim = branch_net.layers[-1].out_features
68-
69-
if trunk_out_dim != branch_out_dim:
70-
raise ValueError('Branch and trunk networks have not the same '
71-
'output dimension.')
72-
73-
self.trunk_net = trunk_net
74-
self.branch_net = branch_net
122+
self._init_aggregator(aggregator, n_nets=len(nets))
123+
hidden_size = nets[0].model[-1].out_features
124+
self._init_reduction(reduction, hidden_size=hidden_size)
125+
126+
if not DeepONet._all_nets_same_output_layer_size(nets):
127+
raise ValueError("All networks should have the same output size")
128+
self._nets = torch.nn.ModuleList(nets)
129+
logging.info("Combo DeepONet children: %s", list(self.children()))
130+
131+
@staticmethod
132+
def _symbol_functions(**kwargs):
133+
return {
134+
"+": partial(torch.sum, **kwargs),
135+
"*": partial(torch.prod, **kwargs),
136+
"mean": partial(torch.mean, **kwargs),
137+
"min": lambda x: torch.min(x, **kwargs).values,
138+
"max": lambda x: torch.max(x, **kwargs).values,
139+
}
140+
141+
def _init_aggregator(self, aggregator, n_nets):
142+
aggregator_funcs = DeepONet._symbol_functions(dim=2)
143+
if aggregator in aggregator_funcs:
144+
aggregator_func = aggregator_funcs[aggregator]
145+
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
146+
aggregator_func = aggregator
147+
elif aggregator == "linear":
148+
aggregator_func = nn.Linear(n_nets, len(self.output_variables))
149+
else:
150+
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
151+
152+
self._aggregator = aggregator_func
153+
logging.info("Selected aggregator: %s", str(aggregator_func))
154+
155+
# test the aggregator
156+
test = self._aggregator(torch.ones((20, 3, n_nets)))
157+
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
158+
raise ValueError(
159+
f"Invalid aggregator output shape: {(20, 3, n_nets)} -> {test.shape}"
160+
)
75161

76-
self.reduction = nn.Linear(trunk_out_dim, self.output_dimension)
162+
def _init_reduction(self, reduction, hidden_size):
163+
reduction_funcs = DeepONet._symbol_functions(dim=2)
164+
if reduction in reduction_funcs:
165+
reduction_func = reduction_funcs[reduction]
166+
elif isinstance(reduction, nn.Module) or is_function(reduction):
167+
reduction_func = reduction
168+
elif reduction == "linear":
169+
reduction_func = nn.Linear(hidden_size, len(self.output_variables))
170+
else:
171+
raise ValueError(f"Unsupported reduction: {reduction}")
172+
173+
self._reduction = reduction_func
174+
logging.info("Selected reduction: %s", str(reduction))
175+
176+
# test the reduction
177+
test = self._reduction(torch.ones((20, 3, hidden_size)))
178+
if test.ndim < 2 or tuple(test.shape)[:2] != (20, 3):
179+
msg = f"Invalid reduction output shape: {(20, 3, hidden_size)} -> {test.shape}"
180+
raise ValueError(msg)
181+
182+
@staticmethod
183+
def _all_nets_same_output_layer_size(nets):
184+
size = nets[0].layers[-1].out_features
185+
return all((net.layers[-1].out_features == size for net in nets[1:]))
77186

78187
@property
79188
def input_variables(self):
80189
"""The input variables of the model"""
81-
return self.trunk_net.input_variables + self.branch_net.input_variables
190+
nets_input_variables = map(lambda net: net.input_variables, self._nets)
191+
return reduce(sum, nets_input_variables)
82192

83193
def forward(self, x):
84194
"""
@@ -89,15 +199,20 @@ def forward(self, x):
89199
:rtype: LabelTensor
90200
"""
91201

92-
branch_output = self.branch_net(
93-
x.extract(self.branch_net.input_variables))
202+
nets_outputs = tuple(
203+
net(x.extract(net.input_variables)) for net in self._nets
204+
)
205+
# torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets)
206+
aggregated = self._aggregator(torch.dstack(nets_outputs))
207+
# net_output_size = output_variables * hidden_size
208+
aggregated_reshaped = aggregated.view(
209+
(len(x), len(self.output_variables), -1)
210+
)
211+
output_ = self._reduction(aggregated_reshaped)
212+
output_ = torch.squeeze(torch.atleast_3d(output_), dim=2)
94213

95-
trunk_output = self.trunk_net(
96-
x.extract(self.trunk_net.input_variables))
97-
98-
output_ = self.reduction(trunk_output * branch_output)
214+
assert output_.shape == (len(x), len(self.output_variables))
99215

100216
output_ = output_.as_subclass(LabelTensor)
101217
output_.labels = self.output_variables
102-
103218
return output_

pina/model/feed_forward.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ class FeedForward(torch.nn.Module):
2626
`inner_size` are not considered.
2727
:param iterable(torch.nn.Module) extra_features: the additional input
2828
features to use ad augmented input.
29+
:param bool bias: If `True` the MLP will consider some bias.
2930
"""
3031
def __init__(self, input_variables, output_variables, inner_size=20,
31-
n_layers=2, func=nn.Tanh, layers=None, extra_features=None):
32+
n_layers=2, func=nn.Tanh, layers=None, extra_features=None,
33+
bias=True):
3234
"""
3335
"""
3436
super().__init__()
@@ -62,7 +64,9 @@ def __init__(self, input_variables, output_variables, inner_size=20,
6264

6365
self.layers = []
6466
for i in range(len(tmp_layers)-1):
65-
self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1]))
67+
self.layers.append(
68+
nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias)
69+
)
6670

6771
if isinstance(func, list):
6872
self.functions = func
@@ -94,7 +98,7 @@ def forward(self, x):
9498
if self.input_variables:
9599
x = x.extract(self.input_variables)
96100

97-
for i, feature in enumerate(self.extra_features):
101+
for feature in self.extra_features:
98102
x = x.append(feature(x))
99103

100104
output = self.model(x).as_subclass(LabelTensor)

pina/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Utils module"""
22
from functools import reduce
3+
import types
4+
35
import torch
46
from torch.utils.data import DataLoader, default_collate, ConcatDataset
57

@@ -85,6 +87,17 @@ def torch_lhs(n, dim):
8587
return samples
8688

8789

90+
def is_function(f):
91+
"""
92+
Checks whether the given object `f` is a function or lambda.
93+
94+
:param object f: The object to be checked.
95+
:return: `True` if `f` is a function, `False` otherwise.
96+
:rtype: bool
97+
"""
98+
return type(f) == types.FunctionType or type(f) == types.LambdaType
99+
100+
88101
class PinaDataset():
89102

90103
def __init__(self, pinn) -> None:
@@ -144,4 +157,4 @@ def __getitem__(self, index):
144157
return {self._location: tensor}
145158

146159
def __len__(self):
147-
return self._len
160+
return self._len

0 commit comments

Comments
 (0)