From 9729318cd548c2b387dc467227074f366a3633e3 Mon Sep 17 00:00:00 2001 From: MinZhu123 <84722601+MinZhu123@users.noreply.github.com> Date: Fri, 22 Jul 2022 22:47:26 -0600 Subject: [PATCH] Backend TensorFlow 1: DeepONet customized branches (#807) --- deepxde/nn/tensorflow_compat_v1/mionet.py | 40 ++++++++++++++++------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/deepxde/nn/tensorflow_compat_v1/mionet.py b/deepxde/nn/tensorflow_compat_v1/mionet.py index 26fd73d6a..ccca404f9 100644 --- a/deepxde/nn/tensorflow_compat_v1/mionet.py +++ b/deepxde/nn/tensorflow_compat_v1/mionet.py @@ -58,13 +58,21 @@ def build(self): self._inputs = [self.X_func1, self.X_func2, self.X_loc] # Branch net 1 - y_func1 = self._net( - self.X_func1, self.layer_branch1[1:], self.activation_branch1 - ) + if callable(self.layer_branch1[1]): + # User-defined network + y_func1 = self.layer_branch1[1](self.X_func1) + else: + y_func1 = self._net( + self.X_func1, self.layer_branch1[1:], self.activation_branch1 + ) # Branch net 2 - y_func2 = self._net( - self.X_func2, self.layer_branch2[1:], self.activation_branch2 - ) + if callable(self.layer_branch2[1]): + # User-defined network + y_func2 = self.layer_branch2[1](self.X_func2) + else: + y_func2 = self._net( + self.X_func2, self.layer_branch2[1:], self.activation_branch2 + ) # Trunk net y_loc = self._net(self.X_loc, self.layer_trunk[1:], self.activation_trunk) @@ -103,13 +111,21 @@ def build(self): self._inputs = [self.X_func1, self.X_func2, self.X_loc] # Branch net 1 - y_func1 = self._net( - self.X_func1, self.layer_branch1[1:], self.activation_branch1 - ) + if callable(self.layer_branch1[1]): + # User-defined network + y_func1 = self.layer_branch1[1](self.X_func1) + else: + y_func1 = self._net( + self.X_func1, self.layer_branch1[1:], self.activation_branch1 + ) # Branch net 2 - y_func2 = self._net( - self.X_func2, self.layer_branch2[1:], self.activation_branch2 - ) + if callable(self.layer_branch2[1]): + # User-defined network + y_func2 = self.layer_branch2[1](self.X_func2) + else: + y_func2 = self._net( + self.X_func2, self.layer_branch2[1:], self.activation_branch2 + ) # Trunk net y_loc = self._net(self.X_loc, self.layer_trunk[1:], self.activation_trunk)