jinnax.nn
1#Functions to train NN 2import jax 3import jax.numpy as jnp 4import optax 5from alive_progress import alive_bar 6import math 7import time 8import pandas as pd 9import numpy as np 10import matplotlib.pyplot as plt 11import dill as pickle 12from genree import bolstering as gb 13from genree import kernel as gk 14from jax import random 15from jinnax import data as jd 16import os 17import numpy as np 18from scipy.fft import dst, idst 19from itertools import product 20from functools import partial 21import orthax 22from jax import lax 23from jaxopt import LBFGS 24 25__docformat__ = "numpy" 26 27#MSE 28@jax.jit 29def MSE(pred,true): 30 """ 31 Squared error 32 ---------- 33 Parameters 34 ---------- 35 pred : jax.numpy.array 36 37 A JAX numpy array with the predicted values 38 39 true : jax.numpy.array 40 41 A JAX numpy array with the true values 42 43 Returns 44 ------- 45 squared error 46 """ 47 return (true - pred) ** 2 48 49#MSE self-adaptative 50@jax.jit 51def MSE_SA(pred,true,w,q = 2): 52 """ 53 Self-adaptative squared error 54 ---------- 55 Parameters 56 ---------- 57 pred : jax.numpy.array 58 59 A JAX numpy array with the predicted values 60 61 true : jax.numpy.array 62 63 A JAX numpy array with the true values 64 65 weight : jax.numpy.array 66 67 A JAX numpy array with the weights 68 69 q : float 70 71 Power for the weights mask 72 73 Returns 74 ------- 75 self-adaptative squared error with polynomial mask 76 """ 77 return (w ** q) * ((true - pred) ** 2) 78 79#L2 error 80@jax.jit 81def L2error(pred,true): 82 """ 83 L2-error in percentage (%) 84 ---------- 85 Parameters 86 ---------- 87 pred : jax.numpy.array 88 89 A JAX numpy array with the predicted values 90 91 true : jax.numpy.array 92 93 A JAX numpy array with the true values 94 95 Returns 96 ------- 97 L2-error 98 """ 99 return 100*jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2)) 100 101#Auxialiry functions to sample singular Matern 102def idst1(x,axis = -1): 103 """ 104 Inverse Discrete Sine Transform of type I with orthonormal scaling 105 ---------- 106 Parameters 107 ---------- 108 x : jax.numpy.array 109 110 Array to apply the transformation 111 112 axis : int 113 114 Axis to apply the transformation over 115 116 Returns 117 ------- 118 jax.numpy.array 119 """ 120 return idst(x,type = 1,axis = axis,norm = 'ortho') 121 122def dstn(x,axes = None): 123 """ 124 Discrete Sine Transform of type I with orthonormal scaling over many axes 125 ---------- 126 Parameters 127 ---------- 128 x : jax.numpy.array 129 130 Array to apply the transformation 131 132 axes : int 133 134 Axes to apply the transformation over 135 136 Returns 137 ------- 138 jax.numpy.array 139 """ 140 if axes is None: 141 axes = tuple(range(x.ndim)) 142 y = x 143 for ax in axes: 144 y = dst(x,type = 1,axis = ax,norm = 'ortho') 145 return y 146 147def idstn(x,axes = None): 148 """ 149 Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes 150 ---------- 151 Parameters 152 ---------- 153 x : jax.numpy.array 154 155 Array to apply the transformation 156 157 axes : int 158 159 Axes to apply the transformation over 160 161 Returns 162 ------- 163 jax.numpy.array 164 """ 165 if axes is None: 166 axes = tuple(range(x.ndim)) 167 y = x 168 for ax in axes: 169 y = idst1(y,axis = ax) 170 return y 171 172def dirichlet_eigs_nd(n,L): 173 """ 174 Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle 175 ---------- 176 Parameters 177 ---------- 178 n : list 179 180 List with the number of points in the grid in each dimension 181 182 L : list 183 184 List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero 185 186 Returns 187 ------- 188 jax.numpy.array 189 """ 190 #Unidimensional eigenvalues 191 lam_axes = [] 192 for ni, Li in zip(n,L): 193 h = Li / (ni + 1.0) 194 k = jnp.arange(1,ni + 1,dtype = np.float32) 195 ln = (2.0 / (h*h)) * (1.0 - jnp.cos(jnp.pi * k / (ni + 1.0))) 196 lam_axes.append(ln) 197 grids = jnp.meshgrid(*lam_axes, indexing='ij') 198 Lam = jnp.zeros_like(grids[0]) 199 for g in grids: 200 Lam += g 201 return Lam 202 203 204#Sample from d-dimensional Matern process 205def generate_matern_sample(key,d = 2,N = 128,L = 1.0,kappa = 1,alpha = 1,sigma = 1,periodic = False): 206 """ 207 Sample d-dimensional Matern process 208 ---------- 209 Parameters 210 ---------- 211 key : int 212 213 Seed for randomization 214 215 d : int 216 217 Dimension. Default 2 218 219 N : int 220 221 Size of grid in each dimension. Default 128 222 223 L : list of float 224 225 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 226 227 kappa,alpha,sigma : float 228 229 Parameters of the Matern process 230 231 periodic : logical 232 233 Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT 234 235 Returns 236 ------- 237 jax.numpy.array 238 """ 239 if periodic: 240 #Shape and key 241 key = jax.random.PRNGKey(key) 242 shape = (N,) * d 243 if isinstance(L,float) or isinstance(L,int): 244 L = d*[L] 245 if isinstance(N,float) or isinstance(N,int): 246 N = d*[N] 247 248 #Setup Frequency Grid (2D) 249 freq = [jnp.fft.fftfreq(N[j],d = L[j]/N[j]) * 2 * jnp.pi for j in range(d)] 250 grids = jnp.meshgrid(*freq, indexing='ij') 251 sq_norm_xi = sum(g**2 for g in grids) 252 253 #Generate White Noise in Fourier Space 254 key_re, key_im = jax.random.split(key) 255 white_noise_f = (jax.random.normal(key_re, shape) + 256 1j * jax.random.normal(key_im, shape)) 257 258 #Apply the Whittle Filter 259 amplitude_filter = (kappa ** 2 + sq_norm_xi) ** (-alpha / 2.0) 260 field_f = white_noise_f * amplitude_filter 261 262 #Transform back to Physical Space 263 sample = jnp.real(jnp.fft.ifftn(field_f)) 264 return sigma*sample 265 else: #NOT JAX 266 #Shape and key 267 rng = np.random.default_rng(seed = key) 268 if isinstance(L,float) or isinstance(L,int): 269 L = d*[L] 270 if isinstance(N,float) or isinstance(N,int): 271 N = d*[N] 272 shape = tuple(N) 273 274 #White noise in real space 275 W = rng.standard_normal(size = shape) 276 277 #To Dirichlet eigenbasis via separable DST-I (orthonormal) 278 W_hat = dstn(W) 279 280 #Discrete Dirichlet Laplacian eigenvalues 281 lam = dirichlet_eigs_nd(N, L) 282 283 #Spectral filter 284 filt = ((kappa + lam) ** (-alpha/2.0)) 285 psi_hat = filt * W_hat 286 287 #Back to real space 288 psi = idstn(psi_hat) 289 return jnp.array(sigma*psi) 290 291#Vectorized generate_matern_sample 292def generate_matern_sample_batch(d = 2,N = 512,L = 1.0,kappa = 10.0,alpha = 1,sigma = 10,periodic = False): 293 """ 294 Create function to sample d-dimensional Matern process 295 ---------- 296 Parameters 297 ---------- 298 d : int 299 300 Dimension. Default 2 301 302 N : int 303 304 Size of grid in each dimension. Default 128 305 306 L : list of float 307 308 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 309 310 kappa,alpha,sigma : float 311 312 Parameters of the Matern process 313 314 periodic : logical 315 316 Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT 317 318 Returns 319 ------- 320 function 321 """ 322 if periodic: 323 return jax.vmap(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)) 324 else: 325 return lambda keys: jnp.array(np.apply_along_axis(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic),1,keys.reshape((keys.shape[0],1)))) 326 327#Build function to compute the eigenfunctions of Laplacian 328def eigenf_laplace(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None): 329 """ 330 Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. 331 ---------- 332 Parameters 333 ---------- 334 L_vec : list of float 335 336 The domain of the function in each coordinate is [0,L[1]] 337 338 kmax_per_axis : list 339 340 List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions 341 342 bc : str 343 344 Boundary condition. 'dirichlet' or 'neumann' 345 346 max_ef : int 347 348 Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions 349 350 Returns 351 ------- 352 function to compute eigenfunctions,eigenvalues of the eigenfunctions considered 353 """ 354 #Parameters 355 L_vec = jnp.asarray(L_vec,dtype = jnp.float32) 356 d = L_vec.shape[0] 357 bc = bc.lower() 358 359 #Maximum number of functions 360 if max_ef is None: 361 if d == 1: 362 max_ef = jnp.max(jnp.array(kmax_per_axis)) 363 else: 364 max_ef = jnp.max(d * jnp.array(kmax_per_axis)) 365 366 #Build the candidate multi-indices per axis 367 kmax_per_axis = list(map(int, kmax_per_axis)) 368 if bc.startswith("d"): 369 axis_ranges = [range(1, km + 1) for km in kmax_per_axis] 370 elif bc.startswith("n"): 371 axis_ranges = [range(0, km + 1) for km in kmax_per_axis] 372 373 #Get all multi-indices 374 Ks_list = list(product(*axis_ranges)) 375 Ks = jnp.array(Ks_list,dtype = jnp.float32) 376 377 #Eigenvalues of the continuous Laplacian 378 pi_over_L = jnp.pi / L_vec 379 lambdas_all = jnp.sum((Ks * pi_over_L) ** 2, axis=1) 380 381 #Sort by eigenvalue 382 order = jnp.argsort(lambdas_all) 383 Ks = Ks[order] 384 lambdas_all = lambdas_all[order] 385 386 #Keep first max_ef 387 Ks = Ks[:max_ef] 388 lambdas = lambdas_all[:max_ef] 389 m = Ks.shape[0] 390 391 #Precompute per-feature normalization factor (closed form) 392 def per_axis_norm_factor(k_i, L_i, is_dirichlet): 393 if is_dirichlet: 394 return jnp.sqrt(2.0 / L_i) 395 else: 396 return jnp.where(k_i == 0, jnp.sqrt(1.0 / L_i), jnp.sqrt(2.0 / L_i)) 397 if bc.startswith("d"): 398 nf = jnp.prod(jnp.sqrt(2.0 / L_vec)[None, :],axis = 1) 399 norm_factors = jnp.ones((m,),dtype = jnp.float32) * nf 400 else: 401 # per-mode product across axes 402 def nf_row(k_row): 403 return jnp.prod(per_axis_norm_factor(k_row, L_vec, False)) 404 norm_factors = jax.vmap(nf_row)(Ks) 405 406 #Build the callable function 407 Ks_int = Ks # float array, but only integer values 408 L_vec_f = L_vec 409 @jax.jit 410 def phi(x): 411 x = jnp.asarray(x,dtype = jnp.float32) 412 #Initialize with ones 413 vals = jnp.ones(x.shape[:-1] + (m,), dtype=jnp.float32) 414 #Compute eigenfunction 415 for i in range(d): 416 ang = (jnp.pi / L_vec_f[i]) * x[..., i][..., None] * Ks_int[:, i] 417 if bc.startswith("d"): 418 comp = jnp.sin(ang) 419 else: 420 comp = jnp.cos(ang) 421 vals = vals * comp 422 #Apply L2-normalizing constants 423 vals = vals * norm_factors[None, ...] if vals.ndim > 1 else vals * norm_factors 424 return vals 425 return phi, lambdas 426 427#Compute multiple frequences of domain aware fourrier fesatures 428def multiple_daff(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None): 429 """ 430 Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain. 431 ---------- 432 Parameters 433 ---------- 434 L_vec : list of lists of float 435 436 List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]] 437 438 kmax_per_axis : list 439 440 List with the maximum number of eigenfunctions per dimension. 441 442 bc : str 443 444 Boundary condition. 'dirichlet' or 'neumann' 445 446 max_ef : int 447 448 Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions 449 450 Returns 451 ------- 452 function to compute daff,eigenvalues of the eigenfunctions considered 453 """ 454 psi = [] 455 lamb = [] 456 for L in L_vec: 457 tmp,l = eigenf_laplace(L,kmax_per_axis,bc,max_ef) #Get function 458 lamb.append(l) 459 psi.append(tmp) 460 del tmp 461 #Create function to compute features 462 @jax.jit 463 def mff(x): 464 y = [] 465 for i in range(len(psi)): 466 y.append(psi[i](x)) 467 if len(psi) == 1: 468 return y[0] 469 else: 470 return jnp.concatenate(y,1) 471 return mff,jnp.concatenate(lamb) 472 473#Code for chebyshev polynomials writeen by AI (deprecated) 474def _chebyshev_T_all(t, K: int): 475 """ 476 Compute T_0..T_K(t) with the standard recurrence. 477 t shape should be (..., d). We DO NOT squeeze any axis to preserve 'd' 478 even when d == 1. 479 Returns: array of shape (K+1, ...) matching t's batch dims, including d. 480 """ 481 # Expect t to have last axis = d (keep it, even if d == 1) 482 T0 = jnp.ones_like(t) # (..., d) 483 if K == 0: 484 return T0[None, ...] # (1, ..., d) 485 486 T1 = t # (..., d) 487 if K == 1: 488 return jnp.stack([T0, T1], axis=0) # (2, ..., d) 489 490 def body(carry, _): 491 Tkm1, Tk = carry # each (..., d) 492 Tkp1 = 2.0 * t * Tk - Tkm1 # (..., d) 493 return (Tk, Tkp1), Tkp1 494 495 # K >= 2: produce T_2..T_K 496 (_, _), T2_to_TK = lax.scan(body, (T0, T1), jnp.arange(K - 1)) # (K-1, ..., d) 497 return jnp.concatenate([T0[None, ...], T1[None, ...], T2_to_TK], axis=0) # (K+1, ..., d) 498 499@partial(jax.jit,static_argnums=(2,)) # n is static here; compile once per n 500def multiple_cheb_fast(x, L_vec, n: int): 501 """ 502 x: (N, d) 503 L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension 504 n: number of k terms (static) 505 returns: (N, L*n) 506 """ 507 N, d = x.shape 508 L = L_vec.shape[0] 509 510 a = 0.0 511 b = L_vec # (L, d) 512 # Map x to t in [-1, 1] for each l, j: shape (L, N, d) 513 t = (2.0 * x[None, :, :] - (a + b)[:, None, :]) / (b - a)[:, None, :] 514 515 # Chebyshev T_0..T_{n+2} for all (L, N, d): shape (n+3, L, N, d) 516 T = _chebyshev_T_all(t, n + 2) 517 518 # phi_k = T_{k+2} - T_k, k = 0..n-1 => shape (n, L, N, d) 519 ks = jnp.arange(n) 520 phi = T[ks + 2, ...] - T[ks, ...] 521 522 # Multiply across dimensions (over the last axis = d) => (n, L, N) 523 z = jnp.prod(phi, axis=-1) 524 525 # Reorder to (N, L, n) then flatten to (N, L*n) 526 z = jnp.transpose(z, (2, 1, 0)).reshape(N, L * n) 527 return z 528 529def multiple_cheb(L_vec, n: int): 530 """ 531 Factory that closes over static n and L_vec (so shapes are constant). 532 """ 533 L_vec = jnp.asarray(L_vec) 534 @jax.jit # optional; multiple_cheb_fast is already jitted 535 def mcheb(x): 536 x = jnp.asarray(x) 537 return multiple_cheb_fast(x, L_vec, n) 538 return mcheb 539 540 541#Initialize fully connected neyral network Return the initial parameters and the function for the forward pass 542def fconNN(width,activation = jax.nn.tanh,key = 0,mlp = False,ftype = None,fargs = None,static = None): 543 """ 544 Initialize fully connected neural network 545 ---------- 546 Parameters 547 ---------- 548 width : list 549 550 List with the layers width 551 552 activation : jax.nn activation 553 554 The activation function. Default jax.nn.tanh 555 556 key : int 557 558 Seed for parameters initialization. Default 0 559 560 mlp : logical 561 562 Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension. 563 564 ftype : str 565 566 Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'. 567 568 fargs : list 569 570 Arguments for deature transformation: 571 572 For 'ff': A list with the number of frequences and value of greatest frequence standard deviation. 573 574 For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet. 575 576 static : function 577 578 A static function to sum to the neural network output. 579 580 Returns 581 ------- 582 dict with initial parameters and the function for the forward pass 583 """ 584 #Initialize parameters with Glorot initialization 585 initializer = jax.nn.initializers.glorot_normal() 586 params = list() 587 if static is None: 588 static = lambda x: 0.0 589 590 #Feature mapping 591 if ftype == 'ff': #Fourrier features 592 for s in range(fargs[0]): 593 sd = fargs[1] ** ((s + 1)/fargs[0]) 594 if s == 0: 595 Bff = sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))) 596 else: 597 Bff = jnp.append(Bff,sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))),1) 598 @jax.jit 599 def phi(x): 600 x = x @ Bff 601 return jnp.concatenate([jnp.sin(2 * jnp.pi * x),jnp.cos(2 * jnp.pi * x)],axis = -1) 602 width = width[1:] 603 width[0] = 2*Bff.shape[1] 604 elif ftype == 'daff' or ftype == 'daff_bias': 605 if not isinstance(fargs, dict): 606 fargs = {'L': fargs,'bc': "dirichlet"} 607 phi,lamb = multiple_daff(list(fargs.values())[0],kmax_per_axis = [width[1]] * width[0],bc = list(fargs.values())[1]) 608 width = width[1:] 609 width[0] = lamb.shape[0] 610 elif ftype == 'cheb' or ftype == 'cheb_bias': 611 phi = multiple_cheb(fargs,n = width[1]) 612 width = width[1:] 613 width[0] = len(fargs)*width[0] 614 else: 615 @jax.jit 616 def phi(x): 617 return x 618 619 #Initialize parameters 620 if mlp: 621 k = jax.random.split(jax.random.PRNGKey(key),4) 622 WU = initializer(k[0],(width[0],width[1]),jnp.float32) 623 BU = initializer(k[1],(1,width[1]),jnp.float32) 624 WV = initializer(k[2],(width[0],width[1]),jnp.float32) 625 BV = initializer(k[3],(1,width[1]),jnp.float32) 626 params.append({'WU':WU,'BU':BU,'WV':WV,'BV':BV}) 627 key = jax.random.split(jax.random.PRNGKey(key + 1),len(width)-1) #Seed for initialization 628 for key,lin,lout in zip(key,width[:-1],width[1:]): 629 W = initializer(key,(lin,lout),jnp.float32) 630 B = initializer(key,(1,lout),jnp.float32) 631 params.append({'W':W,'B':B}) 632 633 #Define function for forward pass 634 if mlp: 635 if ftype != 'daff' and ftype != 'cheb': 636 @jax.jit 637 def forward(x,params): 638 encode,*hidden,output = params 639 sx = static(x) 640 x = phi(x) 641 U = activation(x @ encode['WU'] + encode['BU']) 642 V = activation(x @ encode['WV'] + encode['BV']) 643 for layer in hidden: 644 x = activation(x @ layer['W'] + layer['B']) 645 x = x * U + (1 - x) * V 646 return x @ output['W'] + output['B'] + sx 647 else: 648 @jax.jit 649 def forward(x,params): 650 encode,*hidden,output = params 651 sx = static(x) 652 x = phi(x) 653 U = activation(x @ encode['WU']) 654 V = activation(x @ encode['WV']) 655 for layer in hidden: 656 x = activation(x @ layer['W']) 657 x = x * U + (1 - x) * V 658 return x @ output['W'] + sx 659 else: 660 if ftype != 'daff' and ftype != 'cheb': 661 @jax.jit 662 def forward(x,params): 663 *hidden,output = params 664 sx = static(x) 665 x = phi(x) 666 for layer in hidden: 667 x = activation(x @ layer['W'] + layer['B']) 668 return x @ output['W'] + output['B'] + sx 669 else: 670 @jax.jit 671 def forward(x,params): 672 *hidden,output = params 673 sx = static(x) 674 x = phi(x) 675 for layer in hidden: 676 x = activation(x @ layer['W']) 677 return x @ output['W'] + sx 678 679 #Return initial parameters and forward function 680 return {'params': params,'forward': forward} 681 682#Get activation from string 683def get_activation(act): 684 """ 685 Return activation function from string 686 ---------- 687 Parameters 688 ---------- 689 act : str 690 691 Name of the activation function. Default 'tanh' 692 693 Returns 694 ------- 695 jax.nn activation function 696 """ 697 if act == 'tanh': 698 return jax.nn.tanh 699 elif act == 'relu': 700 return jax.nn.relu 701 elif act == 'relu6': 702 return jax.nn.relu6 703 elif act == 'sigmoid': 704 return jax.nn.sigmoid 705 elif act == 'softplus': 706 return jax.nn.softplus 707 elif act == 'sparse_plus': 708 return jx.nn.sparse_plus 709 elif act == 'soft_sign': 710 return jax.nn.soft_sign 711 elif act == 'silu': 712 return jax.nn.silu 713 elif act == 'swish': 714 return jax.nn.swish 715 elif act == 'log_sigmoid': 716 return jax.nn.log_sigmoid 717 elif act == 'leaky_relu': 718 return jax.nn.leaky_relu 719 elif act == 'hard_sigmoid': 720 return jax.nn.hard_sigmoid 721 elif act == 'hard_silu': 722 return jax.nn.hard_silu 723 elif act == 'hard_swish': 724 return jax.nn.hard_swish 725 elif act == 'hard_tanh': 726 return jax.nn.hard_tanh 727 elif act == 'elu': 728 return jax.nn.elu 729 elif act == 'celu': 730 return jax.nn.celu 731 elif act == 'selu': 732 return jax.nn.selu 733 elif act == 'gelu': 734 return jax.nn.gelu 735 elif act == 'glu': 736 return jax.nn.glu 737 elif act == 'squareplus': 738 return jax.nn.squareplus 739 elif act == 'mish': 740 return jax.nn.mish 741 742#Training PINN 743def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,mlp = False): 744 """ 745 Train a Physics-informed Neural Network 746 ---------- 747 Parameters 748 ---------- 749 data : dict 750 751 Data generated by the jinnax.data.generate_PINNdata function 752 753 width : list 754 755 A list with the width of each layer 756 757 pde : function 758 759 The partial differential operator. Its arguments are u, x and t 760 761 test_data : dict, None 762 763 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 764 765 epochs : int 766 767 Number of training epochs. Default 100 768 769 at_each : int 770 771 Save results for epochs multiple of at_each. Default 10 772 773 activation : str 774 775 The name of the activation function of the neural network. Default 'tanh' 776 777 neumann : logical 778 779 Whether to consider Neumann boundary conditions 780 781 oper_neumann : function 782 783 Penalization of Neumann boundary conditions 784 785 sa : logical 786 787 Whether to consider self-adaptative PINN 788 789 c : dict 790 791 Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1 792 793 inverse : logical 794 795 Whether to estimate parameters of the PDE 796 797 initial_par : jax.numpy.array 798 799 Initial value of the parameters of the PDE in an inverse problem 800 801 lr,b1,b2,eps,eps_root: float 802 803 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 804 805 key : int 806 807 Seed for parameters initialization. Default 0 808 809 epoch_print : int 810 811 Number of epochs to calculate and print test errors. Default 100 812 813 save : logical 814 815 Whether to save the current parameters. Default False 816 817 file_name : str 818 819 File prefix to save the current parameters. Default 'result_pinn' 820 821 exp_decay : logical 822 823 Whether to consider exponential decay of learning rate. Default False 824 825 transition_steps : int 826 827 Number of steps for exponential decay. Default 1000 828 829 decay_rate : float 830 831 Rate of exponential decay. Default 0.9 832 833 mlp : logical 834 835 Whether to consider modifed multi-layer perceptron 836 837 Returns 838 ------- 839 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time 840 """ 841 842 #Initialize architecture 843 nnet = fconNN(width,get_activation(activation),key,mlp) 844 forward = nnet['forward'] 845 846 #Initialize self adaptative weights 847 par_sa = {} 848 if sa: 849 #Initialize wheights close to zero 850 ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000) 851 if data['sensor'] is not None: 852 par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))}) 853 if data['initial'] is not None: 854 par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))}) 855 if data['collocation'] is not None: 856 par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))}) 857 if data['boundary'] is not None: 858 par_sa.update({'wb': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))}) 859 860 #Store all parameters 861 params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa} 862 863 #Save config file 864 if save: 865 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 866 867 #Define loss function 868 if sa: 869 #Define loss function 870 @jax.jit 871 def lf(params,x): 872 loss = 0 873 if x['sensor'] is not None: 874 #Term that refers to sensor data 875 loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws'])) 876 if x['boundary'] is not None: 877 if neumann: 878 #Neumann coditions 879 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 880 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 881 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb'])) 882 else: 883 #Term that refers to boundary data 884 loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb'])) 885 if x['initial'] is not None: 886 #Term that refers to initial data 887 loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0'])) 888 if x['collocation'] is not None: 889 #Term that refers to collocation points 890 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 891 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 892 if inverse: 893 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr'])) 894 else: 895 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr'])) 896 return loss 897 else: 898 @jax.jit 899 def lf(params,x): 900 loss = 0 901 if x['sensor'] is not None: 902 #Term that refers to sensor data 903 loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor'])) 904 if x['boundary'] is not None: 905 if neumann: 906 #Neumann coditions 907 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 908 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 909 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb)) 910 else: 911 #Term that refers to boundary data 912 loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary'])) 913 if x['initial'] is not None: 914 #Term that refers to initial data 915 loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial'])) 916 if x['collocation'] is not None: 917 #Term that refers to collocation points 918 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 919 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 920 if inverse: 921 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0)) 922 else: 923 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0)) 924 return loss 925 926 #Initialize Adam Optmizer 927 if exp_decay: 928 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 929 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 930 opt_state = optimizer.init(params) 931 932 #Define the gradient function 933 grad_loss = jax.jit(jax.grad(lf,0)) 934 935 #Define update function 936 @jax.jit 937 def update(opt_state,params,x): 938 #Compute gradient 939 grads = grad_loss(params,x) 940 #Invert gradient of self-adaptative wheights 941 if sa: 942 for w in grads['sa']: 943 grads['sa'][w] = - grads['sa'][w] 944 #Calculate parameters updates 945 updates, opt_state = optimizer.update(grads, opt_state) 946 #Update parameters 947 params = optax.apply_updates(params, updates) 948 #Return state of optmizer and updated parameters 949 return opt_state,params 950 951 ###Training### 952 t0 = time.time() 953 #Initialize alive_bar for tracing in terminal 954 with alive_bar(epochs) as bar: 955 #For each epoch 956 for e in range(epochs): 957 #Update optimizer state and parameters 958 opt_state,params = update(opt_state,params,data) 959 #After epoch_print epochs 960 if e % epoch_print == 0: 961 #Compute elapsed time and current error 962 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6)) 963 #If there is test data, compute current L2 error 964 if test_data is not None: 965 #Compute L2 error 966 l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist() 967 l = l + ' L2 error: ' + str(jnp.round(l2_test,3)) 968 if inverse: 969 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 970 #Print 971 print(l) 972 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 973 #Save current parameters 974 pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 975 #Update alive_bar 976 bar() 977 #Define estimated function 978 def u(xt): 979 return forward(xt,params['net']) 980 981 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0} 982 983#Training PINN 984def train_Matern_PINN(data,width,pde,test_data = None,params = None,d = 2,N = 128,L = 1,alpha = 1,kappa = 1,sigma = 100,bsize = 1024,resample = False,epochs = 100,at_each = 10,activation = 'tanh', 985 neumann = False,oper_neumann = None,inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 1,save = False,file_name = 'result_pinn', 986 exp_decay = True,transition_steps = 100,decay_rate = 0.9,mlp = True,ftype = None,fargs = None,q = 4,w = None,periodic = False,static = None,opt = 'LBFGS'): 987 """ 988 Train a Physics-informed Neural Network 989 ---------- 990 Parameters 991 ---------- 992 data : dict 993 994 Data generated by the jinnax.data.generate_PINNdata function 995 996 width : list 997 998 A list with the width of each layer 999 1000 pde : function 1001 1002 The partial differential operator. Its arguments are u, x and t 1003 1004 test_data : dict, None 1005 1006 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 1007 1008 params : list 1009 1010 Initial parameters for the neural network. Default None to initialize randomly 1011 1012 d : int 1013 1014 Dimension of the problem including the time variable if present. Default 2 1015 1016 N : int 1017 1018 Size of grid in each dimension. Default 128 1019 1020 L : list of float 1021 1022 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 1023 1024 kappa,alpha,sigma : float 1025 1026 Parameters of the Matern process 1027 1028 bsize : int 1029 1030 Batch size for weak norm computation. Default 1024 1031 1032 resample : logical 1033 1034 Whether to resample the test functions at each epoch 1035 1036 epochs : int 1037 1038 Number of training epochs. Default 100 1039 1040 at_each : int 1041 1042 Save results for epochs multiple of at_each. Default 10 1043 1044 activation : str 1045 1046 The name of the activation function of the neural network. Default 'tanh' 1047 1048 neumann : logical 1049 1050 Whether to consider Neumann boundary conditions 1051 1052 oper_neumann : function 1053 1054 Penalization of Neumann boundary conditions 1055 1056 inverse : logical 1057 1058 Whether to estimate parameters of the PDE 1059 1060 initial_par : jax.numpy.array 1061 1062 Initial value of the parameters of the PDE in an inverse problem 1063 1064 lr,b1,b2,eps,eps_root: float 1065 1066 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 1067 1068 key : int 1069 1070 Seed for parameters initialization. Default 0 1071 1072 epoch_print : int 1073 1074 Number of epochs to calculate and print test errors. Default 1 1075 1076 save : logical 1077 1078 Whether to save the current parameters. Default False 1079 1080 file_name : str 1081 1082 File prefix to save the current parameters. Default 'result_pinn' 1083 1084 exp_decay : logical 1085 1086 Whether to consider exponential decay of learning rate. Default True 1087 1088 transition_steps : int 1089 1090 Number of steps for exponential decay. Default 100 1091 1092 decay_rate : float 1093 1094 Rate of exponential decay. Default 0.9 1095 1096 mlp : logical 1097 1098 Whether to consider modifed multilayer perceptron 1099 1100 ftype : str 1101 1102 Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'. 1103 1104 fargs : list 1105 1106 Arguments for deature transformation: 1107 1108 For 'ff': A list with the number of frequences and value of greatest frequence standard deviation. 1109 1110 For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet. 1111 1112 q : int 1113 1114 Power of weights mask. Default 4 1115 1116 w : dict 1117 1118 Initila weights for self-adaptive scheme. 1119 1120 periodic : logical 1121 1122 Whether to consider periodic test functions. Default False. 1123 1124 static : function 1125 1126 A static function to sum to the neural network output. 1127 1128 opt : str 1129 1130 Optimizer. Default LBFGS. 1131 1132 Returns 1133 ------- 1134 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch 1135 """ 1136 #Initialize architecture 1137 nnet = fconNN(width,get_activation(activation),key,mlp,ftype,fargs,static) 1138 forward = nnet['forward'] 1139 if params is not None: 1140 nnet['params'] = params 1141 1142 #Generate from Matern process 1143 if sigma > 0: 1144 if isinstance(L,float) or isinstance(L,int): 1145 L = d*[L] 1146 #Grid for weak norm 1147 grid = [jnp.linspace(0,L[i],N) for i in range(d)] 1148 grid = jnp.meshgrid(*grid, indexing='ij') 1149 grid = jnp.stack(grid, axis=-1).reshape((-1, d)) 1150 #Set sigma 1151 if data['boundary'] is not None: 1152 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma) 1153 tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0]) 1154 if neumann: 1155 loss_boundary = oper_neumann(lambda x: forward(x,params['net']),data['boundary']) 1156 else: 1157 loss_boundary = jnp.mean(MSE(forward(data['boundary'],nnet['params']),data['uboundary'])) 1158 output_w = pde(lambda x: forward(x,nnet['params']),grid) 1159 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(tf) 1160 loss_res_weak = jnp.mean(integralOmega ** 2) 1161 sigma = float(jnp.sqrt(loss_boundary/loss_res_weak).tolist()) 1162 del gen 1163 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic) 1164 tf = sigma*tf 1165 else: 1166 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic) 1167 tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0]) 1168 1169 #Define loss function 1170 @jax.jit 1171 def lf_each(params,x,k): 1172 if sigma > 0: 1173 #Term that refers to weak loss 1174 if resample: 1175 test_functions = gen(jax.random.split(jax.random.PRNGKey(k[0]),(bsize,))[:,0]) 1176 else: 1177 test_functions = tf 1178 loss_sensor = loss_boundary = loss_initial = loss_res = loss_res_weak = 0 1179 if x['sensor'] is not None: 1180 #Term that refers to sensor data 1181 loss_sensor = jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor'])) 1182 if x['boundary'] is not None: 1183 if neumann: 1184 #Neumann coditions 1185 loss_boundary = oper_neumann(lambda x: forward(x,params['net']),x['boundary']) 1186 else: 1187 #Term that refers to boundary data 1188 loss_boundary = MSE(forward(x['boundary'],params['net']),x['uboundary']) 1189 if x['initial'] is not None: 1190 #Term that refers to initial data 1191 loss_initial = MSE(forward(x['initial'],params['net']),x['uinitial']) 1192 if x['collocation'] is not None and sigma == 0: 1193 if inverse: 1194 output = pde(lambda x: forward(x,params['net']),x['collocation'],params['inverse']) 1195 loss_res = MSE(output,0) 1196 else: 1197 output = pde(lambda x: forward(x,params['net']),x['collocation']) 1198 loss_res = MSE(output,0) 1199 if sigma > 0: 1200 #Term that refers to weak loss 1201 if inverse: 1202 output_w = pde(lambda x: forward(x,params['net']),grid,params['inverse']) 1203 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions) 1204 loss_res_weak = jnp.mean(integralOmega ** 2) 1205 else: 1206 output_w = pde(lambda x: forward(x,params['net']),grid) 1207 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions) 1208 loss_res_weak = jnp.mean(integralOmega ** 2) 1209 return {'ls': loss_sensor,'lb': loss_boundary,'li': loss_initial,'lc': loss_res,'lc_weak': loss_res_weak} 1210 1211 @jax.jit 1212 def lf(params,x,k): 1213 l = lf_each(params,x,k) 1214 w = params['w'] 1215 loss = jnp.mean((w['ws'] ** q)*l['ls']) + jnp.mean((w['wb'] ** q)*l['lb']) + jnp.mean((w['wi'] ** q)*l['li']) + jnp.mean((w['wc'] ** q)*l['lc']) + (w['wc_weak'] ** q)*l['lc_weak'] 1216 if opt != 'LBFGS': 1217 return loss 1218 else: 1219 l2 = None 1220 if test_data is not None: 1221 l2 = L2error(forward(test_data['sensor'],params['net']),test_data['usensor']) 1222 return loss,{'loss': loss,'l2': l2} 1223 1224 #Initialize self-adaptive weights 1225 if w is None: 1226 w = {'ws': jnp.array(1.0),'wb': jnp.array(1.0),'wi': jnp.array(1.0),'wc': jnp.array(1.0),'wc_weak': jnp.array(1.0)} 1227 if q != 0: 1228 if data['sensor'] is not None: 1229 w['ws'] = w['ws'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+1),(data['sensor'].shape[0],1)) 1230 if data['boundary'] is not None: 1231 w['wb'] = w['wb'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+2),(data['boundary'].shape[0],1)) 1232 if data['initial'] is not None: 1233 w['wi'] = w['wi'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+3),(data['initial'].shape[0],1)) 1234 if data['collocation'] is not None: 1235 w['wc'] = w['wc'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+4),(data['collocation'].shape[0],1)) 1236 1237 #Store all parameters 1238 params = {'net': nnet['params'],'inverse': initial_par,'w': w} 1239 1240 #Save config file 1241 if save: 1242 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1243 1244 #Initialize Adam Optmizer 1245 if opt != 'LBFGS': 1246 print('--------- GRADIENT DESCENT OPTIMIZER ---------') 1247 if exp_decay: 1248 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 1249 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 1250 opt_state = optimizer.init(params) 1251 1252 #Define the gradient function 1253 grad_loss = jax.jit(jax.grad(lf,0)) 1254 1255 #Define update function 1256 @jax.jit 1257 def update(opt_state,params,x,k): 1258 #Compute gradient 1259 grads = grad_loss(params,x,k) 1260 #Calculate parameters updates 1261 updates, opt_state = optimizer.update(grads, opt_state) 1262 #Update parameters 1263 if q != 0: 1264 updates = {**updates, 'w': jax.tree_util.tree_map(lambda x: -x, updates['w'])} #Change signs of weights 1265 params = optax.apply_updates(params, updates) 1266 #Return state of optmizer and updated parameters 1267 return opt_state,params 1268 else: 1269 print('--------- LBFGS OPTIMIZER ---------') 1270 @jax.jit 1271 def loss_LBFGS(params): 1272 return lf(params,data,key + 234) 1273 solver = LBFGS(fun = loss_LBFGS,has_aux = True,maxiter = epochs,tol = 1e-9,verbose = False,linesearch = 'zoom',history_size = 100) # linesearch='zoom' by default 1274 state = solver.init_state(params) 1275 1276 ###Training### 1277 t0 = time.time() 1278 k = jax.random.split(jax.random.PRNGKey(key+234),(epochs,)) 1279 sloss = [] 1280 sL2 = [] 1281 stime = [] 1282 #Initialize alive_bar for tracing in terminal 1283 with alive_bar(epochs) as bar: 1284 #For each epoch 1285 for e in range(epochs): 1286 if opt != 'LBFGS': 1287 #Update optimizer state and parameters 1288 opt_state,params = update(opt_state,params,data,k[e,:]) 1289 sloss.append(lf(params,data,k[e,:])) 1290 if test_data is not None: 1291 sL2.append(L2error(forward(test_data['sensor'],params['net']),test_data['usensor'])) 1292 else: 1293 params, state = solver.update(params, state) 1294 sL2.append(state.aux["l2"]) 1295 sloss.append(state.aux["loss"]) 1296 stime.append(time.time() - t0) 1297 #After epoch_print epochs 1298 if e % epoch_print == 0: 1299 #Compute elapsed time and current error 1300 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(sloss[-1],6)) 1301 #If there is test data, compute current L2 error 1302 if test_data is not None: 1303 #Compute L2 error 1304 l = l + ' L2 error: ' + str(jnp.round(sL2[-1],6)) 1305 if inverse: 1306 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 1307 #Print 1308 print(l) 1309 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 1310 #Save current parameters 1311 pickle.dump({'params': params,'width': width,'time': stime,'loss': sloss,'L2error': sL2},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1312 #Update alive_bar 1313 bar() 1314 #Define estimated function 1315 def u(xt): 1316 return forward(xt,params['net']) 1317 1318 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0,'loss_each': lf_each(params,data,[key + 100]),'loss': sloss,'L2error': sL2} 1319 1320 1321#Process result 1322def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1): 1323 """ 1324 Process the results of a Physics-informed Neural Network 1325 ---------- 1326 1327 Parameters 1328 ---------- 1329 test_data : dict 1330 1331 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1332 1333 fit : function 1334 1335 The fitted function 1336 1337 train_data : dict 1338 1339 Training data generated by the jinnax.data.generate_PINNdata 1340 1341 plot : logical 1342 1343 Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True 1344 1345 plot_test : logical 1346 1347 Whether to plot the test data. Default True 1348 1349 times : int 1350 1351 Number of points along the time interval to plot. Default 5 1352 1353 d2 : logical 1354 1355 Whether to plot 2D plot when the spatial dimension is one. Default True 1356 1357 save : logical 1358 1359 Whether to save the plots. Default False 1360 1361 show : logical 1362 1363 Whether to show the plots. Default True 1364 1365 file_name : str 1366 1367 File prefix to save the plots. Default 'result_pinn' 1368 1369 print_res : logical 1370 1371 Whether to print the L2 error. Default True 1372 1373 p : int 1374 1375 Output dimension. Default 1 1376 1377 Returns 1378 ------- 1379 pandas data frame with L2 and MSE errors 1380 """ 1381 1382 #Dimension 1383 d = test_data['xt'].shape[1] - 1 1384 1385 #Number of plots multiple of 5 1386 times = 5 * round(times/5.0) 1387 1388 #Data 1389 td = get_train_data(train_data) 1390 xt_train = td['x'] 1391 u_train = td['y'] 1392 upred_train = fit(xt_train) 1393 upred_test = fit(test_data['xt']) 1394 1395 #Results 1396 l2_error_test = L2error(upred_test,test_data['u']).tolist() 1397 MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist() 1398 l2_error_train = L2error(upred_train,u_train).tolist() 1399 MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist() 1400 1401 df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)), 1402 columns=['l2_error_test','MSE_test','l2_error_train','MSE_train']) 1403 if print_res: 1404 print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) ) 1405 1406 #Plots 1407 if d == 1 and p ==1 and plot: 1408 plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name) 1409 elif p == 2 and plot: 1410 plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test) 1411 1412 return df 1413 1414#Plot results for d = 1 1415def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''): 1416 """ 1417 Plot the prediction of a 1D PINN 1418 ---------- 1419 1420 Parameters 1421 ---------- 1422 times : int 1423 1424 Number of points along the time interval to plot. Default 5 1425 1426 xt : jax.numpy.array 1427 1428 Test data xt array 1429 1430 u : jax.numpy.array 1431 1432 Test data u(x,t) array 1433 1434 upred : jax.numpy.array 1435 1436 Predicted upred(x,t) array on test data 1437 1438 d2 : logical 1439 1440 Whether to plot 2D plot. Default True 1441 1442 save : logical 1443 1444 Whether to save the plots. Default False 1445 1446 show : logical 1447 1448 Whether to show the plots. Default True 1449 1450 file_name : str 1451 1452 File prefix to save the plots. Default 'result_pinn' 1453 1454 title_1d : str 1455 1456 Title of 1D plot 1457 1458 title_2d : str 1459 1460 Title of 2D plot 1461 1462 Returns 1463 ------- 1464 None 1465 """ 1466 #Initialize 1467 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 1468 tlo = jnp.min(xt[:,-1]) 1469 tup = jnp.max(xt[:,-1]) 1470 ylo = jnp.min(u) 1471 ylo = ylo - 0.1*jnp.abs(ylo) 1472 yup = jnp.max(u) 1473 yup = yup + 0.1*jnp.abs(yup) 1474 k = 0 1475 t_values = np.linspace(tlo,tup,times) 1476 1477 #Create 1478 for i in range(int(times/5)): 1479 for j in range(5): 1480 if k < len(t_values): 1481 t = t_values[k] 1482 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 1483 x_plot = xt[xt[:,-1] == t,:-1] 1484 y_plot = upred[xt[:,-1] == t,:] 1485 u_plot = u[xt[:,-1] == t,:] 1486 if int(times/5) > 1: 1487 ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 1488 ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 1489 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 1490 ax[i,j].set_xlabel(' ') 1491 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1492 else: 1493 ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 1494 ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 1495 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 1496 ax[j].set_xlabel(' ') 1497 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1498 k = k + 1 1499 1500 #Title 1501 fig.suptitle(title_1d) 1502 fig.tight_layout() 1503 1504 #Show and save 1505 fig = plt.gcf() 1506 if show: 1507 plt.show() 1508 if save: 1509 fig.savefig(file_name + '_slices.png') 1510 plt.close() 1511 1512 #2d plot 1513 if d2: 1514 #Initialize 1515 fig, ax = plt.subplots(1,2) 1516 l1 = jnp.unique(xt[:,-1]).shape[0] 1517 l2 = jnp.unique(xt[:,0]).shape[0] 1518 1519 #Create 1520 ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 1521 ax[0].set_title('Exact') 1522 ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 1523 ax[1].set_title('Predicted') 1524 1525 #Title 1526 fig.suptitle(title_2d) 1527 fig.tight_layout() 1528 1529 #Show and save 1530 fig = plt.gcf() 1531 if show: 1532 plt.show() 1533 if save: 1534 fig.savefig(file_name + '_2d.png') 1535 plt.close() 1536 1537#Plot results for d = 1 1538def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True): 1539 """ 1540 Plot the prediction of a PINN with 2D output 1541 ---------- 1542 Parameters 1543 ---------- 1544 times : int 1545 1546 Number of points along the time interval to plot. Default 5 1547 1548 xt : jax.numpy.array 1549 1550 Test data xt array 1551 1552 u : jax.numpy.array 1553 1554 Test data u(x,t) array 1555 1556 upred : jax.numpy.array 1557 1558 Predicted upred(x,t) array on test data 1559 1560 save : logical 1561 1562 Whether to save the plots. Default False 1563 1564 show : logical 1565 1566 Whether to show the plots. Default True 1567 1568 file_name : str 1569 1570 File prefix to save the plots. Default 'result_pinn' 1571 1572 title : str 1573 1574 Title of plot 1575 1576 plot_test : logical 1577 1578 Whether to plot the test data. Default True 1579 1580 Returns 1581 ------- 1582 None 1583 """ 1584 #Initialize 1585 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 1586 tlo = jnp.min(xt[:,-1]) 1587 tup = jnp.max(xt[:,-1]) 1588 xlo = jnp.min(u[:,0]) 1589 xlo = xlo - 0.1*jnp.abs(xlo) 1590 xup = jnp.max(u[:,0]) 1591 xup = xup + 0.1*jnp.abs(xup) 1592 ylo = jnp.min(u[:,1]) 1593 ylo = ylo - 0.1*jnp.abs(ylo) 1594 yup = jnp.max(u[:,1]) 1595 yup = yup + 0.1*jnp.abs(yup) 1596 k = 0 1597 t_values = np.linspace(tlo,tup,times) 1598 1599 #Create 1600 for i in range(int(times/5)): 1601 for j in range(5): 1602 if k < len(t_values): 1603 t = t_values[k] 1604 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 1605 xpred_plot = upred[xt[:,-1] == t,0] 1606 ypred_plot = upred[xt[:,-1] == t,1] 1607 if plot_test: 1608 x_plot = u[xt[:,-1] == t,0] 1609 y_plot = u[xt[:,-1] == t,1] 1610 if int(times/5) > 1: 1611 if plot_test: 1612 ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 1613 ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction') 1614 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 1615 ax[i,j].set_xlabel(' ') 1616 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1617 else: 1618 if plot_test: 1619 ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 1620 ax[j].plot(xpred_plot,ypred,'r-',linewidth=2,label='Prediction') 1621 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 1622 ax[j].set_xlabel(' ') 1623 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1624 k = k + 1 1625 1626 #Title 1627 fig.suptitle(title) 1628 fig.tight_layout() 1629 1630 #Show and save 1631 fig = plt.gcf() 1632 if show: 1633 plt.show() 1634 if save: 1635 fig.savefig(file_name + '_slices.png') 1636 plt.close() 1637 1638#Get train data in one array 1639def get_train_data(train_data): 1640 """ 1641 Process training sample 1642 ---------- 1643 1644 Parameters 1645 ---------- 1646 train_data : dict 1647 1648 A dictionay with train data generated by the jinnax.data.generate_PINNdata function 1649 1650 Returns 1651 ------- 1652 dict with the processed training data 1653 """ 1654 xdata = None 1655 ydata = None 1656 xydata = None 1657 if train_data['sensor'] is not None: 1658 sensor_sample = train_data['sensor'].shape[0] 1659 xdata = train_data['sensor'] 1660 ydata = train_data['usensor'] 1661 xydata = jnp.column_stack((train_data['sensor'],train_data['usensor'])) 1662 else: 1663 sensor_sample = 0 1664 if train_data['boundary'] is not None: 1665 boundary_sample = train_data['boundary'].shape[0] 1666 if xdata is not None: 1667 xdata = jnp.vstack((xdata,train_data['boundary'])) 1668 ydata = jnp.vstack((ydata,train_data['uboundary'])) 1669 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary'])))) 1670 else: 1671 xdata = train_data['boundary'] 1672 ydata = train_data['uboundary'] 1673 xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary'])) 1674 else: 1675 boundary_sample = 0 1676 if train_data['initial'] is not None: 1677 initial_sample = train_data['initial'].shape[0] 1678 if xdata is not None: 1679 xdata = jnp.vstack((xdata,train_data['initial'])) 1680 ydata = jnp.vstack((ydata,train_data['uinitial'])) 1681 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial'])))) 1682 else: 1683 xdata = train_data['initial'] 1684 ydata = train_data['uinitial'] 1685 xydata = jnp.column_stack((train_data['initial'],train_data['uinitial'])) 1686 else: 1687 initial_sample = 0 1688 if train_data['collocation'] is not None: 1689 collocation_sample = train_data['collocation'].shape[0] 1690 else: 1691 collocation_sample = 0 1692 1693 return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample} 1694 1695#Process training 1696def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1): 1697 """ 1698 Process the training of a Physics-informed Neural Network 1699 ---------- 1700 1701 Parameters 1702 ---------- 1703 test_data : dict 1704 1705 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1706 1707 file_name : str 1708 1709 Name of the files saved during training 1710 1711 at_each : int 1712 1713 Compute results for epochs multiple of at_each. Default 100 1714 1715 bolstering : logical 1716 1717 Whether to compute bolstering mean square error. Default True 1718 1719 mc_sample : int 1720 1721 Number of sample for Monte Carlo integration in bolstering. Default 10000 1722 1723 save : logical 1724 1725 Whether to save the training results. Default False 1726 1727 file_name_save : str 1728 1729 File prefix to save the plots and the L2 error. Default 'result_pinn' 1730 1731 key : int 1732 1733 Key for random samples in bolstering. Default 0 1734 1735 ec : float 1736 1737 Stopping criteria error for EM algorithm in bolstering. Default 1e-6 1738 1739 lamb : float 1740 1741 Hyperparameter of EM algorithm in bolstering. Default 1 1742 1743 Returns 1744 ------- 1745 pandas data frame with training results 1746 """ 1747 #Config 1748 config = pickle.load(open(file_name + '_config.pickle', 'rb')) 1749 epochs = config['epochs'] 1750 train_data = config['train_data'] 1751 forward = config['forward'] 1752 1753 #Get train data 1754 td = get_train_data(train_data) 1755 xydata = td['xy'] 1756 xdata = td['x'] 1757 ydata = td['y'] 1758 sensor_sample = td['sensor_sample'] 1759 boundary_sample = td['boundary_sample'] 1760 initial_sample = td['initial_sample'] 1761 collocation_sample = td['collocation_sample'] 1762 1763 #Generate keys 1764 if bolstering: 1765 keys = jax.random.split(jax.random.PRNGKey(key),epochs) 1766 1767 #Initialize loss 1768 train_mse = [] 1769 test_mse = [] 1770 train_L2 = [] 1771 test_L2 = [] 1772 bolstX = [] 1773 bolstXY = [] 1774 loss = [] 1775 time = [] 1776 ep = [] 1777 1778 #Process training 1779 with alive_bar(epochs) as bar: 1780 for e in range(epochs): 1781 if (e % at_each == 0 and at_each != epochs) or e == epochs - 1: 1782 ep = ep + [e] 1783 1784 #Read parameters 1785 params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb')) 1786 1787 #Time 1788 time = time + [params['time']] 1789 1790 #Define learned function 1791 def psi(x): 1792 return forward(x,params['params']['net']) 1793 1794 #Train MSE and L2 1795 if xdata is not None: 1796 train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()] 1797 train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()] 1798 else: 1799 train_mse = train_mse + [None] 1800 train_L2 = train_L2 + [None] 1801 1802 #Test MSE and L2 1803 test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()] 1804 test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()] 1805 1806 #Bolstering 1807 if bolstering: 1808 bX = [] 1809 bXY = [] 1810 for method in ['chi','mm','mpe']: 1811 kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 1812 kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 1813 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 1814 bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()] 1815 for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]: 1816 kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias) 1817 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 1818 bolstX = bolstX + [bX] 1819 bolstXY = bolstXY + [bXY] 1820 else: 1821 bolstX = bolstX + [None] 1822 bolstXY = bolstXY + [None] 1823 1824 #Loss 1825 loss = loss + [params['loss'].tolist()] 1826 1827 #Delete 1828 del params, psi 1829 #Update alive_bar 1830 bar() 1831 1832 #Bolstering results 1833 if bolstering: 1834 bolstX = jnp.array(bolstX) 1835 bolstXY = jnp.array(bolstXY) 1836 1837 #Create data frame 1838 if bolstering: 1839 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 1840 train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]), 1841 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4']) 1842 else: 1843 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 1844 train_mse,test_mse,train_L2,test_L2]), 1845 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2']) 1846 if save: 1847 df.to_csv(file_name_save + '.csv',index = False) 1848 1849 return df 1850 1851#Demo video for training1D PINN 1852def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10): 1853 """ 1854 Demo video with the training of a 1D PINN 1855 ---------- 1856 1857 Parameters 1858 ---------- 1859 test_data : dict 1860 1861 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1862 1863 file_name : str 1864 1865 Name of the files saved during training 1866 1867 at_each : int 1868 1869 Compute results for epochs multiple of at_each. Default 100 1870 1871 times : int 1872 1873 Number of points along the time interval to plot. Default 5 1874 1875 d2 : logical 1876 1877 Whether to make video demo of 2D plot. Default True 1878 1879 file_name_save : str 1880 1881 File prefix to save the plots and videos. Default 'result_pinn_demo' 1882 1883 title : str 1884 1885 Title for plots 1886 1887 framerate : int 1888 1889 Framerate for video. Default 10 1890 1891 Returns 1892 ------- 1893 None 1894 """ 1895 #Config 1896 with open(file_name + '_config.pickle', 'rb') as file: 1897 config = pickle.load(file) 1898 epochs = config['epochs'] 1899 train_data = config['train_data'] 1900 forward = config['forward'] 1901 1902 #Get train data 1903 td = get_train_data(train_data) 1904 xt = td['x'] 1905 u = td['y'] 1906 1907 #Create folder to save plots 1908 os.system('mkdir ' + file_name_save) 1909 1910 #Create images 1911 k = 1 1912 with alive_bar(epochs) as bar: 1913 for e in range(epochs): 1914 if e % at_each == 0 or e == epochs - 1: 1915 #Read parameters 1916 params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1917 1918 #Define learned function 1919 def psi(x): 1920 return forward(x,params['params']['net']) 1921 1922 #Compute L2 train, L2 test and loss 1923 loss = params['loss'] 1924 L2_train = L2error(psi(xt),u) 1925 L2_test = L2error(psi(test_data['xt']),test_data['u']) 1926 title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6)) 1927 1928 #Save plot 1929 plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch) 1930 k = k + 1 1931 1932 #Delete 1933 del params, psi, loss, L2_train, L2_test, title_epoch 1934 #Update alive_bar 1935 bar() 1936 #Create demo video 1937 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4') 1938 if d2: 1939 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4') 1940 1941#Demo in time for 1D PINN 1942def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10): 1943 """ 1944 Demo video with the time evolution of a 1D PINN 1945 ---------- 1946 1947 Parameters 1948 ---------- 1949 test_data : dict 1950 1951 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1952 1953 file_name : str 1954 1955 Name of the files saved during training 1956 1957 epochs : list 1958 1959 Which training epochs to plot 1960 1961 file_name_save : str 1962 1963 File prefix to save the plots and video. Default 'result_pinn_time_demo' 1964 1965 title : str 1966 1967 Title for plots 1968 1969 framerate : int 1970 1971 Framerate for video. Default 10 1972 1973 Returns 1974 ------- 1975 None 1976 """ 1977 #Config 1978 with open(file_name + '_config.pickle', 'rb') as file: 1979 config = pickle.load(file) 1980 train_data = config['train_data'] 1981 forward = config['forward'] 1982 1983 #Create folder to save plots 1984 os.system('mkdir ' + file_name_save) 1985 1986 #Plot parameters 1987 tdom = jnp.unique(test_data['xt'][:,-1]) 1988 ylo = jnp.min(test_data['u']) 1989 ylo = ylo - 0.1*jnp.abs(ylo) 1990 yup = jnp.max(test_data['u']) 1991 yup = yup + 0.1*jnp.abs(yup) 1992 1993 #Open PINN for each epoch 1994 results = [] 1995 upred = [] 1996 for e in epochs: 1997 tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1998 results = results + [tmp] 1999 upred = upred + [forward(test_data['xt'],tmp['params']['net'])] 2000 2001 #Create images 2002 k = 1 2003 with alive_bar(len(tdom)) as bar: 2004 for t in tdom: 2005 #Test data 2006 xt_step = test_data['xt'][test_data['xt'][:,-1] == t] 2007 u_step = test_data['u'][test_data['xt'][:,-1] == t] 2008 #Initialize plot 2009 if len(epochs) > 1: 2010 fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2)) 2011 else: 2012 fig, ax = plt.subplots(1,1,figsize = (10,5)) 2013 #Create 2014 index = 0 2015 if int(len(epochs)/2) > 1: 2016 for i in range(int(len(epochs)/2)): 2017 for j in range(min(2,len(epochs))): 2018 upred_step = upred[index][test_data['xt'][:,-1] == t] 2019 ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2020 ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2021 ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2022 ax[i,j].set_xlabel(' ') 2023 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2024 index = index + 1 2025 elif len(epochs) > 1: 2026 for j in range(2): 2027 upred_step = upred[index][test_data['xt'][:,-1] == t] 2028 ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2029 ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2030 ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2031 ax[j].set_xlabel(' ') 2032 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2033 index = index + 1 2034 else: 2035 upred_step = upred[index][test_data['xt'][:,-1] == t] 2036 ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2037 ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2038 ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2039 ax.set_xlabel(' ') 2040 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2041 index = index + 1 2042 2043 2044 #Title 2045 fig.suptitle(title + 't = ' + str(round(t,4))) 2046 fig.tight_layout() 2047 2048 #Show and save 2049 fig = plt.gcf() 2050 fig.savefig(file_name_save + '/' + str(k) + '.png') 2051 k = k + 1 2052 plt.close() 2053 bar() 2054 2055 #Create demo video 2056 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
29@jax.jit 30def MSE(pred,true): 31 """ 32 Squared error 33 ---------- 34 Parameters 35 ---------- 36 pred : jax.numpy.array 37 38 A JAX numpy array with the predicted values 39 40 true : jax.numpy.array 41 42 A JAX numpy array with the true values 43 44 Returns 45 ------- 46 squared error 47 """ 48 return (true - pred) ** 2
Squared error
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
Returns
- squared error
51@jax.jit 52def MSE_SA(pred,true,w,q = 2): 53 """ 54 Self-adaptative squared error 55 ---------- 56 Parameters 57 ---------- 58 pred : jax.numpy.array 59 60 A JAX numpy array with the predicted values 61 62 true : jax.numpy.array 63 64 A JAX numpy array with the true values 65 66 weight : jax.numpy.array 67 68 A JAX numpy array with the weights 69 70 q : float 71 72 Power for the weights mask 73 74 Returns 75 ------- 76 self-adaptative squared error with polynomial mask 77 """ 78 return (w ** q) * ((true - pred) ** 2)
Self-adaptative squared error
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
- weight (jax.numpy.array): A JAX numpy array with the weights
- q (float): Power for the weights mask
Returns
- self-adaptative squared error with polynomial mask
81@jax.jit 82def L2error(pred,true): 83 """ 84 L2-error in percentage (%) 85 ---------- 86 Parameters 87 ---------- 88 pred : jax.numpy.array 89 90 A JAX numpy array with the predicted values 91 92 true : jax.numpy.array 93 94 A JAX numpy array with the true values 95 96 Returns 97 ------- 98 L2-error 99 """ 100 return 100*jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))
L2-error in percentage (%)
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
Returns
- L2-error
103def idst1(x,axis = -1): 104 """ 105 Inverse Discrete Sine Transform of type I with orthonormal scaling 106 ---------- 107 Parameters 108 ---------- 109 x : jax.numpy.array 110 111 Array to apply the transformation 112 113 axis : int 114 115 Axis to apply the transformation over 116 117 Returns 118 ------- 119 jax.numpy.array 120 """ 121 return idst(x,type = 1,axis = axis,norm = 'ortho')
Inverse Discrete Sine Transform of type I with orthonormal scaling
Parameters
- x (jax.numpy.array): Array to apply the transformation
- axis (int): Axis to apply the transformation over
Returns
- jax.numpy.array
123def dstn(x,axes = None): 124 """ 125 Discrete Sine Transform of type I with orthonormal scaling over many axes 126 ---------- 127 Parameters 128 ---------- 129 x : jax.numpy.array 130 131 Array to apply the transformation 132 133 axes : int 134 135 Axes to apply the transformation over 136 137 Returns 138 ------- 139 jax.numpy.array 140 """ 141 if axes is None: 142 axes = tuple(range(x.ndim)) 143 y = x 144 for ax in axes: 145 y = dst(x,type = 1,axis = ax,norm = 'ortho') 146 return y
Discrete Sine Transform of type I with orthonormal scaling over many axes
Parameters
- x (jax.numpy.array): Array to apply the transformation
- axes (int): Axes to apply the transformation over
Returns
- jax.numpy.array
148def idstn(x,axes = None): 149 """ 150 Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes 151 ---------- 152 Parameters 153 ---------- 154 x : jax.numpy.array 155 156 Array to apply the transformation 157 158 axes : int 159 160 Axes to apply the transformation over 161 162 Returns 163 ------- 164 jax.numpy.array 165 """ 166 if axes is None: 167 axes = tuple(range(x.ndim)) 168 y = x 169 for ax in axes: 170 y = idst1(y,axis = ax) 171 return y
Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes
Parameters
- x (jax.numpy.array): Array to apply the transformation
- axes (int): Axes to apply the transformation over
Returns
- jax.numpy.array
173def dirichlet_eigs_nd(n,L): 174 """ 175 Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle 176 ---------- 177 Parameters 178 ---------- 179 n : list 180 181 List with the number of points in the grid in each dimension 182 183 L : list 184 185 List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero 186 187 Returns 188 ------- 189 jax.numpy.array 190 """ 191 #Unidimensional eigenvalues 192 lam_axes = [] 193 for ni, Li in zip(n,L): 194 h = Li / (ni + 1.0) 195 k = jnp.arange(1,ni + 1,dtype = np.float32) 196 ln = (2.0 / (h*h)) * (1.0 - jnp.cos(jnp.pi * k / (ni + 1.0))) 197 lam_axes.append(ln) 198 grids = jnp.meshgrid(*lam_axes, indexing='ij') 199 Lam = jnp.zeros_like(grids[0]) 200 for g in grids: 201 Lam += g 202 return Lam
Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle
Parameters
- n (list): List with the number of points in the grid in each dimension
- L (list): List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero
Returns
- jax.numpy.array
206def generate_matern_sample(key,d = 2,N = 128,L = 1.0,kappa = 1,alpha = 1,sigma = 1,periodic = False): 207 """ 208 Sample d-dimensional Matern process 209 ---------- 210 Parameters 211 ---------- 212 key : int 213 214 Seed for randomization 215 216 d : int 217 218 Dimension. Default 2 219 220 N : int 221 222 Size of grid in each dimension. Default 128 223 224 L : list of float 225 226 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 227 228 kappa,alpha,sigma : float 229 230 Parameters of the Matern process 231 232 periodic : logical 233 234 Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT 235 236 Returns 237 ------- 238 jax.numpy.array 239 """ 240 if periodic: 241 #Shape and key 242 key = jax.random.PRNGKey(key) 243 shape = (N,) * d 244 if isinstance(L,float) or isinstance(L,int): 245 L = d*[L] 246 if isinstance(N,float) or isinstance(N,int): 247 N = d*[N] 248 249 #Setup Frequency Grid (2D) 250 freq = [jnp.fft.fftfreq(N[j],d = L[j]/N[j]) * 2 * jnp.pi for j in range(d)] 251 grids = jnp.meshgrid(*freq, indexing='ij') 252 sq_norm_xi = sum(g**2 for g in grids) 253 254 #Generate White Noise in Fourier Space 255 key_re, key_im = jax.random.split(key) 256 white_noise_f = (jax.random.normal(key_re, shape) + 257 1j * jax.random.normal(key_im, shape)) 258 259 #Apply the Whittle Filter 260 amplitude_filter = (kappa ** 2 + sq_norm_xi) ** (-alpha / 2.0) 261 field_f = white_noise_f * amplitude_filter 262 263 #Transform back to Physical Space 264 sample = jnp.real(jnp.fft.ifftn(field_f)) 265 return sigma*sample 266 else: #NOT JAX 267 #Shape and key 268 rng = np.random.default_rng(seed = key) 269 if isinstance(L,float) or isinstance(L,int): 270 L = d*[L] 271 if isinstance(N,float) or isinstance(N,int): 272 N = d*[N] 273 shape = tuple(N) 274 275 #White noise in real space 276 W = rng.standard_normal(size = shape) 277 278 #To Dirichlet eigenbasis via separable DST-I (orthonormal) 279 W_hat = dstn(W) 280 281 #Discrete Dirichlet Laplacian eigenvalues 282 lam = dirichlet_eigs_nd(N, L) 283 284 #Spectral filter 285 filt = ((kappa + lam) ** (-alpha/2.0)) 286 psi_hat = filt * W_hat 287 288 #Back to real space 289 psi = idstn(psi_hat) 290 return jnp.array(sigma*psi)
Sample d-dimensional Matern process
Parameters
- key (int): Seed for randomization
- d (int): Dimension. Default 2
- N (int): Size of grid in each dimension. Default 128
- L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
- kappa,alpha,sigma (float): Parameters of the Matern process
- periodic (logical): Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
Returns
- jax.numpy.array
293def generate_matern_sample_batch(d = 2,N = 512,L = 1.0,kappa = 10.0,alpha = 1,sigma = 10,periodic = False): 294 """ 295 Create function to sample d-dimensional Matern process 296 ---------- 297 Parameters 298 ---------- 299 d : int 300 301 Dimension. Default 2 302 303 N : int 304 305 Size of grid in each dimension. Default 128 306 307 L : list of float 308 309 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 310 311 kappa,alpha,sigma : float 312 313 Parameters of the Matern process 314 315 periodic : logical 316 317 Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT 318 319 Returns 320 ------- 321 function 322 """ 323 if periodic: 324 return jax.vmap(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)) 325 else: 326 return lambda keys: jnp.array(np.apply_along_axis(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic),1,keys.reshape((keys.shape[0],1))))
Create function to sample d-dimensional Matern process
Parameters
- d (int): Dimension. Default 2
- N (int): Size of grid in each dimension. Default 128
- L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
- kappa,alpha,sigma (float): Parameters of the Matern process
- periodic (logical): Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
Returns
- function
329def eigenf_laplace(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None): 330 """ 331 Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. 332 ---------- 333 Parameters 334 ---------- 335 L_vec : list of float 336 337 The domain of the function in each coordinate is [0,L[1]] 338 339 kmax_per_axis : list 340 341 List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions 342 343 bc : str 344 345 Boundary condition. 'dirichlet' or 'neumann' 346 347 max_ef : int 348 349 Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions 350 351 Returns 352 ------- 353 function to compute eigenfunctions,eigenvalues of the eigenfunctions considered 354 """ 355 #Parameters 356 L_vec = jnp.asarray(L_vec,dtype = jnp.float32) 357 d = L_vec.shape[0] 358 bc = bc.lower() 359 360 #Maximum number of functions 361 if max_ef is None: 362 if d == 1: 363 max_ef = jnp.max(jnp.array(kmax_per_axis)) 364 else: 365 max_ef = jnp.max(d * jnp.array(kmax_per_axis)) 366 367 #Build the candidate multi-indices per axis 368 kmax_per_axis = list(map(int, kmax_per_axis)) 369 if bc.startswith("d"): 370 axis_ranges = [range(1, km + 1) for km in kmax_per_axis] 371 elif bc.startswith("n"): 372 axis_ranges = [range(0, km + 1) for km in kmax_per_axis] 373 374 #Get all multi-indices 375 Ks_list = list(product(*axis_ranges)) 376 Ks = jnp.array(Ks_list,dtype = jnp.float32) 377 378 #Eigenvalues of the continuous Laplacian 379 pi_over_L = jnp.pi / L_vec 380 lambdas_all = jnp.sum((Ks * pi_over_L) ** 2, axis=1) 381 382 #Sort by eigenvalue 383 order = jnp.argsort(lambdas_all) 384 Ks = Ks[order] 385 lambdas_all = lambdas_all[order] 386 387 #Keep first max_ef 388 Ks = Ks[:max_ef] 389 lambdas = lambdas_all[:max_ef] 390 m = Ks.shape[0] 391 392 #Precompute per-feature normalization factor (closed form) 393 def per_axis_norm_factor(k_i, L_i, is_dirichlet): 394 if is_dirichlet: 395 return jnp.sqrt(2.0 / L_i) 396 else: 397 return jnp.where(k_i == 0, jnp.sqrt(1.0 / L_i), jnp.sqrt(2.0 / L_i)) 398 if bc.startswith("d"): 399 nf = jnp.prod(jnp.sqrt(2.0 / L_vec)[None, :],axis = 1) 400 norm_factors = jnp.ones((m,),dtype = jnp.float32) * nf 401 else: 402 # per-mode product across axes 403 def nf_row(k_row): 404 return jnp.prod(per_axis_norm_factor(k_row, L_vec, False)) 405 norm_factors = jax.vmap(nf_row)(Ks) 406 407 #Build the callable function 408 Ks_int = Ks # float array, but only integer values 409 L_vec_f = L_vec 410 @jax.jit 411 def phi(x): 412 x = jnp.asarray(x,dtype = jnp.float32) 413 #Initialize with ones 414 vals = jnp.ones(x.shape[:-1] + (m,), dtype=jnp.float32) 415 #Compute eigenfunction 416 for i in range(d): 417 ang = (jnp.pi / L_vec_f[i]) * x[..., i][..., None] * Ks_int[:, i] 418 if bc.startswith("d"): 419 comp = jnp.sin(ang) 420 else: 421 comp = jnp.cos(ang) 422 vals = vals * comp 423 #Apply L2-normalizing constants 424 vals = vals * norm_factors[None, ...] if vals.ndim > 1 else vals * norm_factors 425 return vals 426 return phi, lambdas
Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace.
Parameters
- L_vec (list of float): The domain of the function in each coordinate is [0,L[1]]
- kmax_per_axis (list): List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions
- bc (str): Boundary condition. 'dirichlet' or 'neumann'
- max_ef (int): Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
Returns
- function to compute eigenfunctions,eigenvalues of the eigenfunctions considered
429def multiple_daff(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None): 430 """ 431 Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain. 432 ---------- 433 Parameters 434 ---------- 435 L_vec : list of lists of float 436 437 List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]] 438 439 kmax_per_axis : list 440 441 List with the maximum number of eigenfunctions per dimension. 442 443 bc : str 444 445 Boundary condition. 'dirichlet' or 'neumann' 446 447 max_ef : int 448 449 Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions 450 451 Returns 452 ------- 453 function to compute daff,eigenvalues of the eigenfunctions considered 454 """ 455 psi = [] 456 lamb = [] 457 for L in L_vec: 458 tmp,l = eigenf_laplace(L,kmax_per_axis,bc,max_ef) #Get function 459 lamb.append(l) 460 psi.append(tmp) 461 del tmp 462 #Create function to compute features 463 @jax.jit 464 def mff(x): 465 y = [] 466 for i in range(len(psi)): 467 y.append(psi[i](x)) 468 if len(psi) == 1: 469 return y[0] 470 else: 471 return jnp.concatenate(y,1) 472 return mff,jnp.concatenate(lamb)
Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain.
Parameters
- L_vec (list of lists of float): List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]]
- kmax_per_axis (list): List with the maximum number of eigenfunctions per dimension.
- bc (str): Boundary condition. 'dirichlet' or 'neumann'
- max_ef (int): Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
Returns
- function to compute daff,eigenvalues of the eigenfunctions considered
500@partial(jax.jit,static_argnums=(2,)) # n is static here; compile once per n 501def multiple_cheb_fast(x, L_vec, n: int): 502 """ 503 x: (N, d) 504 L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension 505 n: number of k terms (static) 506 returns: (N, L*n) 507 """ 508 N, d = x.shape 509 L = L_vec.shape[0] 510 511 a = 0.0 512 b = L_vec # (L, d) 513 # Map x to t in [-1, 1] for each l, j: shape (L, N, d) 514 t = (2.0 * x[None, :, :] - (a + b)[:, None, :]) / (b - a)[:, None, :] 515 516 # Chebyshev T_0..T_{n+2} for all (L, N, d): shape (n+3, L, N, d) 517 T = _chebyshev_T_all(t, n + 2) 518 519 # phi_k = T_{k+2} - T_k, k = 0..n-1 => shape (n, L, N, d) 520 ks = jnp.arange(n) 521 phi = T[ks + 2, ...] - T[ks, ...] 522 523 # Multiply across dimensions (over the last axis = d) => (n, L, N) 524 z = jnp.prod(phi, axis=-1) 525 526 # Reorder to (N, L, n) then flatten to (N, L*n) 527 z = jnp.transpose(z, (2, 1, 0)).reshape(N, L * n) 528 return z
x: (N, d) L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension n: number of k terms (static) returns: (N, L*n)
530def multiple_cheb(L_vec, n: int): 531 """ 532 Factory that closes over static n and L_vec (so shapes are constant). 533 """ 534 L_vec = jnp.asarray(L_vec) 535 @jax.jit # optional; multiple_cheb_fast is already jitted 536 def mcheb(x): 537 x = jnp.asarray(x) 538 return multiple_cheb_fast(x, L_vec, n) 539 return mcheb
Factory that closes over static n and L_vec (so shapes are constant).
543def fconNN(width,activation = jax.nn.tanh,key = 0,mlp = False,ftype = None,fargs = None,static = None): 544 """ 545 Initialize fully connected neural network 546 ---------- 547 Parameters 548 ---------- 549 width : list 550 551 List with the layers width 552 553 activation : jax.nn activation 554 555 The activation function. Default jax.nn.tanh 556 557 key : int 558 559 Seed for parameters initialization. Default 0 560 561 mlp : logical 562 563 Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension. 564 565 ftype : str 566 567 Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'. 568 569 fargs : list 570 571 Arguments for deature transformation: 572 573 For 'ff': A list with the number of frequences and value of greatest frequence standard deviation. 574 575 For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet. 576 577 static : function 578 579 A static function to sum to the neural network output. 580 581 Returns 582 ------- 583 dict with initial parameters and the function for the forward pass 584 """ 585 #Initialize parameters with Glorot initialization 586 initializer = jax.nn.initializers.glorot_normal() 587 params = list() 588 if static is None: 589 static = lambda x: 0.0 590 591 #Feature mapping 592 if ftype == 'ff': #Fourrier features 593 for s in range(fargs[0]): 594 sd = fargs[1] ** ((s + 1)/fargs[0]) 595 if s == 0: 596 Bff = sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))) 597 else: 598 Bff = jnp.append(Bff,sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))),1) 599 @jax.jit 600 def phi(x): 601 x = x @ Bff 602 return jnp.concatenate([jnp.sin(2 * jnp.pi * x),jnp.cos(2 * jnp.pi * x)],axis = -1) 603 width = width[1:] 604 width[0] = 2*Bff.shape[1] 605 elif ftype == 'daff' or ftype == 'daff_bias': 606 if not isinstance(fargs, dict): 607 fargs = {'L': fargs,'bc': "dirichlet"} 608 phi,lamb = multiple_daff(list(fargs.values())[0],kmax_per_axis = [width[1]] * width[0],bc = list(fargs.values())[1]) 609 width = width[1:] 610 width[0] = lamb.shape[0] 611 elif ftype == 'cheb' or ftype == 'cheb_bias': 612 phi = multiple_cheb(fargs,n = width[1]) 613 width = width[1:] 614 width[0] = len(fargs)*width[0] 615 else: 616 @jax.jit 617 def phi(x): 618 return x 619 620 #Initialize parameters 621 if mlp: 622 k = jax.random.split(jax.random.PRNGKey(key),4) 623 WU = initializer(k[0],(width[0],width[1]),jnp.float32) 624 BU = initializer(k[1],(1,width[1]),jnp.float32) 625 WV = initializer(k[2],(width[0],width[1]),jnp.float32) 626 BV = initializer(k[3],(1,width[1]),jnp.float32) 627 params.append({'WU':WU,'BU':BU,'WV':WV,'BV':BV}) 628 key = jax.random.split(jax.random.PRNGKey(key + 1),len(width)-1) #Seed for initialization 629 for key,lin,lout in zip(key,width[:-1],width[1:]): 630 W = initializer(key,(lin,lout),jnp.float32) 631 B = initializer(key,(1,lout),jnp.float32) 632 params.append({'W':W,'B':B}) 633 634 #Define function for forward pass 635 if mlp: 636 if ftype != 'daff' and ftype != 'cheb': 637 @jax.jit 638 def forward(x,params): 639 encode,*hidden,output = params 640 sx = static(x) 641 x = phi(x) 642 U = activation(x @ encode['WU'] + encode['BU']) 643 V = activation(x @ encode['WV'] + encode['BV']) 644 for layer in hidden: 645 x = activation(x @ layer['W'] + layer['B']) 646 x = x * U + (1 - x) * V 647 return x @ output['W'] + output['B'] + sx 648 else: 649 @jax.jit 650 def forward(x,params): 651 encode,*hidden,output = params 652 sx = static(x) 653 x = phi(x) 654 U = activation(x @ encode['WU']) 655 V = activation(x @ encode['WV']) 656 for layer in hidden: 657 x = activation(x @ layer['W']) 658 x = x * U + (1 - x) * V 659 return x @ output['W'] + sx 660 else: 661 if ftype != 'daff' and ftype != 'cheb': 662 @jax.jit 663 def forward(x,params): 664 *hidden,output = params 665 sx = static(x) 666 x = phi(x) 667 for layer in hidden: 668 x = activation(x @ layer['W'] + layer['B']) 669 return x @ output['W'] + output['B'] + sx 670 else: 671 @jax.jit 672 def forward(x,params): 673 *hidden,output = params 674 sx = static(x) 675 x = phi(x) 676 for layer in hidden: 677 x = activation(x @ layer['W']) 678 return x @ output['W'] + sx 679 680 #Return initial parameters and forward function 681 return {'params': params,'forward': forward}
Initialize fully connected neural network
Parameters
- width (list): List with the layers width
- activation (jax.nn activation): The activation function. Default jax.nn.tanh
- key (int): Seed for parameters initialization. Default 0
- mlp (logical): Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension.
- ftype (str): Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
fargs (list): Arguments for deature transformation:
For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
- static (function): A static function to sum to the neural network output.
Returns
- dict with initial parameters and the function for the forward pass
684def get_activation(act): 685 """ 686 Return activation function from string 687 ---------- 688 Parameters 689 ---------- 690 act : str 691 692 Name of the activation function. Default 'tanh' 693 694 Returns 695 ------- 696 jax.nn activation function 697 """ 698 if act == 'tanh': 699 return jax.nn.tanh 700 elif act == 'relu': 701 return jax.nn.relu 702 elif act == 'relu6': 703 return jax.nn.relu6 704 elif act == 'sigmoid': 705 return jax.nn.sigmoid 706 elif act == 'softplus': 707 return jax.nn.softplus 708 elif act == 'sparse_plus': 709 return jx.nn.sparse_plus 710 elif act == 'soft_sign': 711 return jax.nn.soft_sign 712 elif act == 'silu': 713 return jax.nn.silu 714 elif act == 'swish': 715 return jax.nn.swish 716 elif act == 'log_sigmoid': 717 return jax.nn.log_sigmoid 718 elif act == 'leaky_relu': 719 return jax.nn.leaky_relu 720 elif act == 'hard_sigmoid': 721 return jax.nn.hard_sigmoid 722 elif act == 'hard_silu': 723 return jax.nn.hard_silu 724 elif act == 'hard_swish': 725 return jax.nn.hard_swish 726 elif act == 'hard_tanh': 727 return jax.nn.hard_tanh 728 elif act == 'elu': 729 return jax.nn.elu 730 elif act == 'celu': 731 return jax.nn.celu 732 elif act == 'selu': 733 return jax.nn.selu 734 elif act == 'gelu': 735 return jax.nn.gelu 736 elif act == 'glu': 737 return jax.nn.glu 738 elif act == 'squareplus': 739 return jax.nn.squareplus 740 elif act == 'mish': 741 return jax.nn.mish
Return activation function from string
Parameters
- act (str): Name of the activation function. Default 'tanh'
Returns
- jax.nn activation function
744def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,mlp = False): 745 """ 746 Train a Physics-informed Neural Network 747 ---------- 748 Parameters 749 ---------- 750 data : dict 751 752 Data generated by the jinnax.data.generate_PINNdata function 753 754 width : list 755 756 A list with the width of each layer 757 758 pde : function 759 760 The partial differential operator. Its arguments are u, x and t 761 762 test_data : dict, None 763 764 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 765 766 epochs : int 767 768 Number of training epochs. Default 100 769 770 at_each : int 771 772 Save results for epochs multiple of at_each. Default 10 773 774 activation : str 775 776 The name of the activation function of the neural network. Default 'tanh' 777 778 neumann : logical 779 780 Whether to consider Neumann boundary conditions 781 782 oper_neumann : function 783 784 Penalization of Neumann boundary conditions 785 786 sa : logical 787 788 Whether to consider self-adaptative PINN 789 790 c : dict 791 792 Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1 793 794 inverse : logical 795 796 Whether to estimate parameters of the PDE 797 798 initial_par : jax.numpy.array 799 800 Initial value of the parameters of the PDE in an inverse problem 801 802 lr,b1,b2,eps,eps_root: float 803 804 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 805 806 key : int 807 808 Seed for parameters initialization. Default 0 809 810 epoch_print : int 811 812 Number of epochs to calculate and print test errors. Default 100 813 814 save : logical 815 816 Whether to save the current parameters. Default False 817 818 file_name : str 819 820 File prefix to save the current parameters. Default 'result_pinn' 821 822 exp_decay : logical 823 824 Whether to consider exponential decay of learning rate. Default False 825 826 transition_steps : int 827 828 Number of steps for exponential decay. Default 1000 829 830 decay_rate : float 831 832 Rate of exponential decay. Default 0.9 833 834 mlp : logical 835 836 Whether to consider modifed multi-layer perceptron 837 838 Returns 839 ------- 840 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time 841 """ 842 843 #Initialize architecture 844 nnet = fconNN(width,get_activation(activation),key,mlp) 845 forward = nnet['forward'] 846 847 #Initialize self adaptative weights 848 par_sa = {} 849 if sa: 850 #Initialize wheights close to zero 851 ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000) 852 if data['sensor'] is not None: 853 par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))}) 854 if data['initial'] is not None: 855 par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))}) 856 if data['collocation'] is not None: 857 par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))}) 858 if data['boundary'] is not None: 859 par_sa.update({'wb': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))}) 860 861 #Store all parameters 862 params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa} 863 864 #Save config file 865 if save: 866 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 867 868 #Define loss function 869 if sa: 870 #Define loss function 871 @jax.jit 872 def lf(params,x): 873 loss = 0 874 if x['sensor'] is not None: 875 #Term that refers to sensor data 876 loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws'])) 877 if x['boundary'] is not None: 878 if neumann: 879 #Neumann coditions 880 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 881 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 882 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb'])) 883 else: 884 #Term that refers to boundary data 885 loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb'])) 886 if x['initial'] is not None: 887 #Term that refers to initial data 888 loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0'])) 889 if x['collocation'] is not None: 890 #Term that refers to collocation points 891 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 892 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 893 if inverse: 894 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr'])) 895 else: 896 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr'])) 897 return loss 898 else: 899 @jax.jit 900 def lf(params,x): 901 loss = 0 902 if x['sensor'] is not None: 903 #Term that refers to sensor data 904 loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor'])) 905 if x['boundary'] is not None: 906 if neumann: 907 #Neumann coditions 908 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 909 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 910 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb)) 911 else: 912 #Term that refers to boundary data 913 loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary'])) 914 if x['initial'] is not None: 915 #Term that refers to initial data 916 loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial'])) 917 if x['collocation'] is not None: 918 #Term that refers to collocation points 919 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 920 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 921 if inverse: 922 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0)) 923 else: 924 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0)) 925 return loss 926 927 #Initialize Adam Optmizer 928 if exp_decay: 929 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 930 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 931 opt_state = optimizer.init(params) 932 933 #Define the gradient function 934 grad_loss = jax.jit(jax.grad(lf,0)) 935 936 #Define update function 937 @jax.jit 938 def update(opt_state,params,x): 939 #Compute gradient 940 grads = grad_loss(params,x) 941 #Invert gradient of self-adaptative wheights 942 if sa: 943 for w in grads['sa']: 944 grads['sa'][w] = - grads['sa'][w] 945 #Calculate parameters updates 946 updates, opt_state = optimizer.update(grads, opt_state) 947 #Update parameters 948 params = optax.apply_updates(params, updates) 949 #Return state of optmizer and updated parameters 950 return opt_state,params 951 952 ###Training### 953 t0 = time.time() 954 #Initialize alive_bar for tracing in terminal 955 with alive_bar(epochs) as bar: 956 #For each epoch 957 for e in range(epochs): 958 #Update optimizer state and parameters 959 opt_state,params = update(opt_state,params,data) 960 #After epoch_print epochs 961 if e % epoch_print == 0: 962 #Compute elapsed time and current error 963 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6)) 964 #If there is test data, compute current L2 error 965 if test_data is not None: 966 #Compute L2 error 967 l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist() 968 l = l + ' L2 error: ' + str(jnp.round(l2_test,3)) 969 if inverse: 970 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 971 #Print 972 print(l) 973 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 974 #Save current parameters 975 pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 976 #Update alive_bar 977 bar() 978 #Define estimated function 979 def u(xt): 980 return forward(xt,params['net']) 981 982 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}
Train a Physics-informed Neural Network
Parameters
- data (dict): Data generated by the jinnax.data.generate_PINNdata function
- width (list): A list with the width of each layer
- pde (function): The partial differential operator. Its arguments are u, x and t
- test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
- epochs (int): Number of training epochs. Default 100
- at_each (int): Save results for epochs multiple of at_each. Default 10
- activation (str): The name of the activation function of the neural network. Default 'tanh'
- neumann (logical): Whether to consider Neumann boundary conditions
- oper_neumann (function): Penalization of Neumann boundary conditions
- sa (logical): Whether to consider self-adaptative PINN
- c (dict): Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
- inverse (logical): Whether to estimate parameters of the PDE
- initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
- lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
- key (int): Seed for parameters initialization. Default 0
- epoch_print (int): Number of epochs to calculate and print test errors. Default 100
- save (logical): Whether to save the current parameters. Default False
- file_name (str): File prefix to save the current parameters. Default 'result_pinn'
- exp_decay (logical): Whether to consider exponential decay of learning rate. Default False
- transition_steps (int): Number of steps for exponential decay. Default 1000
- decay_rate (float): Rate of exponential decay. Default 0.9
- mlp (logical): Whether to consider modifed multi-layer perceptron
Returns
- dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
985def train_Matern_PINN(data,width,pde,test_data = None,params = None,d = 2,N = 128,L = 1,alpha = 1,kappa = 1,sigma = 100,bsize = 1024,resample = False,epochs = 100,at_each = 10,activation = 'tanh', 986 neumann = False,oper_neumann = None,inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 1,save = False,file_name = 'result_pinn', 987 exp_decay = True,transition_steps = 100,decay_rate = 0.9,mlp = True,ftype = None,fargs = None,q = 4,w = None,periodic = False,static = None,opt = 'LBFGS'): 988 """ 989 Train a Physics-informed Neural Network 990 ---------- 991 Parameters 992 ---------- 993 data : dict 994 995 Data generated by the jinnax.data.generate_PINNdata function 996 997 width : list 998 999 A list with the width of each layer 1000 1001 pde : function 1002 1003 The partial differential operator. Its arguments are u, x and t 1004 1005 test_data : dict, None 1006 1007 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 1008 1009 params : list 1010 1011 Initial parameters for the neural network. Default None to initialize randomly 1012 1013 d : int 1014 1015 Dimension of the problem including the time variable if present. Default 2 1016 1017 N : int 1018 1019 Size of grid in each dimension. Default 128 1020 1021 L : list of float 1022 1023 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 1024 1025 kappa,alpha,sigma : float 1026 1027 Parameters of the Matern process 1028 1029 bsize : int 1030 1031 Batch size for weak norm computation. Default 1024 1032 1033 resample : logical 1034 1035 Whether to resample the test functions at each epoch 1036 1037 epochs : int 1038 1039 Number of training epochs. Default 100 1040 1041 at_each : int 1042 1043 Save results for epochs multiple of at_each. Default 10 1044 1045 activation : str 1046 1047 The name of the activation function of the neural network. Default 'tanh' 1048 1049 neumann : logical 1050 1051 Whether to consider Neumann boundary conditions 1052 1053 oper_neumann : function 1054 1055 Penalization of Neumann boundary conditions 1056 1057 inverse : logical 1058 1059 Whether to estimate parameters of the PDE 1060 1061 initial_par : jax.numpy.array 1062 1063 Initial value of the parameters of the PDE in an inverse problem 1064 1065 lr,b1,b2,eps,eps_root: float 1066 1067 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 1068 1069 key : int 1070 1071 Seed for parameters initialization. Default 0 1072 1073 epoch_print : int 1074 1075 Number of epochs to calculate and print test errors. Default 1 1076 1077 save : logical 1078 1079 Whether to save the current parameters. Default False 1080 1081 file_name : str 1082 1083 File prefix to save the current parameters. Default 'result_pinn' 1084 1085 exp_decay : logical 1086 1087 Whether to consider exponential decay of learning rate. Default True 1088 1089 transition_steps : int 1090 1091 Number of steps for exponential decay. Default 100 1092 1093 decay_rate : float 1094 1095 Rate of exponential decay. Default 0.9 1096 1097 mlp : logical 1098 1099 Whether to consider modifed multilayer perceptron 1100 1101 ftype : str 1102 1103 Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'. 1104 1105 fargs : list 1106 1107 Arguments for deature transformation: 1108 1109 For 'ff': A list with the number of frequences and value of greatest frequence standard deviation. 1110 1111 For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet. 1112 1113 q : int 1114 1115 Power of weights mask. Default 4 1116 1117 w : dict 1118 1119 Initila weights for self-adaptive scheme. 1120 1121 periodic : logical 1122 1123 Whether to consider periodic test functions. Default False. 1124 1125 static : function 1126 1127 A static function to sum to the neural network output. 1128 1129 opt : str 1130 1131 Optimizer. Default LBFGS. 1132 1133 Returns 1134 ------- 1135 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch 1136 """ 1137 #Initialize architecture 1138 nnet = fconNN(width,get_activation(activation),key,mlp,ftype,fargs,static) 1139 forward = nnet['forward'] 1140 if params is not None: 1141 nnet['params'] = params 1142 1143 #Generate from Matern process 1144 if sigma > 0: 1145 if isinstance(L,float) or isinstance(L,int): 1146 L = d*[L] 1147 #Grid for weak norm 1148 grid = [jnp.linspace(0,L[i],N) for i in range(d)] 1149 grid = jnp.meshgrid(*grid, indexing='ij') 1150 grid = jnp.stack(grid, axis=-1).reshape((-1, d)) 1151 #Set sigma 1152 if data['boundary'] is not None: 1153 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma) 1154 tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0]) 1155 if neumann: 1156 loss_boundary = oper_neumann(lambda x: forward(x,params['net']),data['boundary']) 1157 else: 1158 loss_boundary = jnp.mean(MSE(forward(data['boundary'],nnet['params']),data['uboundary'])) 1159 output_w = pde(lambda x: forward(x,nnet['params']),grid) 1160 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(tf) 1161 loss_res_weak = jnp.mean(integralOmega ** 2) 1162 sigma = float(jnp.sqrt(loss_boundary/loss_res_weak).tolist()) 1163 del gen 1164 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic) 1165 tf = sigma*tf 1166 else: 1167 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic) 1168 tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0]) 1169 1170 #Define loss function 1171 @jax.jit 1172 def lf_each(params,x,k): 1173 if sigma > 0: 1174 #Term that refers to weak loss 1175 if resample: 1176 test_functions = gen(jax.random.split(jax.random.PRNGKey(k[0]),(bsize,))[:,0]) 1177 else: 1178 test_functions = tf 1179 loss_sensor = loss_boundary = loss_initial = loss_res = loss_res_weak = 0 1180 if x['sensor'] is not None: 1181 #Term that refers to sensor data 1182 loss_sensor = jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor'])) 1183 if x['boundary'] is not None: 1184 if neumann: 1185 #Neumann coditions 1186 loss_boundary = oper_neumann(lambda x: forward(x,params['net']),x['boundary']) 1187 else: 1188 #Term that refers to boundary data 1189 loss_boundary = MSE(forward(x['boundary'],params['net']),x['uboundary']) 1190 if x['initial'] is not None: 1191 #Term that refers to initial data 1192 loss_initial = MSE(forward(x['initial'],params['net']),x['uinitial']) 1193 if x['collocation'] is not None and sigma == 0: 1194 if inverse: 1195 output = pde(lambda x: forward(x,params['net']),x['collocation'],params['inverse']) 1196 loss_res = MSE(output,0) 1197 else: 1198 output = pde(lambda x: forward(x,params['net']),x['collocation']) 1199 loss_res = MSE(output,0) 1200 if sigma > 0: 1201 #Term that refers to weak loss 1202 if inverse: 1203 output_w = pde(lambda x: forward(x,params['net']),grid,params['inverse']) 1204 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions) 1205 loss_res_weak = jnp.mean(integralOmega ** 2) 1206 else: 1207 output_w = pde(lambda x: forward(x,params['net']),grid) 1208 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions) 1209 loss_res_weak = jnp.mean(integralOmega ** 2) 1210 return {'ls': loss_sensor,'lb': loss_boundary,'li': loss_initial,'lc': loss_res,'lc_weak': loss_res_weak} 1211 1212 @jax.jit 1213 def lf(params,x,k): 1214 l = lf_each(params,x,k) 1215 w = params['w'] 1216 loss = jnp.mean((w['ws'] ** q)*l['ls']) + jnp.mean((w['wb'] ** q)*l['lb']) + jnp.mean((w['wi'] ** q)*l['li']) + jnp.mean((w['wc'] ** q)*l['lc']) + (w['wc_weak'] ** q)*l['lc_weak'] 1217 if opt != 'LBFGS': 1218 return loss 1219 else: 1220 l2 = None 1221 if test_data is not None: 1222 l2 = L2error(forward(test_data['sensor'],params['net']),test_data['usensor']) 1223 return loss,{'loss': loss,'l2': l2} 1224 1225 #Initialize self-adaptive weights 1226 if w is None: 1227 w = {'ws': jnp.array(1.0),'wb': jnp.array(1.0),'wi': jnp.array(1.0),'wc': jnp.array(1.0),'wc_weak': jnp.array(1.0)} 1228 if q != 0: 1229 if data['sensor'] is not None: 1230 w['ws'] = w['ws'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+1),(data['sensor'].shape[0],1)) 1231 if data['boundary'] is not None: 1232 w['wb'] = w['wb'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+2),(data['boundary'].shape[0],1)) 1233 if data['initial'] is not None: 1234 w['wi'] = w['wi'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+3),(data['initial'].shape[0],1)) 1235 if data['collocation'] is not None: 1236 w['wc'] = w['wc'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+4),(data['collocation'].shape[0],1)) 1237 1238 #Store all parameters 1239 params = {'net': nnet['params'],'inverse': initial_par,'w': w} 1240 1241 #Save config file 1242 if save: 1243 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1244 1245 #Initialize Adam Optmizer 1246 if opt != 'LBFGS': 1247 print('--------- GRADIENT DESCENT OPTIMIZER ---------') 1248 if exp_decay: 1249 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 1250 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 1251 opt_state = optimizer.init(params) 1252 1253 #Define the gradient function 1254 grad_loss = jax.jit(jax.grad(lf,0)) 1255 1256 #Define update function 1257 @jax.jit 1258 def update(opt_state,params,x,k): 1259 #Compute gradient 1260 grads = grad_loss(params,x,k) 1261 #Calculate parameters updates 1262 updates, opt_state = optimizer.update(grads, opt_state) 1263 #Update parameters 1264 if q != 0: 1265 updates = {**updates, 'w': jax.tree_util.tree_map(lambda x: -x, updates['w'])} #Change signs of weights 1266 params = optax.apply_updates(params, updates) 1267 #Return state of optmizer and updated parameters 1268 return opt_state,params 1269 else: 1270 print('--------- LBFGS OPTIMIZER ---------') 1271 @jax.jit 1272 def loss_LBFGS(params): 1273 return lf(params,data,key + 234) 1274 solver = LBFGS(fun = loss_LBFGS,has_aux = True,maxiter = epochs,tol = 1e-9,verbose = False,linesearch = 'zoom',history_size = 100) # linesearch='zoom' by default 1275 state = solver.init_state(params) 1276 1277 ###Training### 1278 t0 = time.time() 1279 k = jax.random.split(jax.random.PRNGKey(key+234),(epochs,)) 1280 sloss = [] 1281 sL2 = [] 1282 stime = [] 1283 #Initialize alive_bar for tracing in terminal 1284 with alive_bar(epochs) as bar: 1285 #For each epoch 1286 for e in range(epochs): 1287 if opt != 'LBFGS': 1288 #Update optimizer state and parameters 1289 opt_state,params = update(opt_state,params,data,k[e,:]) 1290 sloss.append(lf(params,data,k[e,:])) 1291 if test_data is not None: 1292 sL2.append(L2error(forward(test_data['sensor'],params['net']),test_data['usensor'])) 1293 else: 1294 params, state = solver.update(params, state) 1295 sL2.append(state.aux["l2"]) 1296 sloss.append(state.aux["loss"]) 1297 stime.append(time.time() - t0) 1298 #After epoch_print epochs 1299 if e % epoch_print == 0: 1300 #Compute elapsed time and current error 1301 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(sloss[-1],6)) 1302 #If there is test data, compute current L2 error 1303 if test_data is not None: 1304 #Compute L2 error 1305 l = l + ' L2 error: ' + str(jnp.round(sL2[-1],6)) 1306 if inverse: 1307 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 1308 #Print 1309 print(l) 1310 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 1311 #Save current parameters 1312 pickle.dump({'params': params,'width': width,'time': stime,'loss': sloss,'L2error': sL2},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1313 #Update alive_bar 1314 bar() 1315 #Define estimated function 1316 def u(xt): 1317 return forward(xt,params['net']) 1318 1319 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0,'loss_each': lf_each(params,data,[key + 100]),'loss': sloss,'L2error': sL2}
Train a Physics-informed Neural Network
Parameters
- data (dict): Data generated by the jinnax.data.generate_PINNdata function
- width (list): A list with the width of each layer
- pde (function): The partial differential operator. Its arguments are u, x and t
- test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
- params (list): Initial parameters for the neural network. Default None to initialize randomly
- d (int): Dimension of the problem including the time variable if present. Default 2
- N (int): Size of grid in each dimension. Default 128
- L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
- kappa,alpha,sigma (float): Parameters of the Matern process
- bsize (int): Batch size for weak norm computation. Default 1024
- resample (logical): Whether to resample the test functions at each epoch
- epochs (int): Number of training epochs. Default 100
- at_each (int): Save results for epochs multiple of at_each. Default 10
- activation (str): The name of the activation function of the neural network. Default 'tanh'
- neumann (logical): Whether to consider Neumann boundary conditions
- oper_neumann (function): Penalization of Neumann boundary conditions
- inverse (logical): Whether to estimate parameters of the PDE
- initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
- lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
- key (int): Seed for parameters initialization. Default 0
- epoch_print (int): Number of epochs to calculate and print test errors. Default 1
- save (logical): Whether to save the current parameters. Default False
- file_name (str): File prefix to save the current parameters. Default 'result_pinn'
- exp_decay (logical): Whether to consider exponential decay of learning rate. Default True
- transition_steps (int): Number of steps for exponential decay. Default 100
- decay_rate (float): Rate of exponential decay. Default 0.9
- mlp (logical): Whether to consider modifed multilayer perceptron
- ftype (str): Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
fargs (list): Arguments for deature transformation:
For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
- q (int): Power of weights mask. Default 4
- w (dict): Initila weights for self-adaptive scheme.
- periodic (logical): Whether to consider periodic test functions. Default False.
- static (function): A static function to sum to the neural network output.
- opt (str): Optimizer. Default LBFGS.
Returns
- dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch
1323def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1): 1324 """ 1325 Process the results of a Physics-informed Neural Network 1326 ---------- 1327 1328 Parameters 1329 ---------- 1330 test_data : dict 1331 1332 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1333 1334 fit : function 1335 1336 The fitted function 1337 1338 train_data : dict 1339 1340 Training data generated by the jinnax.data.generate_PINNdata 1341 1342 plot : logical 1343 1344 Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True 1345 1346 plot_test : logical 1347 1348 Whether to plot the test data. Default True 1349 1350 times : int 1351 1352 Number of points along the time interval to plot. Default 5 1353 1354 d2 : logical 1355 1356 Whether to plot 2D plot when the spatial dimension is one. Default True 1357 1358 save : logical 1359 1360 Whether to save the plots. Default False 1361 1362 show : logical 1363 1364 Whether to show the plots. Default True 1365 1366 file_name : str 1367 1368 File prefix to save the plots. Default 'result_pinn' 1369 1370 print_res : logical 1371 1372 Whether to print the L2 error. Default True 1373 1374 p : int 1375 1376 Output dimension. Default 1 1377 1378 Returns 1379 ------- 1380 pandas data frame with L2 and MSE errors 1381 """ 1382 1383 #Dimension 1384 d = test_data['xt'].shape[1] - 1 1385 1386 #Number of plots multiple of 5 1387 times = 5 * round(times/5.0) 1388 1389 #Data 1390 td = get_train_data(train_data) 1391 xt_train = td['x'] 1392 u_train = td['y'] 1393 upred_train = fit(xt_train) 1394 upred_test = fit(test_data['xt']) 1395 1396 #Results 1397 l2_error_test = L2error(upred_test,test_data['u']).tolist() 1398 MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist() 1399 l2_error_train = L2error(upred_train,u_train).tolist() 1400 MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist() 1401 1402 df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)), 1403 columns=['l2_error_test','MSE_test','l2_error_train','MSE_train']) 1404 if print_res: 1405 print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) ) 1406 1407 #Plots 1408 if d == 1 and p ==1 and plot: 1409 plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name) 1410 elif p == 2 and plot: 1411 plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test) 1412 1413 return df
Process the results of a Physics-informed Neural Network
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- fit (function): The fitted function
- train_data (dict): Training data generated by the jinnax.data.generate_PINNdata
- plot (logical): Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
- plot_test (logical): Whether to plot the test data. Default True
- times (int): Number of points along the time interval to plot. Default 5
- d2 (logical): Whether to plot 2D plot when the spatial dimension is one. Default True
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- print_res (logical): Whether to print the L2 error. Default True
- p (int): Output dimension. Default 1
Returns
- pandas data frame with L2 and MSE errors
1416def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''): 1417 """ 1418 Plot the prediction of a 1D PINN 1419 ---------- 1420 1421 Parameters 1422 ---------- 1423 times : int 1424 1425 Number of points along the time interval to plot. Default 5 1426 1427 xt : jax.numpy.array 1428 1429 Test data xt array 1430 1431 u : jax.numpy.array 1432 1433 Test data u(x,t) array 1434 1435 upred : jax.numpy.array 1436 1437 Predicted upred(x,t) array on test data 1438 1439 d2 : logical 1440 1441 Whether to plot 2D plot. Default True 1442 1443 save : logical 1444 1445 Whether to save the plots. Default False 1446 1447 show : logical 1448 1449 Whether to show the plots. Default True 1450 1451 file_name : str 1452 1453 File prefix to save the plots. Default 'result_pinn' 1454 1455 title_1d : str 1456 1457 Title of 1D plot 1458 1459 title_2d : str 1460 1461 Title of 2D plot 1462 1463 Returns 1464 ------- 1465 None 1466 """ 1467 #Initialize 1468 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 1469 tlo = jnp.min(xt[:,-1]) 1470 tup = jnp.max(xt[:,-1]) 1471 ylo = jnp.min(u) 1472 ylo = ylo - 0.1*jnp.abs(ylo) 1473 yup = jnp.max(u) 1474 yup = yup + 0.1*jnp.abs(yup) 1475 k = 0 1476 t_values = np.linspace(tlo,tup,times) 1477 1478 #Create 1479 for i in range(int(times/5)): 1480 for j in range(5): 1481 if k < len(t_values): 1482 t = t_values[k] 1483 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 1484 x_plot = xt[xt[:,-1] == t,:-1] 1485 y_plot = upred[xt[:,-1] == t,:] 1486 u_plot = u[xt[:,-1] == t,:] 1487 if int(times/5) > 1: 1488 ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 1489 ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 1490 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 1491 ax[i,j].set_xlabel(' ') 1492 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1493 else: 1494 ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 1495 ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 1496 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 1497 ax[j].set_xlabel(' ') 1498 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1499 k = k + 1 1500 1501 #Title 1502 fig.suptitle(title_1d) 1503 fig.tight_layout() 1504 1505 #Show and save 1506 fig = plt.gcf() 1507 if show: 1508 plt.show() 1509 if save: 1510 fig.savefig(file_name + '_slices.png') 1511 plt.close() 1512 1513 #2d plot 1514 if d2: 1515 #Initialize 1516 fig, ax = plt.subplots(1,2) 1517 l1 = jnp.unique(xt[:,-1]).shape[0] 1518 l2 = jnp.unique(xt[:,0]).shape[0] 1519 1520 #Create 1521 ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 1522 ax[0].set_title('Exact') 1523 ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 1524 ax[1].set_title('Predicted') 1525 1526 #Title 1527 fig.suptitle(title_2d) 1528 fig.tight_layout() 1529 1530 #Show and save 1531 fig = plt.gcf() 1532 if show: 1533 plt.show() 1534 if save: 1535 fig.savefig(file_name + '_2d.png') 1536 plt.close()
Plot the prediction of a 1D PINN
Parameters
- times (int): Number of points along the time interval to plot. Default 5
- xt (jax.numpy.array): Test data xt array
- u (jax.numpy.array): Test data u(x,t) array
- upred (jax.numpy.array): Predicted upred(x,t) array on test data
- d2 (logical): Whether to plot 2D plot. Default True
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- title_1d (str): Title of 1D plot
- title_2d (str): Title of 2D plot
Returns
- None
1539def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True): 1540 """ 1541 Plot the prediction of a PINN with 2D output 1542 ---------- 1543 Parameters 1544 ---------- 1545 times : int 1546 1547 Number of points along the time interval to plot. Default 5 1548 1549 xt : jax.numpy.array 1550 1551 Test data xt array 1552 1553 u : jax.numpy.array 1554 1555 Test data u(x,t) array 1556 1557 upred : jax.numpy.array 1558 1559 Predicted upred(x,t) array on test data 1560 1561 save : logical 1562 1563 Whether to save the plots. Default False 1564 1565 show : logical 1566 1567 Whether to show the plots. Default True 1568 1569 file_name : str 1570 1571 File prefix to save the plots. Default 'result_pinn' 1572 1573 title : str 1574 1575 Title of plot 1576 1577 plot_test : logical 1578 1579 Whether to plot the test data. Default True 1580 1581 Returns 1582 ------- 1583 None 1584 """ 1585 #Initialize 1586 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 1587 tlo = jnp.min(xt[:,-1]) 1588 tup = jnp.max(xt[:,-1]) 1589 xlo = jnp.min(u[:,0]) 1590 xlo = xlo - 0.1*jnp.abs(xlo) 1591 xup = jnp.max(u[:,0]) 1592 xup = xup + 0.1*jnp.abs(xup) 1593 ylo = jnp.min(u[:,1]) 1594 ylo = ylo - 0.1*jnp.abs(ylo) 1595 yup = jnp.max(u[:,1]) 1596 yup = yup + 0.1*jnp.abs(yup) 1597 k = 0 1598 t_values = np.linspace(tlo,tup,times) 1599 1600 #Create 1601 for i in range(int(times/5)): 1602 for j in range(5): 1603 if k < len(t_values): 1604 t = t_values[k] 1605 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 1606 xpred_plot = upred[xt[:,-1] == t,0] 1607 ypred_plot = upred[xt[:,-1] == t,1] 1608 if plot_test: 1609 x_plot = u[xt[:,-1] == t,0] 1610 y_plot = u[xt[:,-1] == t,1] 1611 if int(times/5) > 1: 1612 if plot_test: 1613 ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 1614 ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction') 1615 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 1616 ax[i,j].set_xlabel(' ') 1617 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1618 else: 1619 if plot_test: 1620 ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 1621 ax[j].plot(xpred_plot,ypred,'r-',linewidth=2,label='Prediction') 1622 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 1623 ax[j].set_xlabel(' ') 1624 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1625 k = k + 1 1626 1627 #Title 1628 fig.suptitle(title) 1629 fig.tight_layout() 1630 1631 #Show and save 1632 fig = plt.gcf() 1633 if show: 1634 plt.show() 1635 if save: 1636 fig.savefig(file_name + '_slices.png') 1637 plt.close()
Plot the prediction of a PINN with 2D output
Parameters
- times (int): Number of points along the time interval to plot. Default 5
- xt (jax.numpy.array): Test data xt array
- u (jax.numpy.array): Test data u(x,t) array
- upred (jax.numpy.array): Predicted upred(x,t) array on test data
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- title (str): Title of plot
- plot_test (logical): Whether to plot the test data. Default True
Returns
- None
1640def get_train_data(train_data): 1641 """ 1642 Process training sample 1643 ---------- 1644 1645 Parameters 1646 ---------- 1647 train_data : dict 1648 1649 A dictionay with train data generated by the jinnax.data.generate_PINNdata function 1650 1651 Returns 1652 ------- 1653 dict with the processed training data 1654 """ 1655 xdata = None 1656 ydata = None 1657 xydata = None 1658 if train_data['sensor'] is not None: 1659 sensor_sample = train_data['sensor'].shape[0] 1660 xdata = train_data['sensor'] 1661 ydata = train_data['usensor'] 1662 xydata = jnp.column_stack((train_data['sensor'],train_data['usensor'])) 1663 else: 1664 sensor_sample = 0 1665 if train_data['boundary'] is not None: 1666 boundary_sample = train_data['boundary'].shape[0] 1667 if xdata is not None: 1668 xdata = jnp.vstack((xdata,train_data['boundary'])) 1669 ydata = jnp.vstack((ydata,train_data['uboundary'])) 1670 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary'])))) 1671 else: 1672 xdata = train_data['boundary'] 1673 ydata = train_data['uboundary'] 1674 xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary'])) 1675 else: 1676 boundary_sample = 0 1677 if train_data['initial'] is not None: 1678 initial_sample = train_data['initial'].shape[0] 1679 if xdata is not None: 1680 xdata = jnp.vstack((xdata,train_data['initial'])) 1681 ydata = jnp.vstack((ydata,train_data['uinitial'])) 1682 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial'])))) 1683 else: 1684 xdata = train_data['initial'] 1685 ydata = train_data['uinitial'] 1686 xydata = jnp.column_stack((train_data['initial'],train_data['uinitial'])) 1687 else: 1688 initial_sample = 0 1689 if train_data['collocation'] is not None: 1690 collocation_sample = train_data['collocation'].shape[0] 1691 else: 1692 collocation_sample = 0 1693 1694 return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
Process training sample
Parameters
- train_data (dict): A dictionay with train data generated by the jinnax.data.generate_PINNdata function
Returns
- dict with the processed training data
1697def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1): 1698 """ 1699 Process the training of a Physics-informed Neural Network 1700 ---------- 1701 1702 Parameters 1703 ---------- 1704 test_data : dict 1705 1706 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1707 1708 file_name : str 1709 1710 Name of the files saved during training 1711 1712 at_each : int 1713 1714 Compute results for epochs multiple of at_each. Default 100 1715 1716 bolstering : logical 1717 1718 Whether to compute bolstering mean square error. Default True 1719 1720 mc_sample : int 1721 1722 Number of sample for Monte Carlo integration in bolstering. Default 10000 1723 1724 save : logical 1725 1726 Whether to save the training results. Default False 1727 1728 file_name_save : str 1729 1730 File prefix to save the plots and the L2 error. Default 'result_pinn' 1731 1732 key : int 1733 1734 Key for random samples in bolstering. Default 0 1735 1736 ec : float 1737 1738 Stopping criteria error for EM algorithm in bolstering. Default 1e-6 1739 1740 lamb : float 1741 1742 Hyperparameter of EM algorithm in bolstering. Default 1 1743 1744 Returns 1745 ------- 1746 pandas data frame with training results 1747 """ 1748 #Config 1749 config = pickle.load(open(file_name + '_config.pickle', 'rb')) 1750 epochs = config['epochs'] 1751 train_data = config['train_data'] 1752 forward = config['forward'] 1753 1754 #Get train data 1755 td = get_train_data(train_data) 1756 xydata = td['xy'] 1757 xdata = td['x'] 1758 ydata = td['y'] 1759 sensor_sample = td['sensor_sample'] 1760 boundary_sample = td['boundary_sample'] 1761 initial_sample = td['initial_sample'] 1762 collocation_sample = td['collocation_sample'] 1763 1764 #Generate keys 1765 if bolstering: 1766 keys = jax.random.split(jax.random.PRNGKey(key),epochs) 1767 1768 #Initialize loss 1769 train_mse = [] 1770 test_mse = [] 1771 train_L2 = [] 1772 test_L2 = [] 1773 bolstX = [] 1774 bolstXY = [] 1775 loss = [] 1776 time = [] 1777 ep = [] 1778 1779 #Process training 1780 with alive_bar(epochs) as bar: 1781 for e in range(epochs): 1782 if (e % at_each == 0 and at_each != epochs) or e == epochs - 1: 1783 ep = ep + [e] 1784 1785 #Read parameters 1786 params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb')) 1787 1788 #Time 1789 time = time + [params['time']] 1790 1791 #Define learned function 1792 def psi(x): 1793 return forward(x,params['params']['net']) 1794 1795 #Train MSE and L2 1796 if xdata is not None: 1797 train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()] 1798 train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()] 1799 else: 1800 train_mse = train_mse + [None] 1801 train_L2 = train_L2 + [None] 1802 1803 #Test MSE and L2 1804 test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()] 1805 test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()] 1806 1807 #Bolstering 1808 if bolstering: 1809 bX = [] 1810 bXY = [] 1811 for method in ['chi','mm','mpe']: 1812 kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 1813 kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 1814 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 1815 bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()] 1816 for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]: 1817 kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias) 1818 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 1819 bolstX = bolstX + [bX] 1820 bolstXY = bolstXY + [bXY] 1821 else: 1822 bolstX = bolstX + [None] 1823 bolstXY = bolstXY + [None] 1824 1825 #Loss 1826 loss = loss + [params['loss'].tolist()] 1827 1828 #Delete 1829 del params, psi 1830 #Update alive_bar 1831 bar() 1832 1833 #Bolstering results 1834 if bolstering: 1835 bolstX = jnp.array(bolstX) 1836 bolstXY = jnp.array(bolstXY) 1837 1838 #Create data frame 1839 if bolstering: 1840 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 1841 train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]), 1842 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4']) 1843 else: 1844 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 1845 train_mse,test_mse,train_L2,test_L2]), 1846 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2']) 1847 if save: 1848 df.to_csv(file_name_save + '.csv',index = False) 1849 1850 return df
Process the training of a Physics-informed Neural Network
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- at_each (int): Compute results for epochs multiple of at_each. Default 100
- bolstering (logical): Whether to compute bolstering mean square error. Default True
- mc_sample (int): Number of sample for Monte Carlo integration in bolstering. Default 10000
- save (logical): Whether to save the training results. Default False
- file_name_save (str): File prefix to save the plots and the L2 error. Default 'result_pinn'
- key (int): Key for random samples in bolstering. Default 0
- ec (float): Stopping criteria error for EM algorithm in bolstering. Default 1e-6
- lamb (float): Hyperparameter of EM algorithm in bolstering. Default 1
Returns
- pandas data frame with training results
1853def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10): 1854 """ 1855 Demo video with the training of a 1D PINN 1856 ---------- 1857 1858 Parameters 1859 ---------- 1860 test_data : dict 1861 1862 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1863 1864 file_name : str 1865 1866 Name of the files saved during training 1867 1868 at_each : int 1869 1870 Compute results for epochs multiple of at_each. Default 100 1871 1872 times : int 1873 1874 Number of points along the time interval to plot. Default 5 1875 1876 d2 : logical 1877 1878 Whether to make video demo of 2D plot. Default True 1879 1880 file_name_save : str 1881 1882 File prefix to save the plots and videos. Default 'result_pinn_demo' 1883 1884 title : str 1885 1886 Title for plots 1887 1888 framerate : int 1889 1890 Framerate for video. Default 10 1891 1892 Returns 1893 ------- 1894 None 1895 """ 1896 #Config 1897 with open(file_name + '_config.pickle', 'rb') as file: 1898 config = pickle.load(file) 1899 epochs = config['epochs'] 1900 train_data = config['train_data'] 1901 forward = config['forward'] 1902 1903 #Get train data 1904 td = get_train_data(train_data) 1905 xt = td['x'] 1906 u = td['y'] 1907 1908 #Create folder to save plots 1909 os.system('mkdir ' + file_name_save) 1910 1911 #Create images 1912 k = 1 1913 with alive_bar(epochs) as bar: 1914 for e in range(epochs): 1915 if e % at_each == 0 or e == epochs - 1: 1916 #Read parameters 1917 params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1918 1919 #Define learned function 1920 def psi(x): 1921 return forward(x,params['params']['net']) 1922 1923 #Compute L2 train, L2 test and loss 1924 loss = params['loss'] 1925 L2_train = L2error(psi(xt),u) 1926 L2_test = L2error(psi(test_data['xt']),test_data['u']) 1927 title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6)) 1928 1929 #Save plot 1930 plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch) 1931 k = k + 1 1932 1933 #Delete 1934 del params, psi, loss, L2_train, L2_test, title_epoch 1935 #Update alive_bar 1936 bar() 1937 #Create demo video 1938 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4') 1939 if d2: 1940 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')
Demo video with the training of a 1D PINN
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- at_each (int): Compute results for epochs multiple of at_each. Default 100
- times (int): Number of points along the time interval to plot. Default 5
- d2 (logical): Whether to make video demo of 2D plot. Default True
- file_name_save (str): File prefix to save the plots and videos. Default 'result_pinn_demo'
- title (str): Title for plots
- framerate (int): Framerate for video. Default 10
Returns
- None
1943def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10): 1944 """ 1945 Demo video with the time evolution of a 1D PINN 1946 ---------- 1947 1948 Parameters 1949 ---------- 1950 test_data : dict 1951 1952 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1953 1954 file_name : str 1955 1956 Name of the files saved during training 1957 1958 epochs : list 1959 1960 Which training epochs to plot 1961 1962 file_name_save : str 1963 1964 File prefix to save the plots and video. Default 'result_pinn_time_demo' 1965 1966 title : str 1967 1968 Title for plots 1969 1970 framerate : int 1971 1972 Framerate for video. Default 10 1973 1974 Returns 1975 ------- 1976 None 1977 """ 1978 #Config 1979 with open(file_name + '_config.pickle', 'rb') as file: 1980 config = pickle.load(file) 1981 train_data = config['train_data'] 1982 forward = config['forward'] 1983 1984 #Create folder to save plots 1985 os.system('mkdir ' + file_name_save) 1986 1987 #Plot parameters 1988 tdom = jnp.unique(test_data['xt'][:,-1]) 1989 ylo = jnp.min(test_data['u']) 1990 ylo = ylo - 0.1*jnp.abs(ylo) 1991 yup = jnp.max(test_data['u']) 1992 yup = yup + 0.1*jnp.abs(yup) 1993 1994 #Open PINN for each epoch 1995 results = [] 1996 upred = [] 1997 for e in epochs: 1998 tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1999 results = results + [tmp] 2000 upred = upred + [forward(test_data['xt'],tmp['params']['net'])] 2001 2002 #Create images 2003 k = 1 2004 with alive_bar(len(tdom)) as bar: 2005 for t in tdom: 2006 #Test data 2007 xt_step = test_data['xt'][test_data['xt'][:,-1] == t] 2008 u_step = test_data['u'][test_data['xt'][:,-1] == t] 2009 #Initialize plot 2010 if len(epochs) > 1: 2011 fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2)) 2012 else: 2013 fig, ax = plt.subplots(1,1,figsize = (10,5)) 2014 #Create 2015 index = 0 2016 if int(len(epochs)/2) > 1: 2017 for i in range(int(len(epochs)/2)): 2018 for j in range(min(2,len(epochs))): 2019 upred_step = upred[index][test_data['xt'][:,-1] == t] 2020 ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2021 ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2022 ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2023 ax[i,j].set_xlabel(' ') 2024 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2025 index = index + 1 2026 elif len(epochs) > 1: 2027 for j in range(2): 2028 upred_step = upred[index][test_data['xt'][:,-1] == t] 2029 ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2030 ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2031 ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2032 ax[j].set_xlabel(' ') 2033 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2034 index = index + 1 2035 else: 2036 upred_step = upred[index][test_data['xt'][:,-1] == t] 2037 ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2038 ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2039 ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2040 ax.set_xlabel(' ') 2041 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2042 index = index + 1 2043 2044 2045 #Title 2046 fig.suptitle(title + 't = ' + str(round(t,4))) 2047 fig.tight_layout() 2048 2049 #Show and save 2050 fig = plt.gcf() 2051 fig.savefig(file_name_save + '/' + str(k) + '.png') 2052 k = k + 1 2053 plt.close() 2054 bar() 2055 2056 #Create demo video 2057 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
Demo video with the time evolution of a 1D PINN
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- epochs (list): Which training epochs to plot
- file_name_save (str): File prefix to save the plots and video. Default 'result_pinn_time_demo'
- title (str): Title for plots
- framerate (int): Framerate for video. Default 10
Returns
- None