1
1
"""Module for DeepONet model"""
2
+ import logging
3
+ from functools import partial , reduce
4
+
2
5
import torch
3
6
import torch .nn as nn
4
7
5
8
from pina import LabelTensor
9
+ from pina .model import FeedForward
10
+ from pina .utils import is_function
11
+
12
+
13
+ def check_combos (combos , variables ):
14
+ """
15
+ Check that the given combinations are subsets (overlapping
16
+ is allowed) of the given set of variables.
17
+
18
+ :param iterable(iterable(str)) combos: Combinations of variables.
19
+ :param iterable(str) variables: Variables.
20
+ """
21
+ for combo in combos :
22
+ for variable in combo :
23
+ if variable not in variables :
24
+ raise ValueError (
25
+ f"Combinations should be (overlapping) subsets of input variables, { variable } is not an input variable"
26
+ )
27
+
28
+
29
+ def spawn_combo_networks (
30
+ combos , layers , output_dimension , func , extra_feature , bias = True
31
+ ):
32
+ """
33
+ Spawn internal networks for DeepONet based on the given combos.
34
+
35
+ :param iterable(iterable(str)) combos: Combinations of variables.
36
+ :param iterable(int) layers: Size of hidden layers.
37
+ :param int output_dimension: Size of the output layer of the networks.
38
+ :param func: Nonlinearity.
39
+ :param extra_feature: Extra feature to be considered by the networks.
40
+ :param bool bias: Whether to consider bias or not.
41
+ """
42
+ if is_function (extra_feature ):
43
+ extra_feature_func = lambda _ : extra_feature
44
+ else :
45
+ extra_feature_func = extra_feature
46
+
47
+ return [
48
+ FeedForward (
49
+ layers = layers ,
50
+ input_variables = tuple (combo ),
51
+ output_variables = output_dimension ,
52
+ func = func ,
53
+ extra_features = extra_feature_func (combo ),
54
+ bias = bias ,
55
+ )
56
+ for combo in combos
57
+ ]
6
58
7
59
8
60
class DeepONet (torch .nn .Module ):
@@ -18,23 +70,27 @@ class DeepONet(torch.nn.Module):
18
70
<https://doi.org/10.1038/s42256-021-00302-5>`_
19
71
20
72
"""
21
- def __init__ (self , branch_net , trunk_net , output_variables ):
73
+
74
+ def __init__ (self , nets , output_variables , aggregator = "*" , reduction = "+" ):
22
75
"""
23
- :param torch.nn.Module branch_net: the neural network to use as branch
24
- model. It has to take as input a :class:`LabelTensor`. The number
25
- of dimension of the output has to be the same of the `trunk_net`.
26
- :param torch.nn.Module trunk_net: the neural network to use as trunk
27
- model. It has to take as input a :class:`LabelTensor`. The number
28
- of dimension of the output has to be the same of the `branch_net`.
76
+ :param iterable(torch.nn.Module) nets: Internal DeepONet networks
77
+ (branch and trunk in the original DeepONet).
29
78
:param list(str) output_variables: the list containing the labels
30
79
corresponding to the components of the output computed by the
31
80
model.
81
+ :param string | callable aggregator: Aggregator to be used to aggregate
82
+ partial results from the modules in `nets`. Partial results are
83
+ aggregated component-wise. See :func:`_symbol_functions` for the
84
+ available default aggregators.
85
+ :param string | callable reduction: Reduction to be used to reduce
86
+ the aggregated result of the modules in `nets` to the desired output
87
+ dimension. See :func:`_symbol_functions` for the available default
88
+ reductions.
32
89
33
90
:Example:
34
91
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
35
92
>>> trunk = FFN(input_variables=['b'], output_variables=20)
36
- >>> onet = DeepONet(trunk_net=trunk, branch_net=branch
37
- >>> output_variables=output_vars)
93
+ >>> onet = DeepONet(nets=[trunk, branch], output_variables=output_vars)
38
94
DeepONet(
39
95
(trunk_net): FeedForward(
40
96
(extra_features): Sequential()
@@ -63,22 +119,76 @@ def __init__(self, branch_net, trunk_net, output_variables):
63
119
self .output_variables = output_variables
64
120
self .output_dimension = len (output_variables )
65
121
66
- trunk_out_dim = trunk_net .layers [- 1 ].out_features
67
- branch_out_dim = branch_net .layers [- 1 ].out_features
68
-
69
- if trunk_out_dim != branch_out_dim :
70
- raise ValueError ('Branch and trunk networks have not the same '
71
- 'output dimension.' )
72
-
73
- self .trunk_net = trunk_net
74
- self .branch_net = branch_net
122
+ self ._init_aggregator (aggregator , n_nets = len (nets ))
123
+ hidden_size = nets [0 ].model [- 1 ].out_features
124
+ self ._init_reduction (reduction , hidden_size = hidden_size )
125
+
126
+ if not DeepONet ._all_nets_same_output_layer_size (nets ):
127
+ raise ValueError ("All networks should have the same output size" )
128
+ self ._nets = torch .nn .ModuleList (nets )
129
+ logging .info ("Combo DeepONet children: %s" , list (self .children ()))
130
+
131
+ @staticmethod
132
+ def _symbol_functions (** kwargs ):
133
+ return {
134
+ "+" : partial (torch .sum , ** kwargs ),
135
+ "*" : partial (torch .prod , ** kwargs ),
136
+ "mean" : partial (torch .mean , ** kwargs ),
137
+ "min" : lambda x : torch .min (x , ** kwargs ).values ,
138
+ "max" : lambda x : torch .max (x , ** kwargs ).values ,
139
+ }
140
+
141
+ def _init_aggregator (self , aggregator , n_nets ):
142
+ aggregator_funcs = DeepONet ._symbol_functions (dim = 2 )
143
+ if aggregator in aggregator_funcs :
144
+ aggregator_func = aggregator_funcs [aggregator ]
145
+ elif isinstance (aggregator , nn .Module ) or is_function (aggregator ):
146
+ aggregator_func = aggregator
147
+ elif aggregator == "linear" :
148
+ aggregator_func = nn .Linear (n_nets , len (self .output_variables ))
149
+ else :
150
+ raise ValueError (f"Unsupported aggregation: { str (aggregator )} " )
151
+
152
+ self ._aggregator = aggregator_func
153
+ logging .info ("Selected aggregator: %s" , str (aggregator_func ))
154
+
155
+ # test the aggregator
156
+ test = self ._aggregator (torch .ones ((20 , 3 , n_nets )))
157
+ if test .ndim < 2 or tuple (test .shape )[:2 ] != (20 , 3 ):
158
+ raise ValueError (
159
+ f"Invalid aggregator output shape: { (20 , 3 , n_nets )} -> { test .shape } "
160
+ )
75
161
76
- self .reduction = nn .Linear (trunk_out_dim , self .output_dimension )
162
+ def _init_reduction (self , reduction , hidden_size ):
163
+ reduction_funcs = DeepONet ._symbol_functions (dim = 2 )
164
+ if reduction in reduction_funcs :
165
+ reduction_func = reduction_funcs [reduction ]
166
+ elif isinstance (reduction , nn .Module ) or is_function (reduction ):
167
+ reduction_func = reduction
168
+ elif reduction == "linear" :
169
+ reduction_func = nn .Linear (hidden_size , len (self .output_variables ))
170
+ else :
171
+ raise ValueError (f"Unsupported reduction: { reduction } " )
172
+
173
+ self ._reduction = reduction_func
174
+ logging .info ("Selected reduction: %s" , str (reduction ))
175
+
176
+ # test the reduction
177
+ test = self ._reduction (torch .ones ((20 , 3 , hidden_size )))
178
+ if test .ndim < 2 or tuple (test .shape )[:2 ] != (20 , 3 ):
179
+ msg = f"Invalid reduction output shape: { (20 , 3 , hidden_size )} -> { test .shape } "
180
+ raise ValueError (msg )
181
+
182
+ @staticmethod
183
+ def _all_nets_same_output_layer_size (nets ):
184
+ size = nets [0 ].layers [- 1 ].out_features
185
+ return all ((net .layers [- 1 ].out_features == size for net in nets [1 :]))
77
186
78
187
@property
79
188
def input_variables (self ):
80
189
"""The input variables of the model"""
81
- return self .trunk_net .input_variables + self .branch_net .input_variables
190
+ nets_input_variables = map (lambda net : net .input_variables , self ._nets )
191
+ return reduce (sum , nets_input_variables )
82
192
83
193
def forward (self , x ):
84
194
"""
@@ -89,15 +199,20 @@ def forward(self, x):
89
199
:rtype: LabelTensor
90
200
"""
91
201
92
- branch_output = self .branch_net (
93
- x .extract (self .branch_net .input_variables ))
202
+ nets_outputs = tuple (
203
+ net (x .extract (net .input_variables )) for net in self ._nets
204
+ )
205
+ # torch.dstack(nets_outputs): (batch_size, net_output_size, n_nets)
206
+ aggregated = self ._aggregator (torch .dstack (nets_outputs ))
207
+ # net_output_size = output_variables * hidden_size
208
+ aggregated_reshaped = aggregated .view (
209
+ (len (x ), len (self .output_variables ), - 1 )
210
+ )
211
+ output_ = self ._reduction (aggregated_reshaped )
212
+ output_ = torch .squeeze (torch .atleast_3d (output_ ), dim = 2 )
94
213
95
- trunk_output = self .trunk_net (
96
- x .extract (self .trunk_net .input_variables ))
97
-
98
- output_ = self .reduction (trunk_output * branch_output )
214
+ assert output_ .shape == (len (x ), len (self .output_variables ))
99
215
100
216
output_ = output_ .as_subclass (LabelTensor )
101
217
output_ .labels = self .output_variables
102
-
103
218
return output_
0 commit comments