jinnax.class_csf
1from functools import partial 2 3import jax 4import jax.numpy as jnp 5from jax import lax, jit, grad, vmap, jacrev, hessian 6from jax.tree_util import tree_map 7 8import optax 9 10from jaxpi import archs 11from jaxpi.models import ForwardIVP 12from jaxpi.evaluator import BaseEvaluator 13from jaxpi.utils import ntk_fn 14 15class DN_csf(ForwardIVP): 16 def __init__(self, config): 17 super().__init__(config) 18 19 #Initial condition function 20 self.uinitial = config.uinitial 21 22 #Boundary points 23 self.xl = config.xl 24 self.xu = config.xu 25 self.tu = config.tu 26 27 #Radius left dirichlet condition 28 self.radius = config.radius 29 30 #Right dirichlet point 31 self.rd = config.rd 32 33 # Predictions over array of x fot t fixed 34 self.u1_0_pred_fn = vmap( 35 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 36 ) 37 self.u2_0_pred_fn = vmap( 38 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 39 ) 40 41 #Prediction over array of t for x fixed 42 self.u2_bound_pred_fn = vmap( 43 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 44 ) 45 46 self.u1_bound_pred_fn = vmap( 47 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 48 ) 49 50 #Vmap neural net 51 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 52 53 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 54 55 #Vmap residual operator 56 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 57 58 #Derivatives on x for x fixed and t in a array 59 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 60 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None)) 61 62 #Neural net forward function 63 def neural_net(self, params, t, x): 64 t = t / self.tu 65 z = jnp.stack([t, x]) 66 _, outputs = self.state.apply_fn(params, z) 67 u1 = outputs[0] 68 u2 = outputs[1] 69 return u1, u2 70 71 #1st coordinate neural net forward function 72 def u1_net(self, params, t, x): 73 u1, _ = self.neural_net(params, t, x) 74 return u1 75 76 #2st coordinate neural net forward function 77 def u2_net(self, params, t, x): 78 _, u2 = self.neural_net(params, t, x) 79 return u2 80 81 #Residual operator 82 def r_net(self, params, t, x): 83 #Derivatives in x and t 84 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 85 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 86 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 87 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 88 89 #Two derivatives in x 90 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 91 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 92 93 #Each coordinate of residual operator 94 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2 95 96 #1st coordinate residual operator 97 def r_net1(self, params, t, x): 98 r1,_ = self.r_net(params, t, x) 99 return r1 100 101 #2nd coordinate residual operator 102 def r_net2(self, params, t, x): 103 _,r2 = self.r_net(params, t, x) 104 return r2 105 106 #Right Dirichlet Condition residual 107 def res_right_dirichlet(self, params, t, x): 108 #Apply neural net to point in the boundary 109 u1 = self.u1_bound_pred_fn(params, t, x) 110 u2 = self.u2_bound_pred_fn(params, t, x) 111 #Return residual 112 return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2 113 114 #Right Dirichlet Condition residual non-vectorised 115 def res_right_dirichlet_nv(self, params, t, x): 116 #Apply neural net to point in the boundary 117 u1 = self.u1_net(params, t, x) 118 u2 = self.u2_net(params, t, x) 119 #Return residual 120 return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2 121 122 #Left Dirichlet Condition residual 123 def res_left_dirichlet(self, params, t, x): 124 #Apply neural net to point in the boundary 125 u1 = self.u1_bound_pred_fn(params, t, x) 126 u2 = self.u2_bound_pred_fn(params, t, x) 127 #Return residual 128 return (jnp.sqrt(u1 ** 2 + u2 ** 2) - self.radius) ** 2 129 130 #Left Dirichlet Condition residual non-vectorised 131 def res_left_dirichlet_nv(self, params, t, x): 132 #Apply neural net to point in the boundary 133 u1 = self.u1_net(params, t, x) 134 u2 = self.u2_net(params, t, x) 135 #Return residual 136 return (jnp.sqrt(u1 ** 2 + u2 ** 2) - self.radius) ** 2 137 138 #Neumman condition residual non-vectorised 139 def res_neumann_nv(self, params, t, x): 140 #Apply neural net 141 u1 = self.u1_net(params, t, x) 142 u2 = self.u2_net(params, t, x) 143 144 #Derivatives in x 145 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 146 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 147 148 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 149 nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5) 150 151 #Normal at u(x,y) 152 nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 153 154 #Return inner product 155 return jnp.sum(nS * nu) ** 2 156 157 #Neumman condition residual 158 def res_neumann(self, params, t, x): 159 #Apply neural net to points in the boundary 160 u1 = self.u1_bound_pred_fn(params, t, x) 161 u2 = self.u2_bound_pred_fn(params, t, x) 162 163 #Derivatives in x 164 u1_x = self.u1_bound_x(params, t, x) 165 u2_x = self.u2_bound_x(params, t, x) 166 167 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 168 nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5) 169 170 #Normal at u(x,y) 171 nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 172 173 #Return inner product 174 return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2 175 176 #Compute residuals with causal weights 177 @partial(jit, static_argnums=(0,)) 178 def res_causal(self, params, batch): 179 # Sort temporal coordinates 180 t_sorted = batch[:, 0].sort() 181 182 #Compute residuals 183 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 184 185 #Reshape 186 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 187 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 188 189 #Compute mean residuals 190 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 191 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 192 193 #Compute weights 194 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 195 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 196 197 # Take minimum of the causal weights 198 gamma = jnp.vstack([res_gamma1,res_gamma2]) 199 gamma = gamma.min(0) 200 201 return res_l1, res_l2, gamma 202 203 #Compute losses 204 @partial(jit, static_argnums=(0,)) 205 def losses(self, params, batch): 206 # Initial conditions loss 207 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 208 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 209 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 210 211 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 212 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 213 214 # Residual loss 215 if self.config.weighting.use_causal == True: 216 res_l1, res_l2, gamma = self.res_causal(params, batch) 217 res_loss1 = jnp.mean(res_l1 * gamma) 218 res_loss2 = jnp.mean(res_l2 * gamma) 219 else: 220 res_pred1,res_pred2 = self.r_pred_fn( 221 params, batch[:, 0], batch[:, 1] 222 ) 223 # Compute loss 224 res_loss1 = jnp.mean(res_pred1 ** 2) 225 res_loss2 = jnp.mean(res_pred2 ** 2) 226 227 loss_dict = { 228 "ic": u1_ic_loss + u2_ic_loss, 229 "res1": res_loss1, 230 "res2": res_loss2, 231 'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 232 'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 233 'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)) 234 } 235 return loss_dict 236 237 #Compute NTK 238 @partial(jit, static_argnums = (0,)) 239 def compute_diag_ntk(self, params, batch): 240 #Initial Condition 241 u1_ic_ntk = vmap( 242 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 243 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 244 245 u2_ic_ntk = vmap( 246 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 247 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 248 249 #Right Dirichlet 250 rd_ntk = vmap( 251 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 252 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 253 254 #Left Dirichlet 255 ld_ntk = vmap( 256 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 257 )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 258 259 #Left neumann 260 ln_ntk = vmap( 261 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 262 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 263 264 # Consider the effect of causal weights 265 if self.config.weighting.use_causal: 266 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 267 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 268 self.r_net1, params, batch[:, 0], batch[:, 1] 269 ) 270 271 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 272 self.r_net2, params, batch[:, 0], batch[:, 1] 273 ) 274 275 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 276 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 277 278 res_ntk1 = jnp.mean(res_ntk1, axis=1) 279 res_ntk2 = jnp.mean(res_ntk2, axis=1) 280 281 _,_, casual_weights = self.res_causal(params, batch) 282 res_ntk1 = res_ntk1 * casual_weights 283 res_ntk2 = res_ntk2 * casual_weights 284 else: 285 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 286 self.r_net1, params, batch[:, 0], batch[:, 1] 287 ) 288 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 289 self.r_net2, params, batch[:, 0], batch[:, 1] 290 ) 291 292 ntk_dict = { 293 "ic": u1_ic_ntk + u2_ic_ntk, 294 "res1": res_ntk1, 295 "res2": res_ntk2, 296 'rd': rd_ntk, 297 'ld': ld_ntk, 298 'ln': ln_ntk 299 } 300 return ntk_dict 301 302class DN_csf_Evaluator(BaseEvaluator): 303 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 304 super().__init__(config, model) 305 306 self.x0_test = x0_test 307 self.tb_test = tb_test 308 self.xc_test = xc_test 309 self.tc_test = tc_test 310 self.u2_0_test = u2_0_test 311 self.u1_0_test = u1_0_test 312 313 def log_errors(self, params): 314 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 315 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 316 317 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 318 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 319 320 res_pred1,res_pred2 = self.model.r_pred_fn( 321 params, self.tc_test[:,0], self.xc_test[:,0] 322 ) 323 res_loss1 = jnp.mean(res_pred1 ** 2) 324 res_loss2 = jnp.mean(res_pred2 ** 2) 325 326 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 327 self.log_dict["res1_test"] = res_loss1 328 self.log_dict["res2_test"] = res_loss2 329 self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu)) 330 self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl)) 331 self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl)) 332 333 def __call__(self, state, batch): 334 self.log_dict = super().__call__(state, batch) 335 336 if self.config.logging.log_errors: 337 self.log_errors(state.params) 338 339 if self.config.weighting.use_causal: 340 _, _, causal_weight = self.model.res_causal(state.params, batch) 341 self.log_dict["cas_weight"] = causal_weight.min() 342 343 return self.log_dict 344 345class NN_csf(ForwardIVP): 346 def __init__(self, config): 347 super().__init__(config) 348 349 #Initial condition function 350 self.uinitial = config.uinitial 351 352 #Boundary points 353 self.xl = config.xl 354 self.xu = config.xu 355 self.tu = config.tu 356 357 #Radius left dirichlet condition 358 self.radius = config.radius 359 360 # Predictions over array of x fot t fixed 361 self.u1_0_pred_fn = vmap( 362 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 363 ) 364 self.u2_0_pred_fn = vmap( 365 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 366 ) 367 368 #Prediction over array of t for x fixed 369 self.u2_bound_pred_fn = vmap( 370 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 371 ) 372 373 self.u1_bound_pred_fn = vmap( 374 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 375 ) 376 377 #Vmap neural net 378 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 379 380 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 381 382 #Vmap residual operator 383 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 384 385 #Derivatives on x for x fixed and t in a array 386 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 387 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None)) 388 389 #Neural net forward function 390 def neural_net(self, params, t, x): 391 t = t / self.tu 392 z = jnp.stack([t, x]) 393 _, outputs = self.state.apply_fn(params, z) 394 u1 = outputs[0] 395 u2 = outputs[1] 396 return u1, u2 397 398 #1st coordinate neural net forward function 399 def u1_net(self, params, t, x): 400 u1, _ = self.neural_net(params, t, x) 401 return u1 402 403 #2st coordinate neural net forward function 404 def u2_net(self, params, t, x): 405 _, u2 = self.neural_net(params, t, x) 406 return u2 407 408 #Residual operator 409 def r_net(self, params, t, x): 410 #Derivatives in x and t 411 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 412 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 413 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 414 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 415 416 #Two derivatives in x 417 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 418 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 419 420 #Each coordinate of residual operator 421 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2 422 423 #1st coordinate residual operator 424 def r_net1(self, params, t, x): 425 r1,_ = self.r_net(params, t, x) 426 return r1 427 428 #2nd coordinate residual operator 429 def r_net2(self, params, t, x): 430 _,r2 = self.r_net(params, t, x) 431 return r2 432 433 #Right Dirichlet Condition residual 434 def res_dirichlet(self, params, t, x): 435 #Apply neural net to point in the boundary 436 u1 = self.u1_bound_pred_fn(params, t, x) 437 u2 = self.u2_bound_pred_fn(params, t, x) 438 #Return residual 439 return (jnp.sqrt(u1 ** 2 + u2 ** 2) - self.radius) ** 2 440 441 #Right Dirichlet Condition residual non-vectorised 442 def res_dirichlet_nv(self, params, t, x): 443 #Apply neural net to point in the boundary 444 u1 = self.u1_net(params, t, x) 445 u2 = self.u2_net(params, t, x) 446 #Return residual 447 return (jnp.sqrt(u1 ** 2 + u2 ** 2) - self.radius) ** 2 448 449 #Neumman condition residual non-vectorised 450 def res_neumann_nv(self, params, t, x): 451 #Apply neural net 452 u1 = self.u1_net(params, t, x) 453 u2 = self.u2_net(params, t, x) 454 455 #Derivatives in x 456 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 457 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 458 459 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 460 nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5) 461 462 #Normal at u(x,y) 463 nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 464 465 #Return inner product 466 return jnp.sum(nS * nu) ** 2 467 468 #Neumman condition residual 469 def res_neumann(self, params, t, x): 470 #Apply neural net to points in the boundary 471 u1 = self.u1_bound_pred_fn(params, t, x) 472 u2 = self.u2_bound_pred_fn(params, t, x) 473 474 #Derivatives in x 475 u1_x = self.u1_bound_x(params, t, x) 476 u2_x = self.u2_bound_x(params, t, x) 477 478 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 479 nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5) 480 481 #Normal at u(x,y) 482 nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 483 484 #Return inner product 485 return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2 486 487 #Compute residuals with causal weights 488 @partial(jit, static_argnums=(0,)) 489 def res_causal(self, params, batch): 490 # Sort temporal coordinates 491 t_sorted = batch[:, 0].sort() 492 493 #Compute residuals 494 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 495 496 #Reshape 497 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 498 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 499 500 #Compute mean residuals 501 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 502 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 503 504 #Compute weights 505 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 506 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 507 508 # Take minimum of the causal weights 509 gamma = jnp.vstack([res_gamma1,res_gamma2]) 510 gamma = gamma.min(0) 511 512 return res_l1, res_l2, gamma 513 514 #Compute losses 515 @partial(jit, static_argnums=(0,)) 516 def losses(self, params, batch): 517 # Initial conditions loss 518 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 519 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 520 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 521 522 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 523 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 524 525 # Residual loss 526 if self.config.weighting.use_causal == True: 527 res_l1, res_l2, gamma = self.res_causal(params, batch) 528 res_loss1 = jnp.mean(res_l1 * gamma) 529 res_loss2 = jnp.mean(res_l2 * gamma) 530 else: 531 res_pred1,res_pred2 = self.r_pred_fn( 532 params, batch[:, 0], batch[:, 1] 533 ) 534 # Compute loss 535 res_loss1 = jnp.mean(res_pred1 ** 2) 536 res_loss2 = jnp.mean(res_pred2 ** 2) 537 538 loss_dict = { 539 "ic": u1_ic_loss + u2_ic_loss, 540 "res1": res_loss1, 541 "res2": res_loss2, 542 'rd': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 543 'ld': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 544 'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 545 'rn': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)) 546 } 547 return loss_dict 548 549 #Compute NTK 550 @partial(jit, static_argnums = (0,)) 551 def compute_diag_ntk(self, params, batch): 552 #Initial Condition 553 u1_ic_ntk = vmap( 554 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 555 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 556 557 u2_ic_ntk = vmap( 558 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 559 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 560 561 #Right Dirichlet 562 rd_ntk = vmap( 563 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 564 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 565 566 #Dirichlet 567 ld_ntk = vmap( 568 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 569 )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 570 rd_ntk = vmap( 571 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 572 )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 573 574 #Left neumann 575 ln_ntk = vmap( 576 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 577 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 578 rn_ntk = vmap( 579 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 580 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 581 582 # Consider the effect of causal weights 583 if self.config.weighting.use_causal: 584 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 585 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 586 self.r_net1, params, batch[:, 0], batch[:, 1] 587 ) 588 589 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 590 self.r_net2, params, batch[:, 0], batch[:, 1] 591 ) 592 593 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 594 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 595 596 res_ntk1 = jnp.mean(res_ntk1, axis=1) 597 res_ntk2 = jnp.mean(res_ntk2, axis=1) 598 599 _,_, casual_weights = self.res_causal(params, batch) 600 res_ntk1 = res_ntk1 * casual_weights 601 res_ntk2 = res_ntk2 * casual_weights 602 else: 603 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 604 self.r_net1, params, batch[:, 0], batch[:, 1] 605 ) 606 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 607 self.r_net2, params, batch[:, 0], batch[:, 1] 608 ) 609 610 ntk_dict = { 611 "ic": u1_ic_ntk + u2_ic_ntk, 612 "res1": res_ntk1, 613 "res2": res_ntk2, 614 'rd': rd_ntk, 615 'ld': ld_ntk, 616 'ln': ln_ntk, 617 'rn': rn_ntk 618 } 619 return ntk_dict 620 621class NN_csf_Evaluator(BaseEvaluator): 622 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 623 super().__init__(config, model) 624 625 self.x0_test = x0_test 626 self.tb_test = tb_test 627 self.xc_test = xc_test 628 self.tc_test = tc_test 629 self.u2_0_test = u2_0_test 630 self.u1_0_test = u1_0_test 631 632 def log_errors(self, params): 633 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 634 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 635 636 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 637 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 638 639 res_pred1,res_pred2 = self.model.r_pred_fn( 640 params, self.tc_test[:,0], self.xc_test[:,0] 641 ) 642 res_loss1 = jnp.mean(res_pred1 ** 2) 643 res_loss2 = jnp.mean(res_pred2 ** 2) 644 645 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 646 self.log_dict["res1_test"] = res_loss1 647 self.log_dict["res2_test"] = res_loss2 648 self.log_dict["rd_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xu)) 649 self.log_dict["ld_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xl)) 650 self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl)) 651 self.log_dict["rn_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xu)) 652 653 def __call__(self, state, batch): 654 self.log_dict = super().__call__(state, batch) 655 656 if self.config.logging.log_errors: 657 self.log_errors(state.params) 658 659 if self.config.weighting.use_causal: 660 _, _, causal_weight = self.model.res_causal(state.params, batch) 661 self.log_dict["cas_weight"] = causal_weight.min() 662 663 return self.log_dict 664 665class DD_csf(ForwardIVP): 666 def __init__(self, config): 667 super().__init__(config) 668 669 #Initial condition function 670 self.uinitial = config.uinitial 671 672 #Boundary points 673 self.xl = config.xl 674 self.xu = config.xu 675 self.tu = config.tu 676 677 #Right dirichlet point 678 self.ld = config.ld 679 self.rd = config.rd 680 681 # Predictions over array of x fot t fixed 682 self.u1_0_pred_fn = vmap( 683 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 684 ) 685 self.u2_0_pred_fn = vmap( 686 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 687 ) 688 689 #Prediction over array of t for x fixed 690 self.u2_bound_pred_fn = vmap( 691 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 692 ) 693 694 self.u1_bound_pred_fn = vmap( 695 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 696 ) 697 698 #Vmap neural net 699 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 700 701 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 702 703 #Vmap residual operator 704 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 705 706 #Derivatives on x for x fixed and t in a array 707 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 708 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None)) 709 710 #Neural net forward function 711 def neural_net(self, params, t, x): 712 t = t / self.tu 713 z = jnp.stack([t, x]) 714 _, outputs = self.state.apply_fn(params, z) 715 u1 = outputs[0] 716 u2 = outputs[1] 717 return u1, u2 718 719 #1st coordinate neural net forward function 720 def u1_net(self, params, t, x): 721 u1, _ = self.neural_net(params, t, x) 722 return u1 723 724 #2st coordinate neural net forward function 725 def u2_net(self, params, t, x): 726 _, u2 = self.neural_net(params, t, x) 727 return u2 728 729 #Residual operator 730 def r_net(self, params, t, x): 731 #Derivatives in x and t 732 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 733 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 734 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 735 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 736 737 #Two derivatives in x 738 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 739 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 740 741 #Each coordinate of residual operator 742 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2 743 744 #1st coordinate residual operator 745 def r_net1(self, params, t, x): 746 r1,_ = self.r_net(params, t, x) 747 return r1 748 749 #2nd coordinate residual operator 750 def r_net2(self, params, t, x): 751 _,r2 = self.r_net(params, t, x) 752 return r2 753 754 #Right Dirichlet Condition residual 755 def res_right_dirichlet(self, params, t, x): 756 #Apply neural net to point in the boundary 757 u1 = self.u1_bound_pred_fn(params, t, x) 758 u2 = self.u2_bound_pred_fn(params, t, x) 759 #Return residual 760 return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2 761 762 #Right Dirichlet Condition residual non-vectorised 763 def res_right_dirichlet_nv(self, params, t, x): 764 #Apply neural net to point in the boundary 765 u1 = self.u1_net(params, t, x) 766 u2 = self.u2_net(params, t, x) 767 #Return residual 768 return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2 769 770 #Left Dirichlet Condition residual 771 def res_left_dirichlet(self, params, t, x): 772 #Apply neural net to point in the boundary 773 u1 = self.u1_bound_pred_fn(params, t, x) 774 u2 = self.u2_bound_pred_fn(params, t, x) 775 #Return residual 776 return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2 777 778 #Left Dirichlet Condition residual non-vectorised 779 def res_left_dirichlet_nv(self, params, t, x): 780 #Apply neural net to point in the boundary 781 u1 = self.u1_net(params, t, x) 782 u2 = self.u2_net(params, t, x) 783 #Return residual 784 return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2 785 786 #Compute residuals with causal weights 787 @partial(jit, static_argnums=(0,)) 788 def res_causal(self, params, batch): 789 # Sort temporal coordinates 790 t_sorted = batch[:, 0].sort() 791 792 #Compute residuals 793 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 794 795 #Reshape 796 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 797 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 798 799 #Compute mean residuals 800 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 801 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 802 803 #Compute weights 804 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 805 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 806 807 # Take minimum of the causal weights 808 gamma = jnp.vstack([res_gamma1,res_gamma2]) 809 gamma = gamma.min(0) 810 811 return res_l1, res_l2, gamma 812 813 #Compute losses 814 @partial(jit, static_argnums=(0,)) 815 def losses(self, params, batch): 816 # Initial conditions loss 817 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 818 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 819 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 820 821 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 822 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 823 824 # Residual loss 825 if self.config.weighting.use_causal == True: 826 res_l1, res_l2, gamma = self.res_causal(params, batch) 827 res_loss1 = jnp.mean(res_l1 * gamma) 828 res_loss2 = jnp.mean(res_l2 * gamma) 829 else: 830 res_pred1,res_pred2 = self.r_pred_fn( 831 params, batch[:, 0], batch[:, 1] 832 ) 833 # Compute loss 834 res_loss1 = jnp.mean(res_pred1 ** 2) 835 res_loss2 = jnp.mean(res_pred2 ** 2) 836 837 loss_dict = { 838 "ic": u1_ic_loss + u2_ic_loss, 839 "res1": res_loss1, 840 "res2": res_loss2, 841 'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 842 'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)) 843 } 844 return loss_dict 845 846 #Compute NTK 847 @partial(jit, static_argnums = (0,)) 848 def compute_diag_ntk(self, params, batch): 849 #Initial Condition 850 u1_ic_ntk = vmap( 851 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 852 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 853 854 u2_ic_ntk = vmap( 855 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 856 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 857 858 #Right Dirichlet 859 rd_ntk = vmap( 860 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 861 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 862 863 #Left Dirichlet 864 ld_ntk = vmap( 865 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 866 )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 867 868 # Consider the effect of causal weights 869 if self.config.weighting.use_causal: 870 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 871 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 872 self.r_net1, params, batch[:, 0], batch[:, 1] 873 ) 874 875 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 876 self.r_net2, params, batch[:, 0], batch[:, 1] 877 ) 878 879 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 880 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 881 882 res_ntk1 = jnp.mean(res_ntk1, axis=1) 883 res_ntk2 = jnp.mean(res_ntk2, axis=1) 884 885 _,_, casual_weights = self.res_causal(params, batch) 886 res_ntk1 = res_ntk1 * casual_weights 887 res_ntk2 = res_ntk2 * casual_weights 888 else: 889 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 890 self.r_net1, params, batch[:, 0], batch[:, 1] 891 ) 892 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 893 self.r_net2, params, batch[:, 0], batch[:, 1] 894 ) 895 896 ntk_dict = { 897 "ic": u1_ic_ntk + u2_ic_ntk, 898 "res1": res_ntk1, 899 "res2": res_ntk2, 900 'rd': rd_ntk, 901 'ld': ld_ntk 902 } 903 return ntk_dict 904 905class DD_csf_Evaluator(BaseEvaluator): 906 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 907 super().__init__(config, model) 908 909 self.x0_test = x0_test 910 self.tb_test = tb_test 911 self.xc_test = xc_test 912 self.tc_test = tc_test 913 self.u2_0_test = u2_0_test 914 self.u1_0_test = u1_0_test 915 916 def log_errors(self, params): 917 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 918 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 919 920 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 921 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 922 923 res_pred1,res_pred2 = self.model.r_pred_fn( 924 params, self.tc_test[:,0], self.xc_test[:,0] 925 ) 926 res_loss1 = jnp.mean(res_pred1 ** 2) 927 res_loss2 = jnp.mean(res_pred2 ** 2) 928 929 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 930 self.log_dict["res1_test"] = res_loss1 931 self.log_dict["res2_test"] = res_loss2 932 self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu)) 933 self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl)) 934 935 def __call__(self, state, batch): 936 self.log_dict = super().__call__(state, batch) 937 938 if self.config.logging.log_errors: 939 self.log_errors(state.params) 940 941 if self.config.weighting.use_causal: 942 _, _, causal_weight = self.model.res_causal(state.params, batch) 943 self.log_dict["cas_weight"] = causal_weight.min() 944 945 return self.log_dict 946 947class closed_csf(ForwardIVP): 948 def __init__(self, config): 949 super().__init__(config) 950 951 #Initial condition function 952 self.uinitial = config.uinitial 953 954 #Boundary points 955 self.xl = config.xl 956 self.xu = config.xu 957 self.tu = config.tu 958 959 # Predictions over array of x fot t fixed 960 self.u1_0_pred_fn = vmap( 961 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 962 ) 963 self.u2_0_pred_fn = vmap( 964 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 965 ) 966 967 #Prediction over array of t for x fixed 968 self.u2_bound_pred_fn = vmap( 969 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 970 ) 971 972 self.u1_bound_pred_fn = vmap( 973 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 974 ) 975 976 #Vmap neural net 977 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 978 979 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 980 981 #Vmap residual operator 982 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 983 984 #Derivatives on x for x fixed and t in a array 985 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 986 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None)) 987 988 #Neural net forward function 989 def neural_net(self, params, t, x): 990 t = t / self.tu 991 z = jnp.stack([t, x]) 992 _, outputs = self.state.apply_fn(params, z) 993 u1 = outputs[0] 994 u2 = outputs[1] 995 return u1, u2 996 997 #1st coordinate neural net forward function 998 def u1_net(self, params, t, x): 999 u1, _ = self.neural_net(params, t, x) 1000 return u1 1001 1002 #2st coordinate neural net forward function 1003 def u2_net(self, params, t, x): 1004 _, u2 = self.neural_net(params, t, x) 1005 return u2 1006 1007 #Residual operator 1008 def r_net(self, params, t, x): 1009 #Derivatives in x and t 1010 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 1011 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 1012 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 1013 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 1014 1015 #Two derivatives in x 1016 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 1017 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 1018 1019 #Each coordinate of residual operator 1020 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2 1021 1022 #1st coordinate residual operator 1023 def r_net1(self, params, t, x): 1024 r1,_ = self.r_net(params, t, x) 1025 return r1 1026 1027 #2nd coordinate residual operator 1028 def r_net2(self, params, t, x): 1029 _,r2 = self.r_net(params, t, x) 1030 return r2 1031 1032 #Compute residuals with causal weights 1033 @partial(jit, static_argnums=(0,)) 1034 def res_causal(self, params, batch): 1035 # Sort temporal coordinates 1036 t_sorted = batch[:, 0].sort() 1037 1038 #Compute residuals 1039 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 1040 1041 #Reshape 1042 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 1043 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 1044 1045 #Compute mean residuals 1046 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 1047 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 1048 1049 #Compute weights 1050 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 1051 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 1052 1053 # Take minimum of the causal weights 1054 gamma = jnp.vstack([res_gamma1,res_gamma2]) 1055 gamma = gamma.min(0) 1056 1057 return res_l1, res_l2, gamma 1058 1059 #Compute losses 1060 @partial(jit, static_argnums=(0,)) 1061 def losses(self, params, batch): 1062 # Initial conditions loss 1063 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1064 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1065 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 1066 1067 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 1068 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 1069 1070 periodic1 = jnp.mean((self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2) 1071 periodic2 = jnp.mean((self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2) 1072 1073 # Residual loss 1074 if self.config.weighting.use_causal == True: 1075 res_l1, res_l2, gamma = self.res_causal(params, batch) 1076 res_loss1 = jnp.mean(res_l1 * gamma) 1077 res_loss2 = jnp.mean(res_l2 * gamma) 1078 else: 1079 res_pred1,res_pred2 = self.r_pred_fn( 1080 params, batch[:, 0], batch[:, 1] 1081 ) 1082 # Compute loss 1083 res_loss1 = jnp.mean(res_pred1 ** 2) 1084 res_loss2 = jnp.mean(res_pred2 ** 2) 1085 1086 loss_dict = { 1087 "ic": u1_ic_loss + u2_ic_loss, 1088 "res1": res_loss1, 1089 "res2": res_loss2, 1090 'periodic1': periodic1, 1091 'periodic2': periodic2 1092 } 1093 return loss_dict 1094 1095 #Compute NTK 1096 @partial(jit, static_argnums = (0,)) 1097 def compute_diag_ntk(self, params, batch): 1098 #Initial Condition 1099 u1_ic_ntk = vmap( 1100 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 1101 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1102 1103 u2_ic_ntk = vmap( 1104 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 1105 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1106 1107 # Consider the effect of causal weights 1108 if self.config.weighting.use_causal: 1109 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 1110 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 1111 self.r_net1, params, batch[:, 0], batch[:, 1] 1112 ) 1113 1114 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 1115 self.r_net2, params, batch[:, 0], batch[:, 1] 1116 ) 1117 1118 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 1119 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 1120 1121 res_ntk1 = jnp.mean(res_ntk1, axis=1) 1122 res_ntk2 = jnp.mean(res_ntk2, axis=1) 1123 1124 _,_, casual_weights = self.res_causal(params, batch) 1125 res_ntk1 = res_ntk1 * casual_weights 1126 res_ntk2 = res_ntk2 * casual_weights 1127 else: 1128 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 1129 self.r_net1, params, batch[:, 0], batch[:, 1] 1130 ) 1131 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 1132 self.r_net2, params, batch[:, 0], batch[:, 1] 1133 ) 1134 1135 ntk_dict = { 1136 "ic": u1_ic_ntk + u2_ic_ntk, 1137 "res1": res_ntk1, 1138 "res2": res_ntk2 1139 } 1140 return ntk_dict 1141 1142class closed_csf_Evaluator(BaseEvaluator): 1143 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 1144 super().__init__(config, model) 1145 1146 self.x0_test = x0_test 1147 self.tb_test = tb_test 1148 self.xc_test = xc_test 1149 self.tc_test = tc_test 1150 self.u2_0_test = u2_0_test 1151 self.u1_0_test = u1_0_test 1152 1153 def log_errors(self, params): 1154 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 1155 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 1156 1157 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 1158 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 1159 1160 res_pred1,res_pred2 = self.model.r_pred_fn( 1161 params, self.tc_test[:,0], self.xc_test[:,0] 1162 ) 1163 res_loss1 = jnp.mean(res_pred1 ** 2) 1164 res_loss2 = jnp.mean(res_pred2 ** 2) 1165 1166 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 1167 self.log_dict["res1_test"] = res_loss1 1168 self.log_dict["res2_test"] = res_loss2 1169 self.log_dict["periodic1_test"] = jnp.mean((self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2) 1170 self.log_dict["periodic2_test"] = jnp.mean((self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2) 1171 1172 def __call__(self, state, batch): 1173 self.log_dict = super().__call__(state, batch) 1174 1175 if self.config.logging.log_errors: 1176 self.log_errors(state.params) 1177 1178 if self.config.weighting.use_causal: 1179 _, _, causal_weight = self.model.res_causal(state.params, batch) 1180 self.log_dict["cas_weight"] = causal_weight.min() 1181 1182 return self.log_dict
class
DN_csf(jaxpi.models.ForwardIVP):
16class DN_csf(ForwardIVP): 17 def __init__(self, config): 18 super().__init__(config) 19 20 #Initial condition function 21 self.uinitial = config.uinitial 22 23 #Boundary points 24 self.xl = config.xl 25 self.xu = config.xu 26 self.tu = config.tu 27 28 #Radius left dirichlet condition 29 self.radius = config.radius 30 31 #Right dirichlet point 32 self.rd = config.rd 33 34 # Predictions over array of x fot t fixed 35 self.u1_0_pred_fn = vmap( 36 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 37 ) 38 self.u2_0_pred_fn = vmap( 39 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 40 ) 41 42 #Prediction over array of t for x fixed 43 self.u2_bound_pred_fn = vmap( 44 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 45 ) 46 47 self.u1_bound_pred_fn = vmap( 48 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 49 ) 50 51 #Vmap neural net 52 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 53 54 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 55 56 #Vmap residual operator 57 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 58 59 #Derivatives on x for x fixed and t in a array 60 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 61 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None)) 62 63 #Neural net forward function 64 def neural_net(self, params, t, x): 65 t = t / self.tu 66 z = jnp.stack([t, x]) 67 _, outputs = self.state.apply_fn(params, z) 68 u1 = outputs[0] 69 u2 = outputs[1] 70 return u1, u2 71 72 #1st coordinate neural net forward function 73 def u1_net(self, params, t, x): 74 u1, _ = self.neural_net(params, t, x) 75 return u1 76 77 #2st coordinate neural net forward function 78 def u2_net(self, params, t, x): 79 _, u2 = self.neural_net(params, t, x) 80 return u2 81 82 #Residual operator 83 def r_net(self, params, t, x): 84 #Derivatives in x and t 85 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 86 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 87 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 88 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 89 90 #Two derivatives in x 91 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 92 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 93 94 #Each coordinate of residual operator 95 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2 96 97 #1st coordinate residual operator 98 def r_net1(self, params, t, x): 99 r1,_ = self.r_net(params, t, x) 100 return r1 101 102 #2nd coordinate residual operator 103 def r_net2(self, params, t, x): 104 _,r2 = self.r_net(params, t, x) 105 return r2 106 107 #Right Dirichlet Condition residual 108 def res_right_dirichlet(self, params, t, x): 109 #Apply neural net to point in the boundary 110 u1 = self.u1_bound_pred_fn(params, t, x) 111 u2 = self.u2_bound_pred_fn(params, t, x) 112 #Return residual 113 return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2 114 115 #Right Dirichlet Condition residual non-vectorised 116 def res_right_dirichlet_nv(self, params, t, x): 117 #Apply neural net to point in the boundary 118 u1 = self.u1_net(params, t, x) 119 u2 = self.u2_net(params, t, x) 120 #Return residual 121 return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2 122 123 #Left Dirichlet Condition residual 124 def res_left_dirichlet(self, params, t, x): 125 #Apply neural net to point in the boundary 126 u1 = self.u1_bound_pred_fn(params, t, x) 127 u2 = self.u2_bound_pred_fn(params, t, x) 128 #Return residual 129 return (jnp.sqrt(u1 ** 2 + u2 ** 2) - self.radius) ** 2 130 131 #Left Dirichlet Condition residual non-vectorised 132 def res_left_dirichlet_nv(self, params, t, x): 133 #Apply neural net to point in the boundary 134 u1 = self.u1_net(params, t, x) 135 u2 = self.u2_net(params, t, x) 136 #Return residual 137 return (jnp.sqrt(u1 ** 2 + u2 ** 2) - self.radius) ** 2 138 139 #Neumman condition residual non-vectorised 140 def res_neumann_nv(self, params, t, x): 141 #Apply neural net 142 u1 = self.u1_net(params, t, x) 143 u2 = self.u2_net(params, t, x) 144 145 #Derivatives in x 146 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 147 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 148 149 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 150 nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5) 151 152 #Normal at u(x,y) 153 nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 154 155 #Return inner product 156 return jnp.sum(nS * nu) ** 2 157 158 #Neumman condition residual 159 def res_neumann(self, params, t, x): 160 #Apply neural net to points in the boundary 161 u1 = self.u1_bound_pred_fn(params, t, x) 162 u2 = self.u2_bound_pred_fn(params, t, x) 163 164 #Derivatives in x 165 u1_x = self.u1_bound_x(params, t, x) 166 u2_x = self.u2_bound_x(params, t, x) 167 168 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 169 nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5) 170 171 #Normal at u(x,y) 172 nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 173 174 #Return inner product 175 return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2 176 177 #Compute residuals with causal weights 178 @partial(jit, static_argnums=(0,)) 179 def res_causal(self, params, batch): 180 # Sort temporal coordinates 181 t_sorted = batch[:, 0].sort() 182 183 #Compute residuals 184 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 185 186 #Reshape 187 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 188 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 189 190 #Compute mean residuals 191 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 192 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 193 194 #Compute weights 195 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 196 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 197 198 # Take minimum of the causal weights 199 gamma = jnp.vstack([res_gamma1,res_gamma2]) 200 gamma = gamma.min(0) 201 202 return res_l1, res_l2, gamma 203 204 #Compute losses 205 @partial(jit, static_argnums=(0,)) 206 def losses(self, params, batch): 207 # Initial conditions loss 208 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 209 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 210 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 211 212 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 213 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 214 215 # Residual loss 216 if self.config.weighting.use_causal == True: 217 res_l1, res_l2, gamma = self.res_causal(params, batch) 218 res_loss1 = jnp.mean(res_l1 * gamma) 219 res_loss2 = jnp.mean(res_l2 * gamma) 220 else: 221 res_pred1,res_pred2 = self.r_pred_fn( 222 params, batch[:, 0], batch[:, 1] 223 ) 224 # Compute loss 225 res_loss1 = jnp.mean(res_pred1 ** 2) 226 res_loss2 = jnp.mean(res_pred2 ** 2) 227 228 loss_dict = { 229 "ic": u1_ic_loss + u2_ic_loss, 230 "res1": res_loss1, 231 "res2": res_loss2, 232 'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 233 'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 234 'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)) 235 } 236 return loss_dict 237 238 #Compute NTK 239 @partial(jit, static_argnums = (0,)) 240 def compute_diag_ntk(self, params, batch): 241 #Initial Condition 242 u1_ic_ntk = vmap( 243 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 244 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 245 246 u2_ic_ntk = vmap( 247 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 248 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 249 250 #Right Dirichlet 251 rd_ntk = vmap( 252 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 253 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 254 255 #Left Dirichlet 256 ld_ntk = vmap( 257 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 258 )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 259 260 #Left neumann 261 ln_ntk = vmap( 262 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 263 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 264 265 # Consider the effect of causal weights 266 if self.config.weighting.use_causal: 267 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 268 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 269 self.r_net1, params, batch[:, 0], batch[:, 1] 270 ) 271 272 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 273 self.r_net2, params, batch[:, 0], batch[:, 1] 274 ) 275 276 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 277 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 278 279 res_ntk1 = jnp.mean(res_ntk1, axis=1) 280 res_ntk2 = jnp.mean(res_ntk2, axis=1) 281 282 _,_, casual_weights = self.res_causal(params, batch) 283 res_ntk1 = res_ntk1 * casual_weights 284 res_ntk2 = res_ntk2 * casual_weights 285 else: 286 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 287 self.r_net1, params, batch[:, 0], batch[:, 1] 288 ) 289 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 290 self.r_net2, params, batch[:, 0], batch[:, 1] 291 ) 292 293 ntk_dict = { 294 "ic": u1_ic_ntk + u2_ic_ntk, 295 "res1": res_ntk1, 296 "res2": res_ntk2, 297 'rd': rd_ntk, 298 'ld': ld_ntk, 299 'ln': ln_ntk 300 } 301 return ntk_dict
DN_csf(config)
17 def __init__(self, config): 18 super().__init__(config) 19 20 #Initial condition function 21 self.uinitial = config.uinitial 22 23 #Boundary points 24 self.xl = config.xl 25 self.xu = config.xu 26 self.tu = config.tu 27 28 #Radius left dirichlet condition 29 self.radius = config.radius 30 31 #Right dirichlet point 32 self.rd = config.rd 33 34 # Predictions over array of x fot t fixed 35 self.u1_0_pred_fn = vmap( 36 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 37 ) 38 self.u2_0_pred_fn = vmap( 39 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 40 ) 41 42 #Prediction over array of t for x fixed 43 self.u2_bound_pred_fn = vmap( 44 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 45 ) 46 47 self.u1_bound_pred_fn = vmap( 48 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 49 ) 50 51 #Vmap neural net 52 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 53 54 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 55 56 #Vmap residual operator 57 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 58 59 #Derivatives on x for x fixed and t in a array 60 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 61 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
def
r_net(self, params, t, x):
83 def r_net(self, params, t, x): 84 #Derivatives in x and t 85 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 86 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 87 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 88 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 89 90 #Two derivatives in x 91 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 92 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 93 94 #Each coordinate of residual operator 95 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
def
res_neumann_nv(self, params, t, x):
140 def res_neumann_nv(self, params, t, x): 141 #Apply neural net 142 u1 = self.u1_net(params, t, x) 143 u2 = self.u2_net(params, t, x) 144 145 #Derivatives in x 146 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 147 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 148 149 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 150 nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5) 151 152 #Normal at u(x,y) 153 nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 154 155 #Return inner product 156 return jnp.sum(nS * nu) ** 2
def
res_neumann(self, params, t, x):
159 def res_neumann(self, params, t, x): 160 #Apply neural net to points in the boundary 161 u1 = self.u1_bound_pred_fn(params, t, x) 162 u2 = self.u2_bound_pred_fn(params, t, x) 163 164 #Derivatives in x 165 u1_x = self.u1_bound_x(params, t, x) 166 u2_x = self.u2_bound_x(params, t, x) 167 168 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 169 nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5) 170 171 #Normal at u(x,y) 172 nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 173 174 #Return inner product 175 return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2
@partial(jit, static_argnums=(0,))
def
res_causal(self, params, batch):
178 @partial(jit, static_argnums=(0,)) 179 def res_causal(self, params, batch): 180 # Sort temporal coordinates 181 t_sorted = batch[:, 0].sort() 182 183 #Compute residuals 184 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 185 186 #Reshape 187 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 188 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 189 190 #Compute mean residuals 191 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 192 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 193 194 #Compute weights 195 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 196 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 197 198 # Take minimum of the causal weights 199 gamma = jnp.vstack([res_gamma1,res_gamma2]) 200 gamma = gamma.min(0) 201 202 return res_l1, res_l2, gamma
@partial(jit, static_argnums=(0,))
def
losses(self, params, batch):
205 @partial(jit, static_argnums=(0,)) 206 def losses(self, params, batch): 207 # Initial conditions loss 208 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 209 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 210 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 211 212 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 213 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 214 215 # Residual loss 216 if self.config.weighting.use_causal == True: 217 res_l1, res_l2, gamma = self.res_causal(params, batch) 218 res_loss1 = jnp.mean(res_l1 * gamma) 219 res_loss2 = jnp.mean(res_l2 * gamma) 220 else: 221 res_pred1,res_pred2 = self.r_pred_fn( 222 params, batch[:, 0], batch[:, 1] 223 ) 224 # Compute loss 225 res_loss1 = jnp.mean(res_pred1 ** 2) 226 res_loss2 = jnp.mean(res_pred2 ** 2) 227 228 loss_dict = { 229 "ic": u1_ic_loss + u2_ic_loss, 230 "res1": res_loss1, 231 "res2": res_loss2, 232 'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 233 'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 234 'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)) 235 } 236 return loss_dict
@partial(jit, static_argnums=(0,))
def
compute_diag_ntk(self, params, batch):
239 @partial(jit, static_argnums = (0,)) 240 def compute_diag_ntk(self, params, batch): 241 #Initial Condition 242 u1_ic_ntk = vmap( 243 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 244 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 245 246 u2_ic_ntk = vmap( 247 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 248 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 249 250 #Right Dirichlet 251 rd_ntk = vmap( 252 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 253 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 254 255 #Left Dirichlet 256 ld_ntk = vmap( 257 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 258 )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 259 260 #Left neumann 261 ln_ntk = vmap( 262 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 263 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 264 265 # Consider the effect of causal weights 266 if self.config.weighting.use_causal: 267 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 268 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 269 self.r_net1, params, batch[:, 0], batch[:, 1] 270 ) 271 272 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 273 self.r_net2, params, batch[:, 0], batch[:, 1] 274 ) 275 276 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 277 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 278 279 res_ntk1 = jnp.mean(res_ntk1, axis=1) 280 res_ntk2 = jnp.mean(res_ntk2, axis=1) 281 282 _,_, casual_weights = self.res_causal(params, batch) 283 res_ntk1 = res_ntk1 * casual_weights 284 res_ntk2 = res_ntk2 * casual_weights 285 else: 286 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 287 self.r_net1, params, batch[:, 0], batch[:, 1] 288 ) 289 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 290 self.r_net2, params, batch[:, 0], batch[:, 1] 291 ) 292 293 ntk_dict = { 294 "ic": u1_ic_ntk + u2_ic_ntk, 295 "res1": res_ntk1, 296 "res2": res_ntk2, 297 'rd': rd_ntk, 298 'ld': ld_ntk, 299 'ln': ln_ntk 300 } 301 return ntk_dict
class
DN_csf_Evaluator(jaxpi.evaluator.BaseEvaluator):
303class DN_csf_Evaluator(BaseEvaluator): 304 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 305 super().__init__(config, model) 306 307 self.x0_test = x0_test 308 self.tb_test = tb_test 309 self.xc_test = xc_test 310 self.tc_test = tc_test 311 self.u2_0_test = u2_0_test 312 self.u1_0_test = u1_0_test 313 314 def log_errors(self, params): 315 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 316 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 317 318 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 319 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 320 321 res_pred1,res_pred2 = self.model.r_pred_fn( 322 params, self.tc_test[:,0], self.xc_test[:,0] 323 ) 324 res_loss1 = jnp.mean(res_pred1 ** 2) 325 res_loss2 = jnp.mean(res_pred2 ** 2) 326 327 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 328 self.log_dict["res1_test"] = res_loss1 329 self.log_dict["res2_test"] = res_loss2 330 self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu)) 331 self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl)) 332 self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl)) 333 334 def __call__(self, state, batch): 335 self.log_dict = super().__call__(state, batch) 336 337 if self.config.logging.log_errors: 338 self.log_errors(state.params) 339 340 if self.config.weighting.use_causal: 341 _, _, causal_weight = self.model.res_causal(state.params, batch) 342 self.log_dict["cas_weight"] = causal_weight.min() 343 344 return self.log_dict
DN_csf_Evaluator( config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
304 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 305 super().__init__(config, model) 306 307 self.x0_test = x0_test 308 self.tb_test = tb_test 309 self.xc_test = xc_test 310 self.tc_test = tc_test 311 self.u2_0_test = u2_0_test 312 self.u1_0_test = u1_0_test
def
log_errors(self, params):
314 def log_errors(self, params): 315 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 316 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 317 318 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 319 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 320 321 res_pred1,res_pred2 = self.model.r_pred_fn( 322 params, self.tc_test[:,0], self.xc_test[:,0] 323 ) 324 res_loss1 = jnp.mean(res_pred1 ** 2) 325 res_loss2 = jnp.mean(res_pred2 ** 2) 326 327 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 328 self.log_dict["res1_test"] = res_loss1 329 self.log_dict["res2_test"] = res_loss2 330 self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu)) 331 self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl)) 332 self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl))
class
NN_csf(jaxpi.models.ForwardIVP):
346class NN_csf(ForwardIVP): 347 def __init__(self, config): 348 super().__init__(config) 349 350 #Initial condition function 351 self.uinitial = config.uinitial 352 353 #Boundary points 354 self.xl = config.xl 355 self.xu = config.xu 356 self.tu = config.tu 357 358 #Radius left dirichlet condition 359 self.radius = config.radius 360 361 # Predictions over array of x fot t fixed 362 self.u1_0_pred_fn = vmap( 363 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 364 ) 365 self.u2_0_pred_fn = vmap( 366 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 367 ) 368 369 #Prediction over array of t for x fixed 370 self.u2_bound_pred_fn = vmap( 371 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 372 ) 373 374 self.u1_bound_pred_fn = vmap( 375 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 376 ) 377 378 #Vmap neural net 379 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 380 381 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 382 383 #Vmap residual operator 384 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 385 386 #Derivatives on x for x fixed and t in a array 387 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 388 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None)) 389 390 #Neural net forward function 391 def neural_net(self, params, t, x): 392 t = t / self.tu 393 z = jnp.stack([t, x]) 394 _, outputs = self.state.apply_fn(params, z) 395 u1 = outputs[0] 396 u2 = outputs[1] 397 return u1, u2 398 399 #1st coordinate neural net forward function 400 def u1_net(self, params, t, x): 401 u1, _ = self.neural_net(params, t, x) 402 return u1 403 404 #2st coordinate neural net forward function 405 def u2_net(self, params, t, x): 406 _, u2 = self.neural_net(params, t, x) 407 return u2 408 409 #Residual operator 410 def r_net(self, params, t, x): 411 #Derivatives in x and t 412 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 413 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 414 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 415 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 416 417 #Two derivatives in x 418 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 419 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 420 421 #Each coordinate of residual operator 422 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2 423 424 #1st coordinate residual operator 425 def r_net1(self, params, t, x): 426 r1,_ = self.r_net(params, t, x) 427 return r1 428 429 #2nd coordinate residual operator 430 def r_net2(self, params, t, x): 431 _,r2 = self.r_net(params, t, x) 432 return r2 433 434 #Right Dirichlet Condition residual 435 def res_dirichlet(self, params, t, x): 436 #Apply neural net to point in the boundary 437 u1 = self.u1_bound_pred_fn(params, t, x) 438 u2 = self.u2_bound_pred_fn(params, t, x) 439 #Return residual 440 return (jnp.sqrt(u1 ** 2 + u2 ** 2) - self.radius) ** 2 441 442 #Right Dirichlet Condition residual non-vectorised 443 def res_dirichlet_nv(self, params, t, x): 444 #Apply neural net to point in the boundary 445 u1 = self.u1_net(params, t, x) 446 u2 = self.u2_net(params, t, x) 447 #Return residual 448 return (jnp.sqrt(u1 ** 2 + u2 ** 2) - self.radius) ** 2 449 450 #Neumman condition residual non-vectorised 451 def res_neumann_nv(self, params, t, x): 452 #Apply neural net 453 u1 = self.u1_net(params, t, x) 454 u2 = self.u2_net(params, t, x) 455 456 #Derivatives in x 457 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 458 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 459 460 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 461 nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5) 462 463 #Normal at u(x,y) 464 nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 465 466 #Return inner product 467 return jnp.sum(nS * nu) ** 2 468 469 #Neumman condition residual 470 def res_neumann(self, params, t, x): 471 #Apply neural net to points in the boundary 472 u1 = self.u1_bound_pred_fn(params, t, x) 473 u2 = self.u2_bound_pred_fn(params, t, x) 474 475 #Derivatives in x 476 u1_x = self.u1_bound_x(params, t, x) 477 u2_x = self.u2_bound_x(params, t, x) 478 479 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 480 nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5) 481 482 #Normal at u(x,y) 483 nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 484 485 #Return inner product 486 return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2 487 488 #Compute residuals with causal weights 489 @partial(jit, static_argnums=(0,)) 490 def res_causal(self, params, batch): 491 # Sort temporal coordinates 492 t_sorted = batch[:, 0].sort() 493 494 #Compute residuals 495 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 496 497 #Reshape 498 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 499 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 500 501 #Compute mean residuals 502 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 503 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 504 505 #Compute weights 506 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 507 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 508 509 # Take minimum of the causal weights 510 gamma = jnp.vstack([res_gamma1,res_gamma2]) 511 gamma = gamma.min(0) 512 513 return res_l1, res_l2, gamma 514 515 #Compute losses 516 @partial(jit, static_argnums=(0,)) 517 def losses(self, params, batch): 518 # Initial conditions loss 519 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 520 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 521 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 522 523 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 524 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 525 526 # Residual loss 527 if self.config.weighting.use_causal == True: 528 res_l1, res_l2, gamma = self.res_causal(params, batch) 529 res_loss1 = jnp.mean(res_l1 * gamma) 530 res_loss2 = jnp.mean(res_l2 * gamma) 531 else: 532 res_pred1,res_pred2 = self.r_pred_fn( 533 params, batch[:, 0], batch[:, 1] 534 ) 535 # Compute loss 536 res_loss1 = jnp.mean(res_pred1 ** 2) 537 res_loss2 = jnp.mean(res_pred2 ** 2) 538 539 loss_dict = { 540 "ic": u1_ic_loss + u2_ic_loss, 541 "res1": res_loss1, 542 "res2": res_loss2, 543 'rd': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 544 'ld': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 545 'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 546 'rn': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)) 547 } 548 return loss_dict 549 550 #Compute NTK 551 @partial(jit, static_argnums = (0,)) 552 def compute_diag_ntk(self, params, batch): 553 #Initial Condition 554 u1_ic_ntk = vmap( 555 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 556 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 557 558 u2_ic_ntk = vmap( 559 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 560 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 561 562 #Right Dirichlet 563 rd_ntk = vmap( 564 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 565 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 566 567 #Dirichlet 568 ld_ntk = vmap( 569 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 570 )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 571 rd_ntk = vmap( 572 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 573 )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 574 575 #Left neumann 576 ln_ntk = vmap( 577 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 578 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 579 rn_ntk = vmap( 580 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 581 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 582 583 # Consider the effect of causal weights 584 if self.config.weighting.use_causal: 585 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 586 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 587 self.r_net1, params, batch[:, 0], batch[:, 1] 588 ) 589 590 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 591 self.r_net2, params, batch[:, 0], batch[:, 1] 592 ) 593 594 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 595 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 596 597 res_ntk1 = jnp.mean(res_ntk1, axis=1) 598 res_ntk2 = jnp.mean(res_ntk2, axis=1) 599 600 _,_, casual_weights = self.res_causal(params, batch) 601 res_ntk1 = res_ntk1 * casual_weights 602 res_ntk2 = res_ntk2 * casual_weights 603 else: 604 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 605 self.r_net1, params, batch[:, 0], batch[:, 1] 606 ) 607 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 608 self.r_net2, params, batch[:, 0], batch[:, 1] 609 ) 610 611 ntk_dict = { 612 "ic": u1_ic_ntk + u2_ic_ntk, 613 "res1": res_ntk1, 614 "res2": res_ntk2, 615 'rd': rd_ntk, 616 'ld': ld_ntk, 617 'ln': ln_ntk, 618 'rn': rn_ntk 619 } 620 return ntk_dict
NN_csf(config)
347 def __init__(self, config): 348 super().__init__(config) 349 350 #Initial condition function 351 self.uinitial = config.uinitial 352 353 #Boundary points 354 self.xl = config.xl 355 self.xu = config.xu 356 self.tu = config.tu 357 358 #Radius left dirichlet condition 359 self.radius = config.radius 360 361 # Predictions over array of x fot t fixed 362 self.u1_0_pred_fn = vmap( 363 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 364 ) 365 self.u2_0_pred_fn = vmap( 366 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 367 ) 368 369 #Prediction over array of t for x fixed 370 self.u2_bound_pred_fn = vmap( 371 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 372 ) 373 374 self.u1_bound_pred_fn = vmap( 375 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 376 ) 377 378 #Vmap neural net 379 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 380 381 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 382 383 #Vmap residual operator 384 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 385 386 #Derivatives on x for x fixed and t in a array 387 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 388 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
def
r_net(self, params, t, x):
410 def r_net(self, params, t, x): 411 #Derivatives in x and t 412 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 413 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 414 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 415 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 416 417 #Two derivatives in x 418 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 419 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 420 421 #Each coordinate of residual operator 422 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
def
res_neumann_nv(self, params, t, x):
451 def res_neumann_nv(self, params, t, x): 452 #Apply neural net 453 u1 = self.u1_net(params, t, x) 454 u2 = self.u2_net(params, t, x) 455 456 #Derivatives in x 457 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 458 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 459 460 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 461 nS = jnp.append(u1,u2)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2) ** 2)) + 1e-5) 462 463 #Normal at u(x,y) 464 nu = jnp.append(u2_x,(-1)*u1_x)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 465 466 #Return inner product 467 return jnp.sum(nS * nu) ** 2
def
res_neumann(self, params, t, x):
470 def res_neumann(self, params, t, x): 471 #Apply neural net to points in the boundary 472 u1 = self.u1_bound_pred_fn(params, t, x) 473 u2 = self.u2_bound_pred_fn(params, t, x) 474 475 #Derivatives in x 476 u1_x = self.u1_bound_x(params, t, x) 477 u2_x = self.u2_bound_x(params, t, x) 478 479 #Assuming that u(x,t) \in S, compute the vector normal to S at u(x,t) 480 nS = jnp.append(u1,u2,1)/(jnp.sqrt(jnp.sum(jnp.append(u1,u2,1) ** 2,1)).reshape(u1.shape[0],1) + 1e-5) 481 482 #Normal at u(x,y) 483 nu = jnp.append(u2_x,(-1)*u1_x,1)/(jnp.sqrt(u1_x ** 2 + u2_x ** 2) + 1e-5) 484 485 #Return inner product 486 return jnp.sum(nS * nu,1).reshape(u1.shape[0],1) ** 2
@partial(jit, static_argnums=(0,))
def
res_causal(self, params, batch):
489 @partial(jit, static_argnums=(0,)) 490 def res_causal(self, params, batch): 491 # Sort temporal coordinates 492 t_sorted = batch[:, 0].sort() 493 494 #Compute residuals 495 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 496 497 #Reshape 498 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 499 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 500 501 #Compute mean residuals 502 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 503 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 504 505 #Compute weights 506 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 507 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 508 509 # Take minimum of the causal weights 510 gamma = jnp.vstack([res_gamma1,res_gamma2]) 511 gamma = gamma.min(0) 512 513 return res_l1, res_l2, gamma
@partial(jit, static_argnums=(0,))
def
losses(self, params, batch):
516 @partial(jit, static_argnums=(0,)) 517 def losses(self, params, batch): 518 # Initial conditions loss 519 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 520 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 521 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 522 523 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 524 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 525 526 # Residual loss 527 if self.config.weighting.use_causal == True: 528 res_l1, res_l2, gamma = self.res_causal(params, batch) 529 res_loss1 = jnp.mean(res_l1 * gamma) 530 res_loss2 = jnp.mean(res_l2 * gamma) 531 else: 532 res_pred1,res_pred2 = self.r_pred_fn( 533 params, batch[:, 0], batch[:, 1] 534 ) 535 # Compute loss 536 res_loss1 = jnp.mean(res_pred1 ** 2) 537 res_loss2 = jnp.mean(res_pred2 ** 2) 538 539 loss_dict = { 540 "ic": u1_ic_loss + u2_ic_loss, 541 "res1": res_loss1, 542 "res2": res_loss2, 543 'rd': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 544 'ld': jnp.mean(self.res_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 545 'ln': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)), 546 'rn': jnp.mean(self.res_neumann(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)) 547 } 548 return loss_dict
@partial(jit, static_argnums=(0,))
def
compute_diag_ntk(self, params, batch):
551 @partial(jit, static_argnums = (0,)) 552 def compute_diag_ntk(self, params, batch): 553 #Initial Condition 554 u1_ic_ntk = vmap( 555 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 556 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 557 558 u2_ic_ntk = vmap( 559 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 560 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 561 562 #Right Dirichlet 563 rd_ntk = vmap( 564 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 565 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 566 567 #Dirichlet 568 ld_ntk = vmap( 569 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 570 )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 571 rd_ntk = vmap( 572 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 573 )(self.res_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 574 575 #Left neumann 576 ln_ntk = vmap( 577 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 578 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 579 rn_ntk = vmap( 580 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 581 )(self.res_neumann_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 582 583 # Consider the effect of causal weights 584 if self.config.weighting.use_causal: 585 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 586 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 587 self.r_net1, params, batch[:, 0], batch[:, 1] 588 ) 589 590 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 591 self.r_net2, params, batch[:, 0], batch[:, 1] 592 ) 593 594 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 595 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 596 597 res_ntk1 = jnp.mean(res_ntk1, axis=1) 598 res_ntk2 = jnp.mean(res_ntk2, axis=1) 599 600 _,_, casual_weights = self.res_causal(params, batch) 601 res_ntk1 = res_ntk1 * casual_weights 602 res_ntk2 = res_ntk2 * casual_weights 603 else: 604 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 605 self.r_net1, params, batch[:, 0], batch[:, 1] 606 ) 607 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 608 self.r_net2, params, batch[:, 0], batch[:, 1] 609 ) 610 611 ntk_dict = { 612 "ic": u1_ic_ntk + u2_ic_ntk, 613 "res1": res_ntk1, 614 "res2": res_ntk2, 615 'rd': rd_ntk, 616 'ld': ld_ntk, 617 'ln': ln_ntk, 618 'rn': rn_ntk 619 } 620 return ntk_dict
class
NN_csf_Evaluator(jaxpi.evaluator.BaseEvaluator):
622class NN_csf_Evaluator(BaseEvaluator): 623 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 624 super().__init__(config, model) 625 626 self.x0_test = x0_test 627 self.tb_test = tb_test 628 self.xc_test = xc_test 629 self.tc_test = tc_test 630 self.u2_0_test = u2_0_test 631 self.u1_0_test = u1_0_test 632 633 def log_errors(self, params): 634 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 635 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 636 637 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 638 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 639 640 res_pred1,res_pred2 = self.model.r_pred_fn( 641 params, self.tc_test[:,0], self.xc_test[:,0] 642 ) 643 res_loss1 = jnp.mean(res_pred1 ** 2) 644 res_loss2 = jnp.mean(res_pred2 ** 2) 645 646 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 647 self.log_dict["res1_test"] = res_loss1 648 self.log_dict["res2_test"] = res_loss2 649 self.log_dict["rd_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xu)) 650 self.log_dict["ld_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xl)) 651 self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl)) 652 self.log_dict["rn_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xu)) 653 654 def __call__(self, state, batch): 655 self.log_dict = super().__call__(state, batch) 656 657 if self.config.logging.log_errors: 658 self.log_errors(state.params) 659 660 if self.config.weighting.use_causal: 661 _, _, causal_weight = self.model.res_causal(state.params, batch) 662 self.log_dict["cas_weight"] = causal_weight.min() 663 664 return self.log_dict
NN_csf_Evaluator( config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
623 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 624 super().__init__(config, model) 625 626 self.x0_test = x0_test 627 self.tb_test = tb_test 628 self.xc_test = xc_test 629 self.tc_test = tc_test 630 self.u2_0_test = u2_0_test 631 self.u1_0_test = u1_0_test
def
log_errors(self, params):
633 def log_errors(self, params): 634 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 635 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 636 637 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 638 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 639 640 res_pred1,res_pred2 = self.model.r_pred_fn( 641 params, self.tc_test[:,0], self.xc_test[:,0] 642 ) 643 res_loss1 = jnp.mean(res_pred1 ** 2) 644 res_loss2 = jnp.mean(res_pred2 ** 2) 645 646 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 647 self.log_dict["res1_test"] = res_loss1 648 self.log_dict["res2_test"] = res_loss2 649 self.log_dict["rd_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xu)) 650 self.log_dict["ld_test"] = jnp.mean(self.model.res_dirichlet(params, self.tb_test, self.model.xl)) 651 self.log_dict["ln_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xl)) 652 self.log_dict["rn_test"] = jnp.mean(self.model.res_neumann(params, self.tb_test, self.model.xu))
class
DD_csf(jaxpi.models.ForwardIVP):
666class DD_csf(ForwardIVP): 667 def __init__(self, config): 668 super().__init__(config) 669 670 #Initial condition function 671 self.uinitial = config.uinitial 672 673 #Boundary points 674 self.xl = config.xl 675 self.xu = config.xu 676 self.tu = config.tu 677 678 #Right dirichlet point 679 self.ld = config.ld 680 self.rd = config.rd 681 682 # Predictions over array of x fot t fixed 683 self.u1_0_pred_fn = vmap( 684 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 685 ) 686 self.u2_0_pred_fn = vmap( 687 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 688 ) 689 690 #Prediction over array of t for x fixed 691 self.u2_bound_pred_fn = vmap( 692 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 693 ) 694 695 self.u1_bound_pred_fn = vmap( 696 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 697 ) 698 699 #Vmap neural net 700 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 701 702 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 703 704 #Vmap residual operator 705 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 706 707 #Derivatives on x for x fixed and t in a array 708 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 709 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None)) 710 711 #Neural net forward function 712 def neural_net(self, params, t, x): 713 t = t / self.tu 714 z = jnp.stack([t, x]) 715 _, outputs = self.state.apply_fn(params, z) 716 u1 = outputs[0] 717 u2 = outputs[1] 718 return u1, u2 719 720 #1st coordinate neural net forward function 721 def u1_net(self, params, t, x): 722 u1, _ = self.neural_net(params, t, x) 723 return u1 724 725 #2st coordinate neural net forward function 726 def u2_net(self, params, t, x): 727 _, u2 = self.neural_net(params, t, x) 728 return u2 729 730 #Residual operator 731 def r_net(self, params, t, x): 732 #Derivatives in x and t 733 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 734 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 735 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 736 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 737 738 #Two derivatives in x 739 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 740 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 741 742 #Each coordinate of residual operator 743 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2 744 745 #1st coordinate residual operator 746 def r_net1(self, params, t, x): 747 r1,_ = self.r_net(params, t, x) 748 return r1 749 750 #2nd coordinate residual operator 751 def r_net2(self, params, t, x): 752 _,r2 = self.r_net(params, t, x) 753 return r2 754 755 #Right Dirichlet Condition residual 756 def res_right_dirichlet(self, params, t, x): 757 #Apply neural net to point in the boundary 758 u1 = self.u1_bound_pred_fn(params, t, x) 759 u2 = self.u2_bound_pred_fn(params, t, x) 760 #Return residual 761 return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2 762 763 #Right Dirichlet Condition residual non-vectorised 764 def res_right_dirichlet_nv(self, params, t, x): 765 #Apply neural net to point in the boundary 766 u1 = self.u1_net(params, t, x) 767 u2 = self.u2_net(params, t, x) 768 #Return residual 769 return (u1 - self.rd[0]) ** 2 + (u2 - self.rd[1]) ** 2 770 771 #Left Dirichlet Condition residual 772 def res_left_dirichlet(self, params, t, x): 773 #Apply neural net to point in the boundary 774 u1 = self.u1_bound_pred_fn(params, t, x) 775 u2 = self.u2_bound_pred_fn(params, t, x) 776 #Return residual 777 return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2 778 779 #Left Dirichlet Condition residual non-vectorised 780 def res_left_dirichlet_nv(self, params, t, x): 781 #Apply neural net to point in the boundary 782 u1 = self.u1_net(params, t, x) 783 u2 = self.u2_net(params, t, x) 784 #Return residual 785 return (u1 - self.ld[0]) ** 2 + (u2 - self.ld[1]) ** 2 786 787 #Compute residuals with causal weights 788 @partial(jit, static_argnums=(0,)) 789 def res_causal(self, params, batch): 790 # Sort temporal coordinates 791 t_sorted = batch[:, 0].sort() 792 793 #Compute residuals 794 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 795 796 #Reshape 797 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 798 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 799 800 #Compute mean residuals 801 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 802 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 803 804 #Compute weights 805 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 806 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 807 808 # Take minimum of the causal weights 809 gamma = jnp.vstack([res_gamma1,res_gamma2]) 810 gamma = gamma.min(0) 811 812 return res_l1, res_l2, gamma 813 814 #Compute losses 815 @partial(jit, static_argnums=(0,)) 816 def losses(self, params, batch): 817 # Initial conditions loss 818 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 819 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 820 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 821 822 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 823 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 824 825 # Residual loss 826 if self.config.weighting.use_causal == True: 827 res_l1, res_l2, gamma = self.res_causal(params, batch) 828 res_loss1 = jnp.mean(res_l1 * gamma) 829 res_loss2 = jnp.mean(res_l2 * gamma) 830 else: 831 res_pred1,res_pred2 = self.r_pred_fn( 832 params, batch[:, 0], batch[:, 1] 833 ) 834 # Compute loss 835 res_loss1 = jnp.mean(res_pred1 ** 2) 836 res_loss2 = jnp.mean(res_pred2 ** 2) 837 838 loss_dict = { 839 "ic": u1_ic_loss + u2_ic_loss, 840 "res1": res_loss1, 841 "res2": res_loss2, 842 'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 843 'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)) 844 } 845 return loss_dict 846 847 #Compute NTK 848 @partial(jit, static_argnums = (0,)) 849 def compute_diag_ntk(self, params, batch): 850 #Initial Condition 851 u1_ic_ntk = vmap( 852 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 853 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 854 855 u2_ic_ntk = vmap( 856 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 857 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 858 859 #Right Dirichlet 860 rd_ntk = vmap( 861 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 862 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 863 864 #Left Dirichlet 865 ld_ntk = vmap( 866 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 867 )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 868 869 # Consider the effect of causal weights 870 if self.config.weighting.use_causal: 871 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 872 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 873 self.r_net1, params, batch[:, 0], batch[:, 1] 874 ) 875 876 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 877 self.r_net2, params, batch[:, 0], batch[:, 1] 878 ) 879 880 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 881 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 882 883 res_ntk1 = jnp.mean(res_ntk1, axis=1) 884 res_ntk2 = jnp.mean(res_ntk2, axis=1) 885 886 _,_, casual_weights = self.res_causal(params, batch) 887 res_ntk1 = res_ntk1 * casual_weights 888 res_ntk2 = res_ntk2 * casual_weights 889 else: 890 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 891 self.r_net1, params, batch[:, 0], batch[:, 1] 892 ) 893 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 894 self.r_net2, params, batch[:, 0], batch[:, 1] 895 ) 896 897 ntk_dict = { 898 "ic": u1_ic_ntk + u2_ic_ntk, 899 "res1": res_ntk1, 900 "res2": res_ntk2, 901 'rd': rd_ntk, 902 'ld': ld_ntk 903 } 904 return ntk_dict
DD_csf(config)
667 def __init__(self, config): 668 super().__init__(config) 669 670 #Initial condition function 671 self.uinitial = config.uinitial 672 673 #Boundary points 674 self.xl = config.xl 675 self.xu = config.xu 676 self.tu = config.tu 677 678 #Right dirichlet point 679 self.ld = config.ld 680 self.rd = config.rd 681 682 # Predictions over array of x fot t fixed 683 self.u1_0_pred_fn = vmap( 684 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 685 ) 686 self.u2_0_pred_fn = vmap( 687 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 688 ) 689 690 #Prediction over array of t for x fixed 691 self.u2_bound_pred_fn = vmap( 692 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 693 ) 694 695 self.u1_bound_pred_fn = vmap( 696 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 697 ) 698 699 #Vmap neural net 700 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 701 702 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 703 704 #Vmap residual operator 705 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 706 707 #Derivatives on x for x fixed and t in a array 708 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 709 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
def
r_net(self, params, t, x):
731 def r_net(self, params, t, x): 732 #Derivatives in x and t 733 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 734 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 735 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 736 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 737 738 #Two derivatives in x 739 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 740 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 741 742 #Each coordinate of residual operator 743 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
@partial(jit, static_argnums=(0,))
def
res_causal(self, params, batch):
788 @partial(jit, static_argnums=(0,)) 789 def res_causal(self, params, batch): 790 # Sort temporal coordinates 791 t_sorted = batch[:, 0].sort() 792 793 #Compute residuals 794 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 795 796 #Reshape 797 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 798 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 799 800 #Compute mean residuals 801 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 802 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 803 804 #Compute weights 805 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 806 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 807 808 # Take minimum of the causal weights 809 gamma = jnp.vstack([res_gamma1,res_gamma2]) 810 gamma = gamma.min(0) 811 812 return res_l1, res_l2, gamma
@partial(jit, static_argnums=(0,))
def
losses(self, params, batch):
815 @partial(jit, static_argnums=(0,)) 816 def losses(self, params, batch): 817 # Initial conditions loss 818 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 819 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 820 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 821 822 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 823 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 824 825 # Residual loss 826 if self.config.weighting.use_causal == True: 827 res_l1, res_l2, gamma = self.res_causal(params, batch) 828 res_loss1 = jnp.mean(res_l1 * gamma) 829 res_loss2 = jnp.mean(res_l2 * gamma) 830 else: 831 res_pred1,res_pred2 = self.r_pred_fn( 832 params, batch[:, 0], batch[:, 1] 833 ) 834 # Compute loss 835 res_loss1 = jnp.mean(res_pred1 ** 2) 836 res_loss2 = jnp.mean(res_pred2 ** 2) 837 838 loss_dict = { 839 "ic": u1_ic_loss + u2_ic_loss, 840 "res1": res_loss1, 841 "res2": res_loss2, 842 'rd': jnp.mean(self.res_right_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xu)), 843 'ld': jnp.mean(self.res_left_dirichlet(params, batch[:, 0].reshape((batch.shape[0],1)), self.xl)) 844 } 845 return loss_dict
@partial(jit, static_argnums=(0,))
def
compute_diag_ntk(self, params, batch):
848 @partial(jit, static_argnums = (0,)) 849 def compute_diag_ntk(self, params, batch): 850 #Initial Condition 851 u1_ic_ntk = vmap( 852 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 853 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 854 855 u2_ic_ntk = vmap( 856 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 857 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 858 859 #Right Dirichlet 860 rd_ntk = vmap( 861 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 862 )(self.res_right_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xu) 863 864 #Left Dirichlet 865 ld_ntk = vmap( 866 vmap(ntk_fn, (None, None, 0, None)), (None, None, 0, None) 867 )(self.res_left_dirichlet_nv, params, batch[:, 0].reshape((batch.shape[0],1)), self.xl) 868 869 # Consider the effect of causal weights 870 if self.config.weighting.use_causal: 871 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 872 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 873 self.r_net1, params, batch[:, 0], batch[:, 1] 874 ) 875 876 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 877 self.r_net2, params, batch[:, 0], batch[:, 1] 878 ) 879 880 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 881 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 882 883 res_ntk1 = jnp.mean(res_ntk1, axis=1) 884 res_ntk2 = jnp.mean(res_ntk2, axis=1) 885 886 _,_, casual_weights = self.res_causal(params, batch) 887 res_ntk1 = res_ntk1 * casual_weights 888 res_ntk2 = res_ntk2 * casual_weights 889 else: 890 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 891 self.r_net1, params, batch[:, 0], batch[:, 1] 892 ) 893 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 894 self.r_net2, params, batch[:, 0], batch[:, 1] 895 ) 896 897 ntk_dict = { 898 "ic": u1_ic_ntk + u2_ic_ntk, 899 "res1": res_ntk1, 900 "res2": res_ntk2, 901 'rd': rd_ntk, 902 'ld': ld_ntk 903 } 904 return ntk_dict
class
DD_csf_Evaluator(jaxpi.evaluator.BaseEvaluator):
906class DD_csf_Evaluator(BaseEvaluator): 907 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 908 super().__init__(config, model) 909 910 self.x0_test = x0_test 911 self.tb_test = tb_test 912 self.xc_test = xc_test 913 self.tc_test = tc_test 914 self.u2_0_test = u2_0_test 915 self.u1_0_test = u1_0_test 916 917 def log_errors(self, params): 918 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 919 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 920 921 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 922 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 923 924 res_pred1,res_pred2 = self.model.r_pred_fn( 925 params, self.tc_test[:,0], self.xc_test[:,0] 926 ) 927 res_loss1 = jnp.mean(res_pred1 ** 2) 928 res_loss2 = jnp.mean(res_pred2 ** 2) 929 930 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 931 self.log_dict["res1_test"] = res_loss1 932 self.log_dict["res2_test"] = res_loss2 933 self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu)) 934 self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl)) 935 936 def __call__(self, state, batch): 937 self.log_dict = super().__call__(state, batch) 938 939 if self.config.logging.log_errors: 940 self.log_errors(state.params) 941 942 if self.config.weighting.use_causal: 943 _, _, causal_weight = self.model.res_causal(state.params, batch) 944 self.log_dict["cas_weight"] = causal_weight.min() 945 946 return self.log_dict
DD_csf_Evaluator( config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
907 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 908 super().__init__(config, model) 909 910 self.x0_test = x0_test 911 self.tb_test = tb_test 912 self.xc_test = xc_test 913 self.tc_test = tc_test 914 self.u2_0_test = u2_0_test 915 self.u1_0_test = u1_0_test
def
log_errors(self, params):
917 def log_errors(self, params): 918 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 919 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 920 921 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 922 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 923 924 res_pred1,res_pred2 = self.model.r_pred_fn( 925 params, self.tc_test[:,0], self.xc_test[:,0] 926 ) 927 res_loss1 = jnp.mean(res_pred1 ** 2) 928 res_loss2 = jnp.mean(res_pred2 ** 2) 929 930 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 931 self.log_dict["res1_test"] = res_loss1 932 self.log_dict["res2_test"] = res_loss2 933 self.log_dict["rd_test"] = jnp.mean(self.model.res_right_dirichlet(params, self.tb_test, self.model.xu)) 934 self.log_dict["ld_test"] = jnp.mean(self.model.res_left_dirichlet(params, self.tb_test, self.model.xl))
class
closed_csf(jaxpi.models.ForwardIVP):
948class closed_csf(ForwardIVP): 949 def __init__(self, config): 950 super().__init__(config) 951 952 #Initial condition function 953 self.uinitial = config.uinitial 954 955 #Boundary points 956 self.xl = config.xl 957 self.xu = config.xu 958 self.tu = config.tu 959 960 # Predictions over array of x fot t fixed 961 self.u1_0_pred_fn = vmap( 962 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 963 ) 964 self.u2_0_pred_fn = vmap( 965 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 966 ) 967 968 #Prediction over array of t for x fixed 969 self.u2_bound_pred_fn = vmap( 970 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 971 ) 972 973 self.u1_bound_pred_fn = vmap( 974 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 975 ) 976 977 #Vmap neural net 978 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 979 980 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 981 982 #Vmap residual operator 983 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 984 985 #Derivatives on x for x fixed and t in a array 986 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 987 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None)) 988 989 #Neural net forward function 990 def neural_net(self, params, t, x): 991 t = t / self.tu 992 z = jnp.stack([t, x]) 993 _, outputs = self.state.apply_fn(params, z) 994 u1 = outputs[0] 995 u2 = outputs[1] 996 return u1, u2 997 998 #1st coordinate neural net forward function 999 def u1_net(self, params, t, x): 1000 u1, _ = self.neural_net(params, t, x) 1001 return u1 1002 1003 #2st coordinate neural net forward function 1004 def u2_net(self, params, t, x): 1005 _, u2 = self.neural_net(params, t, x) 1006 return u2 1007 1008 #Residual operator 1009 def r_net(self, params, t, x): 1010 #Derivatives in x and t 1011 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 1012 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 1013 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 1014 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 1015 1016 #Two derivatives in x 1017 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 1018 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 1019 1020 #Each coordinate of residual operator 1021 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2 1022 1023 #1st coordinate residual operator 1024 def r_net1(self, params, t, x): 1025 r1,_ = self.r_net(params, t, x) 1026 return r1 1027 1028 #2nd coordinate residual operator 1029 def r_net2(self, params, t, x): 1030 _,r2 = self.r_net(params, t, x) 1031 return r2 1032 1033 #Compute residuals with causal weights 1034 @partial(jit, static_argnums=(0,)) 1035 def res_causal(self, params, batch): 1036 # Sort temporal coordinates 1037 t_sorted = batch[:, 0].sort() 1038 1039 #Compute residuals 1040 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 1041 1042 #Reshape 1043 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 1044 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 1045 1046 #Compute mean residuals 1047 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 1048 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 1049 1050 #Compute weights 1051 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 1052 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 1053 1054 # Take minimum of the causal weights 1055 gamma = jnp.vstack([res_gamma1,res_gamma2]) 1056 gamma = gamma.min(0) 1057 1058 return res_l1, res_l2, gamma 1059 1060 #Compute losses 1061 @partial(jit, static_argnums=(0,)) 1062 def losses(self, params, batch): 1063 # Initial conditions loss 1064 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1065 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1066 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 1067 1068 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 1069 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 1070 1071 periodic1 = jnp.mean((self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2) 1072 periodic2 = jnp.mean((self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2) 1073 1074 # Residual loss 1075 if self.config.weighting.use_causal == True: 1076 res_l1, res_l2, gamma = self.res_causal(params, batch) 1077 res_loss1 = jnp.mean(res_l1 * gamma) 1078 res_loss2 = jnp.mean(res_l2 * gamma) 1079 else: 1080 res_pred1,res_pred2 = self.r_pred_fn( 1081 params, batch[:, 0], batch[:, 1] 1082 ) 1083 # Compute loss 1084 res_loss1 = jnp.mean(res_pred1 ** 2) 1085 res_loss2 = jnp.mean(res_pred2 ** 2) 1086 1087 loss_dict = { 1088 "ic": u1_ic_loss + u2_ic_loss, 1089 "res1": res_loss1, 1090 "res2": res_loss2, 1091 'periodic1': periodic1, 1092 'periodic2': periodic2 1093 } 1094 return loss_dict 1095 1096 #Compute NTK 1097 @partial(jit, static_argnums = (0,)) 1098 def compute_diag_ntk(self, params, batch): 1099 #Initial Condition 1100 u1_ic_ntk = vmap( 1101 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 1102 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1103 1104 u2_ic_ntk = vmap( 1105 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 1106 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1107 1108 # Consider the effect of causal weights 1109 if self.config.weighting.use_causal: 1110 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 1111 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 1112 self.r_net1, params, batch[:, 0], batch[:, 1] 1113 ) 1114 1115 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 1116 self.r_net2, params, batch[:, 0], batch[:, 1] 1117 ) 1118 1119 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 1120 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 1121 1122 res_ntk1 = jnp.mean(res_ntk1, axis=1) 1123 res_ntk2 = jnp.mean(res_ntk2, axis=1) 1124 1125 _,_, casual_weights = self.res_causal(params, batch) 1126 res_ntk1 = res_ntk1 * casual_weights 1127 res_ntk2 = res_ntk2 * casual_weights 1128 else: 1129 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 1130 self.r_net1, params, batch[:, 0], batch[:, 1] 1131 ) 1132 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 1133 self.r_net2, params, batch[:, 0], batch[:, 1] 1134 ) 1135 1136 ntk_dict = { 1137 "ic": u1_ic_ntk + u2_ic_ntk, 1138 "res1": res_ntk1, 1139 "res2": res_ntk2 1140 } 1141 return ntk_dict
closed_csf(config)
949 def __init__(self, config): 950 super().__init__(config) 951 952 #Initial condition function 953 self.uinitial = config.uinitial 954 955 #Boundary points 956 self.xl = config.xl 957 self.xu = config.xu 958 self.tu = config.tu 959 960 # Predictions over array of x fot t fixed 961 self.u1_0_pred_fn = vmap( 962 vmap(self.u1_net, (None, None, 0)), (None, None, 0) 963 ) 964 self.u2_0_pred_fn = vmap( 965 vmap(self.u2_net, (None, None, 0)), (None, None, 0) 966 ) 967 968 #Prediction over array of t for x fixed 969 self.u2_bound_pred_fn = vmap( 970 vmap(self.u2_net, (None, 0, None)), (None, 0, None) 971 ) 972 973 self.u1_bound_pred_fn = vmap( 974 vmap(self.u1_net, (None, 0, None)), (None, 0, None) 975 ) 976 977 #Vmap neural net 978 self.u1_pred_fn = vmap(self.u1_net, (None, 0, 0)) 979 980 self.u2_pred_fn = vmap(self.u2_net, (None, 0, 0)) 981 982 #Vmap residual operator 983 self.r_pred_fn = vmap(self.r_net, (None, 0, 0)) 984 985 #Derivatives on x for x fixed and t in a array 986 self.u1_bound_x = vmap(vmap(grad(self.u1_net, argnums = 2), (None, 0, None)), (None, 0, None)) 987 self.u2_bound_x = vmap(vmap(grad(self.u2_net, argnums = 2), (None, 0, None)), (None, 0, None))
def
r_net(self, params, t, x):
1009 def r_net(self, params, t, x): 1010 #Derivatives in x and t 1011 u1_x = grad(self.u1_net, argnums = 2)(params, t, x) 1012 u1_t = grad(self.u1_net, argnums = 1)(params, t, x) 1013 u2_x = grad(self.u2_net, argnums = 2)(params, t, x) 1014 u2_t = grad(self.u2_net, argnums = 1)(params, t, x) 1015 1016 #Two derivatives in x 1017 u1_xx = hessian(self.u1_net, argnums = (2))(params, t, x) 1018 u2_xx = hessian(self.u2_net, argnums = (2))(params, t, x) 1019 1020 #Each coordinate of residual operator 1021 return (u1_t - u1_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2, (u2_t - u2_xx/(u1_x ** 2 + u2_x ** 2 + 1e-5)) ** 2
@partial(jit, static_argnums=(0,))
def
res_causal(self, params, batch):
1034 @partial(jit, static_argnums=(0,)) 1035 def res_causal(self, params, batch): 1036 # Sort temporal coordinates 1037 t_sorted = batch[:, 0].sort() 1038 1039 #Compute residuals 1040 res_pred1,res_pred2 = self.r_pred_fn(params, t_sorted, batch[:, 1]) 1041 1042 #Reshape 1043 res_pred1 = res_pred1.reshape(self.num_chunks, -1) 1044 res_pred2 = res_pred2.reshape(self.num_chunks, -1) 1045 1046 #Compute mean residuals 1047 res_l1 = jnp.mean(res_pred1 ** 2, axis=1) 1048 res_l2 = jnp.mean(res_pred2 ** 2, axis=1) 1049 1050 #Compute weights 1051 res_gamma1 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l1))) 1052 res_gamma2 = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ res_l2))) 1053 1054 # Take minimum of the causal weights 1055 gamma = jnp.vstack([res_gamma1,res_gamma2]) 1056 gamma = gamma.min(0) 1057 1058 return res_l1, res_l2, gamma
@partial(jit, static_argnums=(0,))
def
losses(self, params, batch):
1061 @partial(jit, static_argnums=(0,)) 1062 def losses(self, params, batch): 1063 # Initial conditions loss 1064 u1_pred = self.u1_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1065 u2_pred = self.u2_0_pred_fn(params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1066 u1_0,u2_0 = self.uinitial(batch[:, 1].reshape((batch.shape[0],1))) 1067 1068 u1_ic_loss = jnp.mean((u1_pred - u1_0) ** 2) 1069 u2_ic_loss = jnp.mean((u2_pred - u2_0) ** 2) 1070 1071 periodic1 = jnp.mean((self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u1_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2) 1072 periodic2 = jnp.mean((self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xl) - self.u2_bound_pred_fn(params,batch[:, 0].reshape((batch.shape[0],1)),self.xu)) ** 2) 1073 1074 # Residual loss 1075 if self.config.weighting.use_causal == True: 1076 res_l1, res_l2, gamma = self.res_causal(params, batch) 1077 res_loss1 = jnp.mean(res_l1 * gamma) 1078 res_loss2 = jnp.mean(res_l2 * gamma) 1079 else: 1080 res_pred1,res_pred2 = self.r_pred_fn( 1081 params, batch[:, 0], batch[:, 1] 1082 ) 1083 # Compute loss 1084 res_loss1 = jnp.mean(res_pred1 ** 2) 1085 res_loss2 = jnp.mean(res_pred2 ** 2) 1086 1087 loss_dict = { 1088 "ic": u1_ic_loss + u2_ic_loss, 1089 "res1": res_loss1, 1090 "res2": res_loss2, 1091 'periodic1': periodic1, 1092 'periodic2': periodic2 1093 } 1094 return loss_dict
@partial(jit, static_argnums=(0,))
def
compute_diag_ntk(self, params, batch):
1097 @partial(jit, static_argnums = (0,)) 1098 def compute_diag_ntk(self, params, batch): 1099 #Initial Condition 1100 u1_ic_ntk = vmap( 1101 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 1102 )(self.u1_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1103 1104 u2_ic_ntk = vmap( 1105 vmap(ntk_fn, (None, None, None, 0)), (None, None, None, 0) 1106 )(self.u2_net, params, 0.0, batch[:, 1].reshape((batch.shape[0],1))) 1107 1108 # Consider the effect of causal weights 1109 if self.config.weighting.use_causal: 1110 batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T 1111 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 1112 self.r_net1, params, batch[:, 0], batch[:, 1] 1113 ) 1114 1115 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 1116 self.r_net2, params, batch[:, 0], batch[:, 1] 1117 ) 1118 1119 res_ntk1 = res_ntk1.reshape(self.num_chunks, -1) 1120 res_ntk2 = res_ntk2.reshape(self.num_chunks, -1) 1121 1122 res_ntk1 = jnp.mean(res_ntk1, axis=1) 1123 res_ntk2 = jnp.mean(res_ntk2, axis=1) 1124 1125 _,_, casual_weights = self.res_causal(params, batch) 1126 res_ntk1 = res_ntk1 * casual_weights 1127 res_ntk2 = res_ntk2 * casual_weights 1128 else: 1129 res_ntk1 = vmap(ntk_fn, (None, None, 0, 0))( 1130 self.r_net1, params, batch[:, 0], batch[:, 1] 1131 ) 1132 res_ntk2 = vmap(ntk_fn, (None, None, 0, 0))( 1133 self.r_net2, params, batch[:, 0], batch[:, 1] 1134 ) 1135 1136 ntk_dict = { 1137 "ic": u1_ic_ntk + u2_ic_ntk, 1138 "res1": res_ntk1, 1139 "res2": res_ntk2 1140 } 1141 return ntk_dict
class
closed_csf_Evaluator(jaxpi.evaluator.BaseEvaluator):
1143class closed_csf_Evaluator(BaseEvaluator): 1144 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 1145 super().__init__(config, model) 1146 1147 self.x0_test = x0_test 1148 self.tb_test = tb_test 1149 self.xc_test = xc_test 1150 self.tc_test = tc_test 1151 self.u2_0_test = u2_0_test 1152 self.u1_0_test = u1_0_test 1153 1154 def log_errors(self, params): 1155 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 1156 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 1157 1158 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 1159 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 1160 1161 res_pred1,res_pred2 = self.model.r_pred_fn( 1162 params, self.tc_test[:,0], self.xc_test[:,0] 1163 ) 1164 res_loss1 = jnp.mean(res_pred1 ** 2) 1165 res_loss2 = jnp.mean(res_pred2 ** 2) 1166 1167 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 1168 self.log_dict["res1_test"] = res_loss1 1169 self.log_dict["res2_test"] = res_loss2 1170 self.log_dict["periodic1_test"] = jnp.mean((self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2) 1171 self.log_dict["periodic2_test"] = jnp.mean((self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2) 1172 1173 def __call__(self, state, batch): 1174 self.log_dict = super().__call__(state, batch) 1175 1176 if self.config.logging.log_errors: 1177 self.log_errors(state.params) 1178 1179 if self.config.weighting.use_causal: 1180 _, _, causal_weight = self.model.res_causal(state.params, batch) 1181 self.log_dict["cas_weight"] = causal_weight.min() 1182 1183 return self.log_dict
closed_csf_Evaluator( config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
1144 def __init__(self, config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test): 1145 super().__init__(config, model) 1146 1147 self.x0_test = x0_test 1148 self.tb_test = tb_test 1149 self.xc_test = xc_test 1150 self.tc_test = tc_test 1151 self.u2_0_test = u2_0_test 1152 self.u1_0_test = u1_0_test
def
log_errors(self, params):
1154 def log_errors(self, params): 1155 u1_pred = self.model.u1_0_pred_fn(params, 0.0, self.x0_test) 1156 u2_pred = self.model.u2_0_pred_fn(params, 0.0, self.x0_test) 1157 1158 u1_ic_loss = jnp.mean((u1_pred - self.u1_0_test) ** 2) 1159 u2_ic_loss = jnp.mean((u2_pred - self.u2_0_test) ** 2) 1160 1161 res_pred1,res_pred2 = self.model.r_pred_fn( 1162 params, self.tc_test[:,0], self.xc_test[:,0] 1163 ) 1164 res_loss1 = jnp.mean(res_pred1 ** 2) 1165 res_loss2 = jnp.mean(res_pred2 ** 2) 1166 1167 self.log_dict["ic_rel_test"] = jnp.sqrt((u1_ic_loss + u2_ic_loss)/(jnp.mean(self.u1_0_test ** 2) + jnp.mean(self.u2_0_test ** 2))) 1168 self.log_dict["res1_test"] = res_loss1 1169 self.log_dict["res2_test"] = res_loss2 1170 self.log_dict["periodic1_test"] = jnp.mean((self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u1_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2) 1171 self.log_dict["periodic2_test"] = jnp.mean((self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xl) - self.model.u2_bound_pred_fn(params,self.tb_test,self.config.xu)) ** 2)