jinnax.class_csf

   1from functools import partial
   2
   3import jax
   4import jax.numpy as jnp
   5from jax import lax, jit, grad, vmap, jacrev, hessian
   6from jax.tree_util import tree_map
   7
   8import optax
   9
  10from jaxpi import archs
  11from jaxpi.models import ForwardIVP
  12from jaxpi.evaluator import BaseEvaluator
  13from jaxpi.utils import ntk_fn
  14
  15class DN_csf(ForwardIVP):
  16    def __init__(self, config):
  17        super().__init__(config)
  18
  19        #Initial condition function
  20        self.uinitial = config.uinitial
  21
  22        #Boundary points
  23        self.xl = config.xl
  24        self.xu = config.xu
  25        self.tu = config.tu
  26
  27        #Radius left dirichlet condition
  28        self.radius = config.radius
  29
  30        #Right dirichlet point
  31        self.rd = config.rd
  32
  33        # Predictions over array of x fot t fixed
  34        self.u1_0_pred_fn = vmap(
  35            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
  36        )
  37        self.u2_0_pred_fn = vmap(
  38            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
  39        )
  40
  41        #Prediction over array of t for x fixed
  42        self.u2_bound_pred_fn = vmap(
  43            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
  44        )
  45
  46        self.u1_bound_pred_fn = vmap(
  47            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
  48        )
  49
  50        #Vmap neural net
  51        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
  52
  53        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
  54
  55        #Vmap residual operator
  56        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
  57
  58        #Derivatives on x for x fixed and t in a array
  59        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
  60        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
  61
  62    #Neural net forward function
  63    def neural_net(self, params, t, x):
  64        t = t / self.tu
  65        z = jnp.stack([t, x])
  66        _, outputs = self.state.apply_fn(params, z)
  67        u1 = outputs[0]
  68        u2 = outputs[1]
  69        return u1, u2
  70
  71    #1st coordinate neural net forward function
  72    def u1_net(self, params, t, x):
  73        u1, _ = self.neural_net(params, t, x)
  74        return u1
  75
  76    #2st coordinate neural net forward function
  77    def u2_net(self, params, t, x):
  78        _, u2 = self.neural_net(params, t, x)
  79        return u2
  80
  81    #Residual operator
  82    def r_net(self, params, t, x):
  83        #Derivatives in x and t
  84        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
  85        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
  86        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
  87        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
  88
  89        #Two derivatives in x
  90        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
  91        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
  92
  93        #Each coordinate of residual operator
  94        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
  95
  96    #1st coordinate residual operator
  97    def r_net1(self, params, t, x):
  98        r1,_ = self.r_net(params, t, x)
  99        return r1
 100
 101    #2nd coordinate residual operator
 102    def r_net2(self, params, t, x):
 103        _,r2 = self.r_net(params, t, x)
 104        return r2
 105
 106    #Right Dirichlet Condition residual
 107    def res_right_dirichlet(self, params, t, x):
 108        #Apply neural net to point in the boundary
 109        u1 = self.u1_bound_pred_fn(params, t, x)
 110        u2 = self.u2_bound_pred_fn(params, t, x)
 111        #Return residual
 112        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
 113
 114    #Right Dirichlet Condition residual non-vectorised
 115    def res_right_dirichlet_nv(self, params, t, x):
 116        #Apply neural net to point in the boundary
 117        u1 = self.u1_net(params, t, x)
 118        u2 = self.u2_net(params, t, x)
 119        #Return residual
 120        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
 121
 122    #Left Dirichlet Condition residual
 123    def res_left_dirichlet(self, params, t, x):
 124        #Apply neural net to point in the boundary
 125        u1 = self.u1_bound_pred_fn(params, t, x)
 126        u2 = self.u2_bound_pred_fn(params, t, x)
 127        #Return residual
 128        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
 129
 130    #Left Dirichlet Condition residual non-vectorised
 131    def res_left_dirichlet_nv(self, params, t, x):
 132        #Apply neural net to point in the boundary
 133        u1 = self.u1_net(params, t, x)
 134        u2 = self.u2_net(params, t, x)
 135        #Return residual
 136        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
 137
 138    #Neumman condition residual non-vectorised
 139    def res_neumann_nv(self, params, t, x):
 140        #Apply neural net
 141        u1 = self.u1_net(params, t, x)
 142        u2 = self.u2_net(params, t, x)
 143
 144        #Derivatives in x
 145        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
 146        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
 147
 148        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
 149        nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5)
 150
 151        #Normal at u(x,y)
 152        nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
 153
 154        #Return inner product
 155        return jnp.sum(nS * nu) ** 2
 156
 157    #Neumman condition residual
 158    def res_neumann(self, params, t, x):
 159        #Apply neural net to points in the boundary
 160        u1 = self.u1_bound_pred_fn(params, t, x)
 161        u2 = self.u2_bound_pred_fn(params, t, x)
 162
 163        #Derivatives in x
 164        u1_x = self.u1_bound_x(params, t, x)
 165        u2_x = self.u2_bound_x(params, t, x)
 166
 167        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
 168        nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5)
 169
 170        #Normal at u(x,y)
 171        nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
 172
 173        #Return inner product
 174        return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2
 175
 176    #Compute residuals with causal weights
 177    @partial(jit, static_argnums=(0,))
 178    def res_causal(self, params, batch):
 179        # Sort temporal coordinates
 180        t_sorted = batch[:, 0].sort()
 181
 182        #Compute residuals
 183        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
 184
 185        #Reshape
 186        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
 187        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
 188
 189        #Compute mean residuals
 190        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
 191        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
 192
 193        #Compute weights
 194        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
 195        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
 196
 197        # Take minimum of the causal weights
 198        gamma = jnp.vstack([res_gamma1,res_gamma2])
 199        gamma = gamma.min(0)
 200
 201        return res_l1, res_l2, gamma
 202
 203    #Compute losses
 204    @partial(jit, static_argnums=(0,))
 205    def losses(self, params, batch):
 206        # Initial conditions loss
 207        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 208        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 209        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
 210
 211        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
 212        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
 213
 214        # Residual loss
 215        if self.config.weighting.use_causal == True:
 216            res_l1, res_l2, gamma = self.res_causal(params, batch)
 217            res_loss1 = jnp.mean(res_l1 * gamma)
 218            res_loss2 = jnp.mean(res_l2 * gamma)
 219        else:
 220            res_pred1,res_pred2 = self.r_pred_fn(
 221                params, batch[:, 0], batch[:, 1]
 222            )
 223            # Compute loss
 224            res_loss1 = jnp.mean(res_pred1 ** 2)
 225            res_loss2 = jnp.mean(res_pred2 ** 2)
 226
 227        loss_dict = {
 228            "ic": u1_ic_loss + u2_ic_loss,
 229            "res1": res_loss1,
 230            "res2": res_loss2,
 231            'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
 232            'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
 233            'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl))
 234        }
 235        return loss_dict
 236
 237    #Compute NTK
 238    @partial(jit, static_argnums = (0,))
 239    def compute_diag_ntk(self, params, batch):
 240        #Initial Condition
 241        u1_ic_ntk = vmap(
 242            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
 243        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 244
 245        u2_ic_ntk = vmap(
 246            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
 247        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 248
 249        #Right Dirichlet
 250        rd_ntk = vmap(
 251            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 252        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
 253
 254        #Left Dirichlet
 255        ld_ntk = vmap(
 256            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 257        )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
 258
 259        #Left neumann
 260        ln_ntk = vmap(
 261            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 262        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
 263
 264        # Consider the effect of causal weights
 265        if self.config.weighting.use_causal:
 266            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
 267            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
 268                self.r_net1, params, batch[:, 0], batch[:, 1]
 269            )
 270
 271            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
 272                self.r_net2, params, batch[:, 0], batch[:, 1]
 273            )
 274
 275            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
 276            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
 277
 278            res_ntk1 = jnp.mean(res_ntk1, axis=1)
 279            res_ntk2 = jnp.mean(res_ntk2, axis=1)
 280
 281            _,_, casual_weights = self.res_causal(params, batch)
 282            res_ntk1 = res_ntk1 * casual_weights
 283            res_ntk2 = res_ntk2 * casual_weights
 284        else:
 285            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
 286                self.r_net1, params, batch[:, 0], batch[:, 1]
 287            )
 288            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
 289                self.r_net2, params, batch[:, 0], batch[:, 1]
 290            )
 291
 292        ntk_dict = {
 293            "ic": u1_ic_ntk + u2_ic_ntk,
 294            "res1": res_ntk1,
 295            "res2": res_ntk2,
 296            'rd': rd_ntk,
 297            'ld': ld_ntk,
 298            'ln': ln_ntk
 299        }
 300        return ntk_dict
 301
 302class DN_csf_Evaluator(BaseEvaluator):
 303    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
 304        super().__init__(config, model)
 305
 306        self.x0_test = x0_test
 307        self.tb_test = tb_test
 308        self.xc_test = xc_test
 309        self.tc_test = tc_test
 310        self.u2_0_test = u2_0_test
 311        self.u1_0_test = u1_0_test
 312
 313    def log_errors(self, params):
 314        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
 315        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
 316
 317        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
 318        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
 319
 320        res_pred1,res_pred2 = self.model.r_pred_fn(
 321            params, self.tc_test[:,0], self.xc_test[:,0]
 322        )
 323        res_loss1 = jnp.mean(res_pred1 ** 2)
 324        res_loss2 = jnp.mean(res_pred2 ** 2)
 325
 326        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
 327        self.log_dict["res1_test"] = res_loss1
 328        self.log_dict["res2_test"] = res_loss2
 329        self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu))
 330        self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl))
 331        self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl))
 332
 333    def __call__(self, state, batch):
 334        self.log_dict = super().__call__(state, batch)
 335
 336        if self.config.logging.log_errors:
 337            self.log_errors(state.params)
 338
 339        if self.config.weighting.use_causal:
 340            _, _, causal_weight = self.model.res_causal(state.params, batch)
 341            self.log_dict["cas_weight"] = causal_weight.min()
 342
 343        return self.log_dict
 344
 345class NN_csf(ForwardIVP):
 346    def __init__(self, config):
 347        super().__init__(config)
 348
 349        #Initial condition function
 350        self.uinitial = config.uinitial
 351
 352        #Boundary points
 353        self.xl = config.xl
 354        self.xu = config.xu
 355        self.tu = config.tu
 356
 357        #Radius left dirichlet condition
 358        self.radius = config.radius
 359
 360        # Predictions over array of x fot t fixed
 361        self.u1_0_pred_fn = vmap(
 362            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
 363        )
 364        self.u2_0_pred_fn = vmap(
 365            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
 366        )
 367
 368        #Prediction over array of t for x fixed
 369        self.u2_bound_pred_fn = vmap(
 370            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
 371        )
 372
 373        self.u1_bound_pred_fn = vmap(
 374            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
 375        )
 376
 377        #Vmap neural net
 378        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
 379
 380        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
 381
 382        #Vmap residual operator
 383        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
 384
 385        #Derivatives on x for x fixed and t in a array
 386        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
 387        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
 388
 389    #Neural net forward function
 390    def neural_net(self, params, t, x):
 391        t = t / self.tu
 392        z = jnp.stack([t, x])
 393        _, outputs = self.state.apply_fn(params, z)
 394        u1 = outputs[0]
 395        u2 = outputs[1]
 396        return u1, u2
 397
 398    #1st coordinate neural net forward function
 399    def u1_net(self, params, t, x):
 400        u1, _ = self.neural_net(params, t, x)
 401        return u1
 402
 403    #2st coordinate neural net forward function
 404    def u2_net(self, params, t, x):
 405        _, u2 = self.neural_net(params, t, x)
 406        return u2
 407
 408    #Residual operator
 409    def r_net(self, params, t, x):
 410        #Derivatives in x and t
 411        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
 412        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
 413        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
 414        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
 415
 416        #Two derivatives in x
 417        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
 418        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
 419
 420        #Each coordinate of residual operator
 421        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
 422
 423    #1st coordinate residual operator
 424    def r_net1(self, params, t, x):
 425        r1,_ = self.r_net(params, t, x)
 426        return r1
 427
 428    #2nd coordinate residual operator
 429    def r_net2(self, params, t, x):
 430        _,r2 = self.r_net(params, t, x)
 431        return r2
 432
 433    #Right Dirichlet Condition residual
 434    def res_dirichlet(self, params, t, x):
 435        #Apply neural net to point in the boundary
 436        u1 = self.u1_bound_pred_fn(params, t, x)
 437        u2 = self.u2_bound_pred_fn(params, t, x)
 438        #Return residual
 439        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
 440
 441    #Right Dirichlet Condition residual non-vectorised
 442    def res_dirichlet_nv(self, params, t, x):
 443        #Apply neural net to point in the boundary
 444        u1 = self.u1_net(params, t, x)
 445        u2 = self.u2_net(params, t, x)
 446        #Return residual
 447        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
 448
 449    #Neumman condition residual non-vectorised
 450    def res_neumann_nv(self, params, t, x):
 451        #Apply neural net
 452        u1 = self.u1_net(params, t, x)
 453        u2 = self.u2_net(params, t, x)
 454
 455        #Derivatives in x
 456        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
 457        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
 458
 459        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
 460        nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5)
 461
 462        #Normal at u(x,y)
 463        nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
 464
 465        #Return inner product
 466        return jnp.sum(nS * nu) ** 2
 467
 468    #Neumman condition residual
 469    def res_neumann(self, params, t, x):
 470        #Apply neural net to points in the boundary
 471        u1 = self.u1_bound_pred_fn(params, t, x)
 472        u2 = self.u2_bound_pred_fn(params, t, x)
 473
 474        #Derivatives in x
 475        u1_x = self.u1_bound_x(params, t, x)
 476        u2_x = self.u2_bound_x(params, t, x)
 477
 478        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
 479        nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5)
 480
 481        #Normal at u(x,y)
 482        nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
 483
 484        #Return inner product
 485        return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2
 486
 487    #Compute residuals with causal weights
 488    @partial(jit, static_argnums=(0,))
 489    def res_causal(self, params, batch):
 490        # Sort temporal coordinates
 491        t_sorted = batch[:, 0].sort()
 492
 493        #Compute residuals
 494        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
 495
 496        #Reshape
 497        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
 498        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
 499
 500        #Compute mean residuals
 501        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
 502        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
 503
 504        #Compute weights
 505        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
 506        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
 507
 508        # Take minimum of the causal weights
 509        gamma = jnp.vstack([res_gamma1,res_gamma2])
 510        gamma = gamma.min(0)
 511
 512        return res_l1, res_l2, gamma
 513
 514    #Compute losses
 515    @partial(jit, static_argnums=(0,))
 516    def losses(self, params, batch):
 517        # Initial conditions loss
 518        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 519        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 520        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
 521
 522        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
 523        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
 524
 525        # Residual loss
 526        if self.config.weighting.use_causal == True:
 527            res_l1, res_l2, gamma = self.res_causal(params, batch)
 528            res_loss1 = jnp.mean(res_l1 * gamma)
 529            res_loss2 = jnp.mean(res_l2 * gamma)
 530        else:
 531            res_pred1,res_pred2 = self.r_pred_fn(
 532                params, batch[:, 0], batch[:, 1]
 533            )
 534            # Compute loss
 535            res_loss1 = jnp.mean(res_pred1 ** 2)
 536            res_loss2 = jnp.mean(res_pred2 ** 2)
 537
 538        loss_dict = {
 539            "ic": u1_ic_loss + u2_ic_loss,
 540            "res1": res_loss1,
 541            "res2": res_loss2,
 542            'rd': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
 543            'ld': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
 544            'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
 545            'rn': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu))
 546        }
 547        return loss_dict
 548
 549    #Compute NTK
 550    @partial(jit, static_argnums = (0,))
 551    def compute_diag_ntk(self, params, batch):
 552        #Initial Condition
 553        u1_ic_ntk = vmap(
 554            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
 555        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 556
 557        u2_ic_ntk = vmap(
 558            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
 559        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 560
 561        #Right Dirichlet
 562        rd_ntk = vmap(
 563            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 564        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
 565
 566        #Dirichlet
 567        ld_ntk = vmap(
 568            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 569        )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
 570        rd_ntk = vmap(
 571            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 572        )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
 573
 574        #Left neumann
 575        ln_ntk = vmap(
 576            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 577        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
 578        rn_ntk = vmap(
 579            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 580        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
 581
 582        # Consider the effect of causal weights
 583        if self.config.weighting.use_causal:
 584            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
 585            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
 586                self.r_net1, params, batch[:, 0], batch[:, 1]
 587            )
 588
 589            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
 590                self.r_net2, params, batch[:, 0], batch[:, 1]
 591            )
 592
 593            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
 594            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
 595
 596            res_ntk1 = jnp.mean(res_ntk1, axis=1)
 597            res_ntk2 = jnp.mean(res_ntk2, axis=1)
 598
 599            _,_, casual_weights = self.res_causal(params, batch)
 600            res_ntk1 = res_ntk1 * casual_weights
 601            res_ntk2 = res_ntk2 * casual_weights
 602        else:
 603            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
 604                self.r_net1, params, batch[:, 0], batch[:, 1]
 605            )
 606            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
 607                self.r_net2, params, batch[:, 0], batch[:, 1]
 608            )
 609
 610        ntk_dict = {
 611            "ic": u1_ic_ntk + u2_ic_ntk,
 612            "res1": res_ntk1,
 613            "res2": res_ntk2,
 614            'rd': rd_ntk,
 615            'ld': ld_ntk,
 616            'ln': ln_ntk,
 617            'rn': rn_ntk
 618        }
 619        return ntk_dict
 620
 621class NN_csf_Evaluator(BaseEvaluator):
 622    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
 623        super().__init__(config, model)
 624
 625        self.x0_test = x0_test
 626        self.tb_test = tb_test
 627        self.xc_test = xc_test
 628        self.tc_test = tc_test
 629        self.u2_0_test = u2_0_test
 630        self.u1_0_test = u1_0_test
 631
 632    def log_errors(self, params):
 633        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
 634        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
 635
 636        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
 637        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
 638
 639        res_pred1,res_pred2 = self.model.r_pred_fn(
 640            params, self.tc_test[:,0], self.xc_test[:,0]
 641        )
 642        res_loss1 = jnp.mean(res_pred1 ** 2)
 643        res_loss2 = jnp.mean(res_pred2 ** 2)
 644
 645        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
 646        self.log_dict["res1_test"] = res_loss1
 647        self.log_dict["res2_test"] = res_loss2
 648        self.log_dict["rd_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xu))
 649        self.log_dict["ld_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xl))
 650        self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl))
 651        self.log_dict["rn_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xu))
 652
 653    def __call__(self, state, batch):
 654        self.log_dict = super().__call__(state, batch)
 655
 656        if self.config.logging.log_errors:
 657            self.log_errors(state.params)
 658
 659        if self.config.weighting.use_causal:
 660            _, _, causal_weight = self.model.res_causal(state.params, batch)
 661            self.log_dict["cas_weight"] = causal_weight.min()
 662
 663        return self.log_dict
 664
 665class DD_csf(ForwardIVP):
 666    def __init__(self, config):
 667        super().__init__(config)
 668
 669        #Initial condition function
 670        self.uinitial = config.uinitial
 671
 672        #Boundary points
 673        self.xl = config.xl
 674        self.xu = config.xu
 675        self.tu = config.tu
 676
 677        #Right dirichlet point
 678        self.ld = config.ld
 679        self.rd = config.rd
 680
 681        # Predictions over array of x fot t fixed
 682        self.u1_0_pred_fn = vmap(
 683            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
 684        )
 685        self.u2_0_pred_fn = vmap(
 686            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
 687        )
 688
 689        #Prediction over array of t for x fixed
 690        self.u2_bound_pred_fn = vmap(
 691            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
 692        )
 693
 694        self.u1_bound_pred_fn = vmap(
 695            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
 696        )
 697
 698        #Vmap neural net
 699        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
 700
 701        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
 702
 703        #Vmap residual operator
 704        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
 705
 706        #Derivatives on x for x fixed and t in a array
 707        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
 708        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
 709
 710    #Neural net forward function
 711    def neural_net(self, params, t, x):
 712        t = t / self.tu
 713        z = jnp.stack([t, x])
 714        _, outputs = self.state.apply_fn(params, z)
 715        u1 = outputs[0]
 716        u2 = outputs[1]
 717        return u1, u2
 718
 719    #1st coordinate neural net forward function
 720    def u1_net(self, params, t, x):
 721        u1, _ = self.neural_net(params, t, x)
 722        return u1
 723
 724    #2st coordinate neural net forward function
 725    def u2_net(self, params, t, x):
 726        _, u2 = self.neural_net(params, t, x)
 727        return u2
 728
 729    #Residual operator
 730    def r_net(self, params, t, x):
 731        #Derivatives in x and t
 732        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
 733        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
 734        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
 735        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
 736
 737        #Two derivatives in x
 738        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
 739        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
 740
 741        #Each coordinate of residual operator
 742        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
 743
 744    #1st coordinate residual operator
 745    def r_net1(self, params, t, x):
 746        r1,_ = self.r_net(params, t, x)
 747        return r1
 748
 749    #2nd coordinate residual operator
 750    def r_net2(self, params, t, x):
 751        _,r2 = self.r_net(params, t, x)
 752        return r2
 753
 754    #Right Dirichlet Condition residual
 755    def res_right_dirichlet(self, params, t, x):
 756        #Apply neural net to point in the boundary
 757        u1 = self.u1_bound_pred_fn(params, t, x)
 758        u2 = self.u2_bound_pred_fn(params, t, x)
 759        #Return residual
 760        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
 761
 762    #Right Dirichlet Condition residual non-vectorised
 763    def res_right_dirichlet_nv(self, params, t, x):
 764        #Apply neural net to point in the boundary
 765        u1 = self.u1_net(params, t, x)
 766        u2 = self.u2_net(params, t, x)
 767        #Return residual
 768        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
 769
 770    #Left Dirichlet Condition residual
 771    def res_left_dirichlet(self, params, t, x):
 772        #Apply neural net to point in the boundary
 773        u1 = self.u1_bound_pred_fn(params, t, x)
 774        u2 = self.u2_bound_pred_fn(params, t, x)
 775        #Return residual
 776        return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2
 777
 778    #Left Dirichlet Condition residual non-vectorised
 779    def res_left_dirichlet_nv(self, params, t, x):
 780        #Apply neural net to point in the boundary
 781        u1 = self.u1_net(params, t, x)
 782        u2 = self.u2_net(params, t, x)
 783        #Return residual
 784        return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2
 785
 786    #Compute residuals with causal weights
 787    @partial(jit, static_argnums=(0,))
 788    def res_causal(self, params, batch):
 789        # Sort temporal coordinates
 790        t_sorted = batch[:, 0].sort()
 791
 792        #Compute residuals
 793        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
 794
 795        #Reshape
 796        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
 797        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
 798
 799        #Compute mean residuals
 800        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
 801        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
 802
 803        #Compute weights
 804        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
 805        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
 806
 807        # Take minimum of the causal weights
 808        gamma = jnp.vstack([res_gamma1,res_gamma2])
 809        gamma = gamma.min(0)
 810
 811        return res_l1, res_l2, gamma
 812
 813    #Compute losses
 814    @partial(jit, static_argnums=(0,))
 815    def losses(self, params, batch):
 816        # Initial conditions loss
 817        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 818        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 819        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
 820
 821        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
 822        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
 823
 824        # Residual loss
 825        if self.config.weighting.use_causal == True:
 826            res_l1, res_l2, gamma = self.res_causal(params, batch)
 827            res_loss1 = jnp.mean(res_l1 * gamma)
 828            res_loss2 = jnp.mean(res_l2 * gamma)
 829        else:
 830            res_pred1,res_pred2 = self.r_pred_fn(
 831                params, batch[:, 0], batch[:, 1]
 832            )
 833            # Compute loss
 834            res_loss1 = jnp.mean(res_pred1 ** 2)
 835            res_loss2 = jnp.mean(res_pred2 ** 2)
 836
 837        loss_dict = {
 838            "ic": u1_ic_loss + u2_ic_loss,
 839            "res1": res_loss1,
 840            "res2": res_loss2,
 841            'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
 842            'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl))
 843        }
 844        return loss_dict
 845
 846    #Compute NTK
 847    @partial(jit, static_argnums = (0,))
 848    def compute_diag_ntk(self, params, batch):
 849        #Initial Condition
 850        u1_ic_ntk = vmap(
 851            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
 852        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 853
 854        u2_ic_ntk = vmap(
 855            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
 856        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
 857
 858        #Right Dirichlet
 859        rd_ntk = vmap(
 860            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 861        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
 862
 863        #Left Dirichlet
 864        ld_ntk = vmap(
 865            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
 866        )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
 867
 868        # Consider the effect of causal weights
 869        if self.config.weighting.use_causal:
 870            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
 871            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
 872                self.r_net1, params, batch[:, 0], batch[:, 1]
 873            )
 874
 875            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
 876                self.r_net2, params, batch[:, 0], batch[:, 1]
 877            )
 878
 879            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
 880            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
 881
 882            res_ntk1 = jnp.mean(res_ntk1, axis=1)
 883            res_ntk2 = jnp.mean(res_ntk2, axis=1)
 884
 885            _,_, casual_weights = self.res_causal(params, batch)
 886            res_ntk1 = res_ntk1 * casual_weights
 887            res_ntk2 = res_ntk2 * casual_weights
 888        else:
 889            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
 890                self.r_net1, params, batch[:, 0], batch[:, 1]
 891            )
 892            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
 893                self.r_net2, params, batch[:, 0], batch[:, 1]
 894            )
 895
 896        ntk_dict = {
 897            "ic": u1_ic_ntk + u2_ic_ntk,
 898            "res1": res_ntk1,
 899            "res2": res_ntk2,
 900            'rd': rd_ntk,
 901            'ld': ld_ntk
 902        }
 903        return ntk_dict
 904
 905class DD_csf_Evaluator(BaseEvaluator):
 906    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
 907        super().__init__(config, model)
 908
 909        self.x0_test = x0_test
 910        self.tb_test = tb_test
 911        self.xc_test = xc_test
 912        self.tc_test = tc_test
 913        self.u2_0_test = u2_0_test
 914        self.u1_0_test = u1_0_test
 915
 916    def log_errors(self, params):
 917        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
 918        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
 919
 920        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
 921        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
 922
 923        res_pred1,res_pred2 = self.model.r_pred_fn(
 924            params, self.tc_test[:,0], self.xc_test[:,0]
 925        )
 926        res_loss1 = jnp.mean(res_pred1 ** 2)
 927        res_loss2 = jnp.mean(res_pred2 ** 2)
 928
 929        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
 930        self.log_dict["res1_test"] = res_loss1
 931        self.log_dict["res2_test"] = res_loss2
 932        self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu))
 933        self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl))
 934
 935    def __call__(self, state, batch):
 936        self.log_dict = super().__call__(state, batch)
 937
 938        if self.config.logging.log_errors:
 939            self.log_errors(state.params)
 940
 941        if self.config.weighting.use_causal:
 942            _, _, causal_weight = self.model.res_causal(state.params, batch)
 943            self.log_dict["cas_weight"] = causal_weight.min()
 944
 945        return self.log_dict
 946
 947class closed_csf(ForwardIVP):
 948    def __init__(self, config):
 949        super().__init__(config)
 950
 951        #Initial condition function
 952        self.uinitial = config.uinitial
 953
 954        #Boundary points
 955        self.xl = config.xl
 956        self.xu = config.xu
 957        self.tu = config.tu
 958
 959        # Predictions over array of x fot t fixed
 960        self.u1_0_pred_fn = vmap(
 961            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
 962        )
 963        self.u2_0_pred_fn = vmap(
 964            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
 965        )
 966
 967        #Prediction over array of t for x fixed
 968        self.u2_bound_pred_fn = vmap(
 969            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
 970        )
 971
 972        self.u1_bound_pred_fn = vmap(
 973            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
 974        )
 975
 976        #Vmap neural net
 977        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
 978
 979        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
 980
 981        #Vmap residual operator
 982        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
 983
 984        #Derivatives on x for x fixed and t in a array
 985        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
 986        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
 987
 988    #Neural net forward function
 989    def neural_net(self, params, t, x):
 990        t = t / self.tu
 991        z = jnp.stack([t, x])
 992        _, outputs = self.state.apply_fn(params, z)
 993        u1 = outputs[0]
 994        u2 = outputs[1]
 995        return u1, u2
 996
 997    #1st coordinate neural net forward function
 998    def u1_net(self, params, t, x):
 999        u1, _ = self.neural_net(params, t, x)
1000        return u1
1001
1002    #2st coordinate neural net forward function
1003    def u2_net(self, params, t, x):
1004        _, u2 = self.neural_net(params, t, x)
1005        return u2
1006
1007    #Residual operator
1008    def r_net(self, params, t, x):
1009        #Derivatives in x and t
1010        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
1011        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
1012        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
1013        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
1014
1015        #Two derivatives in x
1016        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
1017        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
1018
1019        #Each coordinate of residual operator
1020        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
1021
1022    #1st coordinate residual operator
1023    def r_net1(self, params, t, x):
1024        r1,_ = self.r_net(params, t, x)
1025        return r1
1026
1027    #2nd coordinate residual operator
1028    def r_net2(self, params, t, x):
1029        _,r2 = self.r_net(params, t, x)
1030        return r2
1031
1032    #Compute residuals with causal weights
1033    @partial(jit, static_argnums=(0,))
1034    def res_causal(self, params, batch):
1035        # Sort temporal coordinates
1036        t_sorted = batch[:, 0].sort()
1037
1038        #Compute residuals
1039        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
1040
1041        #Reshape
1042        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
1043        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
1044
1045        #Compute mean residuals
1046        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
1047        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
1048
1049        #Compute weights
1050        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
1051        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
1052
1053        # Take minimum of the causal weights
1054        gamma = jnp.vstack([res_gamma1,res_gamma2])
1055        gamma = gamma.min(0)
1056
1057        return res_l1, res_l2, gamma
1058
1059    #Compute losses
1060    @partial(jit, static_argnums=(0,))
1061    def losses(self, params, batch):
1062        # Initial conditions loss
1063        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1064        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1065        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
1066
1067        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
1068        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
1069
1070        periodic1 = jnp.mean((self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2)
1071        periodic2 = jnp.mean((self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2)
1072
1073        # Residual loss
1074        if self.config.weighting.use_causal == True:
1075            res_l1, res_l2, gamma = self.res_causal(params, batch)
1076            res_loss1 = jnp.mean(res_l1 * gamma)
1077            res_loss2 = jnp.mean(res_l2 * gamma)
1078        else:
1079            res_pred1,res_pred2 = self.r_pred_fn(
1080                params, batch[:, 0], batch[:, 1]
1081            )
1082            # Compute loss
1083            res_loss1 = jnp.mean(res_pred1 ** 2)
1084            res_loss2 = jnp.mean(res_pred2 ** 2)
1085
1086        loss_dict = {
1087            "ic": u1_ic_loss + u2_ic_loss,
1088            "res1": res_loss1,
1089            "res2": res_loss2,
1090            'periodic1': periodic1,
1091            'periodic2': periodic2
1092        }
1093        return loss_dict
1094
1095    #Compute NTK
1096    @partial(jit, static_argnums = (0,))
1097    def compute_diag_ntk(self, params, batch):
1098        #Initial Condition
1099        u1_ic_ntk = vmap(
1100            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
1101        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1102
1103        u2_ic_ntk = vmap(
1104            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
1105        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1106
1107        # Consider the effect of causal weights
1108        if self.config.weighting.use_causal:
1109            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
1110            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
1111                self.r_net1, params, batch[:, 0], batch[:, 1]
1112            )
1113
1114            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
1115                self.r_net2, params, batch[:, 0], batch[:, 1]
1116            )
1117
1118            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
1119            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
1120
1121            res_ntk1 = jnp.mean(res_ntk1, axis=1)
1122            res_ntk2 = jnp.mean(res_ntk2, axis=1)
1123
1124            _,_, casual_weights = self.res_causal(params, batch)
1125            res_ntk1 = res_ntk1 * casual_weights
1126            res_ntk2 = res_ntk2 * casual_weights
1127        else:
1128            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
1129                self.r_net1, params, batch[:, 0], batch[:, 1]
1130            )
1131            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
1132                self.r_net2, params, batch[:, 0], batch[:, 1]
1133            )
1134
1135        ntk_dict = {
1136            "ic": u1_ic_ntk + u2_ic_ntk,
1137            "res1": res_ntk1,
1138            "res2": res_ntk2
1139        }
1140        return ntk_dict
1141
1142class closed_csf_Evaluator(BaseEvaluator):
1143    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
1144        super().__init__(config, model)
1145
1146        self.x0_test = x0_test
1147        self.tb_test = tb_test
1148        self.xc_test = xc_test
1149        self.tc_test = tc_test
1150        self.u2_0_test = u2_0_test
1151        self.u1_0_test = u1_0_test
1152
1153    def log_errors(self, params):
1154        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
1155        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
1156
1157        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
1158        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
1159
1160        res_pred1,res_pred2 = self.model.r_pred_fn(
1161            params, self.tc_test[:,0], self.xc_test[:,0]
1162        )
1163        res_loss1 = jnp.mean(res_pred1 ** 2)
1164        res_loss2 = jnp.mean(res_pred2 ** 2)
1165
1166        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
1167        self.log_dict["res1_test"] = res_loss1
1168        self.log_dict["res2_test"] = res_loss2
1169        self.log_dict["periodic1_test"] = jnp.mean((self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2)
1170        self.log_dict["periodic2_test"] = jnp.mean((self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2)
1171
1172    def __call__(self, state, batch):
1173        self.log_dict = super().__call__(state, batch)
1174
1175        if self.config.logging.log_errors:
1176            self.log_errors(state.params)
1177
1178        if self.config.weighting.use_causal:
1179            _, _, causal_weight = self.model.res_causal(state.params, batch)
1180            self.log_dict["cas_weight"] = causal_weight.min()
1181
1182        return self.log_dict
class DN_csf(jaxpi.models.ForwardIVP):
 16class DN_csf(ForwardIVP):
 17    def __init__(self, config):
 18        super().__init__(config)
 19
 20        #Initial condition function
 21        self.uinitial = config.uinitial
 22
 23        #Boundary points
 24        self.xl = config.xl
 25        self.xu = config.xu
 26        self.tu = config.tu
 27
 28        #Radius left dirichlet condition
 29        self.radius = config.radius
 30
 31        #Right dirichlet point
 32        self.rd = config.rd
 33
 34        # Predictions over array of x fot t fixed
 35        self.u1_0_pred_fn = vmap(
 36            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
 37        )
 38        self.u2_0_pred_fn = vmap(
 39            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
 40        )
 41
 42        #Prediction over array of t for x fixed
 43        self.u2_bound_pred_fn = vmap(
 44            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
 45        )
 46
 47        self.u1_bound_pred_fn = vmap(
 48            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
 49        )
 50
 51        #Vmap neural net
 52        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
 53
 54        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
 55
 56        #Vmap residual operator
 57        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
 58
 59        #Derivatives on x for x fixed and t in a array
 60        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
 61        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
 62
 63    #Neural net forward function
 64    def neural_net(self, params, t, x):
 65        t = t / self.tu
 66        z = jnp.stack([t, x])
 67        _, outputs = self.state.apply_fn(params, z)
 68        u1 = outputs[0]
 69        u2 = outputs[1]
 70        return u1, u2
 71
 72    #1st coordinate neural net forward function
 73    def u1_net(self, params, t, x):
 74        u1, _ = self.neural_net(params, t, x)
 75        return u1
 76
 77    #2st coordinate neural net forward function
 78    def u2_net(self, params, t, x):
 79        _, u2 = self.neural_net(params, t, x)
 80        return u2
 81
 82    #Residual operator
 83    def r_net(self, params, t, x):
 84        #Derivatives in x and t
 85        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
 86        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
 87        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
 88        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
 89
 90        #Two derivatives in x
 91        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
 92        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
 93
 94        #Each coordinate of residual operator
 95        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
 96
 97    #1st coordinate residual operator
 98    def r_net1(self, params, t, x):
 99        r1,_ = self.r_net(params, t, x)
100        return r1
101
102    #2nd coordinate residual operator
103    def r_net2(self, params, t, x):
104        _,r2 = self.r_net(params, t, x)
105        return r2
106
107    #Right Dirichlet Condition residual
108    def res_right_dirichlet(self, params, t, x):
109        #Apply neural net to point in the boundary
110        u1 = self.u1_bound_pred_fn(params, t, x)
111        u2 = self.u2_bound_pred_fn(params, t, x)
112        #Return residual
113        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
114
115    #Right Dirichlet Condition residual non-vectorised
116    def res_right_dirichlet_nv(self, params, t, x):
117        #Apply neural net to point in the boundary
118        u1 = self.u1_net(params, t, x)
119        u2 = self.u2_net(params, t, x)
120        #Return residual
121        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
122
123    #Left Dirichlet Condition residual
124    def res_left_dirichlet(self, params, t, x):
125        #Apply neural net to point in the boundary
126        u1 = self.u1_bound_pred_fn(params, t, x)
127        u2 = self.u2_bound_pred_fn(params, t, x)
128        #Return residual
129        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
130
131    #Left Dirichlet Condition residual non-vectorised
132    def res_left_dirichlet_nv(self, params, t, x):
133        #Apply neural net to point in the boundary
134        u1 = self.u1_net(params, t, x)
135        u2 = self.u2_net(params, t, x)
136        #Return residual
137        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
138
139    #Neumman condition residual non-vectorised
140    def res_neumann_nv(self, params, t, x):
141        #Apply neural net
142        u1 = self.u1_net(params, t, x)
143        u2 = self.u2_net(params, t, x)
144
145        #Derivatives in x
146        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
147        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
148
149        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
150        nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5)
151
152        #Normal at u(x,y)
153        nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
154
155        #Return inner product
156        return jnp.sum(nS * nu) ** 2
157
158    #Neumman condition residual
159    def res_neumann(self, params, t, x):
160        #Apply neural net to points in the boundary
161        u1 = self.u1_bound_pred_fn(params, t, x)
162        u2 = self.u2_bound_pred_fn(params, t, x)
163
164        #Derivatives in x
165        u1_x = self.u1_bound_x(params, t, x)
166        u2_x = self.u2_bound_x(params, t, x)
167
168        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
169        nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5)
170
171        #Normal at u(x,y)
172        nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
173
174        #Return inner product
175        return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2
176
177    #Compute residuals with causal weights
178    @partial(jit, static_argnums=(0,))
179    def res_causal(self, params, batch):
180        # Sort temporal coordinates
181        t_sorted = batch[:, 0].sort()
182
183        #Compute residuals
184        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
185
186        #Reshape
187        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
188        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
189
190        #Compute mean residuals
191        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
192        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
193
194        #Compute weights
195        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
196        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
197
198        # Take minimum of the causal weights
199        gamma = jnp.vstack([res_gamma1,res_gamma2])
200        gamma = gamma.min(0)
201
202        return res_l1, res_l2, gamma
203
204    #Compute losses
205    @partial(jit, static_argnums=(0,))
206    def losses(self, params, batch):
207        # Initial conditions loss
208        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
209        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
210        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
211
212        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
213        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
214
215        # Residual loss
216        if self.config.weighting.use_causal == True:
217            res_l1, res_l2, gamma = self.res_causal(params, batch)
218            res_loss1 = jnp.mean(res_l1 * gamma)
219            res_loss2 = jnp.mean(res_l2 * gamma)
220        else:
221            res_pred1,res_pred2 = self.r_pred_fn(
222                params, batch[:, 0], batch[:, 1]
223            )
224            # Compute loss
225            res_loss1 = jnp.mean(res_pred1 ** 2)
226            res_loss2 = jnp.mean(res_pred2 ** 2)
227
228        loss_dict = {
229            "ic": u1_ic_loss + u2_ic_loss,
230            "res1": res_loss1,
231            "res2": res_loss2,
232            'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
233            'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
234            'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl))
235        }
236        return loss_dict
237
238    #Compute NTK
239    @partial(jit, static_argnums = (0,))
240    def compute_diag_ntk(self, params, batch):
241        #Initial Condition
242        u1_ic_ntk = vmap(
243            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
244        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
245
246        u2_ic_ntk = vmap(
247            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
248        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
249
250        #Right Dirichlet
251        rd_ntk = vmap(
252            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
253        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
254
255        #Left Dirichlet
256        ld_ntk = vmap(
257            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
258        )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
259
260        #Left neumann
261        ln_ntk = vmap(
262            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
263        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
264
265        # Consider the effect of causal weights
266        if self.config.weighting.use_causal:
267            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
268            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
269                self.r_net1, params, batch[:, 0], batch[:, 1]
270            )
271
272            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
273                self.r_net2, params, batch[:, 0], batch[:, 1]
274            )
275
276            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
277            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
278
279            res_ntk1 = jnp.mean(res_ntk1, axis=1)
280            res_ntk2 = jnp.mean(res_ntk2, axis=1)
281
282            _,_, casual_weights = self.res_causal(params, batch)
283            res_ntk1 = res_ntk1 * casual_weights
284            res_ntk2 = res_ntk2 * casual_weights
285        else:
286            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
287                self.r_net1, params, batch[:, 0], batch[:, 1]
288            )
289            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
290                self.r_net2, params, batch[:, 0], batch[:, 1]
291            )
292
293        ntk_dict = {
294            "ic": u1_ic_ntk + u2_ic_ntk,
295            "res1": res_ntk1,
296            "res2": res_ntk2,
297            'rd': rd_ntk,
298            'ld': ld_ntk,
299            'ln': ln_ntk
300        }
301        return ntk_dict
DN_csf(config)
17    def __init__(self, config):
18        super().__init__(config)
19
20        #Initial condition function
21        self.uinitial = config.uinitial
22
23        #Boundary points
24        self.xl = config.xl
25        self.xu = config.xu
26        self.tu = config.tu
27
28        #Radius left dirichlet condition
29        self.radius = config.radius
30
31        #Right dirichlet point
32        self.rd = config.rd
33
34        # Predictions over array of x fot t fixed
35        self.u1_0_pred_fn = vmap(
36            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
37        )
38        self.u2_0_pred_fn = vmap(
39            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
40        )
41
42        #Prediction over array of t for x fixed
43        self.u2_bound_pred_fn = vmap(
44            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
45        )
46
47        self.u1_bound_pred_fn = vmap(
48            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
49        )
50
51        #Vmap neural net
52        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
53
54        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
55
56        #Vmap residual operator
57        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
58
59        #Derivatives on x for x fixed and t in a array
60        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
61        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
uinitial
xl
xu
tu
radius
rd
u1_0_pred_fn
u2_0_pred_fn
u2_bound_pred_fn
u1_bound_pred_fn
u1_pred_fn
u2_pred_fn
r_pred_fn
u1_bound_x
u2_bound_x
def neural_net(self, params, t, x):
64    def neural_net(self, params, t, x):
65        t = t / self.tu
66        z = jnp.stack([t, x])
67        _, outputs = self.state.apply_fn(params, z)
68        u1 = outputs[0]
69        u2 = outputs[1]
70        return u1, u2
def u1_net(self, params, t, x):
73    def u1_net(self, params, t, x):
74        u1, _ = self.neural_net(params, t, x)
75        return u1
def u2_net(self, params, t, x):
78    def u2_net(self, params, t, x):
79        _, u2 = self.neural_net(params, t, x)
80        return u2
def r_net(self, params, t, x):
83    def r_net(self, params, t, x):
84        #Derivatives in x and t
85        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
86        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
87        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
88        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
89
90        #Two derivatives in x
91        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
92        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
93
94        #Each coordinate of residual operator
95        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
def r_net1(self, params, t, x):
 98    def r_net1(self, params, t, x):
 99        r1,_ = self.r_net(params, t, x)
100        return r1
def r_net2(self, params, t, x):
103    def r_net2(self, params, t, x):
104        _,r2 = self.r_net(params, t, x)
105        return r2
def res_right_dirichlet(self, params, t, x):
108    def res_right_dirichlet(self, params, t, x):
109        #Apply neural net to point in the boundary
110        u1 = self.u1_bound_pred_fn(params, t, x)
111        u2 = self.u2_bound_pred_fn(params, t, x)
112        #Return residual
113        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
def res_right_dirichlet_nv(self, params, t, x):
116    def res_right_dirichlet_nv(self, params, t, x):
117        #Apply neural net to point in the boundary
118        u1 = self.u1_net(params, t, x)
119        u2 = self.u2_net(params, t, x)
120        #Return residual
121        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
def res_left_dirichlet(self, params, t, x):
124    def res_left_dirichlet(self, params, t, x):
125        #Apply neural net to point in the boundary
126        u1 = self.u1_bound_pred_fn(params, t, x)
127        u2 = self.u2_bound_pred_fn(params, t, x)
128        #Return residual
129        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
def res_left_dirichlet_nv(self, params, t, x):
132    def res_left_dirichlet_nv(self, params, t, x):
133        #Apply neural net to point in the boundary
134        u1 = self.u1_net(params, t, x)
135        u2 = self.u2_net(params, t, x)
136        #Return residual
137        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
def res_neumann_nv(self, params, t, x):
140    def res_neumann_nv(self, params, t, x):
141        #Apply neural net
142        u1 = self.u1_net(params, t, x)
143        u2 = self.u2_net(params, t, x)
144
145        #Derivatives in x
146        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
147        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
148
149        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
150        nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5)
151
152        #Normal at u(x,y)
153        nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
154
155        #Return inner product
156        return jnp.sum(nS * nu) ** 2
def res_neumann(self, params, t, x):
159    def res_neumann(self, params, t, x):
160        #Apply neural net to points in the boundary
161        u1 = self.u1_bound_pred_fn(params, t, x)
162        u2 = self.u2_bound_pred_fn(params, t, x)
163
164        #Derivatives in x
165        u1_x = self.u1_bound_x(params, t, x)
166        u2_x = self.u2_bound_x(params, t, x)
167
168        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
169        nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5)
170
171        #Normal at u(x,y)
172        nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
173
174        #Return inner product
175        return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2
@partial(jit, static_argnums=(0,))
def res_causal(self, params, batch):
178    @partial(jit, static_argnums=(0,))
179    def res_causal(self, params, batch):
180        # Sort temporal coordinates
181        t_sorted = batch[:, 0].sort()
182
183        #Compute residuals
184        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
185
186        #Reshape
187        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
188        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
189
190        #Compute mean residuals
191        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
192        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
193
194        #Compute weights
195        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
196        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
197
198        # Take minimum of the causal weights
199        gamma = jnp.vstack([res_gamma1,res_gamma2])
200        gamma = gamma.min(0)
201
202        return res_l1, res_l2, gamma
@partial(jit, static_argnums=(0,))
def losses(self, params, batch):
205    @partial(jit, static_argnums=(0,))
206    def losses(self, params, batch):
207        # Initial conditions loss
208        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
209        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
210        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
211
212        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
213        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
214
215        # Residual loss
216        if self.config.weighting.use_causal == True:
217            res_l1, res_l2, gamma = self.res_causal(params, batch)
218            res_loss1 = jnp.mean(res_l1 * gamma)
219            res_loss2 = jnp.mean(res_l2 * gamma)
220        else:
221            res_pred1,res_pred2 = self.r_pred_fn(
222                params, batch[:, 0], batch[:, 1]
223            )
224            # Compute loss
225            res_loss1 = jnp.mean(res_pred1 ** 2)
226            res_loss2 = jnp.mean(res_pred2 ** 2)
227
228        loss_dict = {
229            "ic": u1_ic_loss + u2_ic_loss,
230            "res1": res_loss1,
231            "res2": res_loss2,
232            'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
233            'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
234            'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl))
235        }
236        return loss_dict
@partial(jit, static_argnums=(0,))
def compute_diag_ntk(self, params, batch):
239    @partial(jit, static_argnums = (0,))
240    def compute_diag_ntk(self, params, batch):
241        #Initial Condition
242        u1_ic_ntk = vmap(
243            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
244        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
245
246        u2_ic_ntk = vmap(
247            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
248        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
249
250        #Right Dirichlet
251        rd_ntk = vmap(
252            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
253        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
254
255        #Left Dirichlet
256        ld_ntk = vmap(
257            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
258        )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
259
260        #Left neumann
261        ln_ntk = vmap(
262            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
263        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
264
265        # Consider the effect of causal weights
266        if self.config.weighting.use_causal:
267            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
268            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
269                self.r_net1, params, batch[:, 0], batch[:, 1]
270            )
271
272            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
273                self.r_net2, params, batch[:, 0], batch[:, 1]
274            )
275
276            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
277            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
278
279            res_ntk1 = jnp.mean(res_ntk1, axis=1)
280            res_ntk2 = jnp.mean(res_ntk2, axis=1)
281
282            _,_, casual_weights = self.res_causal(params, batch)
283            res_ntk1 = res_ntk1 * casual_weights
284            res_ntk2 = res_ntk2 * casual_weights
285        else:
286            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
287                self.r_net1, params, batch[:, 0], batch[:, 1]
288            )
289            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
290                self.r_net2, params, batch[:, 0], batch[:, 1]
291            )
292
293        ntk_dict = {
294            "ic": u1_ic_ntk + u2_ic_ntk,
295            "res1": res_ntk1,
296            "res2": res_ntk2,
297            'rd': rd_ntk,
298            'ld': ld_ntk,
299            'ln': ln_ntk
300        }
301        return ntk_dict
class DN_csf_Evaluator(jaxpi.evaluator.BaseEvaluator):
303class DN_csf_Evaluator(BaseEvaluator):
304    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
305        super().__init__(config, model)
306
307        self.x0_test = x0_test
308        self.tb_test = tb_test
309        self.xc_test = xc_test
310        self.tc_test = tc_test
311        self.u2_0_test = u2_0_test
312        self.u1_0_test = u1_0_test
313
314    def log_errors(self, params):
315        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
316        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
317
318        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
319        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
320
321        res_pred1,res_pred2 = self.model.r_pred_fn(
322            params, self.tc_test[:,0], self.xc_test[:,0]
323        )
324        res_loss1 = jnp.mean(res_pred1 ** 2)
325        res_loss2 = jnp.mean(res_pred2 ** 2)
326
327        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
328        self.log_dict["res1_test"] = res_loss1
329        self.log_dict["res2_test"] = res_loss2
330        self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu))
331        self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl))
332        self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl))
333
334    def __call__(self, state, batch):
335        self.log_dict = super().__call__(state, batch)
336
337        if self.config.logging.log_errors:
338            self.log_errors(state.params)
339
340        if self.config.weighting.use_causal:
341            _, _, causal_weight = self.model.res_causal(state.params, batch)
342            self.log_dict["cas_weight"] = causal_weight.min()
343
344        return self.log_dict
DN_csf_Evaluator( config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
304    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
305        super().__init__(config, model)
306
307        self.x0_test = x0_test
308        self.tb_test = tb_test
309        self.xc_test = xc_test
310        self.tc_test = tc_test
311        self.u2_0_test = u2_0_test
312        self.u1_0_test = u1_0_test
x0_test
tb_test
xc_test
tc_test
u2_0_test
u1_0_test
def log_errors(self, params):
314    def log_errors(self, params):
315        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
316        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
317
318        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
319        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
320
321        res_pred1,res_pred2 = self.model.r_pred_fn(
322            params, self.tc_test[:,0], self.xc_test[:,0]
323        )
324        res_loss1 = jnp.mean(res_pred1 ** 2)
325        res_loss2 = jnp.mean(res_pred2 ** 2)
326
327        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
328        self.log_dict["res1_test"] = res_loss1
329        self.log_dict["res2_test"] = res_loss2
330        self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu))
331        self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl))
332        self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl))
class NN_csf(jaxpi.models.ForwardIVP):
346class NN_csf(ForwardIVP):
347    def __init__(self, config):
348        super().__init__(config)
349
350        #Initial condition function
351        self.uinitial = config.uinitial
352
353        #Boundary points
354        self.xl = config.xl
355        self.xu = config.xu
356        self.tu = config.tu
357
358        #Radius left dirichlet condition
359        self.radius = config.radius
360
361        # Predictions over array of x fot t fixed
362        self.u1_0_pred_fn = vmap(
363            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
364        )
365        self.u2_0_pred_fn = vmap(
366            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
367        )
368
369        #Prediction over array of t for x fixed
370        self.u2_bound_pred_fn = vmap(
371            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
372        )
373
374        self.u1_bound_pred_fn = vmap(
375            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
376        )
377
378        #Vmap neural net
379        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
380
381        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
382
383        #Vmap residual operator
384        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
385
386        #Derivatives on x for x fixed and t in a array
387        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
388        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
389
390    #Neural net forward function
391    def neural_net(self, params, t, x):
392        t = t / self.tu
393        z = jnp.stack([t, x])
394        _, outputs = self.state.apply_fn(params, z)
395        u1 = outputs[0]
396        u2 = outputs[1]
397        return u1, u2
398
399    #1st coordinate neural net forward function
400    def u1_net(self, params, t, x):
401        u1, _ = self.neural_net(params, t, x)
402        return u1
403
404    #2st coordinate neural net forward function
405    def u2_net(self, params, t, x):
406        _, u2 = self.neural_net(params, t, x)
407        return u2
408
409    #Residual operator
410    def r_net(self, params, t, x):
411        #Derivatives in x and t
412        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
413        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
414        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
415        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
416
417        #Two derivatives in x
418        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
419        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
420
421        #Each coordinate of residual operator
422        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
423
424    #1st coordinate residual operator
425    def r_net1(self, params, t, x):
426        r1,_ = self.r_net(params, t, x)
427        return r1
428
429    #2nd coordinate residual operator
430    def r_net2(self, params, t, x):
431        _,r2 = self.r_net(params, t, x)
432        return r2
433
434    #Right Dirichlet Condition residual
435    def res_dirichlet(self, params, t, x):
436        #Apply neural net to point in the boundary
437        u1 = self.u1_bound_pred_fn(params, t, x)
438        u2 = self.u2_bound_pred_fn(params, t, x)
439        #Return residual
440        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
441
442    #Right Dirichlet Condition residual non-vectorised
443    def res_dirichlet_nv(self, params, t, x):
444        #Apply neural net to point in the boundary
445        u1 = self.u1_net(params, t, x)
446        u2 = self.u2_net(params, t, x)
447        #Return residual
448        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
449
450    #Neumman condition residual non-vectorised
451    def res_neumann_nv(self, params, t, x):
452        #Apply neural net
453        u1 = self.u1_net(params, t, x)
454        u2 = self.u2_net(params, t, x)
455
456        #Derivatives in x
457        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
458        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
459
460        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
461        nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5)
462
463        #Normal at u(x,y)
464        nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
465
466        #Return inner product
467        return jnp.sum(nS * nu) ** 2
468
469    #Neumman condition residual
470    def res_neumann(self, params, t, x):
471        #Apply neural net to points in the boundary
472        u1 = self.u1_bound_pred_fn(params, t, x)
473        u2 = self.u2_bound_pred_fn(params, t, x)
474
475        #Derivatives in x
476        u1_x = self.u1_bound_x(params, t, x)
477        u2_x = self.u2_bound_x(params, t, x)
478
479        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
480        nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5)
481
482        #Normal at u(x,y)
483        nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
484
485        #Return inner product
486        return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2
487
488    #Compute residuals with causal weights
489    @partial(jit, static_argnums=(0,))
490    def res_causal(self, params, batch):
491        # Sort temporal coordinates
492        t_sorted = batch[:, 0].sort()
493
494        #Compute residuals
495        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
496
497        #Reshape
498        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
499        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
500
501        #Compute mean residuals
502        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
503        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
504
505        #Compute weights
506        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
507        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
508
509        # Take minimum of the causal weights
510        gamma = jnp.vstack([res_gamma1,res_gamma2])
511        gamma = gamma.min(0)
512
513        return res_l1, res_l2, gamma
514
515    #Compute losses
516    @partial(jit, static_argnums=(0,))
517    def losses(self, params, batch):
518        # Initial conditions loss
519        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
520        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
521        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
522
523        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
524        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
525
526        # Residual loss
527        if self.config.weighting.use_causal == True:
528            res_l1, res_l2, gamma = self.res_causal(params, batch)
529            res_loss1 = jnp.mean(res_l1 * gamma)
530            res_loss2 = jnp.mean(res_l2 * gamma)
531        else:
532            res_pred1,res_pred2 = self.r_pred_fn(
533                params, batch[:, 0], batch[:, 1]
534            )
535            # Compute loss
536            res_loss1 = jnp.mean(res_pred1 ** 2)
537            res_loss2 = jnp.mean(res_pred2 ** 2)
538
539        loss_dict = {
540            "ic": u1_ic_loss + u2_ic_loss,
541            "res1": res_loss1,
542            "res2": res_loss2,
543            'rd': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
544            'ld': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
545            'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
546            'rn': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu))
547        }
548        return loss_dict
549
550    #Compute NTK
551    @partial(jit, static_argnums = (0,))
552    def compute_diag_ntk(self, params, batch):
553        #Initial Condition
554        u1_ic_ntk = vmap(
555            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
556        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
557
558        u2_ic_ntk = vmap(
559            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
560        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
561
562        #Right Dirichlet
563        rd_ntk = vmap(
564            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
565        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
566
567        #Dirichlet
568        ld_ntk = vmap(
569            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
570        )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
571        rd_ntk = vmap(
572            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
573        )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
574
575        #Left neumann
576        ln_ntk = vmap(
577            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
578        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
579        rn_ntk = vmap(
580            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
581        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
582
583        # Consider the effect of causal weights
584        if self.config.weighting.use_causal:
585            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
586            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
587                self.r_net1, params, batch[:, 0], batch[:, 1]
588            )
589
590            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
591                self.r_net2, params, batch[:, 0], batch[:, 1]
592            )
593
594            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
595            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
596
597            res_ntk1 = jnp.mean(res_ntk1, axis=1)
598            res_ntk2 = jnp.mean(res_ntk2, axis=1)
599
600            _,_, casual_weights = self.res_causal(params, batch)
601            res_ntk1 = res_ntk1 * casual_weights
602            res_ntk2 = res_ntk2 * casual_weights
603        else:
604            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
605                self.r_net1, params, batch[:, 0], batch[:, 1]
606            )
607            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
608                self.r_net2, params, batch[:, 0], batch[:, 1]
609            )
610
611        ntk_dict = {
612            "ic": u1_ic_ntk + u2_ic_ntk,
613            "res1": res_ntk1,
614            "res2": res_ntk2,
615            'rd': rd_ntk,
616            'ld': ld_ntk,
617            'ln': ln_ntk,
618            'rn': rn_ntk
619        }
620        return ntk_dict
NN_csf(config)
347    def __init__(self, config):
348        super().__init__(config)
349
350        #Initial condition function
351        self.uinitial = config.uinitial
352
353        #Boundary points
354        self.xl = config.xl
355        self.xu = config.xu
356        self.tu = config.tu
357
358        #Radius left dirichlet condition
359        self.radius = config.radius
360
361        # Predictions over array of x fot t fixed
362        self.u1_0_pred_fn = vmap(
363            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
364        )
365        self.u2_0_pred_fn = vmap(
366            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
367        )
368
369        #Prediction over array of t for x fixed
370        self.u2_bound_pred_fn = vmap(
371            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
372        )
373
374        self.u1_bound_pred_fn = vmap(
375            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
376        )
377
378        #Vmap neural net
379        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
380
381        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
382
383        #Vmap residual operator
384        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
385
386        #Derivatives on x for x fixed and t in a array
387        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
388        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
uinitial
xl
xu
tu
radius
u1_0_pred_fn
u2_0_pred_fn
u2_bound_pred_fn
u1_bound_pred_fn
u1_pred_fn
u2_pred_fn
r_pred_fn
u1_bound_x
u2_bound_x
def neural_net(self, params, t, x):
391    def neural_net(self, params, t, x):
392        t = t / self.tu
393        z = jnp.stack([t, x])
394        _, outputs = self.state.apply_fn(params, z)
395        u1 = outputs[0]
396        u2 = outputs[1]
397        return u1, u2
def u1_net(self, params, t, x):
400    def u1_net(self, params, t, x):
401        u1, _ = self.neural_net(params, t, x)
402        return u1
def u2_net(self, params, t, x):
405    def u2_net(self, params, t, x):
406        _, u2 = self.neural_net(params, t, x)
407        return u2
def r_net(self, params, t, x):
410    def r_net(self, params, t, x):
411        #Derivatives in x and t
412        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
413        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
414        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
415        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
416
417        #Two derivatives in x
418        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
419        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
420
421        #Each coordinate of residual operator
422        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
def r_net1(self, params, t, x):
425    def r_net1(self, params, t, x):
426        r1,_ = self.r_net(params, t, x)
427        return r1
def r_net2(self, params, t, x):
430    def r_net2(self, params, t, x):
431        _,r2 = self.r_net(params, t, x)
432        return r2
def res_dirichlet(self, params, t, x):
435    def res_dirichlet(self, params, t, x):
436        #Apply neural net to point in the boundary
437        u1 = self.u1_bound_pred_fn(params, t, x)
438        u2 = self.u2_bound_pred_fn(params, t, x)
439        #Return residual
440        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
def res_dirichlet_nv(self, params, t, x):
443    def res_dirichlet_nv(self, params, t, x):
444        #Apply neural net to point in the boundary
445        u1 = self.u1_net(params, t, x)
446        u2 = self.u2_net(params, t, x)
447        #Return residual
448        return (jnp.sqrt(u1  ** 2 + u2 ** 2) - self.radius) ** 2
def res_neumann_nv(self, params, t, x):
451    def res_neumann_nv(self, params, t, x):
452        #Apply neural net
453        u1 = self.u1_net(params, t, x)
454        u2 = self.u2_net(params, t, x)
455
456        #Derivatives in x
457        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
458        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
459
460        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
461        nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5)
462
463        #Normal at u(x,y)
464        nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
465
466        #Return inner product
467        return jnp.sum(nS * nu) ** 2
def res_neumann(self, params, t, x):
470    def res_neumann(self, params, t, x):
471        #Apply neural net to points in the boundary
472        u1 = self.u1_bound_pred_fn(params, t, x)
473        u2 = self.u2_bound_pred_fn(params, t, x)
474
475        #Derivatives in x
476        u1_x = self.u1_bound_x(params, t, x)
477        u2_x = self.u2_bound_x(params, t, x)
478
479        #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t)
480        nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5)
481
482        #Normal at u(x,y)
483        nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5)
484
485        #Return inner product
486        return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2
@partial(jit, static_argnums=(0,))
def res_causal(self, params, batch):
489    @partial(jit, static_argnums=(0,))
490    def res_causal(self, params, batch):
491        # Sort temporal coordinates
492        t_sorted = batch[:, 0].sort()
493
494        #Compute residuals
495        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
496
497        #Reshape
498        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
499        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
500
501        #Compute mean residuals
502        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
503        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
504
505        #Compute weights
506        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
507        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
508
509        # Take minimum of the causal weights
510        gamma = jnp.vstack([res_gamma1,res_gamma2])
511        gamma = gamma.min(0)
512
513        return res_l1, res_l2, gamma
@partial(jit, static_argnums=(0,))
def losses(self, params, batch):
516    @partial(jit, static_argnums=(0,))
517    def losses(self, params, batch):
518        # Initial conditions loss
519        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
520        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
521        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
522
523        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
524        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
525
526        # Residual loss
527        if self.config.weighting.use_causal == True:
528            res_l1, res_l2, gamma = self.res_causal(params, batch)
529            res_loss1 = jnp.mean(res_l1 * gamma)
530            res_loss2 = jnp.mean(res_l2 * gamma)
531        else:
532            res_pred1,res_pred2 = self.r_pred_fn(
533                params, batch[:, 0], batch[:, 1]
534            )
535            # Compute loss
536            res_loss1 = jnp.mean(res_pred1 ** 2)
537            res_loss2 = jnp.mean(res_pred2 ** 2)
538
539        loss_dict = {
540            "ic": u1_ic_loss + u2_ic_loss,
541            "res1": res_loss1,
542            "res2": res_loss2,
543            'rd': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
544            'ld': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
545            'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)),
546            'rn': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu))
547        }
548        return loss_dict
@partial(jit, static_argnums=(0,))
def compute_diag_ntk(self, params, batch):
551    @partial(jit, static_argnums = (0,))
552    def compute_diag_ntk(self, params, batch):
553        #Initial Condition
554        u1_ic_ntk = vmap(
555            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
556        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
557
558        u2_ic_ntk = vmap(
559            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
560        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
561
562        #Right Dirichlet
563        rd_ntk = vmap(
564            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
565        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
566
567        #Dirichlet
568        ld_ntk = vmap(
569            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
570        )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
571        rd_ntk = vmap(
572            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
573        )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
574
575        #Left neumann
576        ln_ntk = vmap(
577            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
578        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
579        rn_ntk = vmap(
580            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
581        )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
582
583        # Consider the effect of causal weights
584        if self.config.weighting.use_causal:
585            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
586            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
587                self.r_net1, params, batch[:, 0], batch[:, 1]
588            )
589
590            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
591                self.r_net2, params, batch[:, 0], batch[:, 1]
592            )
593
594            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
595            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
596
597            res_ntk1 = jnp.mean(res_ntk1, axis=1)
598            res_ntk2 = jnp.mean(res_ntk2, axis=1)
599
600            _,_, casual_weights = self.res_causal(params, batch)
601            res_ntk1 = res_ntk1 * casual_weights
602            res_ntk2 = res_ntk2 * casual_weights
603        else:
604            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
605                self.r_net1, params, batch[:, 0], batch[:, 1]
606            )
607            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
608                self.r_net2, params, batch[:, 0], batch[:, 1]
609            )
610
611        ntk_dict = {
612            "ic": u1_ic_ntk + u2_ic_ntk,
613            "res1": res_ntk1,
614            "res2": res_ntk2,
615            'rd': rd_ntk,
616            'ld': ld_ntk,
617            'ln': ln_ntk,
618            'rn': rn_ntk
619        }
620        return ntk_dict
class NN_csf_Evaluator(jaxpi.evaluator.BaseEvaluator):
622class NN_csf_Evaluator(BaseEvaluator):
623    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
624        super().__init__(config, model)
625
626        self.x0_test = x0_test
627        self.tb_test = tb_test
628        self.xc_test = xc_test
629        self.tc_test = tc_test
630        self.u2_0_test = u2_0_test
631        self.u1_0_test = u1_0_test
632
633    def log_errors(self, params):
634        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
635        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
636
637        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
638        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
639
640        res_pred1,res_pred2 = self.model.r_pred_fn(
641            params, self.tc_test[:,0], self.xc_test[:,0]
642        )
643        res_loss1 = jnp.mean(res_pred1 ** 2)
644        res_loss2 = jnp.mean(res_pred2 ** 2)
645
646        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
647        self.log_dict["res1_test"] = res_loss1
648        self.log_dict["res2_test"] = res_loss2
649        self.log_dict["rd_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xu))
650        self.log_dict["ld_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xl))
651        self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl))
652        self.log_dict["rn_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xu))
653
654    def __call__(self, state, batch):
655        self.log_dict = super().__call__(state, batch)
656
657        if self.config.logging.log_errors:
658            self.log_errors(state.params)
659
660        if self.config.weighting.use_causal:
661            _, _, causal_weight = self.model.res_causal(state.params, batch)
662            self.log_dict["cas_weight"] = causal_weight.min()
663
664        return self.log_dict
NN_csf_Evaluator( config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
623    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
624        super().__init__(config, model)
625
626        self.x0_test = x0_test
627        self.tb_test = tb_test
628        self.xc_test = xc_test
629        self.tc_test = tc_test
630        self.u2_0_test = u2_0_test
631        self.u1_0_test = u1_0_test
x0_test
tb_test
xc_test
tc_test
u2_0_test
u1_0_test
def log_errors(self, params):
633    def log_errors(self, params):
634        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
635        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
636
637        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
638        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
639
640        res_pred1,res_pred2 = self.model.r_pred_fn(
641            params, self.tc_test[:,0], self.xc_test[:,0]
642        )
643        res_loss1 = jnp.mean(res_pred1 ** 2)
644        res_loss2 = jnp.mean(res_pred2 ** 2)
645
646        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
647        self.log_dict["res1_test"] = res_loss1
648        self.log_dict["res2_test"] = res_loss2
649        self.log_dict["rd_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xu))
650        self.log_dict["ld_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xl))
651        self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl))
652        self.log_dict["rn_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xu))
class DD_csf(jaxpi.models.ForwardIVP):
666class DD_csf(ForwardIVP):
667    def __init__(self, config):
668        super().__init__(config)
669
670        #Initial condition function
671        self.uinitial = config.uinitial
672
673        #Boundary points
674        self.xl = config.xl
675        self.xu = config.xu
676        self.tu = config.tu
677
678        #Right dirichlet point
679        self.ld = config.ld
680        self.rd = config.rd
681
682        # Predictions over array of x fot t fixed
683        self.u1_0_pred_fn = vmap(
684            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
685        )
686        self.u2_0_pred_fn = vmap(
687            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
688        )
689
690        #Prediction over array of t for x fixed
691        self.u2_bound_pred_fn = vmap(
692            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
693        )
694
695        self.u1_bound_pred_fn = vmap(
696            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
697        )
698
699        #Vmap neural net
700        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
701
702        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
703
704        #Vmap residual operator
705        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
706
707        #Derivatives on x for x fixed and t in a array
708        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
709        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
710
711    #Neural net forward function
712    def neural_net(self, params, t, x):
713        t = t / self.tu
714        z = jnp.stack([t, x])
715        _, outputs = self.state.apply_fn(params, z)
716        u1 = outputs[0]
717        u2 = outputs[1]
718        return u1, u2
719
720    #1st coordinate neural net forward function
721    def u1_net(self, params, t, x):
722        u1, _ = self.neural_net(params, t, x)
723        return u1
724
725    #2st coordinate neural net forward function
726    def u2_net(self, params, t, x):
727        _, u2 = self.neural_net(params, t, x)
728        return u2
729
730    #Residual operator
731    def r_net(self, params, t, x):
732        #Derivatives in x and t
733        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
734        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
735        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
736        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
737
738        #Two derivatives in x
739        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
740        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
741
742        #Each coordinate of residual operator
743        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
744
745    #1st coordinate residual operator
746    def r_net1(self, params, t, x):
747        r1,_ = self.r_net(params, t, x)
748        return r1
749
750    #2nd coordinate residual operator
751    def r_net2(self, params, t, x):
752        _,r2 = self.r_net(params, t, x)
753        return r2
754
755    #Right Dirichlet Condition residual
756    def res_right_dirichlet(self, params, t, x):
757        #Apply neural net to point in the boundary
758        u1 = self.u1_bound_pred_fn(params, t, x)
759        u2 = self.u2_bound_pred_fn(params, t, x)
760        #Return residual
761        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
762
763    #Right Dirichlet Condition residual non-vectorised
764    def res_right_dirichlet_nv(self, params, t, x):
765        #Apply neural net to point in the boundary
766        u1 = self.u1_net(params, t, x)
767        u2 = self.u2_net(params, t, x)
768        #Return residual
769        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
770
771    #Left Dirichlet Condition residual
772    def res_left_dirichlet(self, params, t, x):
773        #Apply neural net to point in the boundary
774        u1 = self.u1_bound_pred_fn(params, t, x)
775        u2 = self.u2_bound_pred_fn(params, t, x)
776        #Return residual
777        return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2
778
779    #Left Dirichlet Condition residual non-vectorised
780    def res_left_dirichlet_nv(self, params, t, x):
781        #Apply neural net to point in the boundary
782        u1 = self.u1_net(params, t, x)
783        u2 = self.u2_net(params, t, x)
784        #Return residual
785        return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2
786
787    #Compute residuals with causal weights
788    @partial(jit, static_argnums=(0,))
789    def res_causal(self, params, batch):
790        # Sort temporal coordinates
791        t_sorted = batch[:, 0].sort()
792
793        #Compute residuals
794        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
795
796        #Reshape
797        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
798        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
799
800        #Compute mean residuals
801        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
802        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
803
804        #Compute weights
805        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
806        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
807
808        # Take minimum of the causal weights
809        gamma = jnp.vstack([res_gamma1,res_gamma2])
810        gamma = gamma.min(0)
811
812        return res_l1, res_l2, gamma
813
814    #Compute losses
815    @partial(jit, static_argnums=(0,))
816    def losses(self, params, batch):
817        # Initial conditions loss
818        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
819        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
820        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
821
822        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
823        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
824
825        # Residual loss
826        if self.config.weighting.use_causal == True:
827            res_l1, res_l2, gamma = self.res_causal(params, batch)
828            res_loss1 = jnp.mean(res_l1 * gamma)
829            res_loss2 = jnp.mean(res_l2 * gamma)
830        else:
831            res_pred1,res_pred2 = self.r_pred_fn(
832                params, batch[:, 0], batch[:, 1]
833            )
834            # Compute loss
835            res_loss1 = jnp.mean(res_pred1 ** 2)
836            res_loss2 = jnp.mean(res_pred2 ** 2)
837
838        loss_dict = {
839            "ic": u1_ic_loss + u2_ic_loss,
840            "res1": res_loss1,
841            "res2": res_loss2,
842            'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
843            'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl))
844        }
845        return loss_dict
846
847    #Compute NTK
848    @partial(jit, static_argnums = (0,))
849    def compute_diag_ntk(self, params, batch):
850        #Initial Condition
851        u1_ic_ntk = vmap(
852            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
853        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
854
855        u2_ic_ntk = vmap(
856            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
857        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
858
859        #Right Dirichlet
860        rd_ntk = vmap(
861            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
862        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
863
864        #Left Dirichlet
865        ld_ntk = vmap(
866            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
867        )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
868
869        # Consider the effect of causal weights
870        if self.config.weighting.use_causal:
871            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
872            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
873                self.r_net1, params, batch[:, 0], batch[:, 1]
874            )
875
876            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
877                self.r_net2, params, batch[:, 0], batch[:, 1]
878            )
879
880            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
881            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
882
883            res_ntk1 = jnp.mean(res_ntk1, axis=1)
884            res_ntk2 = jnp.mean(res_ntk2, axis=1)
885
886            _,_, casual_weights = self.res_causal(params, batch)
887            res_ntk1 = res_ntk1 * casual_weights
888            res_ntk2 = res_ntk2 * casual_weights
889        else:
890            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
891                self.r_net1, params, batch[:, 0], batch[:, 1]
892            )
893            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
894                self.r_net2, params, batch[:, 0], batch[:, 1]
895            )
896
897        ntk_dict = {
898            "ic": u1_ic_ntk + u2_ic_ntk,
899            "res1": res_ntk1,
900            "res2": res_ntk2,
901            'rd': rd_ntk,
902            'ld': ld_ntk
903        }
904        return ntk_dict
DD_csf(config)
667    def __init__(self, config):
668        super().__init__(config)
669
670        #Initial condition function
671        self.uinitial = config.uinitial
672
673        #Boundary points
674        self.xl = config.xl
675        self.xu = config.xu
676        self.tu = config.tu
677
678        #Right dirichlet point
679        self.ld = config.ld
680        self.rd = config.rd
681
682        # Predictions over array of x fot t fixed
683        self.u1_0_pred_fn = vmap(
684            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
685        )
686        self.u2_0_pred_fn = vmap(
687            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
688        )
689
690        #Prediction over array of t for x fixed
691        self.u2_bound_pred_fn = vmap(
692            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
693        )
694
695        self.u1_bound_pred_fn = vmap(
696            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
697        )
698
699        #Vmap neural net
700        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
701
702        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
703
704        #Vmap residual operator
705        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
706
707        #Derivatives on x for x fixed and t in a array
708        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
709        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
uinitial
xl
xu
tu
ld
rd
u1_0_pred_fn
u2_0_pred_fn
u2_bound_pred_fn
u1_bound_pred_fn
u1_pred_fn
u2_pred_fn
r_pred_fn
u1_bound_x
u2_bound_x
def neural_net(self, params, t, x):
712    def neural_net(self, params, t, x):
713        t = t / self.tu
714        z = jnp.stack([t, x])
715        _, outputs = self.state.apply_fn(params, z)
716        u1 = outputs[0]
717        u2 = outputs[1]
718        return u1, u2
def u1_net(self, params, t, x):
721    def u1_net(self, params, t, x):
722        u1, _ = self.neural_net(params, t, x)
723        return u1
def u2_net(self, params, t, x):
726    def u2_net(self, params, t, x):
727        _, u2 = self.neural_net(params, t, x)
728        return u2
def r_net(self, params, t, x):
731    def r_net(self, params, t, x):
732        #Derivatives in x and t
733        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
734        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
735        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
736        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
737
738        #Two derivatives in x
739        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
740        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
741
742        #Each coordinate of residual operator
743        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
def r_net1(self, params, t, x):
746    def r_net1(self, params, t, x):
747        r1,_ = self.r_net(params, t, x)
748        return r1
def r_net2(self, params, t, x):
751    def r_net2(self, params, t, x):
752        _,r2 = self.r_net(params, t, x)
753        return r2
def res_right_dirichlet(self, params, t, x):
756    def res_right_dirichlet(self, params, t, x):
757        #Apply neural net to point in the boundary
758        u1 = self.u1_bound_pred_fn(params, t, x)
759        u2 = self.u2_bound_pred_fn(params, t, x)
760        #Return residual
761        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
def res_right_dirichlet_nv(self, params, t, x):
764    def res_right_dirichlet_nv(self, params, t, x):
765        #Apply neural net to point in the boundary
766        u1 = self.u1_net(params, t, x)
767        u2 = self.u2_net(params, t, x)
768        #Return residual
769        return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2
def res_left_dirichlet(self, params, t, x):
772    def res_left_dirichlet(self, params, t, x):
773        #Apply neural net to point in the boundary
774        u1 = self.u1_bound_pred_fn(params, t, x)
775        u2 = self.u2_bound_pred_fn(params, t, x)
776        #Return residual
777        return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2
def res_left_dirichlet_nv(self, params, t, x):
780    def res_left_dirichlet_nv(self, params, t, x):
781        #Apply neural net to point in the boundary
782        u1 = self.u1_net(params, t, x)
783        u2 = self.u2_net(params, t, x)
784        #Return residual
785        return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2
@partial(jit, static_argnums=(0,))
def res_causal(self, params, batch):
788    @partial(jit, static_argnums=(0,))
789    def res_causal(self, params, batch):
790        # Sort temporal coordinates
791        t_sorted = batch[:, 0].sort()
792
793        #Compute residuals
794        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
795
796        #Reshape
797        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
798        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
799
800        #Compute mean residuals
801        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
802        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
803
804        #Compute weights
805        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
806        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
807
808        # Take minimum of the causal weights
809        gamma = jnp.vstack([res_gamma1,res_gamma2])
810        gamma = gamma.min(0)
811
812        return res_l1, res_l2, gamma
@partial(jit, static_argnums=(0,))
def losses(self, params, batch):
815    @partial(jit, static_argnums=(0,))
816    def losses(self, params, batch):
817        # Initial conditions loss
818        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
819        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
820        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
821
822        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
823        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
824
825        # Residual loss
826        if self.config.weighting.use_causal == True:
827            res_l1, res_l2, gamma = self.res_causal(params, batch)
828            res_loss1 = jnp.mean(res_l1 * gamma)
829            res_loss2 = jnp.mean(res_l2 * gamma)
830        else:
831            res_pred1,res_pred2 = self.r_pred_fn(
832                params, batch[:, 0], batch[:, 1]
833            )
834            # Compute loss
835            res_loss1 = jnp.mean(res_pred1 ** 2)
836            res_loss2 = jnp.mean(res_pred2 ** 2)
837
838        loss_dict = {
839            "ic": u1_ic_loss + u2_ic_loss,
840            "res1": res_loss1,
841            "res2": res_loss2,
842            'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)),
843            'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl))
844        }
845        return loss_dict
@partial(jit, static_argnums=(0,))
def compute_diag_ntk(self, params, batch):
848    @partial(jit, static_argnums = (0,))
849    def compute_diag_ntk(self, params, batch):
850        #Initial Condition
851        u1_ic_ntk = vmap(
852            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
853        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
854
855        u2_ic_ntk = vmap(
856            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
857        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
858
859        #Right Dirichlet
860        rd_ntk = vmap(
861            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
862        )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)
863
864        #Left Dirichlet
865        ld_ntk = vmap(
866            vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None)
867        )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)
868
869        # Consider the effect of causal weights
870        if self.config.weighting.use_causal:
871            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
872            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
873                self.r_net1, params, batch[:, 0], batch[:, 1]
874            )
875
876            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
877                self.r_net2, params, batch[:, 0], batch[:, 1]
878            )
879
880            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
881            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
882
883            res_ntk1 = jnp.mean(res_ntk1, axis=1)
884            res_ntk2 = jnp.mean(res_ntk2, axis=1)
885
886            _,_, casual_weights = self.res_causal(params, batch)
887            res_ntk1 = res_ntk1 * casual_weights
888            res_ntk2 = res_ntk2 * casual_weights
889        else:
890            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
891                self.r_net1, params, batch[:, 0], batch[:, 1]
892            )
893            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
894                self.r_net2, params, batch[:, 0], batch[:, 1]
895            )
896
897        ntk_dict = {
898            "ic": u1_ic_ntk + u2_ic_ntk,
899            "res1": res_ntk1,
900            "res2": res_ntk2,
901            'rd': rd_ntk,
902            'ld': ld_ntk
903        }
904        return ntk_dict
class DD_csf_Evaluator(jaxpi.evaluator.BaseEvaluator):
906class DD_csf_Evaluator(BaseEvaluator):
907    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
908        super().__init__(config, model)
909
910        self.x0_test = x0_test
911        self.tb_test = tb_test
912        self.xc_test = xc_test
913        self.tc_test = tc_test
914        self.u2_0_test = u2_0_test
915        self.u1_0_test = u1_0_test
916
917    def log_errors(self, params):
918        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
919        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
920
921        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
922        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
923
924        res_pred1,res_pred2 = self.model.r_pred_fn(
925            params, self.tc_test[:,0], self.xc_test[:,0]
926        )
927        res_loss1 = jnp.mean(res_pred1 ** 2)
928        res_loss2 = jnp.mean(res_pred2 ** 2)
929
930        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
931        self.log_dict["res1_test"] = res_loss1
932        self.log_dict["res2_test"] = res_loss2
933        self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu))
934        self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl))
935
936    def __call__(self, state, batch):
937        self.log_dict = super().__call__(state, batch)
938
939        if self.config.logging.log_errors:
940            self.log_errors(state.params)
941
942        if self.config.weighting.use_causal:
943            _, _, causal_weight = self.model.res_causal(state.params, batch)
944            self.log_dict["cas_weight"] = causal_weight.min()
945
946        return self.log_dict
DD_csf_Evaluator( config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
907    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
908        super().__init__(config, model)
909
910        self.x0_test = x0_test
911        self.tb_test = tb_test
912        self.xc_test = xc_test
913        self.tc_test = tc_test
914        self.u2_0_test = u2_0_test
915        self.u1_0_test = u1_0_test
x0_test
tb_test
xc_test
tc_test
u2_0_test
u1_0_test
def log_errors(self, params):
917    def log_errors(self, params):
918        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
919        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
920
921        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
922        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
923
924        res_pred1,res_pred2 = self.model.r_pred_fn(
925            params, self.tc_test[:,0], self.xc_test[:,0]
926        )
927        res_loss1 = jnp.mean(res_pred1 ** 2)
928        res_loss2 = jnp.mean(res_pred2 ** 2)
929
930        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
931        self.log_dict["res1_test"] = res_loss1
932        self.log_dict["res2_test"] = res_loss2
933        self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu))
934        self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl))
class closed_csf(jaxpi.models.ForwardIVP):
 948class closed_csf(ForwardIVP):
 949    def __init__(self, config):
 950        super().__init__(config)
 951
 952        #Initial condition function
 953        self.uinitial = config.uinitial
 954
 955        #Boundary points
 956        self.xl = config.xl
 957        self.xu = config.xu
 958        self.tu = config.tu
 959
 960        # Predictions over array of x fot t fixed
 961        self.u1_0_pred_fn = vmap(
 962            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
 963        )
 964        self.u2_0_pred_fn = vmap(
 965            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
 966        )
 967
 968        #Prediction over array of t for x fixed
 969        self.u2_bound_pred_fn = vmap(
 970            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
 971        )
 972
 973        self.u1_bound_pred_fn = vmap(
 974            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
 975        )
 976
 977        #Vmap neural net
 978        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
 979
 980        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
 981
 982        #Vmap residual operator
 983        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
 984
 985        #Derivatives on x for x fixed and t in a array
 986        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
 987        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
 988
 989    #Neural net forward function
 990    def neural_net(self, params, t, x):
 991        t = t / self.tu
 992        z = jnp.stack([t, x])
 993        _, outputs = self.state.apply_fn(params, z)
 994        u1 = outputs[0]
 995        u2 = outputs[1]
 996        return u1, u2
 997
 998    #1st coordinate neural net forward function
 999    def u1_net(self, params, t, x):
1000        u1, _ = self.neural_net(params, t, x)
1001        return u1
1002
1003    #2st coordinate neural net forward function
1004    def u2_net(self, params, t, x):
1005        _, u2 = self.neural_net(params, t, x)
1006        return u2
1007
1008    #Residual operator
1009    def r_net(self, params, t, x):
1010        #Derivatives in x and t
1011        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
1012        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
1013        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
1014        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
1015
1016        #Two derivatives in x
1017        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
1018        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
1019
1020        #Each coordinate of residual operator
1021        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
1022
1023    #1st coordinate residual operator
1024    def r_net1(self, params, t, x):
1025        r1,_ = self.r_net(params, t, x)
1026        return r1
1027
1028    #2nd coordinate residual operator
1029    def r_net2(self, params, t, x):
1030        _,r2 = self.r_net(params, t, x)
1031        return r2
1032
1033    #Compute residuals with causal weights
1034    @partial(jit, static_argnums=(0,))
1035    def res_causal(self, params, batch):
1036        # Sort temporal coordinates
1037        t_sorted = batch[:, 0].sort()
1038
1039        #Compute residuals
1040        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
1041
1042        #Reshape
1043        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
1044        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
1045
1046        #Compute mean residuals
1047        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
1048        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
1049
1050        #Compute weights
1051        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
1052        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
1053
1054        # Take minimum of the causal weights
1055        gamma = jnp.vstack([res_gamma1,res_gamma2])
1056        gamma = gamma.min(0)
1057
1058        return res_l1, res_l2, gamma
1059
1060    #Compute losses
1061    @partial(jit, static_argnums=(0,))
1062    def losses(self, params, batch):
1063        # Initial conditions loss
1064        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1065        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1066        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
1067
1068        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
1069        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
1070
1071        periodic1 = jnp.mean((self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2)
1072        periodic2 = jnp.mean((self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2)
1073
1074        # Residual loss
1075        if self.config.weighting.use_causal == True:
1076            res_l1, res_l2, gamma = self.res_causal(params, batch)
1077            res_loss1 = jnp.mean(res_l1 * gamma)
1078            res_loss2 = jnp.mean(res_l2 * gamma)
1079        else:
1080            res_pred1,res_pred2 = self.r_pred_fn(
1081                params, batch[:, 0], batch[:, 1]
1082            )
1083            # Compute loss
1084            res_loss1 = jnp.mean(res_pred1 ** 2)
1085            res_loss2 = jnp.mean(res_pred2 ** 2)
1086
1087        loss_dict = {
1088            "ic": u1_ic_loss + u2_ic_loss,
1089            "res1": res_loss1,
1090            "res2": res_loss2,
1091            'periodic1': periodic1,
1092            'periodic2': periodic2
1093        }
1094        return loss_dict
1095
1096    #Compute NTK
1097    @partial(jit, static_argnums = (0,))
1098    def compute_diag_ntk(self, params, batch):
1099        #Initial Condition
1100        u1_ic_ntk = vmap(
1101            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
1102        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1103
1104        u2_ic_ntk = vmap(
1105            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
1106        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1107
1108        # Consider the effect of causal weights
1109        if self.config.weighting.use_causal:
1110            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
1111            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
1112                self.r_net1, params, batch[:, 0], batch[:, 1]
1113            )
1114
1115            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
1116                self.r_net2, params, batch[:, 0], batch[:, 1]
1117            )
1118
1119            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
1120            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
1121
1122            res_ntk1 = jnp.mean(res_ntk1, axis=1)
1123            res_ntk2 = jnp.mean(res_ntk2, axis=1)
1124
1125            _,_, casual_weights = self.res_causal(params, batch)
1126            res_ntk1 = res_ntk1 * casual_weights
1127            res_ntk2 = res_ntk2 * casual_weights
1128        else:
1129            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
1130                self.r_net1, params, batch[:, 0], batch[:, 1]
1131            )
1132            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
1133                self.r_net2, params, batch[:, 0], batch[:, 1]
1134            )
1135
1136        ntk_dict = {
1137            "ic": u1_ic_ntk + u2_ic_ntk,
1138            "res1": res_ntk1,
1139            "res2": res_ntk2
1140        }
1141        return ntk_dict
closed_csf(config)
949    def __init__(self, config):
950        super().__init__(config)
951
952        #Initial condition function
953        self.uinitial = config.uinitial
954
955        #Boundary points
956        self.xl = config.xl
957        self.xu = config.xu
958        self.tu = config.tu
959
960        # Predictions over array of x fot t fixed
961        self.u1_0_pred_fn = vmap(
962            vmap(self.u1_net, (None, None, 0)), (None, None, 0)
963        )
964        self.u2_0_pred_fn = vmap(
965            vmap(self.u2_net, (None, None, 0)), (None, None, 0)
966        )
967
968        #Prediction over array of t for x fixed
969        self.u2_bound_pred_fn = vmap(
970            vmap(self.u2_net, (None, 0, None)), (None, 0, None)
971        )
972
973        self.u1_bound_pred_fn = vmap(
974            vmap(self.u1_net, (None, 0, None)), (None, 0, None)
975        )
976
977        #Vmap neural net
978        self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0))
979
980        self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0))
981
982        #Vmap residual operator
983        self.r_pred_fn = vmap(self.r_net, (None, 0, 0))
984
985        #Derivatives on x for x fixed and t in a array
986        self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None))
987        self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
uinitial
xl
xu
tu
u1_0_pred_fn
u2_0_pred_fn
u2_bound_pred_fn
u1_bound_pred_fn
u1_pred_fn
u2_pred_fn
r_pred_fn
u1_bound_x
u2_bound_x
def neural_net(self, params, t, x):
990    def neural_net(self, params, t, x):
991        t = t / self.tu
992        z = jnp.stack([t, x])
993        _, outputs = self.state.apply_fn(params, z)
994        u1 = outputs[0]
995        u2 = outputs[1]
996        return u1, u2
def u1_net(self, params, t, x):
 999    def u1_net(self, params, t, x):
1000        u1, _ = self.neural_net(params, t, x)
1001        return u1
def u2_net(self, params, t, x):
1004    def u2_net(self, params, t, x):
1005        _, u2 = self.neural_net(params, t, x)
1006        return u2
def r_net(self, params, t, x):
1009    def r_net(self, params, t, x):
1010        #Derivatives in x and t
1011        u1_x = grad(self.u1_net, argnums = 2)(params, t, x)
1012        u1_t = grad(self.u1_net, argnums = 1)(params, t, x)
1013        u2_x = grad(self.u2_net, argnums = 2)(params, t, x)
1014        u2_t = grad(self.u2_net, argnums = 1)(params, t, x)
1015
1016        #Two derivatives in x
1017        u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x)
1018        u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x)
1019
1020        #Each coordinate of residual operator
1021        return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
def r_net1(self, params, t, x):
1024    def r_net1(self, params, t, x):
1025        r1,_ = self.r_net(params, t, x)
1026        return r1
def r_net2(self, params, t, x):
1029    def r_net2(self, params, t, x):
1030        _,r2 = self.r_net(params, t, x)
1031        return r2
@partial(jit, static_argnums=(0,))
def res_causal(self, params, batch):
1034    @partial(jit, static_argnums=(0,))
1035    def res_causal(self, params, batch):
1036        # Sort temporal coordinates
1037        t_sorted = batch[:, 0].sort()
1038
1039        #Compute residuals
1040        res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1])
1041
1042        #Reshape
1043        res_pred1 = res_pred1.reshape(self.num_chunks, -1)
1044        res_pred2 = res_pred2.reshape(self.num_chunks, -1)
1045
1046        #Compute mean residuals
1047        res_l1 = jnp.mean(res_pred1 ** 2, axis=1)
1048        res_l2 = jnp.mean(res_pred2 ** 2, axis=1)
1049
1050        #Compute weights
1051        res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1)))
1052        res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2)))
1053
1054        # Take minimum of the causal weights
1055        gamma = jnp.vstack([res_gamma1,res_gamma2])
1056        gamma = gamma.min(0)
1057
1058        return res_l1, res_l2, gamma
@partial(jit, static_argnums=(0,))
def losses(self, params, batch):
1061    @partial(jit, static_argnums=(0,))
1062    def losses(self, params, batch):
1063        # Initial conditions loss
1064        u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1065        u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1066        u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1)))
1067
1068        u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2)
1069        u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2)
1070
1071        periodic1 = jnp.mean((self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2)
1072        periodic2 = jnp.mean((self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2)
1073
1074        # Residual loss
1075        if self.config.weighting.use_causal == True:
1076            res_l1, res_l2, gamma = self.res_causal(params, batch)
1077            res_loss1 = jnp.mean(res_l1 * gamma)
1078            res_loss2 = jnp.mean(res_l2 * gamma)
1079        else:
1080            res_pred1,res_pred2 = self.r_pred_fn(
1081                params, batch[:, 0], batch[:, 1]
1082            )
1083            # Compute loss
1084            res_loss1 = jnp.mean(res_pred1 ** 2)
1085            res_loss2 = jnp.mean(res_pred2 ** 2)
1086
1087        loss_dict = {
1088            "ic": u1_ic_loss + u2_ic_loss,
1089            "res1": res_loss1,
1090            "res2": res_loss2,
1091            'periodic1': periodic1,
1092            'periodic2': periodic2
1093        }
1094        return loss_dict
@partial(jit, static_argnums=(0,))
def compute_diag_ntk(self, params, batch):
1097    @partial(jit, static_argnums = (0,))
1098    def compute_diag_ntk(self, params, batch):
1099        #Initial Condition
1100        u1_ic_ntk = vmap(
1101            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
1102        )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1103
1104        u2_ic_ntk = vmap(
1105            vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0)
1106        )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1)))
1107
1108        # Consider the effect of causal weights
1109        if self.config.weighting.use_causal:
1110            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
1111            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
1112                self.r_net1, params, batch[:, 0], batch[:, 1]
1113            )
1114
1115            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
1116                self.r_net2, params, batch[:, 0], batch[:, 1]
1117            )
1118
1119            res_ntk1 = res_ntk1.reshape(self.num_chunks, -1)
1120            res_ntk2 = res_ntk2.reshape(self.num_chunks, -1)
1121
1122            res_ntk1 = jnp.mean(res_ntk1, axis=1)
1123            res_ntk2 = jnp.mean(res_ntk2, axis=1)
1124
1125            _,_, casual_weights = self.res_causal(params, batch)
1126            res_ntk1 = res_ntk1 * casual_weights
1127            res_ntk2 = res_ntk2 * casual_weights
1128        else:
1129            res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))(
1130                self.r_net1, params, batch[:, 0], batch[:, 1]
1131            )
1132            res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))(
1133                self.r_net2, params, batch[:, 0], batch[:, 1]
1134            )
1135
1136        ntk_dict = {
1137            "ic": u1_ic_ntk + u2_ic_ntk,
1138            "res1": res_ntk1,
1139            "res2": res_ntk2
1140        }
1141        return ntk_dict
class closed_csf_Evaluator(jaxpi.evaluator.BaseEvaluator):
1143class closed_csf_Evaluator(BaseEvaluator):
1144    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
1145        super().__init__(config, model)
1146
1147        self.x0_test = x0_test
1148        self.tb_test = tb_test
1149        self.xc_test = xc_test
1150        self.tc_test = tc_test
1151        self.u2_0_test = u2_0_test
1152        self.u1_0_test = u1_0_test
1153
1154    def log_errors(self, params):
1155        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
1156        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
1157
1158        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
1159        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
1160
1161        res_pred1,res_pred2 = self.model.r_pred_fn(
1162            params, self.tc_test[:,0], self.xc_test[:,0]
1163        )
1164        res_loss1 = jnp.mean(res_pred1 ** 2)
1165        res_loss2 = jnp.mean(res_pred2 ** 2)
1166
1167        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
1168        self.log_dict["res1_test"] = res_loss1
1169        self.log_dict["res2_test"] = res_loss2
1170        self.log_dict["periodic1_test"] = jnp.mean((self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2)
1171        self.log_dict["periodic2_test"] = jnp.mean((self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2)
1172
1173    def __call__(self, state, batch):
1174        self.log_dict = super().__call__(state, batch)
1175
1176        if self.config.logging.log_errors:
1177            self.log_errors(state.params)
1178
1179        if self.config.weighting.use_causal:
1180            _, _, causal_weight = self.model.res_causal(state.params, batch)
1181            self.log_dict["cas_weight"] = causal_weight.min()
1182
1183        return self.log_dict
closed_csf_Evaluator( config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
1144    def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test):
1145        super().__init__(config, model)
1146
1147        self.x0_test = x0_test
1148        self.tb_test = tb_test
1149        self.xc_test = xc_test
1150        self.tc_test = tc_test
1151        self.u2_0_test = u2_0_test
1152        self.u1_0_test = u1_0_test
x0_test
tb_test
xc_test
tc_test
u2_0_test
u1_0_test
def log_errors(self, params):
1154    def log_errors(self, params):
1155        u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test)
1156        u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test)
1157
1158        u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2)
1159        u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2)
1160
1161        res_pred1,res_pred2 = self.model.r_pred_fn(
1162            params, self.tc_test[:,0], self.xc_test[:,0]
1163        )
1164        res_loss1 = jnp.mean(res_pred1 ** 2)
1165        res_loss2 = jnp.mean(res_pred2 ** 2)
1166
1167        self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2)))
1168        self.log_dict["res1_test"] = res_loss1
1169        self.log_dict["res2_test"] = res_loss2
1170        self.log_dict["periodic1_test"] = jnp.mean((self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2)
1171        self.log_dict["periodic2_test"] = jnp.mean((self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2)