Skip to content

Commit b6adcff

Browse files
authored
Issue #95 fixed and improved plot_loss (#106)
* plotter.py plot_loss fixed Issue #95 * Improved the plot_loss to show title and labels.
1 parent 67fb7fe commit b6adcff

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

pina/plotter.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,24 +186,37 @@ def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
186186
else:
187187
plt.show()
188188

189-
def plot_loss(self, pinn, label=None, log_scale=True):
189+
def plot_loss(self, pinn, label=None, log_scale=True, filename=None):
190190
"""
191191
Plot the loss function values during traininig.
192192
193193
:param PINN pinn: the PINN object.
194194
:param str label: the label to use in the legend, defaults to None.
195195
:param bool log_scale: If True, the y axis is in log scale. Default is
196196
True.
197+
:param str filename: the file name to save the plot. If None, the plot
198+
is not saved. Default is None.
197199
"""
198200

199201
if not label:
200202
label = str(pinn)
201203

202204
epochs = list(pinn.history_loss.keys())
203205
loss = np.array(list(pinn.history_loss.values()))
206+
207+
# if multiple outputs, sum the loss
204208
if loss.ndim != 1:
205-
loss = loss[:, 0]
209+
loss = np.sum(loss, axis=1)
206210

211+
# plot loss
207212
plt.plot(epochs, loss, label=label)
213+
plt.legend()
208214
if log_scale:
209215
plt.yscale('log')
216+
plt.title('Loss function')
217+
plt.xlabel('Epochs')
218+
plt.ylabel('Loss')
219+
220+
# save plot
221+
if filename:
222+
plt.savefig(filename)

0 commit comments

Comments
 (0)