jinnax.nn
1#Functions to train NN 2import jax 3import jax.numpy as jnp 4import optax 5from alive_progress import alive_bar 6import math 7import time 8import pandas as pd 9import numpy as np 10import matplotlib.pyplot as plt 11import dill as pickle 12from genree import bolstering as gb 13from genree import kernel as gk 14from jax import random 15from jinnax import data as jd 16import os 17import numpy as np 18from scipy.fft import dst, idst 19from itertools import product 20from functools import partial 21import orthax 22from jax import lax 23from jaxopt import LBFGS 24from jax.tree_util import tree_flatten 25 26__docformat__ = "numpy" 27 28def assert_tree_float64(tree, name="object"): 29 """ 30 Raises an error if any leaf in a pytree is not float64. 31 """ 32 leaves, _ = tree_flatten(tree) 33 bad = [ 34 (i, x.dtype) 35 for i, x in enumerate(leaves) 36 if hasattr(x, "dtype") and x.dtype != jnp.float64 37 ] 38 if bad: 39 raise RuntimeError( 40 f"[FLOAT64 CHECK FAILED] {name} contains non-float64 dtypes:\n" 41 + "\n".join([f" leaf {i}: {dtype}" for i, dtype in bad]) 42 ) 43 44 45def warn_tree_float64(tree, name="object"): 46 """ 47 Same as assert_tree_float64, but prints a warning instead of raising. 48 """ 49 leaves, _ = tree_flatten(tree) 50 dtypes = {x.dtype for x in leaves if hasattr(x, "dtype")} 51 if dtypes != {jnp.float64}: 52 print(f"[WARNING] {name} dtypes = {dtypes}") 53 54 55def check_grads_float64(grads): 56 leaves, _ = tree_flatten(grads) 57 bad = [ 58 (i, x.dtype) 59 for i, x in enumerate(leaves) 60 if hasattr(x, "dtype") and x.dtype != jnp.float64 61 ] 62 if bad: 63 raise RuntimeError( 64 "[FLOAT64 CHECK FAILED] Gradients are not float64:\n" 65 + "\n".join([f" grad {i}: {dtype}" for i, dtype in bad]) 66 ) 67 68 69def assert_lbfgs_state_float64(state): 70 """ 71 Ensures that all *floating-point* values inside the LBFGS state 72 are float64. Integers and booleans are allowed. 73 """ 74 leaves, _ = tree_flatten(state) 75 bad = [] 76 for i, x in enumerate(leaves): 77 if isinstance(x, jnp.ndarray): 78 if jnp.issubdtype(x.dtype, jnp.floating): 79 if x.dtype != jnp.float64: 80 bad.append((i, x.dtype)) 81 if bad: 82 raise RuntimeError( 83 "[FLOAT64 CHECK FAILED] LBFGS floating-point buffers are not float64:\n" 84 + "\n".join([f" leaf {i}: {dtype}" for i, dtype in bad]) 85 ) 86 87 88#Change to float64 89def to_float64(tree): 90 def cast(x): 91 if isinstance(x, (jax.Array, np.ndarray, float, int)): 92 return x.astype(jnp.float64) 93 return x 94 return jax.tree_util.tree_map(cast, tree) 95 96 97#MSE 98@jax.jit 99def MSE(pred,true): 100 """ 101 Squared error 102 ---------- 103 Parameters 104 ---------- 105 pred : jax.numpy.array 106 107 A JAX numpy array with the predicted values 108 109 true : jax.numpy.array 110 111 A JAX numpy array with the true values 112 113 Returns 114 ------- 115 squared error 116 """ 117 return (true - pred) ** 2 118 119#MSE self-adaptative 120@jax.jit 121def MSE_SA(pred,true,w,q = 2): 122 """ 123 Self-adaptative squared error 124 ---------- 125 Parameters 126 ---------- 127 pred : jax.numpy.array 128 129 A JAX numpy array with the predicted values 130 131 true : jax.numpy.array 132 133 A JAX numpy array with the true values 134 135 weight : jax.numpy.array 136 137 A JAX numpy array with the weights 138 139 q : float 140 141 Power for the weights mask 142 143 Returns 144 ------- 145 self-adaptative squared error with polynomial mask 146 """ 147 return (w ** q) * ((true - pred) ** 2) 148 149#L2 error 150@jax.jit 151def L2error(pred,true): 152 """ 153 L2-error in percentage (%) 154 ---------- 155 Parameters 156 ---------- 157 pred : jax.numpy.array 158 159 A JAX numpy array with the predicted values 160 161 true : jax.numpy.array 162 163 A JAX numpy array with the true values 164 165 Returns 166 ------- 167 L2-error 168 """ 169 return 100*jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2)) 170 171#Auxialiry functions to sample singular Matern 172def idst1(x,axis = -1): 173 """ 174 Inverse Discrete Sine Transform of type I with orthonormal scaling 175 ---------- 176 Parameters 177 ---------- 178 x : jax.numpy.array 179 180 Array to apply the transformation 181 182 axis : int 183 184 Axis to apply the transformation over 185 186 Returns 187 ------- 188 jax.numpy.array 189 """ 190 return idst(x,type = 1,axis = axis,norm = 'ortho') 191 192def dstn(x,axes = None): 193 """ 194 Discrete Sine Transform of type I with orthonormal scaling over many axes 195 ---------- 196 Parameters 197 ---------- 198 x : jax.numpy.array 199 200 Array to apply the transformation 201 202 axes : int 203 204 Axes to apply the transformation over 205 206 Returns 207 ------- 208 jax.numpy.array 209 """ 210 if axes is None: 211 axes = tuple(range(x.ndim)) 212 y = x 213 for ax in axes: 214 y = dst(y,type = 1,axis = ax,norm = 'ortho') 215 return y 216 217def idstn(x,axes = None): 218 """ 219 Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes 220 ---------- 221 Parameters 222 ---------- 223 x : jax.numpy.array 224 225 Array to apply the transformation 226 227 axes : tuple 228 229 Axes to apply the transformation over 230 231 Returns 232 ------- 233 jax.numpy.array 234 """ 235 if axes is None: 236 axes = tuple(range(x.ndim)) 237 y = x 238 for ax in axes: 239 y = idst1(y,axis = ax) 240 return y 241 242def dirichlet_eigs_nd(n,L): 243 """ 244 Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle 245 ---------- 246 Parameters 247 ---------- 248 n : list 249 250 List with the number of points in the grid in each dimension 251 252 L : list 253 254 List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero 255 256 Returns 257 ------- 258 jax.numpy.array 259 """ 260 #Unidimensional eigenvalues 261 lam_axes = [] 262 for ni, Li in zip(n,L): 263 h = Li / (ni + 1) 264 k = jnp.arange(1,ni + 1) 265 ln = (2 / (h*h)) * (1 - jnp.cos(jnp.pi * k / (ni + 1))) 266 lam_axes.append(ln) 267 grids = jnp.meshgrid(*lam_axes, indexing='ij') 268 Lam = jnp.zeros_like(grids[0]) 269 for g in grids: 270 Lam += g 271 return Lam 272 273 274#Sample from d-dimensional Matern process 275def generate_matern_sample(key,d = 2,N = 128,L = 1.0,kappa = 1,alpha = 1,sigma = 1,periodic = False): 276 """ 277 Sample d-dimensional Matern process 278 ---------- 279 Parameters 280 ---------- 281 key : int 282 283 Seed for randomization 284 285 d : int 286 287 Dimension. Default 2 288 289 N : int 290 291 Size of grid in each dimension. Default 128 292 293 L : list of float 294 295 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 296 297 kappa,alpha,sigma : float 298 299 Parameters of the Matern process 300 301 periodic : logical 302 303 Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT 304 305 Returns 306 ------- 307 jax.numpy.array 308 """ 309 if periodic: 310 #Shape and key 311 key = jax.random.PRNGKey(key) 312 shape = (N,) * d 313 if isinstance(L,float) or isinstance(L,int): 314 L = d*[L] 315 if isinstance(N,float) or isinstance(N,int): 316 N = d*[N] 317 318 #Setup Frequency Grid (2D) 319 freq = [jnp.fft.fftfreq(N[j],d = L[j]/N[j]) * 2 * jnp.pi for j in range(d)] 320 grids = jnp.meshgrid(*freq, indexing='ij') 321 sq_norm_xi = sum(g**2 for g in grids) 322 323 #Generate White Noise in Fourier Space 324 key_re, key_im = jax.random.split(key) 325 white_noise_f = (jax.random.normal(key_re, shape) + 326 1j * jax.random.normal(key_im, shape)) 327 328 #Apply the Whittle Filter 329 amplitude_filter = (kappa ** 2 + sq_norm_xi) ** (-alpha / 2) 330 field_f = white_noise_f * amplitude_filter 331 332 #Transform back to Physical Space 333 sample = jnp.real(jnp.fft.ifftn(field_f)) 334 return sigma*sample 335 else: #NOT JAX 336 #Shape and key 337 rng = np.random.default_rng(seed = key) 338 if isinstance(L,float) or isinstance(L,int): 339 L = d*[L] 340 if isinstance(N,float) or isinstance(N,int): 341 N = d*[N] 342 shape = tuple(N) 343 344 #White noise in real space 345 W = rng.standard_normal(size = shape) 346 347 #To Dirichlet eigenbasis via separable DST-I (orthonormal) 348 W_hat = dstn(W) 349 350 #Discrete Dirichlet Laplacian eigenvalues 351 lam = dirichlet_eigs_nd(N, L) 352 353 #Spectral filter 354 filt = ((kappa + lam) ** (-alpha/2)) 355 psi_hat = filt * W_hat 356 357 #Back to real space 358 psi = idstn(psi_hat) 359 return jnp.array(sigma*psi) 360 361#Vectorized generate_matern_sample 362def generate_matern_sample_batch(d = 2,N = 512,L = 1.0,kappa = 10.0,alpha = 1,sigma = 10,periodic = False): 363 """ 364 Create function to sample d-dimensional Matern process 365 ---------- 366 Parameters 367 ---------- 368 d : int 369 370 Dimension. Default 2 371 372 N : int 373 374 Size of grid in each dimension. Default 128 375 376 L : list of float 377 378 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 379 380 kappa,alpha,sigma : float 381 382 Parameters of the Matern process 383 384 periodic : logical 385 386 Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT 387 388 Returns 389 ------- 390 function 391 """ 392 if periodic: 393 return jax.vmap(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)) 394 else: 395 return lambda keys: jnp.array(np.apply_along_axis(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic),1,keys.reshape((keys.shape[0],1)))) 396 397#Build function to compute the eigenfunctions of Laplacian 398def eigenf_laplace(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None): 399 """ 400 Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. 401 ---------- 402 Parameters 403 ---------- 404 L_vec : list of float 405 406 The domain of the function in each coordinate is [0,L[1]] 407 408 kmax_per_axis : list 409 410 List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions 411 412 bc : str 413 414 Boundary condition. 'dirichlet' or 'neumann' 415 416 max_ef : int 417 418 Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions 419 420 Returns 421 ------- 422 function to compute eigenfunctions,eigenvalues of the eigenfunctions considered 423 """ 424 #Parameters 425 L_vec = jnp.asarray(L_vec) 426 d = L_vec.shape[0] 427 bc = bc.lower() 428 429 #Maximum number of functions 430 if max_ef is None: 431 if d == 1: 432 max_ef = jnp.max(jnp.array(kmax_per_axis)) 433 else: 434 max_ef = jnp.max(d * jnp.array(kmax_per_axis)) 435 436 #Build the candidate multi-indices per axis 437 kmax_per_axis = list(map(int, kmax_per_axis)) 438 if bc.startswith("d"): 439 axis_ranges = [range(1, km + 1) for km in kmax_per_axis] 440 elif bc.startswith("n"): 441 axis_ranges = [range(0, km + 1) for km in kmax_per_axis] 442 443 #Get all multi-indices 444 Ks_list = list(product(*axis_ranges)) 445 Ks = jnp.array(Ks_list) 446 447 #Eigenvalues of the continuous Laplacian 448 pi_over_L = jnp.pi / L_vec 449 lambdas_all = jnp.sum((Ks * pi_over_L) ** 2, axis=1) 450 451 #Sort by eigenvalue 452 order = jnp.argsort(lambdas_all) 453 Ks = Ks[order] 454 lambdas_all = lambdas_all[order] 455 456 #Keep first max_ef 457 Ks = Ks[:max_ef] 458 lambdas = lambdas_all[:max_ef] 459 m = Ks.shape[0] 460 461 #Precompute per-feature normalization factor (closed form) 462 def per_axis_norm_factor(k_i, L_i, is_dirichlet): 463 if is_dirichlet: 464 return jnp.sqrt(2 / L_i) 465 else: 466 return jnp.where(k_i == 0, jnp.sqrt(1 / L_i), jnp.sqrt(2 / L_i)) 467 if bc.startswith("d"): 468 nf = jnp.prod(jnp.sqrt(2 / L_vec)[None, :],axis = 1) 469 norm_factors = jnp.ones((m,)) * nf 470 else: 471 # per-mode product across axes 472 def nf_row(k_row): 473 return jnp.prod(per_axis_norm_factor(k_row, L_vec, False)) 474 norm_factors = jax.vmap(nf_row)(Ks) 475 476 #Build the callable function 477 Ks_int = Ks # float array, but only integer values 478 L_vec_f = L_vec 479 @jax.jit 480 def phi(x): 481 x = jnp.asarray(x) 482 #Initialize with ones 483 vals = jnp.ones(x.shape[:-1] + (m,)) 484 #Compute eigenfunction 485 for i in range(d): 486 ang = (jnp.pi / L_vec_f[i]) * x[..., i][..., None] * Ks_int[:, i] 487 if bc.startswith("d"): 488 comp = jnp.sin(ang) 489 else: 490 comp = jnp.cos(ang) 491 vals = vals * comp 492 #Apply L2-normalizing constants 493 vals = vals * norm_factors[None, ...] if vals.ndim > 1 else vals * norm_factors 494 return vals 495 return phi, lambdas 496 497#Compute multiple frequences of domain aware fourrier fesatures 498def multiple_daff(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None): 499 """ 500 Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain. 501 ---------- 502 Parameters 503 ---------- 504 L_vec : list of lists of float 505 506 List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]] 507 508 kmax_per_axis : list 509 510 List with the maximum number of eigenfunctions per dimension. 511 512 bc : str 513 514 Boundary condition. 'dirichlet' or 'neumann' 515 516 max_ef : int 517 518 Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions 519 520 Returns 521 ------- 522 function to compute daff,eigenvalues of the eigenfunctions considered 523 """ 524 psi = [] 525 lamb = [] 526 for L in L_vec: 527 tmp,l = eigenf_laplace(L,kmax_per_axis,bc,max_ef) #Get function 528 lamb.append(l) 529 psi.append(tmp) 530 del tmp 531 #Create function to compute features 532 @jax.jit 533 def mff(x): 534 y = [] 535 for i in range(len(psi)): 536 y.append(psi[i](x)) 537 if len(psi) == 1: 538 return y[0] 539 else: 540 return jnp.concatenate(y,1) 541 return mff,jnp.concatenate(lamb) 542 543#Code for chebyshev polynomials writeen by AI (deprecated) 544def _chebyshev_T_all(t, K: int): 545 """ 546 Compute T_0..T_K(t) with the standard recurrence. 547 t shape should be (..., d). We DO NOT squeeze any axis to preserve 'd' 548 even when d == 1. 549 Returns: array of shape (K+1, ...) matching t's batch dims, including d. 550 """ 551 # Expect t to have last axis = d (keep it, even if d == 1) 552 T0 = jnp.ones_like(t) # (..., d) 553 if K == 0: 554 return T0[None, ...] # (1, ..., d) 555 556 T1 = t # (..., d) 557 if K == 1: 558 return jnp.stack([T0, T1], axis=0) # (2, ..., d) 559 560 def body(carry, _): 561 Tkm1, Tk = carry # each (..., d) 562 Tkp1 = 2 * t * Tk - Tkm1 # (..., d) 563 return (Tk, Tkp1), Tkp1 564 565 # K >= 2: produce T_2..T_K 566 (_, _), T2_to_TK = lax.scan(body, (T0, T1), jnp.arange(K - 1)) # (K-1, ..., d) 567 return jnp.concatenate([T0[None, ...], T1[None, ...], T2_to_TK], axis=0) # (K+1, ..., d) 568 569@partial(jax.jit,static_argnums=(2,)) # n is static here; compile once per n 570def multiple_cheb_fast(x, L_vec, n: int): 571 """ 572 x: (N, d) 573 L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension 574 n: number of k terms (static) 575 returns: (N, L*n) 576 """ 577 N, d = x.shape 578 L = L_vec.shape[0] 579 580 a = 0 581 b = L_vec # (L, d) 582 # Map x to t in [-1, 1] for each l, j: shape (L, N, d) 583 t = (2 * x[None, :, :] - (a + b)[:, None, :]) / (b - a)[:, None, :] 584 585 # Chebyshev T_0..T_{n+2} for all (L, N, d): shape (n+3, L, N, d) 586 T = _chebyshev_T_all(t, n + 2) 587 588 # phi_k = T_{k+2} - T_k, k = 0..n-1 => shape (n, L, N, d) 589 ks = jnp.arange(n) 590 phi = T[ks + 2, ...] - T[ks, ...] 591 592 # Multiply across dimensions (over the last axis = d) => (n, L, N) 593 z = jnp.prod(phi, axis=-1) 594 595 # Reorder to (N, L, n) then flatten to (N, L*n) 596 z = jnp.transpose(z, (2, 1, 0)).reshape(N, L * n) 597 return z 598 599def multiple_cheb(L_vec, n: int): 600 """ 601 Factory that closes over static n and L_vec (so shapes are constant). 602 """ 603 L_vec = jnp.asarray(L_vec) 604 @jax.jit # optional; multiple_cheb_fast is already jitted 605 def mcheb(x): 606 x = jnp.asarray(x) 607 return multiple_cheb_fast(x, L_vec, n) 608 return mcheb 609 610 611#Initialize fully connected neyral network Return the initial parameters and the function for the forward pass 612def fconNN(width,activation = jax.nn.tanh,key = 0,mlp = False,ftype = None,fargs = None,static = None,daff = None): 613 """ 614 Initialize fully connected neural network 615 ---------- 616 Parameters 617 ---------- 618 width : list 619 620 List with the layers width 621 622 activation : jax.nn activation 623 624 The activation function. Default jax.nn.tanh 625 626 key : int 627 628 Seed for parameters initialization. Default 0 629 630 mlp : logical 631 632 Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension. 633 634 ftype : str 635 636 Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'. 637 638 fargs : list 639 640 Arguments for deature transformation: 641 642 For 'ff': A list with the number of frequences and value of greatest frequence standard deviation. 643 644 For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet. 645 646 static : function 647 648 A static function to sum to the neural network output. 649 650 daff : list 651 652 List with function to compute daff and the number of daff. If None computes assuming rectangular domain. 653 654 Returns 655 ------- 656 dict with initial parameters and the function for the forward pass 657 """ 658 #Initialize parameters with Glorot initialization 659 initializer = jax.nn.initializers.glorot_normal() 660 params = list() 661 if static is None: 662 static = lambda x: 0 663 664 #Feature mapping 665 if ftype == 'ff': #Fourrier features 666 for s in range(fargs[0]): 667 sd = fargs[1] ** ((s + 1)/fargs[0]) 668 if s == 0: 669 Bff = sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))) 670 else: 671 Bff = jnp.append(Bff,sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))),1) 672 @jax.jit 673 def phi(x): 674 x = x @ Bff 675 return jnp.concatenate([jnp.sin(2 * jnp.pi * x),jnp.cos(2 * jnp.pi * x)],axis = -1) 676 width = width[1:] 677 width[0] = 2*Bff.shape[1] 678 elif ftype == 'daff' or ftype == 'daff_bias': 679 if not isinstance(fargs, dict): 680 fargs = {'L': fargs,'bc': "dirichlet"} 681 if daff is None: 682 phi,lamb = multiple_daff(list(fargs.values())[0],kmax_per_axis = [width[1]] * width[0],bc = list(fargs.values())[1]) 683 width = width[1:] 684 width[0] = lamb.shape[0] 685 else: 686 phi = daff[0] 687 width = width[1:] 688 width[0] = daff[1] 689 elif ftype == 'cheb' or ftype == 'cheb_bias': 690 phi = multiple_cheb(fargs,n = width[1]) 691 width = width[1:] 692 width[0] = len(fargs)*width[0] 693 else: 694 @jax.jit 695 def phi(x): 696 return x 697 698 #Initialize parameters 699 if mlp: 700 k = jax.random.split(jax.random.PRNGKey(key),4) 701 WU = initializer(k[0],(width[0],width[1])) 702 BU = initializer(k[1],(1,width[1])) 703 WV = initializer(k[2],(width[0],width[1])) 704 BV = initializer(k[3],(1,width[1])) 705 params.append({'WU':WU,'BU':BU,'WV':WV,'BV':BV}) 706 key = jax.random.split(jax.random.PRNGKey(key + 1),len(width)-1) #Seed for initialization 707 for key,lin,lout in zip(key,width[:-1],width[1:]): 708 W = initializer(key,(lin,lout)) 709 B = initializer(key,(1,lout)) 710 params.append({'W':W,'B':B}) 711 712 #Define function for forward pass 713 if mlp: 714 if ftype != 'daff' and ftype != 'cheb': 715 @jax.jit 716 def forward(x,params): 717 encode,*hidden,output = params 718 sx = static(x) 719 x = phi(x) 720 U = activation(x @ encode['WU'] + encode['BU']) 721 V = activation(x @ encode['WV'] + encode['BV']) 722 for layer in hidden: 723 x = activation(x @ layer['W'] + layer['B']) 724 x = x * U + (1 - x) * V 725 return x @ output['W'] + output['B'] + sx 726 else: 727 @jax.jit 728 def forward(x,params): 729 encode,*hidden,output = params 730 sx = static(x) 731 x = phi(x) 732 U = activation(x @ encode['WU']) 733 V = activation(x @ encode['WV']) 734 for layer in hidden: 735 x = activation(x @ layer['W']) 736 x = x * U + (1 - x) * V 737 return x @ output['W'] + sx 738 else: 739 if ftype != 'daff' and ftype != 'cheb': 740 @jax.jit 741 def forward(x,params): 742 *hidden,output = params 743 sx = static(x) 744 x = phi(x) 745 for layer in hidden: 746 x = activation(x @ layer['W'] + layer['B']) 747 return x @ output['W'] + output['B'] + sx 748 else: 749 @jax.jit 750 def forward(x,params): 751 *hidden,output = params 752 sx = static(x) 753 x = phi(x) 754 for layer in hidden: 755 x = activation(x @ layer['W']) 756 return x @ output['W'] + sx 757 758 #Return initial parameters and forward function 759 return {'params': params,'forward': forward} 760 761#Get activation from string 762def get_activation(act): 763 """ 764 Return activation function from string 765 ---------- 766 Parameters 767 ---------- 768 act : str 769 770 Name of the activation function. Default 'tanh' 771 772 Returns 773 ------- 774 jax.nn activation function 775 """ 776 if act == 'tanh': 777 return jax.nn.tanh 778 elif act == 'relu': 779 return jax.nn.relu 780 elif act == 'relu6': 781 return jax.nn.relu6 782 elif act == 'sigmoid': 783 return jax.nn.sigmoid 784 elif act == 'softplus': 785 return jax.nn.softplus 786 elif act == 'sparse_plus': 787 return jax.nn.sparse_plus 788 elif act == 'soft_sign': 789 return jax.nn.soft_sign 790 elif act == 'silu': 791 return jax.nn.silu 792 elif act == 'swish': 793 return jax.nn.swish 794 elif act == 'log_sigmoid': 795 return jax.nn.log_sigmoid 796 elif act == 'leaky_relu': 797 return jax.nn.leaky_relu 798 elif act == 'hard_sigmoid': 799 return jax.nn.hard_sigmoid 800 elif act == 'hard_silu': 801 return jax.nn.hard_silu 802 elif act == 'hard_swish': 803 return jax.nn.hard_swish 804 elif act == 'hard_tanh': 805 return jax.nn.hard_tanh 806 elif act == 'elu': 807 return jax.nn.elu 808 elif act == 'celu': 809 return jax.nn.celu 810 elif act == 'selu': 811 return jax.nn.selu 812 elif act == 'gelu': 813 return jax.nn.gelu 814 elif act == 'glu': 815 return jax.nn.glu 816 elif act == 'squareplus': 817 return jax.nn.squareplus 818 elif act == 'mish': 819 return jax.nn.mish 820 821#Training PINN 822def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,mlp = False): 823 """ 824 Train a Physics-informed Neural Network 825 ---------- 826 Parameters 827 ---------- 828 data : dict 829 830 Data generated by the jinnax.data.generate_PINNdata function 831 832 width : list 833 834 A list with the width of each layer 835 836 pde : function 837 838 The partial differential operator. Its arguments are u, x and t 839 840 test_data : dict, None 841 842 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 843 844 epochs : int 845 846 Number of training epochs. Default 100 847 848 at_each : int 849 850 Save results for epochs multiple of at_each. Default 10 851 852 activation : str 853 854 The name of the activation function of the neural network. Default 'tanh' 855 856 neumann : logical 857 858 Whether to consider Neumann boundary conditions 859 860 oper_neumann : function 861 862 Penalization of Neumann boundary conditions 863 864 sa : logical 865 866 Whether to consider self-adaptative PINN 867 868 c : dict 869 870 Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1 871 872 inverse : logical 873 874 Whether to estimate parameters of the PDE 875 876 initial_par : jax.numpy.array 877 878 Initial value of the parameters of the PDE in an inverse problem 879 880 lr,b1,b2,eps,eps_root: float 881 882 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 883 884 key : int 885 886 Seed for parameters initialization. Default 0 887 888 epoch_print : int 889 890 Number of epochs to calculate and print test errors. Default 100 891 892 save : logical 893 894 Whether to save the current parameters. Default False 895 896 file_name : str 897 898 File prefix to save the current parameters. Default 'result_pinn' 899 900 exp_decay : logical 901 902 Whether to consider exponential decay of learning rate. Default False 903 904 transition_steps : int 905 906 Number of steps for exponential decay. Default 1000 907 908 decay_rate : float 909 910 Rate of exponential decay. Default 0.9 911 912 mlp : logical 913 914 Whether to consider modifed multi-layer perceptron 915 916 Returns 917 ------- 918 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time 919 """ 920 921 #Initialize architecture 922 nnet = fconNN(width,get_activation(activation),key,mlp) 923 forward = nnet['forward'] 924 925 #Initialize self adaptative weights 926 par_sa = {} 927 if sa: 928 #Initialize wheights close to zero 929 ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000) 930 if data['sensor'] is not None: 931 par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))}) 932 if data['initial'] is not None: 933 par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))}) 934 if data['collocation'] is not None: 935 par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))}) 936 if data['boundary'] is not None: 937 par_sa.update({'wb': c['wb'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))}) 938 939 #Store all parameters 940 params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa} 941 942 #Save config file 943 if save: 944 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 945 946 #Define loss function 947 if sa: 948 #Define loss function 949 @jax.jit 950 def lf(params,x): 951 loss = 0 952 if x['sensor'] is not None: 953 #Term that refers to sensor data 954 loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws'])) 955 if x['boundary'] is not None: 956 if neumann: 957 #Neumann coditions 958 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 959 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 960 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb'])) 961 else: 962 #Term that refers to boundary data 963 loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb'])) 964 if x['initial'] is not None: 965 #Term that refers to initial data 966 loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0'])) 967 if x['collocation'] is not None: 968 #Term that refers to collocation points 969 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 970 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 971 if inverse: 972 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr'])) 973 else: 974 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr'])) 975 return loss 976 else: 977 @jax.jit 978 def lf(params,x): 979 loss = 0 980 if x['sensor'] is not None: 981 #Term that refers to sensor data 982 loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor'])) 983 if x['boundary'] is not None: 984 if neumann: 985 #Neumann coditions 986 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 987 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 988 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb)) 989 else: 990 #Term that refers to boundary data 991 loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary'])) 992 if x['initial'] is not None: 993 #Term that refers to initial data 994 loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial'])) 995 if x['collocation'] is not None: 996 #Term that refers to collocation points 997 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 998 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 999 if inverse: 1000 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0)) 1001 else: 1002 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0)) 1003 return loss 1004 1005 #Initialize Adam Optmizer 1006 if exp_decay: 1007 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 1008 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 1009 opt_state = optimizer.init(params) 1010 1011 #Define the gradient function 1012 grad_loss = jax.jit(jax.grad(lf,0)) 1013 1014 #Define update function 1015 @jax.jit 1016 def update(opt_state,params,x): 1017 #Compute gradient 1018 grads = grad_loss(params,x) 1019 #Invert gradient of self-adaptative wheights 1020 if sa: 1021 for w in grads['sa']: 1022 grads['sa'][w] = - grads['sa'][w] 1023 #Calculate parameters updates 1024 updates, opt_state = optimizer.update(grads, opt_state) 1025 #Update parameters 1026 params = optax.apply_updates(params, updates) 1027 #Return state of optmizer and updated parameters 1028 return opt_state,params 1029 1030 ###Training### 1031 t0 = time.time() 1032 #Initialize alive_bar for tracing in terminal 1033 with alive_bar(epochs) as bar: 1034 #For each epoch 1035 for e in range(epochs): 1036 #Update optimizer state and parameters 1037 opt_state,params = update(opt_state,params,data) 1038 #After epoch_print epochs 1039 if e % epoch_print == 0: 1040 #Compute elapsed time and current error 1041 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6)) 1042 #If there is test data, compute current L2 error 1043 if test_data is not None: 1044 #Compute L2 error 1045 l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist() 1046 l = l + ' L2 error: ' + str(jnp.round(l2_test,3)) 1047 if inverse: 1048 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 1049 #Print 1050 print(l) 1051 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 1052 #Save current parameters 1053 pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1054 #Update alive_bar 1055 bar() 1056 #Define estimated function 1057 def u(xt): 1058 return forward(xt,params['net']) 1059 1060 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0} 1061 1062#Training PINN 1063def train_Matern_PINN(data,width,pde,test_data = None,params = None,d = 2,N = 128,L = 1,alpha = 1,kappa = 1,sigma = 100,bsize = 1024,resample = False,epochs = 100,at_each = 10,activation = 'tanh', 1064 neumann = False,oper_neumann = None,inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 1,save = False,file_name = 'result_pinn', 1065 exp_decay = True,transition_steps = 100,decay_rate = 0.9,mlp = True,ftype = None,fargs = None,q = 4,w = None,periodic = False,static = None,opt = 'LBFGS',float64 = False,restart = None): 1066 """ 1067 Train a Physics-informed Neural Network 1068 ---------- 1069 Parameters 1070 ---------- 1071 data : dict 1072 1073 Data generated by the jinnax.data.generate_PINNdata function 1074 1075 width : list 1076 1077 A list with the width of each layer 1078 1079 pde : function 1080 1081 The partial differential operator. Its arguments are u, x and t 1082 1083 test_data : dict, None 1084 1085 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 1086 1087 params : list 1088 1089 Initial parameters for the neural network. Default None to initialize randomly 1090 1091 d : int 1092 1093 Dimension of the problem including the time variable if present. Default 2 1094 1095 N : int 1096 1097 Size of grid in each dimension. Default 128 1098 1099 L : list of float 1100 1101 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 1102 1103 kappa,alpha,sigma : float 1104 1105 Parameters of the Matern process 1106 1107 bsize : int 1108 1109 Batch size for weak norm computation. Default 1024 1110 1111 resample : logical 1112 1113 Whether to resample the test functions at each epoch 1114 1115 epochs : int 1116 1117 Number of training epochs. Default 100 1118 1119 at_each : int 1120 1121 Save results for epochs multiple of at_each. Default 10 1122 1123 activation : str 1124 1125 The name of the activation function of the neural network. Default 'tanh' 1126 1127 neumann : logical 1128 1129 Whether to consider Neumann boundary conditions 1130 1131 oper_neumann : function 1132 1133 Penalization of Neumann boundary conditions 1134 1135 inverse : logical 1136 1137 Whether to estimate parameters of the PDE 1138 1139 initial_par : jax.numpy.array 1140 1141 Initial value of the parameters of the PDE in an inverse problem 1142 1143 lr,b1,b2,eps,eps_root: float 1144 1145 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 1146 1147 key : int 1148 1149 Seed for parameters initialization. Default 0 1150 1151 epoch_print : int 1152 1153 Number of epochs to calculate and print test errors. Default 1 1154 1155 save : logical 1156 1157 Whether to save the current parameters. Default False 1158 1159 file_name : str 1160 1161 File prefix to save the current parameters. Default 'result_pinn' 1162 1163 exp_decay : logical 1164 1165 Whether to consider exponential decay of learning rate. Default True 1166 1167 transition_steps : int 1168 1169 Number of steps for exponential decay. Default 100 1170 1171 decay_rate : float 1172 1173 Rate of exponential decay. Default 0.9 1174 1175 mlp : logical 1176 1177 Whether to consider modifed multilayer perceptron 1178 1179 ftype : str 1180 1181 Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'. 1182 1183 fargs : list 1184 1185 Arguments for deature transformation: 1186 1187 For 'ff': A list with the number of frequences and value of greatest frequence standard deviation. 1188 1189 For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet. 1190 1191 q : int 1192 1193 Power of weights mask. Default 4 1194 1195 w : dict 1196 1197 Initila weights for self-adaptive scheme. 1198 1199 periodic : logical 1200 1201 Whether to consider periodic test functions. Default False. 1202 1203 static : function 1204 1205 A static function to sum to the neural network output. 1206 1207 opt : str 1208 1209 Optimizer. Default LBFGS. 1210 1211 float64 : logical 1212 1213 Whether to train with float64 1214 1215 restart : int 1216 1217 Epochs to restart L-BFGS 1218 1219 Returns 1220 ------- 1221 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch 1222 """ 1223 #Initialize architecture 1224 nnet = fconNN(width,get_activation(activation),key,mlp,ftype,fargs,static) 1225 if float64: 1226 forward = lambda x,params: nnet['forward'](x,params).astype(jnp.float64) 1227 assert jax.config.jax_enable_x64, "JAX is NOT running in float64 mode!" 1228 else: 1229 forward = nnet['forward'] 1230 if params is not None: 1231 nnet['params'] = params 1232 if float64: 1233 data = to_float64(data) 1234 if test_data is not None: 1235 test_data = to_float64(test_data) 1236 1237 #Generate from Matern process 1238 if sigma > 0: 1239 if isinstance(L,float) or isinstance(L,int): 1240 L = d*[L] 1241 #Grid for weak norm 1242 if float64: 1243 grid = [jnp.linspace(0,L[i],N,dtype = jnp.float64) for i in range(d)] 1244 else: 1245 grid = [jnp.linspace(0,L[i],N) for i in range(d)] 1246 grid = jnp.meshgrid(*grid, indexing='ij') 1247 grid = jnp.stack(grid, axis=-1).reshape((-1, d)) 1248 #Set sigma 1249 if data['boundary'] is not None: 1250 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma) 1251 tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0]) 1252 if neumann: 1253 loss_boundary = oper_neumann(lambda x: forward(x,nnet['params']),data['boundary']) 1254 else: 1255 loss_boundary = jnp.mean(MSE(forward(data['boundary'],nnet['params']),data['uboundary'])) 1256 output_w = pde(lambda x: forward(x,nnet['params']),grid) 1257 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(tf) 1258 loss_res_weak = jnp.mean(integralOmega ** 2) 1259 sigma = float(jnp.sqrt(loss_boundary/loss_res_weak).tolist()) 1260 del gen 1261 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic) 1262 tf = sigma*tf 1263 else: 1264 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic) 1265 tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0]) 1266 if float64 and tf is not None: 1267 tf = to_float64(tf) 1268 grid = to_float64(grid) 1269 else: 1270 tf = None 1271 grid = None 1272 1273 #Define loss function 1274 @jax.jit 1275 def lf_each(params,x,k,tf,grid): 1276 if sigma > 0: 1277 #Term that refers to weak loss 1278 if resample: 1279 test_functions = to_float64(gen(jax.random.split(jax.random.PRNGKey(k[0]),(bsize,))[:,0])) 1280 else: 1281 test_functions = tf 1282 loss_sensor = loss_boundary = loss_initial = loss_res = loss_res_weak = 0 1283 if x['sensor'] is not None: 1284 #Term that refers to sensor data 1285 loss_sensor = MSE(forward(x['sensor'],params['net']),x['usensor']) 1286 if x['boundary'] is not None: 1287 if neumann: 1288 #Neumann coditions 1289 loss_boundary = oper_neumann(lambda x: forward(x,nnet['params']),x['boundary']) 1290 else: 1291 #Term that refers to boundary data 1292 loss_boundary = MSE(forward(x['boundary'],params['net']),x['uboundary']) 1293 if x['initial'] is not None: 1294 #Term that refers to initial data 1295 loss_initial = MSE(forward(x['initial'],params['net']),x['uinitial']) 1296 if x['collocation'] is not None and sigma == 0: 1297 if inverse: 1298 output = pde(lambda x: forward(x,params['net']),x['collocation'],params['inverse']) 1299 loss_res = MSE(output,0) 1300 else: 1301 output = pde(lambda x: forward(x,params['net']),x['collocation']) 1302 loss_res = MSE(output,0) 1303 if sigma > 0: 1304 #Term that refers to weak loss 1305 if inverse: 1306 output_w = pde(lambda x: forward(x,params['net']),grid,params['inverse']) 1307 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions) 1308 loss_res_weak = jnp.mean(integralOmega ** 2) 1309 else: 1310 output_w = pde(lambda x: forward(x,params['net']),grid) 1311 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions) 1312 loss_res_weak = jnp.mean(integralOmega ** 2) 1313 return {'ls': loss_sensor,'lb': loss_boundary,'li': loss_initial,'lc': loss_res,'lc_weak': loss_res_weak} 1314 1315 @jax.jit 1316 def lf(params,x,k,tf,grid): 1317 l = lf_each(params,x,k,tf,grid) 1318 if opt != 'LBFGS': 1319 loss = jnp.mean((params['w']['ws'] ** q)*l['ls']) + jnp.mean((params['w']['wb'] ** q)*l['lb']) + jnp.mean((params['w']['wi'] ** q)*l['li']) + jnp.mean((params['w']['wc'] ** q)*l['lc']) + (params['w']['wc_weak'] ** q)*l['lc_weak'] 1320 return loss 1321 else: 1322 loss = jnp.mean((w['ws'] ** q)*l['ls']) + jnp.mean((w['wb'] ** q)*l['lb']) + jnp.mean((w['wi'] ** q)*l['li']) + jnp.mean((w['wc'] ** q)*l['lc']) + (w['wc_weak'] ** q)*l['lc_weak'] 1323 l2 = None 1324 if test_data is not None: 1325 l2 = L2error(forward(test_data['sensor'],params['net']),test_data['usensor']) 1326 return loss,{'loss': loss,'l2': l2} 1327 1328 #Initialize self-adaptive weights 1329 if float64: 1330 typ = jnp.float64 1331 else: 1332 typ = jnp.float32 1333 if w is None: 1334 w = {'ws': jnp.array(1.0),'wb': jnp.array(1.0),'wi': jnp.array(1.0),'wc': jnp.array(1.0),'wc_weak': jnp.array(1.0)} 1335 if q != 0 and opt != 'LBFGS': 1336 if data['sensor'] is not None: 1337 w['ws'] = w['ws'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+1),(data['sensor'].shape[0],1),dtype = typ) 1338 if data['boundary'] is not None: 1339 w['wb'] = w['wb'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+2),(data['boundary'].shape[0],1),dtype = typ) 1340 if data['initial'] is not None: 1341 w['wi'] = w['wi'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+3),(data['initial'].shape[0],1),dtype = typ) 1342 if data['collocation'] is not None: 1343 w['wc'] = w['wc'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+4),(data['collocation'].shape[0],1),dtype = typ) 1344 1345 #Store all parameters 1346 if opt != 'LBFGS': 1347 params = {'net': nnet['params'],'inverse': initial_par,'w': w} 1348 else: 1349 params = {'net': nnet['params'],'inverse': initial_par} 1350 if float64: 1351 params = to_float64(params) 1352 1353 #Save config file 1354 if save: 1355 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1356 1357 #Initialize Adam Optmizer 1358 if opt != 'LBFGS': 1359 print('--------- GRADIENT DESCENT OPTIMIZER ---------') 1360 if exp_decay: 1361 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 1362 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 1363 opt_state = optimizer.init(params) 1364 1365 #Define the gradient function 1366 grad_loss = jax.jit(jax.grad(lf,0)) 1367 1368 #Define update function 1369 @jax.jit 1370 def update(opt_state,params,x,k,tf,grid): 1371 #Compute gradient 1372 grads = grad_loss(params,x,k,tf,grid) 1373 #Calculate parameters updates 1374 updates, opt_state = optimizer.update(grads, opt_state) 1375 #Update parameters 1376 if q != 0: 1377 updates = {**updates, 'w': jax.tree_util.tree_map(lambda x: -x, updates['w'])} #Change signs of weights 1378 params = optax.apply_updates(params, updates) 1379 #Return state of optmizer and updated parameters 1380 return opt_state,params 1381 else: 1382 print('--------- LBFGS OPTIMIZER ---------') 1383 @jax.jit 1384 def loss_LBFGS(params): 1385 return lf(params,data,key + 234,tf,grid) 1386 solver = LBFGS(fun = loss_LBFGS,has_aux = True,maxiter = epochs,tol = 1e-9,verbose = False,linesearch = 'zoom',history_size = 200) # linesearch='zoom' by default 1387 state = solver.init_state(params) 1388 if float64: 1389 assert_lbfgs_state_float64(state) 1390 1391 ###Training### 1392 t0 = time.time() 1393 k = jax.random.split(jax.random.PRNGKey(key+234),(epochs,)) 1394 sloss = [] 1395 sL2 = [] 1396 stime = [] 1397 #Initialize alive_bar for tracing in terminal 1398 with alive_bar(epochs) as bar: 1399 #For each epoch 1400 for e in range(epochs): 1401 if opt != 'LBFGS': 1402 if float64 and e < 10: 1403 assert_tree_float64(params, name="params (before update)") 1404 grads = grad_loss(params, data, k[e,:], tf, grid) 1405 check_grads_float64(grads) 1406 #Update optimizer state and parameters 1407 opt_state,params = update(opt_state,params,data,k[e,:],tf,grid) 1408 sloss.append(lf(params,data,k[e,:],tf,grid)) 1409 if test_data is not None: 1410 sL2.append(L2error(forward(test_data['sensor'],params['net']),test_data['usensor'])) 1411 if float64 and e < 10: 1412 assert_tree_float64(params, name="params (before update)") 1413 grads = grad_loss(params, data, k[e,:], tf, grid) 1414 check_grads_float64(grads) 1415 else: 1416 if float64 and e < 10: 1417 assert_tree_float64(params, name="params (before LBFGS step)") 1418 params, state = solver.update(params, state) 1419 if float64 and e < 10: 1420 assert_tree_float64(params, name="params (after LBFGS step)") 1421 assert_lbfgs_state_float64(state) 1422 sL2.append(state.aux["l2"]) 1423 sloss.append(state.aux["loss"]) 1424 if float64 and e < 10: 1425 assert_tree_float64(params, name="params (before update)") 1426 if restart is not None: 1427 if (e + 1) % restart == 0: 1428 state = solver.init_state(params) 1429 stime.append(time.time() - t0) 1430 #After epoch_print epochs 1431 if e % epoch_print == 0: 1432 #Compute elapsed time and current error 1433 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(sloss[-1],6)) 1434 #If there is test data, compute current L2 error 1435 if test_data is not None: 1436 #Compute L2 error 1437 l = l + ' L2 error: ' + str(jnp.round(sL2[-1],6)) 1438 if inverse: 1439 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 1440 #Print 1441 print(l) 1442 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 1443 #Save current parameters 1444 pickle.dump({'params': params,'width': width,'time': stime,'loss': sloss,'L2error': sL2},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1445 #Update alive_bar 1446 bar() 1447 #Define estimated function 1448 def u(xt): 1449 return forward(xt,params['net']) 1450 1451 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0,'loss_each': lf_each(params,data,[key + 100],tf,grid),'loss': sloss,'L2error': sL2} 1452 1453 1454#Process result 1455def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1): 1456 """ 1457 Process the results of a Physics-informed Neural Network 1458 ---------- 1459 1460 Parameters 1461 ---------- 1462 test_data : dict 1463 1464 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1465 1466 fit : function 1467 1468 The fitted function 1469 1470 train_data : dict 1471 1472 Training data generated by the jinnax.data.generate_PINNdata 1473 1474 plot : logical 1475 1476 Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True 1477 1478 plot_test : logical 1479 1480 Whether to plot the test data. Default True 1481 1482 times : int 1483 1484 Number of points along the time interval to plot. Default 5 1485 1486 d2 : logical 1487 1488 Whether to plot 2D plot when the spatial dimension is one. Default True 1489 1490 save : logical 1491 1492 Whether to save the plots. Default False 1493 1494 show : logical 1495 1496 Whether to show the plots. Default True 1497 1498 file_name : str 1499 1500 File prefix to save the plots. Default 'result_pinn' 1501 1502 print_res : logical 1503 1504 Whether to print the L2 error. Default True 1505 1506 p : int 1507 1508 Output dimension. Default 1 1509 1510 Returns 1511 ------- 1512 pandas data frame with L2 and MSE errors 1513 """ 1514 1515 #Dimension 1516 d = test_data['xt'].shape[1] - 1 1517 1518 #Number of plots multiple of 5 1519 times = 5 * round(times/5) 1520 1521 #Data 1522 td = get_train_data(train_data) 1523 xt_train = td['x'] 1524 u_train = td['y'] 1525 upred_train = fit(xt_train) 1526 upred_test = fit(test_data['xt']) 1527 1528 #Results 1529 l2_error_test = L2error(upred_test,test_data['u']).tolist() 1530 MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist() 1531 l2_error_train = L2error(upred_train,u_train).tolist() 1532 MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist() 1533 1534 df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)), 1535 columns=['l2_error_test','MSE_test','l2_error_train','MSE_train']) 1536 if print_res: 1537 print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) ) 1538 1539 #Plots 1540 if d == 1 and p ==1 and plot: 1541 plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name) 1542 elif p == 2 and plot: 1543 plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test) 1544 1545 return df 1546 1547#Plot results for d = 1 1548def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''): 1549 """ 1550 Plot the prediction of a 1D PINN 1551 ---------- 1552 1553 Parameters 1554 ---------- 1555 times : int 1556 1557 Number of points along the time interval to plot. Default 5 1558 1559 xt : jax.numpy.array 1560 1561 Test data xt array 1562 1563 u : jax.numpy.array 1564 1565 Test data u(x,t) array 1566 1567 upred : jax.numpy.array 1568 1569 Predicted upred(x,t) array on test data 1570 1571 d2 : logical 1572 1573 Whether to plot 2D plot. Default True 1574 1575 save : logical 1576 1577 Whether to save the plots. Default False 1578 1579 show : logical 1580 1581 Whether to show the plots. Default True 1582 1583 file_name : str 1584 1585 File prefix to save the plots. Default 'result_pinn' 1586 1587 title_1d : str 1588 1589 Title of 1D plot 1590 1591 title_2d : str 1592 1593 Title of 2D plot 1594 1595 Returns 1596 ------- 1597 None 1598 """ 1599 #Initialize 1600 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 1601 tlo = jnp.min(xt[:,-1]) 1602 tup = jnp.max(xt[:,-1]) 1603 ylo = jnp.min(u) 1604 ylo = ylo - 0.1*jnp.abs(ylo) 1605 yup = jnp.max(u) 1606 yup = yup + 0.1*jnp.abs(yup) 1607 k = 0 1608 t_values = np.linspace(tlo,tup,times) 1609 1610 #Create 1611 for i in range(int(times/5)): 1612 for j in range(5): 1613 if k < len(t_values): 1614 t = t_values[k] 1615 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 1616 x_plot = xt[xt[:,-1] == t,:-1] 1617 y_plot = upred[xt[:,-1] == t,:] 1618 u_plot = u[xt[:,-1] == t,:] 1619 if int(times/5) > 1: 1620 ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 1621 ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 1622 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 1623 ax[i,j].set_xlabel(' ') 1624 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1625 else: 1626 ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 1627 ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 1628 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 1629 ax[j].set_xlabel(' ') 1630 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1631 k = k + 1 1632 1633 #Title 1634 fig.suptitle(title_1d) 1635 fig.tight_layout() 1636 1637 #Show and save 1638 fig = plt.gcf() 1639 if show: 1640 plt.show() 1641 if save: 1642 fig.savefig(file_name + '_slices.png') 1643 plt.close() 1644 1645 #2d plot 1646 if d2: 1647 #Initialize 1648 fig, ax = plt.subplots(1,2) 1649 l1 = jnp.unique(xt[:,-1]).shape[0] 1650 l2 = jnp.unique(xt[:,0]).shape[0] 1651 1652 #Create 1653 ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 1654 ax[0].set_title('Exact') 1655 ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 1656 ax[1].set_title('Predicted') 1657 1658 #Title 1659 fig.suptitle(title_2d) 1660 fig.tight_layout() 1661 1662 #Show and save 1663 fig = plt.gcf() 1664 if show: 1665 plt.show() 1666 if save: 1667 fig.savefig(file_name + '_2d.png') 1668 plt.close() 1669 1670#Plot results for d = 1 1671def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True): 1672 """ 1673 Plot the prediction of a PINN with 2D output 1674 ---------- 1675 Parameters 1676 ---------- 1677 times : int 1678 1679 Number of points along the time interval to plot. Default 5 1680 1681 xt : jax.numpy.array 1682 1683 Test data xt array 1684 1685 u : jax.numpy.array 1686 1687 Test data u(x,t) array 1688 1689 upred : jax.numpy.array 1690 1691 Predicted upred(x,t) array on test data 1692 1693 save : logical 1694 1695 Whether to save the plots. Default False 1696 1697 show : logical 1698 1699 Whether to show the plots. Default True 1700 1701 file_name : str 1702 1703 File prefix to save the plots. Default 'result_pinn' 1704 1705 title : str 1706 1707 Title of plot 1708 1709 plot_test : logical 1710 1711 Whether to plot the test data. Default True 1712 1713 Returns 1714 ------- 1715 None 1716 """ 1717 #Initialize 1718 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 1719 tlo = jnp.min(xt[:,-1]) 1720 tup = jnp.max(xt[:,-1]) 1721 xlo = jnp.min(u[:,0]) 1722 xlo = xlo - 0.1*jnp.abs(xlo) 1723 xup = jnp.max(u[:,0]) 1724 xup = xup + 0.1*jnp.abs(xup) 1725 ylo = jnp.min(u[:,1]) 1726 ylo = ylo - 0.1*jnp.abs(ylo) 1727 yup = jnp.max(u[:,1]) 1728 yup = yup + 0.1*jnp.abs(yup) 1729 k = 0 1730 t_values = np.linspace(tlo,tup,times) 1731 1732 #Create 1733 for i in range(int(times/5)): 1734 for j in range(5): 1735 if k < len(t_values): 1736 t = t_values[k] 1737 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 1738 xpred_plot = upred[xt[:,-1] == t,0] 1739 ypred_plot = upred[xt[:,-1] == t,1] 1740 if plot_test: 1741 x_plot = u[xt[:,-1] == t,0] 1742 y_plot = u[xt[:,-1] == t,1] 1743 if int(times/5) > 1: 1744 if plot_test: 1745 ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 1746 ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction') 1747 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 1748 ax[i,j].set_xlabel(' ') 1749 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1750 else: 1751 if plot_test: 1752 ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 1753 ax[j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction') 1754 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 1755 ax[j].set_xlabel(' ') 1756 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1757 k = k + 1 1758 1759 #Title 1760 fig.suptitle(title) 1761 fig.tight_layout() 1762 1763 #Show and save 1764 fig = plt.gcf() 1765 if show: 1766 plt.show() 1767 if save: 1768 fig.savefig(file_name + '_slices.png') 1769 plt.close() 1770 1771#Get train data in one array 1772def get_train_data(train_data): 1773 """ 1774 Process training sample 1775 ---------- 1776 1777 Parameters 1778 ---------- 1779 train_data : dict 1780 1781 A dictionay with train data generated by the jinnax.data.generate_PINNdata function 1782 1783 Returns 1784 ------- 1785 dict with the processed training data 1786 """ 1787 xdata = None 1788 ydata = None 1789 xydata = None 1790 if train_data['sensor'] is not None: 1791 sensor_sample = train_data['sensor'].shape[0] 1792 xdata = train_data['sensor'] 1793 ydata = train_data['usensor'] 1794 xydata = jnp.column_stack((train_data['sensor'],train_data['usensor'])) 1795 else: 1796 sensor_sample = 0 1797 if train_data['boundary'] is not None: 1798 boundary_sample = train_data['boundary'].shape[0] 1799 if xdata is not None: 1800 xdata = jnp.vstack((xdata,train_data['boundary'])) 1801 ydata = jnp.vstack((ydata,train_data['uboundary'])) 1802 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary'])))) 1803 else: 1804 xdata = train_data['boundary'] 1805 ydata = train_data['uboundary'] 1806 xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary'])) 1807 else: 1808 boundary_sample = 0 1809 if train_data['initial'] is not None: 1810 initial_sample = train_data['initial'].shape[0] 1811 if xdata is not None: 1812 xdata = jnp.vstack((xdata,train_data['initial'])) 1813 ydata = jnp.vstack((ydata,train_data['uinitial'])) 1814 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial'])))) 1815 else: 1816 xdata = train_data['initial'] 1817 ydata = train_data['uinitial'] 1818 xydata = jnp.column_stack((train_data['initial'],train_data['uinitial'])) 1819 else: 1820 initial_sample = 0 1821 if train_data['collocation'] is not None: 1822 collocation_sample = train_data['collocation'].shape[0] 1823 else: 1824 collocation_sample = 0 1825 1826 return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample} 1827 1828#Process training 1829def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1): 1830 """ 1831 Process the training of a Physics-informed Neural Network 1832 ---------- 1833 1834 Parameters 1835 ---------- 1836 test_data : dict 1837 1838 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1839 1840 file_name : str 1841 1842 Name of the files saved during training 1843 1844 at_each : int 1845 1846 Compute results for epochs multiple of at_each. Default 100 1847 1848 bolstering : logical 1849 1850 Whether to compute bolstering mean square error. Default True 1851 1852 mc_sample : int 1853 1854 Number of sample for Monte Carlo integration in bolstering. Default 10000 1855 1856 save : logical 1857 1858 Whether to save the training results. Default False 1859 1860 file_name_save : str 1861 1862 File prefix to save the plots and the L2 error. Default 'result_pinn' 1863 1864 key : int 1865 1866 Key for random samples in bolstering. Default 0 1867 1868 ec : float 1869 1870 Stopping criteria error for EM algorithm in bolstering. Default 1e-6 1871 1872 lamb : float 1873 1874 Hyperparameter of EM algorithm in bolstering. Default 1 1875 1876 Returns 1877 ------- 1878 pandas data frame with training results 1879 """ 1880 #Config 1881 config = pickle.load(open(file_name + '_config.pickle', 'rb')) 1882 epochs = config['epochs'] 1883 train_data = config['train_data'] 1884 forward = config['forward'] 1885 1886 #Get train data 1887 td = get_train_data(train_data) 1888 xydata = td['xy'] 1889 xdata = td['x'] 1890 ydata = td['y'] 1891 sensor_sample = td['sensor_sample'] 1892 boundary_sample = td['boundary_sample'] 1893 initial_sample = td['initial_sample'] 1894 collocation_sample = td['collocation_sample'] 1895 1896 #Generate keys 1897 if bolstering: 1898 keys = jax.random.split(jax.random.PRNGKey(key),epochs) 1899 1900 #Initialize loss 1901 train_mse = [] 1902 test_mse = [] 1903 train_L2 = [] 1904 test_L2 = [] 1905 bolstX = [] 1906 bolstXY = [] 1907 loss = [] 1908 time = [] 1909 ep = [] 1910 1911 #Process training 1912 with alive_bar(epochs) as bar: 1913 for e in range(epochs): 1914 if (e % at_each == 0 and at_each != epochs) or e == epochs - 1: 1915 ep = ep + [e] 1916 1917 #Read parameters 1918 params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb')) 1919 1920 #Time 1921 time = time + [params['time']] 1922 1923 #Define learned function 1924 def psi(x): 1925 return forward(x,params['params']['net']) 1926 1927 #Train MSE and L2 1928 if xdata is not None: 1929 train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()] 1930 train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()] 1931 else: 1932 train_mse = train_mse + [None] 1933 train_L2 = train_L2 + [None] 1934 1935 #Test MSE and L2 1936 test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()] 1937 test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()] 1938 1939 #Bolstering 1940 if bolstering: 1941 bX = [] 1942 bXY = [] 1943 for method in ['chi','mm','mpe']: 1944 kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 1945 kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 1946 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 1947 bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()] 1948 for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]: 1949 kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias) 1950 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 1951 bolstX = bolstX + [bX] 1952 bolstXY = bolstXY + [bXY] 1953 else: 1954 bolstX = bolstX + [None] 1955 bolstXY = bolstXY + [None] 1956 1957 #Loss 1958 loss = loss + [params['loss'].tolist()] 1959 1960 #Delete 1961 del params, psi 1962 #Update alive_bar 1963 bar() 1964 1965 #Bolstering results 1966 if bolstering: 1967 bolstX = jnp.array(bolstX) 1968 bolstXY = jnp.array(bolstXY) 1969 1970 #Create data frame 1971 if bolstering: 1972 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 1973 train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]), 1974 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4']) 1975 else: 1976 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 1977 train_mse,test_mse,train_L2,test_L2]), 1978 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2']) 1979 if save: 1980 df.to_csv(file_name_save + '.csv',index = False) 1981 1982 return df 1983 1984#Demo video for training1D PINN 1985def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10): 1986 """ 1987 Demo video with the training of a 1D PINN 1988 ---------- 1989 1990 Parameters 1991 ---------- 1992 test_data : dict 1993 1994 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1995 1996 file_name : str 1997 1998 Name of the files saved during training 1999 2000 at_each : int 2001 2002 Compute results for epochs multiple of at_each. Default 100 2003 2004 times : int 2005 2006 Number of points along the time interval to plot. Default 5 2007 2008 d2 : logical 2009 2010 Whether to make video demo of 2D plot. Default True 2011 2012 file_name_save : str 2013 2014 File prefix to save the plots and videos. Default 'result_pinn_demo' 2015 2016 title : str 2017 2018 Title for plots 2019 2020 framerate : int 2021 2022 Framerate for video. Default 10 2023 2024 Returns 2025 ------- 2026 None 2027 """ 2028 #Config 2029 with open(file_name + '_config.pickle', 'rb') as file: 2030 config = pickle.load(file) 2031 epochs = config['epochs'] 2032 train_data = config['train_data'] 2033 forward = config['forward'] 2034 2035 #Get train data 2036 td = get_train_data(train_data) 2037 xt = td['x'] 2038 u = td['y'] 2039 2040 #Create folder to save plots 2041 os.system('mkdir ' + file_name_save) 2042 2043 #Create images 2044 k = 1 2045 with alive_bar(epochs) as bar: 2046 for e in range(epochs): 2047 if e % at_each == 0 or e == epochs - 1: 2048 #Read parameters 2049 params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 2050 2051 #Define learned function 2052 def psi(x): 2053 return forward(x,params['params']['net']) 2054 2055 #Compute L2 train, L2 test and loss 2056 loss = params['loss'] 2057 L2_train = L2error(psi(xt),u) 2058 L2_test = L2error(psi(test_data['xt']),test_data['u']) 2059 title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6)) 2060 2061 #Save plot 2062 plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch) 2063 k = k + 1 2064 2065 #Delete 2066 del params, psi, loss, L2_train, L2_test, title_epoch 2067 #Update alive_bar 2068 bar() 2069 #Create demo video 2070 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4') 2071 if d2: 2072 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4') 2073 2074#Demo in time for 1D PINN 2075def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10): 2076 """ 2077 Demo video with the time evolution of a 1D PINN 2078 ---------- 2079 2080 Parameters 2081 ---------- 2082 test_data : dict 2083 2084 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 2085 2086 file_name : str 2087 2088 Name of the files saved during training 2089 2090 epochs : list 2091 2092 Which training epochs to plot 2093 2094 file_name_save : str 2095 2096 File prefix to save the plots and video. Default 'result_pinn_time_demo' 2097 2098 title : str 2099 2100 Title for plots 2101 2102 framerate : int 2103 2104 Framerate for video. Default 10 2105 2106 Returns 2107 ------- 2108 None 2109 """ 2110 #Config 2111 with open(file_name + '_config.pickle', 'rb') as file: 2112 config = pickle.load(file) 2113 train_data = config['train_data'] 2114 forward = config['forward'] 2115 2116 #Create folder to save plots 2117 os.system('mkdir ' + file_name_save) 2118 2119 #Plot parameters 2120 tdom = jnp.unique(test_data['xt'][:,-1]) 2121 ylo = jnp.min(test_data['u']) 2122 ylo = ylo - 0.1*jnp.abs(ylo) 2123 yup = jnp.max(test_data['u']) 2124 yup = yup + 0.1*jnp.abs(yup) 2125 2126 #Open PINN for each epoch 2127 results = [] 2128 upred = [] 2129 for e in epochs: 2130 tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 2131 results = results + [tmp] 2132 upred = upred + [forward(test_data['xt'],tmp['params']['net'])] 2133 2134 #Create images 2135 k = 1 2136 with alive_bar(len(tdom)) as bar: 2137 for t in tdom: 2138 #Test data 2139 xt_step = test_data['xt'][test_data['xt'][:,-1] == t] 2140 u_step = test_data['u'][test_data['xt'][:,-1] == t] 2141 #Initialize plot 2142 if len(epochs) > 1: 2143 fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2)) 2144 else: 2145 fig, ax = plt.subplots(1,1,figsize = (10,5)) 2146 #Create 2147 index = 0 2148 if int(len(epochs)/2) > 1: 2149 for i in range(int(len(epochs)/2)): 2150 for j in range(min(2,len(epochs))): 2151 upred_step = upred[index][test_data['xt'][:,-1] == t] 2152 ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2153 ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2154 ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2155 ax[i,j].set_xlabel(' ') 2156 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2157 index = index + 1 2158 elif len(epochs) > 1: 2159 for j in range(2): 2160 upred_step = upred[index][test_data['xt'][:,-1] == t] 2161 ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2162 ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2163 ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2164 ax[j].set_xlabel(' ') 2165 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2166 index = index + 1 2167 else: 2168 upred_step = upred[index][test_data['xt'][:,-1] == t] 2169 ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2170 ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2171 ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2172 ax.set_xlabel(' ') 2173 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2174 index = index + 1 2175 2176 2177 #Title 2178 fig.suptitle(title + 't = ' + str(round(t,4))) 2179 fig.tight_layout() 2180 2181 #Show and save 2182 fig = plt.gcf() 2183 fig.savefig(file_name_save + '/' + str(k) + '.png') 2184 k = k + 1 2185 plt.close() 2186 bar() 2187 2188 #Create demo video 2189 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
29def assert_tree_float64(tree, name="object"): 30 """ 31 Raises an error if any leaf in a pytree is not float64. 32 """ 33 leaves, _ = tree_flatten(tree) 34 bad = [ 35 (i, x.dtype) 36 for i, x in enumerate(leaves) 37 if hasattr(x, "dtype") and x.dtype != jnp.float64 38 ] 39 if bad: 40 raise RuntimeError( 41 f"[FLOAT64 CHECK FAILED] {name} contains non-float64 dtypes:\n" 42 + "\n".join([f" leaf {i}: {dtype}" for i, dtype in bad]) 43 )
Raises an error if any leaf in a pytree is not float64.
46def warn_tree_float64(tree, name="object"): 47 """ 48 Same as assert_tree_float64, but prints a warning instead of raising. 49 """ 50 leaves, _ = tree_flatten(tree) 51 dtypes = {x.dtype for x in leaves if hasattr(x, "dtype")} 52 if dtypes != {jnp.float64}: 53 print(f"[WARNING] {name} dtypes = {dtypes}")
Same as assert_tree_float64, but prints a warning instead of raising.
56def check_grads_float64(grads): 57 leaves, _ = tree_flatten(grads) 58 bad = [ 59 (i, x.dtype) 60 for i, x in enumerate(leaves) 61 if hasattr(x, "dtype") and x.dtype != jnp.float64 62 ] 63 if bad: 64 raise RuntimeError( 65 "[FLOAT64 CHECK FAILED] Gradients are not float64:\n" 66 + "\n".join([f" grad {i}: {dtype}" for i, dtype in bad]) 67 )
70def assert_lbfgs_state_float64(state): 71 """ 72 Ensures that all *floating-point* values inside the LBFGS state 73 are float64. Integers and booleans are allowed. 74 """ 75 leaves, _ = tree_flatten(state) 76 bad = [] 77 for i, x in enumerate(leaves): 78 if isinstance(x, jnp.ndarray): 79 if jnp.issubdtype(x.dtype, jnp.floating): 80 if x.dtype != jnp.float64: 81 bad.append((i, x.dtype)) 82 if bad: 83 raise RuntimeError( 84 "[FLOAT64 CHECK FAILED] LBFGS floating-point buffers are not float64:\n" 85 + "\n".join([f" leaf {i}: {dtype}" for i, dtype in bad]) 86 )
Ensures that all floating-point values inside the LBFGS state are float64. Integers and booleans are allowed.
99@jax.jit 100def MSE(pred,true): 101 """ 102 Squared error 103 ---------- 104 Parameters 105 ---------- 106 pred : jax.numpy.array 107 108 A JAX numpy array with the predicted values 109 110 true : jax.numpy.array 111 112 A JAX numpy array with the true values 113 114 Returns 115 ------- 116 squared error 117 """ 118 return (true - pred) ** 2
Squared error
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
Returns
- squared error
121@jax.jit 122def MSE_SA(pred,true,w,q = 2): 123 """ 124 Self-adaptative squared error 125 ---------- 126 Parameters 127 ---------- 128 pred : jax.numpy.array 129 130 A JAX numpy array with the predicted values 131 132 true : jax.numpy.array 133 134 A JAX numpy array with the true values 135 136 weight : jax.numpy.array 137 138 A JAX numpy array with the weights 139 140 q : float 141 142 Power for the weights mask 143 144 Returns 145 ------- 146 self-adaptative squared error with polynomial mask 147 """ 148 return (w ** q) * ((true - pred) ** 2)
Self-adaptative squared error
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
- weight (jax.numpy.array): A JAX numpy array with the weights
- q (float): Power for the weights mask
Returns
- self-adaptative squared error with polynomial mask
151@jax.jit 152def L2error(pred,true): 153 """ 154 L2-error in percentage (%) 155 ---------- 156 Parameters 157 ---------- 158 pred : jax.numpy.array 159 160 A JAX numpy array with the predicted values 161 162 true : jax.numpy.array 163 164 A JAX numpy array with the true values 165 166 Returns 167 ------- 168 L2-error 169 """ 170 return 100*jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))
L2-error in percentage (%)
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
Returns
- L2-error
173def idst1(x,axis = -1): 174 """ 175 Inverse Discrete Sine Transform of type I with orthonormal scaling 176 ---------- 177 Parameters 178 ---------- 179 x : jax.numpy.array 180 181 Array to apply the transformation 182 183 axis : int 184 185 Axis to apply the transformation over 186 187 Returns 188 ------- 189 jax.numpy.array 190 """ 191 return idst(x,type = 1,axis = axis,norm = 'ortho')
Inverse Discrete Sine Transform of type I with orthonormal scaling
Parameters
- x (jax.numpy.array): Array to apply the transformation
- axis (int): Axis to apply the transformation over
Returns
- jax.numpy.array
193def dstn(x,axes = None): 194 """ 195 Discrete Sine Transform of type I with orthonormal scaling over many axes 196 ---------- 197 Parameters 198 ---------- 199 x : jax.numpy.array 200 201 Array to apply the transformation 202 203 axes : int 204 205 Axes to apply the transformation over 206 207 Returns 208 ------- 209 jax.numpy.array 210 """ 211 if axes is None: 212 axes = tuple(range(x.ndim)) 213 y = x 214 for ax in axes: 215 y = dst(y,type = 1,axis = ax,norm = 'ortho') 216 return y
Discrete Sine Transform of type I with orthonormal scaling over many axes
Parameters
- x (jax.numpy.array): Array to apply the transformation
- axes (int): Axes to apply the transformation over
Returns
- jax.numpy.array
218def idstn(x,axes = None): 219 """ 220 Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes 221 ---------- 222 Parameters 223 ---------- 224 x : jax.numpy.array 225 226 Array to apply the transformation 227 228 axes : tuple 229 230 Axes to apply the transformation over 231 232 Returns 233 ------- 234 jax.numpy.array 235 """ 236 if axes is None: 237 axes = tuple(range(x.ndim)) 238 y = x 239 for ax in axes: 240 y = idst1(y,axis = ax) 241 return y
Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes
Parameters
- x (jax.numpy.array): Array to apply the transformation
- axes (tuple): Axes to apply the transformation over
Returns
- jax.numpy.array
243def dirichlet_eigs_nd(n,L): 244 """ 245 Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle 246 ---------- 247 Parameters 248 ---------- 249 n : list 250 251 List with the number of points in the grid in each dimension 252 253 L : list 254 255 List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero 256 257 Returns 258 ------- 259 jax.numpy.array 260 """ 261 #Unidimensional eigenvalues 262 lam_axes = [] 263 for ni, Li in zip(n,L): 264 h = Li / (ni + 1) 265 k = jnp.arange(1,ni + 1) 266 ln = (2 / (h*h)) * (1 - jnp.cos(jnp.pi * k / (ni + 1))) 267 lam_axes.append(ln) 268 grids = jnp.meshgrid(*lam_axes, indexing='ij') 269 Lam = jnp.zeros_like(grids[0]) 270 for g in grids: 271 Lam += g 272 return Lam
Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle
Parameters
- n (list): List with the number of points in the grid in each dimension
- L (list): List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero
Returns
- jax.numpy.array
276def generate_matern_sample(key,d = 2,N = 128,L = 1.0,kappa = 1,alpha = 1,sigma = 1,periodic = False): 277 """ 278 Sample d-dimensional Matern process 279 ---------- 280 Parameters 281 ---------- 282 key : int 283 284 Seed for randomization 285 286 d : int 287 288 Dimension. Default 2 289 290 N : int 291 292 Size of grid in each dimension. Default 128 293 294 L : list of float 295 296 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 297 298 kappa,alpha,sigma : float 299 300 Parameters of the Matern process 301 302 periodic : logical 303 304 Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT 305 306 Returns 307 ------- 308 jax.numpy.array 309 """ 310 if periodic: 311 #Shape and key 312 key = jax.random.PRNGKey(key) 313 shape = (N,) * d 314 if isinstance(L,float) or isinstance(L,int): 315 L = d*[L] 316 if isinstance(N,float) or isinstance(N,int): 317 N = d*[N] 318 319 #Setup Frequency Grid (2D) 320 freq = [jnp.fft.fftfreq(N[j],d = L[j]/N[j]) * 2 * jnp.pi for j in range(d)] 321 grids = jnp.meshgrid(*freq, indexing='ij') 322 sq_norm_xi = sum(g**2 for g in grids) 323 324 #Generate White Noise in Fourier Space 325 key_re, key_im = jax.random.split(key) 326 white_noise_f = (jax.random.normal(key_re, shape) + 327 1j * jax.random.normal(key_im, shape)) 328 329 #Apply the Whittle Filter 330 amplitude_filter = (kappa ** 2 + sq_norm_xi) ** (-alpha / 2) 331 field_f = white_noise_f * amplitude_filter 332 333 #Transform back to Physical Space 334 sample = jnp.real(jnp.fft.ifftn(field_f)) 335 return sigma*sample 336 else: #NOT JAX 337 #Shape and key 338 rng = np.random.default_rng(seed = key) 339 if isinstance(L,float) or isinstance(L,int): 340 L = d*[L] 341 if isinstance(N,float) or isinstance(N,int): 342 N = d*[N] 343 shape = tuple(N) 344 345 #White noise in real space 346 W = rng.standard_normal(size = shape) 347 348 #To Dirichlet eigenbasis via separable DST-I (orthonormal) 349 W_hat = dstn(W) 350 351 #Discrete Dirichlet Laplacian eigenvalues 352 lam = dirichlet_eigs_nd(N, L) 353 354 #Spectral filter 355 filt = ((kappa + lam) ** (-alpha/2)) 356 psi_hat = filt * W_hat 357 358 #Back to real space 359 psi = idstn(psi_hat) 360 return jnp.array(sigma*psi)
Sample d-dimensional Matern process
Parameters
- key (int): Seed for randomization
- d (int): Dimension. Default 2
- N (int): Size of grid in each dimension. Default 128
- L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
- kappa,alpha,sigma (float): Parameters of the Matern process
- periodic (logical): Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
Returns
- jax.numpy.array
363def generate_matern_sample_batch(d = 2,N = 512,L = 1.0,kappa = 10.0,alpha = 1,sigma = 10,periodic = False): 364 """ 365 Create function to sample d-dimensional Matern process 366 ---------- 367 Parameters 368 ---------- 369 d : int 370 371 Dimension. Default 2 372 373 N : int 374 375 Size of grid in each dimension. Default 128 376 377 L : list of float 378 379 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 380 381 kappa,alpha,sigma : float 382 383 Parameters of the Matern process 384 385 periodic : logical 386 387 Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT 388 389 Returns 390 ------- 391 function 392 """ 393 if periodic: 394 return jax.vmap(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)) 395 else: 396 return lambda keys: jnp.array(np.apply_along_axis(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic),1,keys.reshape((keys.shape[0],1))))
Create function to sample d-dimensional Matern process
Parameters
- d (int): Dimension. Default 2
- N (int): Size of grid in each dimension. Default 128
- L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
- kappa,alpha,sigma (float): Parameters of the Matern process
- periodic (logical): Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
Returns
- function
399def eigenf_laplace(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None): 400 """ 401 Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. 402 ---------- 403 Parameters 404 ---------- 405 L_vec : list of float 406 407 The domain of the function in each coordinate is [0,L[1]] 408 409 kmax_per_axis : list 410 411 List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions 412 413 bc : str 414 415 Boundary condition. 'dirichlet' or 'neumann' 416 417 max_ef : int 418 419 Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions 420 421 Returns 422 ------- 423 function to compute eigenfunctions,eigenvalues of the eigenfunctions considered 424 """ 425 #Parameters 426 L_vec = jnp.asarray(L_vec) 427 d = L_vec.shape[0] 428 bc = bc.lower() 429 430 #Maximum number of functions 431 if max_ef is None: 432 if d == 1: 433 max_ef = jnp.max(jnp.array(kmax_per_axis)) 434 else: 435 max_ef = jnp.max(d * jnp.array(kmax_per_axis)) 436 437 #Build the candidate multi-indices per axis 438 kmax_per_axis = list(map(int, kmax_per_axis)) 439 if bc.startswith("d"): 440 axis_ranges = [range(1, km + 1) for km in kmax_per_axis] 441 elif bc.startswith("n"): 442 axis_ranges = [range(0, km + 1) for km in kmax_per_axis] 443 444 #Get all multi-indices 445 Ks_list = list(product(*axis_ranges)) 446 Ks = jnp.array(Ks_list) 447 448 #Eigenvalues of the continuous Laplacian 449 pi_over_L = jnp.pi / L_vec 450 lambdas_all = jnp.sum((Ks * pi_over_L) ** 2, axis=1) 451 452 #Sort by eigenvalue 453 order = jnp.argsort(lambdas_all) 454 Ks = Ks[order] 455 lambdas_all = lambdas_all[order] 456 457 #Keep first max_ef 458 Ks = Ks[:max_ef] 459 lambdas = lambdas_all[:max_ef] 460 m = Ks.shape[0] 461 462 #Precompute per-feature normalization factor (closed form) 463 def per_axis_norm_factor(k_i, L_i, is_dirichlet): 464 if is_dirichlet: 465 return jnp.sqrt(2 / L_i) 466 else: 467 return jnp.where(k_i == 0, jnp.sqrt(1 / L_i), jnp.sqrt(2 / L_i)) 468 if bc.startswith("d"): 469 nf = jnp.prod(jnp.sqrt(2 / L_vec)[None, :],axis = 1) 470 norm_factors = jnp.ones((m,)) * nf 471 else: 472 # per-mode product across axes 473 def nf_row(k_row): 474 return jnp.prod(per_axis_norm_factor(k_row, L_vec, False)) 475 norm_factors = jax.vmap(nf_row)(Ks) 476 477 #Build the callable function 478 Ks_int = Ks # float array, but only integer values 479 L_vec_f = L_vec 480 @jax.jit 481 def phi(x): 482 x = jnp.asarray(x) 483 #Initialize with ones 484 vals = jnp.ones(x.shape[:-1] + (m,)) 485 #Compute eigenfunction 486 for i in range(d): 487 ang = (jnp.pi / L_vec_f[i]) * x[..., i][..., None] * Ks_int[:, i] 488 if bc.startswith("d"): 489 comp = jnp.sin(ang) 490 else: 491 comp = jnp.cos(ang) 492 vals = vals * comp 493 #Apply L2-normalizing constants 494 vals = vals * norm_factors[None, ...] if vals.ndim > 1 else vals * norm_factors 495 return vals 496 return phi, lambdas
Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace.
Parameters
- L_vec (list of float): The domain of the function in each coordinate is [0,L[1]]
- kmax_per_axis (list): List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions
- bc (str): Boundary condition. 'dirichlet' or 'neumann'
- max_ef (int): Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
Returns
- function to compute eigenfunctions,eigenvalues of the eigenfunctions considered
499def multiple_daff(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None): 500 """ 501 Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain. 502 ---------- 503 Parameters 504 ---------- 505 L_vec : list of lists of float 506 507 List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]] 508 509 kmax_per_axis : list 510 511 List with the maximum number of eigenfunctions per dimension. 512 513 bc : str 514 515 Boundary condition. 'dirichlet' or 'neumann' 516 517 max_ef : int 518 519 Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions 520 521 Returns 522 ------- 523 function to compute daff,eigenvalues of the eigenfunctions considered 524 """ 525 psi = [] 526 lamb = [] 527 for L in L_vec: 528 tmp,l = eigenf_laplace(L,kmax_per_axis,bc,max_ef) #Get function 529 lamb.append(l) 530 psi.append(tmp) 531 del tmp 532 #Create function to compute features 533 @jax.jit 534 def mff(x): 535 y = [] 536 for i in range(len(psi)): 537 y.append(psi[i](x)) 538 if len(psi) == 1: 539 return y[0] 540 else: 541 return jnp.concatenate(y,1) 542 return mff,jnp.concatenate(lamb)
Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain.
Parameters
- L_vec (list of lists of float): List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]]
- kmax_per_axis (list): List with the maximum number of eigenfunctions per dimension.
- bc (str): Boundary condition. 'dirichlet' or 'neumann'
- max_ef (int): Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
Returns
- function to compute daff,eigenvalues of the eigenfunctions considered
570@partial(jax.jit,static_argnums=(2,)) # n is static here; compile once per n 571def multiple_cheb_fast(x, L_vec, n: int): 572 """ 573 x: (N, d) 574 L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension 575 n: number of k terms (static) 576 returns: (N, L*n) 577 """ 578 N, d = x.shape 579 L = L_vec.shape[0] 580 581 a = 0 582 b = L_vec # (L, d) 583 # Map x to t in [-1, 1] for each l, j: shape (L, N, d) 584 t = (2 * x[None, :, :] - (a + b)[:, None, :]) / (b - a)[:, None, :] 585 586 # Chebyshev T_0..T_{n+2} for all (L, N, d): shape (n+3, L, N, d) 587 T = _chebyshev_T_all(t, n + 2) 588 589 # phi_k = T_{k+2} - T_k, k = 0..n-1 => shape (n, L, N, d) 590 ks = jnp.arange(n) 591 phi = T[ks + 2, ...] - T[ks, ...] 592 593 # Multiply across dimensions (over the last axis = d) => (n, L, N) 594 z = jnp.prod(phi, axis=-1) 595 596 # Reorder to (N, L, n) then flatten to (N, L*n) 597 z = jnp.transpose(z, (2, 1, 0)).reshape(N, L * n) 598 return z
x: (N, d) L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension n: number of k terms (static) returns: (N, L*n)
600def multiple_cheb(L_vec, n: int): 601 """ 602 Factory that closes over static n and L_vec (so shapes are constant). 603 """ 604 L_vec = jnp.asarray(L_vec) 605 @jax.jit # optional; multiple_cheb_fast is already jitted 606 def mcheb(x): 607 x = jnp.asarray(x) 608 return multiple_cheb_fast(x, L_vec, n) 609 return mcheb
Factory that closes over static n and L_vec (so shapes are constant).
613def fconNN(width,activation = jax.nn.tanh,key = 0,mlp = False,ftype = None,fargs = None,static = None,daff = None): 614 """ 615 Initialize fully connected neural network 616 ---------- 617 Parameters 618 ---------- 619 width : list 620 621 List with the layers width 622 623 activation : jax.nn activation 624 625 The activation function. Default jax.nn.tanh 626 627 key : int 628 629 Seed for parameters initialization. Default 0 630 631 mlp : logical 632 633 Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension. 634 635 ftype : str 636 637 Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'. 638 639 fargs : list 640 641 Arguments for deature transformation: 642 643 For 'ff': A list with the number of frequences and value of greatest frequence standard deviation. 644 645 For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet. 646 647 static : function 648 649 A static function to sum to the neural network output. 650 651 daff : list 652 653 List with function to compute daff and the number of daff. If None computes assuming rectangular domain. 654 655 Returns 656 ------- 657 dict with initial parameters and the function for the forward pass 658 """ 659 #Initialize parameters with Glorot initialization 660 initializer = jax.nn.initializers.glorot_normal() 661 params = list() 662 if static is None: 663 static = lambda x: 0 664 665 #Feature mapping 666 if ftype == 'ff': #Fourrier features 667 for s in range(fargs[0]): 668 sd = fargs[1] ** ((s + 1)/fargs[0]) 669 if s == 0: 670 Bff = sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))) 671 else: 672 Bff = jnp.append(Bff,sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))),1) 673 @jax.jit 674 def phi(x): 675 x = x @ Bff 676 return jnp.concatenate([jnp.sin(2 * jnp.pi * x),jnp.cos(2 * jnp.pi * x)],axis = -1) 677 width = width[1:] 678 width[0] = 2*Bff.shape[1] 679 elif ftype == 'daff' or ftype == 'daff_bias': 680 if not isinstance(fargs, dict): 681 fargs = {'L': fargs,'bc': "dirichlet"} 682 if daff is None: 683 phi,lamb = multiple_daff(list(fargs.values())[0],kmax_per_axis = [width[1]] * width[0],bc = list(fargs.values())[1]) 684 width = width[1:] 685 width[0] = lamb.shape[0] 686 else: 687 phi = daff[0] 688 width = width[1:] 689 width[0] = daff[1] 690 elif ftype == 'cheb' or ftype == 'cheb_bias': 691 phi = multiple_cheb(fargs,n = width[1]) 692 width = width[1:] 693 width[0] = len(fargs)*width[0] 694 else: 695 @jax.jit 696 def phi(x): 697 return x 698 699 #Initialize parameters 700 if mlp: 701 k = jax.random.split(jax.random.PRNGKey(key),4) 702 WU = initializer(k[0],(width[0],width[1])) 703 BU = initializer(k[1],(1,width[1])) 704 WV = initializer(k[2],(width[0],width[1])) 705 BV = initializer(k[3],(1,width[1])) 706 params.append({'WU':WU,'BU':BU,'WV':WV,'BV':BV}) 707 key = jax.random.split(jax.random.PRNGKey(key + 1),len(width)-1) #Seed for initialization 708 for key,lin,lout in zip(key,width[:-1],width[1:]): 709 W = initializer(key,(lin,lout)) 710 B = initializer(key,(1,lout)) 711 params.append({'W':W,'B':B}) 712 713 #Define function for forward pass 714 if mlp: 715 if ftype != 'daff' and ftype != 'cheb': 716 @jax.jit 717 def forward(x,params): 718 encode,*hidden,output = params 719 sx = static(x) 720 x = phi(x) 721 U = activation(x @ encode['WU'] + encode['BU']) 722 V = activation(x @ encode['WV'] + encode['BV']) 723 for layer in hidden: 724 x = activation(x @ layer['W'] + layer['B']) 725 x = x * U + (1 - x) * V 726 return x @ output['W'] + output['B'] + sx 727 else: 728 @jax.jit 729 def forward(x,params): 730 encode,*hidden,output = params 731 sx = static(x) 732 x = phi(x) 733 U = activation(x @ encode['WU']) 734 V = activation(x @ encode['WV']) 735 for layer in hidden: 736 x = activation(x @ layer['W']) 737 x = x * U + (1 - x) * V 738 return x @ output['W'] + sx 739 else: 740 if ftype != 'daff' and ftype != 'cheb': 741 @jax.jit 742 def forward(x,params): 743 *hidden,output = params 744 sx = static(x) 745 x = phi(x) 746 for layer in hidden: 747 x = activation(x @ layer['W'] + layer['B']) 748 return x @ output['W'] + output['B'] + sx 749 else: 750 @jax.jit 751 def forward(x,params): 752 *hidden,output = params 753 sx = static(x) 754 x = phi(x) 755 for layer in hidden: 756 x = activation(x @ layer['W']) 757 return x @ output['W'] + sx 758 759 #Return initial parameters and forward function 760 return {'params': params,'forward': forward}
Initialize fully connected neural network
Parameters
- width (list): List with the layers width
- activation (jax.nn activation): The activation function. Default jax.nn.tanh
- key (int): Seed for parameters initialization. Default 0
- mlp (logical): Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension.
- ftype (str): Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
fargs (list): Arguments for deature transformation:
For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
- static (function): A static function to sum to the neural network output.
daff (list):
List with function to compute daff and the number of daff. If None computes assuming rectangular domain.
Returns
- dict with initial parameters and the function for the forward pass
763def get_activation(act): 764 """ 765 Return activation function from string 766 ---------- 767 Parameters 768 ---------- 769 act : str 770 771 Name of the activation function. Default 'tanh' 772 773 Returns 774 ------- 775 jax.nn activation function 776 """ 777 if act == 'tanh': 778 return jax.nn.tanh 779 elif act == 'relu': 780 return jax.nn.relu 781 elif act == 'relu6': 782 return jax.nn.relu6 783 elif act == 'sigmoid': 784 return jax.nn.sigmoid 785 elif act == 'softplus': 786 return jax.nn.softplus 787 elif act == 'sparse_plus': 788 return jax.nn.sparse_plus 789 elif act == 'soft_sign': 790 return jax.nn.soft_sign 791 elif act == 'silu': 792 return jax.nn.silu 793 elif act == 'swish': 794 return jax.nn.swish 795 elif act == 'log_sigmoid': 796 return jax.nn.log_sigmoid 797 elif act == 'leaky_relu': 798 return jax.nn.leaky_relu 799 elif act == 'hard_sigmoid': 800 return jax.nn.hard_sigmoid 801 elif act == 'hard_silu': 802 return jax.nn.hard_silu 803 elif act == 'hard_swish': 804 return jax.nn.hard_swish 805 elif act == 'hard_tanh': 806 return jax.nn.hard_tanh 807 elif act == 'elu': 808 return jax.nn.elu 809 elif act == 'celu': 810 return jax.nn.celu 811 elif act == 'selu': 812 return jax.nn.selu 813 elif act == 'gelu': 814 return jax.nn.gelu 815 elif act == 'glu': 816 return jax.nn.glu 817 elif act == 'squareplus': 818 return jax.nn.squareplus 819 elif act == 'mish': 820 return jax.nn.mish
Return activation function from string
Parameters
- act (str): Name of the activation function. Default 'tanh'
Returns
- jax.nn activation function
823def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,mlp = False): 824 """ 825 Train a Physics-informed Neural Network 826 ---------- 827 Parameters 828 ---------- 829 data : dict 830 831 Data generated by the jinnax.data.generate_PINNdata function 832 833 width : list 834 835 A list with the width of each layer 836 837 pde : function 838 839 The partial differential operator. Its arguments are u, x and t 840 841 test_data : dict, None 842 843 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 844 845 epochs : int 846 847 Number of training epochs. Default 100 848 849 at_each : int 850 851 Save results for epochs multiple of at_each. Default 10 852 853 activation : str 854 855 The name of the activation function of the neural network. Default 'tanh' 856 857 neumann : logical 858 859 Whether to consider Neumann boundary conditions 860 861 oper_neumann : function 862 863 Penalization of Neumann boundary conditions 864 865 sa : logical 866 867 Whether to consider self-adaptative PINN 868 869 c : dict 870 871 Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1 872 873 inverse : logical 874 875 Whether to estimate parameters of the PDE 876 877 initial_par : jax.numpy.array 878 879 Initial value of the parameters of the PDE in an inverse problem 880 881 lr,b1,b2,eps,eps_root: float 882 883 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 884 885 key : int 886 887 Seed for parameters initialization. Default 0 888 889 epoch_print : int 890 891 Number of epochs to calculate and print test errors. Default 100 892 893 save : logical 894 895 Whether to save the current parameters. Default False 896 897 file_name : str 898 899 File prefix to save the current parameters. Default 'result_pinn' 900 901 exp_decay : logical 902 903 Whether to consider exponential decay of learning rate. Default False 904 905 transition_steps : int 906 907 Number of steps for exponential decay. Default 1000 908 909 decay_rate : float 910 911 Rate of exponential decay. Default 0.9 912 913 mlp : logical 914 915 Whether to consider modifed multi-layer perceptron 916 917 Returns 918 ------- 919 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time 920 """ 921 922 #Initialize architecture 923 nnet = fconNN(width,get_activation(activation),key,mlp) 924 forward = nnet['forward'] 925 926 #Initialize self adaptative weights 927 par_sa = {} 928 if sa: 929 #Initialize wheights close to zero 930 ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000) 931 if data['sensor'] is not None: 932 par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))}) 933 if data['initial'] is not None: 934 par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))}) 935 if data['collocation'] is not None: 936 par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))}) 937 if data['boundary'] is not None: 938 par_sa.update({'wb': c['wb'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))}) 939 940 #Store all parameters 941 params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa} 942 943 #Save config file 944 if save: 945 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 946 947 #Define loss function 948 if sa: 949 #Define loss function 950 @jax.jit 951 def lf(params,x): 952 loss = 0 953 if x['sensor'] is not None: 954 #Term that refers to sensor data 955 loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws'])) 956 if x['boundary'] is not None: 957 if neumann: 958 #Neumann coditions 959 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 960 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 961 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb'])) 962 else: 963 #Term that refers to boundary data 964 loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb'])) 965 if x['initial'] is not None: 966 #Term that refers to initial data 967 loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0'])) 968 if x['collocation'] is not None: 969 #Term that refers to collocation points 970 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 971 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 972 if inverse: 973 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr'])) 974 else: 975 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr'])) 976 return loss 977 else: 978 @jax.jit 979 def lf(params,x): 980 loss = 0 981 if x['sensor'] is not None: 982 #Term that refers to sensor data 983 loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor'])) 984 if x['boundary'] is not None: 985 if neumann: 986 #Neumann coditions 987 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 988 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 989 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb)) 990 else: 991 #Term that refers to boundary data 992 loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary'])) 993 if x['initial'] is not None: 994 #Term that refers to initial data 995 loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial'])) 996 if x['collocation'] is not None: 997 #Term that refers to collocation points 998 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 999 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 1000 if inverse: 1001 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0)) 1002 else: 1003 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0)) 1004 return loss 1005 1006 #Initialize Adam Optmizer 1007 if exp_decay: 1008 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 1009 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 1010 opt_state = optimizer.init(params) 1011 1012 #Define the gradient function 1013 grad_loss = jax.jit(jax.grad(lf,0)) 1014 1015 #Define update function 1016 @jax.jit 1017 def update(opt_state,params,x): 1018 #Compute gradient 1019 grads = grad_loss(params,x) 1020 #Invert gradient of self-adaptative wheights 1021 if sa: 1022 for w in grads['sa']: 1023 grads['sa'][w] = - grads['sa'][w] 1024 #Calculate parameters updates 1025 updates, opt_state = optimizer.update(grads, opt_state) 1026 #Update parameters 1027 params = optax.apply_updates(params, updates) 1028 #Return state of optmizer and updated parameters 1029 return opt_state,params 1030 1031 ###Training### 1032 t0 = time.time() 1033 #Initialize alive_bar for tracing in terminal 1034 with alive_bar(epochs) as bar: 1035 #For each epoch 1036 for e in range(epochs): 1037 #Update optimizer state and parameters 1038 opt_state,params = update(opt_state,params,data) 1039 #After epoch_print epochs 1040 if e % epoch_print == 0: 1041 #Compute elapsed time and current error 1042 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6)) 1043 #If there is test data, compute current L2 error 1044 if test_data is not None: 1045 #Compute L2 error 1046 l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist() 1047 l = l + ' L2 error: ' + str(jnp.round(l2_test,3)) 1048 if inverse: 1049 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 1050 #Print 1051 print(l) 1052 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 1053 #Save current parameters 1054 pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1055 #Update alive_bar 1056 bar() 1057 #Define estimated function 1058 def u(xt): 1059 return forward(xt,params['net']) 1060 1061 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}
Train a Physics-informed Neural Network
Parameters
- data (dict): Data generated by the jinnax.data.generate_PINNdata function
- width (list): A list with the width of each layer
- pde (function): The partial differential operator. Its arguments are u, x and t
- test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
- epochs (int): Number of training epochs. Default 100
- at_each (int): Save results for epochs multiple of at_each. Default 10
- activation (str): The name of the activation function of the neural network. Default 'tanh'
- neumann (logical): Whether to consider Neumann boundary conditions
- oper_neumann (function): Penalization of Neumann boundary conditions
- sa (logical): Whether to consider self-adaptative PINN
- c (dict): Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
- inverse (logical): Whether to estimate parameters of the PDE
- initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
- lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
- key (int): Seed for parameters initialization. Default 0
- epoch_print (int): Number of epochs to calculate and print test errors. Default 100
- save (logical): Whether to save the current parameters. Default False
- file_name (str): File prefix to save the current parameters. Default 'result_pinn'
- exp_decay (logical): Whether to consider exponential decay of learning rate. Default False
- transition_steps (int): Number of steps for exponential decay. Default 1000
- decay_rate (float): Rate of exponential decay. Default 0.9
- mlp (logical): Whether to consider modifed multi-layer perceptron
Returns
- dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
1064def train_Matern_PINN(data,width,pde,test_data = None,params = None,d = 2,N = 128,L = 1,alpha = 1,kappa = 1,sigma = 100,bsize = 1024,resample = False,epochs = 100,at_each = 10,activation = 'tanh', 1065 neumann = False,oper_neumann = None,inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 1,save = False,file_name = 'result_pinn', 1066 exp_decay = True,transition_steps = 100,decay_rate = 0.9,mlp = True,ftype = None,fargs = None,q = 4,w = None,periodic = False,static = None,opt = 'LBFGS',float64 = False,restart = None): 1067 """ 1068 Train a Physics-informed Neural Network 1069 ---------- 1070 Parameters 1071 ---------- 1072 data : dict 1073 1074 Data generated by the jinnax.data.generate_PINNdata function 1075 1076 width : list 1077 1078 A list with the width of each layer 1079 1080 pde : function 1081 1082 The partial differential operator. Its arguments are u, x and t 1083 1084 test_data : dict, None 1085 1086 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 1087 1088 params : list 1089 1090 Initial parameters for the neural network. Default None to initialize randomly 1091 1092 d : int 1093 1094 Dimension of the problem including the time variable if present. Default 2 1095 1096 N : int 1097 1098 Size of grid in each dimension. Default 128 1099 1100 L : list of float 1101 1102 The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1 1103 1104 kappa,alpha,sigma : float 1105 1106 Parameters of the Matern process 1107 1108 bsize : int 1109 1110 Batch size for weak norm computation. Default 1024 1111 1112 resample : logical 1113 1114 Whether to resample the test functions at each epoch 1115 1116 epochs : int 1117 1118 Number of training epochs. Default 100 1119 1120 at_each : int 1121 1122 Save results for epochs multiple of at_each. Default 10 1123 1124 activation : str 1125 1126 The name of the activation function of the neural network. Default 'tanh' 1127 1128 neumann : logical 1129 1130 Whether to consider Neumann boundary conditions 1131 1132 oper_neumann : function 1133 1134 Penalization of Neumann boundary conditions 1135 1136 inverse : logical 1137 1138 Whether to estimate parameters of the PDE 1139 1140 initial_par : jax.numpy.array 1141 1142 Initial value of the parameters of the PDE in an inverse problem 1143 1144 lr,b1,b2,eps,eps_root: float 1145 1146 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 1147 1148 key : int 1149 1150 Seed for parameters initialization. Default 0 1151 1152 epoch_print : int 1153 1154 Number of epochs to calculate and print test errors. Default 1 1155 1156 save : logical 1157 1158 Whether to save the current parameters. Default False 1159 1160 file_name : str 1161 1162 File prefix to save the current parameters. Default 'result_pinn' 1163 1164 exp_decay : logical 1165 1166 Whether to consider exponential decay of learning rate. Default True 1167 1168 transition_steps : int 1169 1170 Number of steps for exponential decay. Default 100 1171 1172 decay_rate : float 1173 1174 Rate of exponential decay. Default 0.9 1175 1176 mlp : logical 1177 1178 Whether to consider modifed multilayer perceptron 1179 1180 ftype : str 1181 1182 Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'. 1183 1184 fargs : list 1185 1186 Arguments for deature transformation: 1187 1188 For 'ff': A list with the number of frequences and value of greatest frequence standard deviation. 1189 1190 For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet. 1191 1192 q : int 1193 1194 Power of weights mask. Default 4 1195 1196 w : dict 1197 1198 Initila weights for self-adaptive scheme. 1199 1200 periodic : logical 1201 1202 Whether to consider periodic test functions. Default False. 1203 1204 static : function 1205 1206 A static function to sum to the neural network output. 1207 1208 opt : str 1209 1210 Optimizer. Default LBFGS. 1211 1212 float64 : logical 1213 1214 Whether to train with float64 1215 1216 restart : int 1217 1218 Epochs to restart L-BFGS 1219 1220 Returns 1221 ------- 1222 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch 1223 """ 1224 #Initialize architecture 1225 nnet = fconNN(width,get_activation(activation),key,mlp,ftype,fargs,static) 1226 if float64: 1227 forward = lambda x,params: nnet['forward'](x,params).astype(jnp.float64) 1228 assert jax.config.jax_enable_x64, "JAX is NOT running in float64 mode!" 1229 else: 1230 forward = nnet['forward'] 1231 if params is not None: 1232 nnet['params'] = params 1233 if float64: 1234 data = to_float64(data) 1235 if test_data is not None: 1236 test_data = to_float64(test_data) 1237 1238 #Generate from Matern process 1239 if sigma > 0: 1240 if isinstance(L,float) or isinstance(L,int): 1241 L = d*[L] 1242 #Grid for weak norm 1243 if float64: 1244 grid = [jnp.linspace(0,L[i],N,dtype = jnp.float64) for i in range(d)] 1245 else: 1246 grid = [jnp.linspace(0,L[i],N) for i in range(d)] 1247 grid = jnp.meshgrid(*grid, indexing='ij') 1248 grid = jnp.stack(grid, axis=-1).reshape((-1, d)) 1249 #Set sigma 1250 if data['boundary'] is not None: 1251 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma) 1252 tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0]) 1253 if neumann: 1254 loss_boundary = oper_neumann(lambda x: forward(x,nnet['params']),data['boundary']) 1255 else: 1256 loss_boundary = jnp.mean(MSE(forward(data['boundary'],nnet['params']),data['uboundary'])) 1257 output_w = pde(lambda x: forward(x,nnet['params']),grid) 1258 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(tf) 1259 loss_res_weak = jnp.mean(integralOmega ** 2) 1260 sigma = float(jnp.sqrt(loss_boundary/loss_res_weak).tolist()) 1261 del gen 1262 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic) 1263 tf = sigma*tf 1264 else: 1265 gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic) 1266 tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0]) 1267 if float64 and tf is not None: 1268 tf = to_float64(tf) 1269 grid = to_float64(grid) 1270 else: 1271 tf = None 1272 grid = None 1273 1274 #Define loss function 1275 @jax.jit 1276 def lf_each(params,x,k,tf,grid): 1277 if sigma > 0: 1278 #Term that refers to weak loss 1279 if resample: 1280 test_functions = to_float64(gen(jax.random.split(jax.random.PRNGKey(k[0]),(bsize,))[:,0])) 1281 else: 1282 test_functions = tf 1283 loss_sensor = loss_boundary = loss_initial = loss_res = loss_res_weak = 0 1284 if x['sensor'] is not None: 1285 #Term that refers to sensor data 1286 loss_sensor = MSE(forward(x['sensor'],params['net']),x['usensor']) 1287 if x['boundary'] is not None: 1288 if neumann: 1289 #Neumann coditions 1290 loss_boundary = oper_neumann(lambda x: forward(x,nnet['params']),x['boundary']) 1291 else: 1292 #Term that refers to boundary data 1293 loss_boundary = MSE(forward(x['boundary'],params['net']),x['uboundary']) 1294 if x['initial'] is not None: 1295 #Term that refers to initial data 1296 loss_initial = MSE(forward(x['initial'],params['net']),x['uinitial']) 1297 if x['collocation'] is not None and sigma == 0: 1298 if inverse: 1299 output = pde(lambda x: forward(x,params['net']),x['collocation'],params['inverse']) 1300 loss_res = MSE(output,0) 1301 else: 1302 output = pde(lambda x: forward(x,params['net']),x['collocation']) 1303 loss_res = MSE(output,0) 1304 if sigma > 0: 1305 #Term that refers to weak loss 1306 if inverse: 1307 output_w = pde(lambda x: forward(x,params['net']),grid,params['inverse']) 1308 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions) 1309 loss_res_weak = jnp.mean(integralOmega ** 2) 1310 else: 1311 output_w = pde(lambda x: forward(x,params['net']),grid) 1312 integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions) 1313 loss_res_weak = jnp.mean(integralOmega ** 2) 1314 return {'ls': loss_sensor,'lb': loss_boundary,'li': loss_initial,'lc': loss_res,'lc_weak': loss_res_weak} 1315 1316 @jax.jit 1317 def lf(params,x,k,tf,grid): 1318 l = lf_each(params,x,k,tf,grid) 1319 if opt != 'LBFGS': 1320 loss = jnp.mean((params['w']['ws'] ** q)*l['ls']) + jnp.mean((params['w']['wb'] ** q)*l['lb']) + jnp.mean((params['w']['wi'] ** q)*l['li']) + jnp.mean((params['w']['wc'] ** q)*l['lc']) + (params['w']['wc_weak'] ** q)*l['lc_weak'] 1321 return loss 1322 else: 1323 loss = jnp.mean((w['ws'] ** q)*l['ls']) + jnp.mean((w['wb'] ** q)*l['lb']) + jnp.mean((w['wi'] ** q)*l['li']) + jnp.mean((w['wc'] ** q)*l['lc']) + (w['wc_weak'] ** q)*l['lc_weak'] 1324 l2 = None 1325 if test_data is not None: 1326 l2 = L2error(forward(test_data['sensor'],params['net']),test_data['usensor']) 1327 return loss,{'loss': loss,'l2': l2} 1328 1329 #Initialize self-adaptive weights 1330 if float64: 1331 typ = jnp.float64 1332 else: 1333 typ = jnp.float32 1334 if w is None: 1335 w = {'ws': jnp.array(1.0),'wb': jnp.array(1.0),'wi': jnp.array(1.0),'wc': jnp.array(1.0),'wc_weak': jnp.array(1.0)} 1336 if q != 0 and opt != 'LBFGS': 1337 if data['sensor'] is not None: 1338 w['ws'] = w['ws'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+1),(data['sensor'].shape[0],1),dtype = typ) 1339 if data['boundary'] is not None: 1340 w['wb'] = w['wb'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+2),(data['boundary'].shape[0],1),dtype = typ) 1341 if data['initial'] is not None: 1342 w['wi'] = w['wi'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+3),(data['initial'].shape[0],1),dtype = typ) 1343 if data['collocation'] is not None: 1344 w['wc'] = w['wc'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+4),(data['collocation'].shape[0],1),dtype = typ) 1345 1346 #Store all parameters 1347 if opt != 'LBFGS': 1348 params = {'net': nnet['params'],'inverse': initial_par,'w': w} 1349 else: 1350 params = {'net': nnet['params'],'inverse': initial_par} 1351 if float64: 1352 params = to_float64(params) 1353 1354 #Save config file 1355 if save: 1356 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1357 1358 #Initialize Adam Optmizer 1359 if opt != 'LBFGS': 1360 print('--------- GRADIENT DESCENT OPTIMIZER ---------') 1361 if exp_decay: 1362 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 1363 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 1364 opt_state = optimizer.init(params) 1365 1366 #Define the gradient function 1367 grad_loss = jax.jit(jax.grad(lf,0)) 1368 1369 #Define update function 1370 @jax.jit 1371 def update(opt_state,params,x,k,tf,grid): 1372 #Compute gradient 1373 grads = grad_loss(params,x,k,tf,grid) 1374 #Calculate parameters updates 1375 updates, opt_state = optimizer.update(grads, opt_state) 1376 #Update parameters 1377 if q != 0: 1378 updates = {**updates, 'w': jax.tree_util.tree_map(lambda x: -x, updates['w'])} #Change signs of weights 1379 params = optax.apply_updates(params, updates) 1380 #Return state of optmizer and updated parameters 1381 return opt_state,params 1382 else: 1383 print('--------- LBFGS OPTIMIZER ---------') 1384 @jax.jit 1385 def loss_LBFGS(params): 1386 return lf(params,data,key + 234,tf,grid) 1387 solver = LBFGS(fun = loss_LBFGS,has_aux = True,maxiter = epochs,tol = 1e-9,verbose = False,linesearch = 'zoom',history_size = 200) # linesearch='zoom' by default 1388 state = solver.init_state(params) 1389 if float64: 1390 assert_lbfgs_state_float64(state) 1391 1392 ###Training### 1393 t0 = time.time() 1394 k = jax.random.split(jax.random.PRNGKey(key+234),(epochs,)) 1395 sloss = [] 1396 sL2 = [] 1397 stime = [] 1398 #Initialize alive_bar for tracing in terminal 1399 with alive_bar(epochs) as bar: 1400 #For each epoch 1401 for e in range(epochs): 1402 if opt != 'LBFGS': 1403 if float64 and e < 10: 1404 assert_tree_float64(params, name="params (before update)") 1405 grads = grad_loss(params, data, k[e,:], tf, grid) 1406 check_grads_float64(grads) 1407 #Update optimizer state and parameters 1408 opt_state,params = update(opt_state,params,data,k[e,:],tf,grid) 1409 sloss.append(lf(params,data,k[e,:],tf,grid)) 1410 if test_data is not None: 1411 sL2.append(L2error(forward(test_data['sensor'],params['net']),test_data['usensor'])) 1412 if float64 and e < 10: 1413 assert_tree_float64(params, name="params (before update)") 1414 grads = grad_loss(params, data, k[e,:], tf, grid) 1415 check_grads_float64(grads) 1416 else: 1417 if float64 and e < 10: 1418 assert_tree_float64(params, name="params (before LBFGS step)") 1419 params, state = solver.update(params, state) 1420 if float64 and e < 10: 1421 assert_tree_float64(params, name="params (after LBFGS step)") 1422 assert_lbfgs_state_float64(state) 1423 sL2.append(state.aux["l2"]) 1424 sloss.append(state.aux["loss"]) 1425 if float64 and e < 10: 1426 assert_tree_float64(params, name="params (before update)") 1427 if restart is not None: 1428 if (e + 1) % restart == 0: 1429 state = solver.init_state(params) 1430 stime.append(time.time() - t0) 1431 #After epoch_print epochs 1432 if e % epoch_print == 0: 1433 #Compute elapsed time and current error 1434 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(sloss[-1],6)) 1435 #If there is test data, compute current L2 error 1436 if test_data is not None: 1437 #Compute L2 error 1438 l = l + ' L2 error: ' + str(jnp.round(sL2[-1],6)) 1439 if inverse: 1440 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 1441 #Print 1442 print(l) 1443 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 1444 #Save current parameters 1445 pickle.dump({'params': params,'width': width,'time': stime,'loss': sloss,'L2error': sL2},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 1446 #Update alive_bar 1447 bar() 1448 #Define estimated function 1449 def u(xt): 1450 return forward(xt,params['net']) 1451 1452 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0,'loss_each': lf_each(params,data,[key + 100],tf,grid),'loss': sloss,'L2error': sL2}
Train a Physics-informed Neural Network
Parameters
- data (dict): Data generated by the jinnax.data.generate_PINNdata function
- width (list): A list with the width of each layer
- pde (function): The partial differential operator. Its arguments are u, x and t
- test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
- params (list): Initial parameters for the neural network. Default None to initialize randomly
- d (int): Dimension of the problem including the time variable if present. Default 2
- N (int): Size of grid in each dimension. Default 128
- L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
- kappa,alpha,sigma (float): Parameters of the Matern process
- bsize (int): Batch size for weak norm computation. Default 1024
- resample (logical): Whether to resample the test functions at each epoch
- epochs (int): Number of training epochs. Default 100
- at_each (int): Save results for epochs multiple of at_each. Default 10
- activation (str): The name of the activation function of the neural network. Default 'tanh'
- neumann (logical): Whether to consider Neumann boundary conditions
- oper_neumann (function): Penalization of Neumann boundary conditions
- inverse (logical): Whether to estimate parameters of the PDE
- initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
- lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
- key (int): Seed for parameters initialization. Default 0
- epoch_print (int): Number of epochs to calculate and print test errors. Default 1
- save (logical): Whether to save the current parameters. Default False
- file_name (str): File prefix to save the current parameters. Default 'result_pinn'
- exp_decay (logical): Whether to consider exponential decay of learning rate. Default True
- transition_steps (int): Number of steps for exponential decay. Default 100
- decay_rate (float): Rate of exponential decay. Default 0.9
- mlp (logical): Whether to consider modifed multilayer perceptron
- ftype (str): Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
fargs (list): Arguments for deature transformation:
For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
- q (int): Power of weights mask. Default 4
- w (dict): Initila weights for self-adaptive scheme.
- periodic (logical): Whether to consider periodic test functions. Default False.
- static (function): A static function to sum to the neural network output.
- opt (str): Optimizer. Default LBFGS.
- float64 (logical): Whether to train with float64
- restart (int): Epochs to restart L-BFGS
Returns
- dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch
1456def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1): 1457 """ 1458 Process the results of a Physics-informed Neural Network 1459 ---------- 1460 1461 Parameters 1462 ---------- 1463 test_data : dict 1464 1465 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1466 1467 fit : function 1468 1469 The fitted function 1470 1471 train_data : dict 1472 1473 Training data generated by the jinnax.data.generate_PINNdata 1474 1475 plot : logical 1476 1477 Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True 1478 1479 plot_test : logical 1480 1481 Whether to plot the test data. Default True 1482 1483 times : int 1484 1485 Number of points along the time interval to plot. Default 5 1486 1487 d2 : logical 1488 1489 Whether to plot 2D plot when the spatial dimension is one. Default True 1490 1491 save : logical 1492 1493 Whether to save the plots. Default False 1494 1495 show : logical 1496 1497 Whether to show the plots. Default True 1498 1499 file_name : str 1500 1501 File prefix to save the plots. Default 'result_pinn' 1502 1503 print_res : logical 1504 1505 Whether to print the L2 error. Default True 1506 1507 p : int 1508 1509 Output dimension. Default 1 1510 1511 Returns 1512 ------- 1513 pandas data frame with L2 and MSE errors 1514 """ 1515 1516 #Dimension 1517 d = test_data['xt'].shape[1] - 1 1518 1519 #Number of plots multiple of 5 1520 times = 5 * round(times/5) 1521 1522 #Data 1523 td = get_train_data(train_data) 1524 xt_train = td['x'] 1525 u_train = td['y'] 1526 upred_train = fit(xt_train) 1527 upred_test = fit(test_data['xt']) 1528 1529 #Results 1530 l2_error_test = L2error(upred_test,test_data['u']).tolist() 1531 MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist() 1532 l2_error_train = L2error(upred_train,u_train).tolist() 1533 MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist() 1534 1535 df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)), 1536 columns=['l2_error_test','MSE_test','l2_error_train','MSE_train']) 1537 if print_res: 1538 print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) ) 1539 1540 #Plots 1541 if d == 1 and p ==1 and plot: 1542 plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name) 1543 elif p == 2 and plot: 1544 plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test) 1545 1546 return df
Process the results of a Physics-informed Neural Network
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- fit (function): The fitted function
- train_data (dict): Training data generated by the jinnax.data.generate_PINNdata
- plot (logical): Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
- plot_test (logical): Whether to plot the test data. Default True
- times (int): Number of points along the time interval to plot. Default 5
- d2 (logical): Whether to plot 2D plot when the spatial dimension is one. Default True
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- print_res (logical): Whether to print the L2 error. Default True
- p (int): Output dimension. Default 1
Returns
- pandas data frame with L2 and MSE errors
1549def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''): 1550 """ 1551 Plot the prediction of a 1D PINN 1552 ---------- 1553 1554 Parameters 1555 ---------- 1556 times : int 1557 1558 Number of points along the time interval to plot. Default 5 1559 1560 xt : jax.numpy.array 1561 1562 Test data xt array 1563 1564 u : jax.numpy.array 1565 1566 Test data u(x,t) array 1567 1568 upred : jax.numpy.array 1569 1570 Predicted upred(x,t) array on test data 1571 1572 d2 : logical 1573 1574 Whether to plot 2D plot. Default True 1575 1576 save : logical 1577 1578 Whether to save the plots. Default False 1579 1580 show : logical 1581 1582 Whether to show the plots. Default True 1583 1584 file_name : str 1585 1586 File prefix to save the plots. Default 'result_pinn' 1587 1588 title_1d : str 1589 1590 Title of 1D plot 1591 1592 title_2d : str 1593 1594 Title of 2D plot 1595 1596 Returns 1597 ------- 1598 None 1599 """ 1600 #Initialize 1601 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 1602 tlo = jnp.min(xt[:,-1]) 1603 tup = jnp.max(xt[:,-1]) 1604 ylo = jnp.min(u) 1605 ylo = ylo - 0.1*jnp.abs(ylo) 1606 yup = jnp.max(u) 1607 yup = yup + 0.1*jnp.abs(yup) 1608 k = 0 1609 t_values = np.linspace(tlo,tup,times) 1610 1611 #Create 1612 for i in range(int(times/5)): 1613 for j in range(5): 1614 if k < len(t_values): 1615 t = t_values[k] 1616 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 1617 x_plot = xt[xt[:,-1] == t,:-1] 1618 y_plot = upred[xt[:,-1] == t,:] 1619 u_plot = u[xt[:,-1] == t,:] 1620 if int(times/5) > 1: 1621 ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 1622 ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 1623 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 1624 ax[i,j].set_xlabel(' ') 1625 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1626 else: 1627 ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 1628 ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 1629 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 1630 ax[j].set_xlabel(' ') 1631 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1632 k = k + 1 1633 1634 #Title 1635 fig.suptitle(title_1d) 1636 fig.tight_layout() 1637 1638 #Show and save 1639 fig = plt.gcf() 1640 if show: 1641 plt.show() 1642 if save: 1643 fig.savefig(file_name + '_slices.png') 1644 plt.close() 1645 1646 #2d plot 1647 if d2: 1648 #Initialize 1649 fig, ax = plt.subplots(1,2) 1650 l1 = jnp.unique(xt[:,-1]).shape[0] 1651 l2 = jnp.unique(xt[:,0]).shape[0] 1652 1653 #Create 1654 ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 1655 ax[0].set_title('Exact') 1656 ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 1657 ax[1].set_title('Predicted') 1658 1659 #Title 1660 fig.suptitle(title_2d) 1661 fig.tight_layout() 1662 1663 #Show and save 1664 fig = plt.gcf() 1665 if show: 1666 plt.show() 1667 if save: 1668 fig.savefig(file_name + '_2d.png') 1669 plt.close()
Plot the prediction of a 1D PINN
Parameters
- times (int): Number of points along the time interval to plot. Default 5
- xt (jax.numpy.array): Test data xt array
- u (jax.numpy.array): Test data u(x,t) array
- upred (jax.numpy.array): Predicted upred(x,t) array on test data
- d2 (logical): Whether to plot 2D plot. Default True
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- title_1d (str): Title of 1D plot
- title_2d (str): Title of 2D plot
Returns
- None
1672def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True): 1673 """ 1674 Plot the prediction of a PINN with 2D output 1675 ---------- 1676 Parameters 1677 ---------- 1678 times : int 1679 1680 Number of points along the time interval to plot. Default 5 1681 1682 xt : jax.numpy.array 1683 1684 Test data xt array 1685 1686 u : jax.numpy.array 1687 1688 Test data u(x,t) array 1689 1690 upred : jax.numpy.array 1691 1692 Predicted upred(x,t) array on test data 1693 1694 save : logical 1695 1696 Whether to save the plots. Default False 1697 1698 show : logical 1699 1700 Whether to show the plots. Default True 1701 1702 file_name : str 1703 1704 File prefix to save the plots. Default 'result_pinn' 1705 1706 title : str 1707 1708 Title of plot 1709 1710 plot_test : logical 1711 1712 Whether to plot the test data. Default True 1713 1714 Returns 1715 ------- 1716 None 1717 """ 1718 #Initialize 1719 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 1720 tlo = jnp.min(xt[:,-1]) 1721 tup = jnp.max(xt[:,-1]) 1722 xlo = jnp.min(u[:,0]) 1723 xlo = xlo - 0.1*jnp.abs(xlo) 1724 xup = jnp.max(u[:,0]) 1725 xup = xup + 0.1*jnp.abs(xup) 1726 ylo = jnp.min(u[:,1]) 1727 ylo = ylo - 0.1*jnp.abs(ylo) 1728 yup = jnp.max(u[:,1]) 1729 yup = yup + 0.1*jnp.abs(yup) 1730 k = 0 1731 t_values = np.linspace(tlo,tup,times) 1732 1733 #Create 1734 for i in range(int(times/5)): 1735 for j in range(5): 1736 if k < len(t_values): 1737 t = t_values[k] 1738 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 1739 xpred_plot = upred[xt[:,-1] == t,0] 1740 ypred_plot = upred[xt[:,-1] == t,1] 1741 if plot_test: 1742 x_plot = u[xt[:,-1] == t,0] 1743 y_plot = u[xt[:,-1] == t,1] 1744 if int(times/5) > 1: 1745 if plot_test: 1746 ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 1747 ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction') 1748 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 1749 ax[i,j].set_xlabel(' ') 1750 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1751 else: 1752 if plot_test: 1753 ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 1754 ax[j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction') 1755 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 1756 ax[j].set_xlabel(' ') 1757 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1758 k = k + 1 1759 1760 #Title 1761 fig.suptitle(title) 1762 fig.tight_layout() 1763 1764 #Show and save 1765 fig = plt.gcf() 1766 if show: 1767 plt.show() 1768 if save: 1769 fig.savefig(file_name + '_slices.png') 1770 plt.close()
Plot the prediction of a PINN with 2D output
Parameters
- times (int): Number of points along the time interval to plot. Default 5
- xt (jax.numpy.array): Test data xt array
- u (jax.numpy.array): Test data u(x,t) array
- upred (jax.numpy.array): Predicted upred(x,t) array on test data
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- title (str): Title of plot
- plot_test (logical): Whether to plot the test data. Default True
Returns
- None
1773def get_train_data(train_data): 1774 """ 1775 Process training sample 1776 ---------- 1777 1778 Parameters 1779 ---------- 1780 train_data : dict 1781 1782 A dictionay with train data generated by the jinnax.data.generate_PINNdata function 1783 1784 Returns 1785 ------- 1786 dict with the processed training data 1787 """ 1788 xdata = None 1789 ydata = None 1790 xydata = None 1791 if train_data['sensor'] is not None: 1792 sensor_sample = train_data['sensor'].shape[0] 1793 xdata = train_data['sensor'] 1794 ydata = train_data['usensor'] 1795 xydata = jnp.column_stack((train_data['sensor'],train_data['usensor'])) 1796 else: 1797 sensor_sample = 0 1798 if train_data['boundary'] is not None: 1799 boundary_sample = train_data['boundary'].shape[0] 1800 if xdata is not None: 1801 xdata = jnp.vstack((xdata,train_data['boundary'])) 1802 ydata = jnp.vstack((ydata,train_data['uboundary'])) 1803 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary'])))) 1804 else: 1805 xdata = train_data['boundary'] 1806 ydata = train_data['uboundary'] 1807 xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary'])) 1808 else: 1809 boundary_sample = 0 1810 if train_data['initial'] is not None: 1811 initial_sample = train_data['initial'].shape[0] 1812 if xdata is not None: 1813 xdata = jnp.vstack((xdata,train_data['initial'])) 1814 ydata = jnp.vstack((ydata,train_data['uinitial'])) 1815 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial'])))) 1816 else: 1817 xdata = train_data['initial'] 1818 ydata = train_data['uinitial'] 1819 xydata = jnp.column_stack((train_data['initial'],train_data['uinitial'])) 1820 else: 1821 initial_sample = 0 1822 if train_data['collocation'] is not None: 1823 collocation_sample = train_data['collocation'].shape[0] 1824 else: 1825 collocation_sample = 0 1826 1827 return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
Process training sample
Parameters
- train_data (dict): A dictionay with train data generated by the jinnax.data.generate_PINNdata function
Returns
- dict with the processed training data
1830def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1): 1831 """ 1832 Process the training of a Physics-informed Neural Network 1833 ---------- 1834 1835 Parameters 1836 ---------- 1837 test_data : dict 1838 1839 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1840 1841 file_name : str 1842 1843 Name of the files saved during training 1844 1845 at_each : int 1846 1847 Compute results for epochs multiple of at_each. Default 100 1848 1849 bolstering : logical 1850 1851 Whether to compute bolstering mean square error. Default True 1852 1853 mc_sample : int 1854 1855 Number of sample for Monte Carlo integration in bolstering. Default 10000 1856 1857 save : logical 1858 1859 Whether to save the training results. Default False 1860 1861 file_name_save : str 1862 1863 File prefix to save the plots and the L2 error. Default 'result_pinn' 1864 1865 key : int 1866 1867 Key for random samples in bolstering. Default 0 1868 1869 ec : float 1870 1871 Stopping criteria error for EM algorithm in bolstering. Default 1e-6 1872 1873 lamb : float 1874 1875 Hyperparameter of EM algorithm in bolstering. Default 1 1876 1877 Returns 1878 ------- 1879 pandas data frame with training results 1880 """ 1881 #Config 1882 config = pickle.load(open(file_name + '_config.pickle', 'rb')) 1883 epochs = config['epochs'] 1884 train_data = config['train_data'] 1885 forward = config['forward'] 1886 1887 #Get train data 1888 td = get_train_data(train_data) 1889 xydata = td['xy'] 1890 xdata = td['x'] 1891 ydata = td['y'] 1892 sensor_sample = td['sensor_sample'] 1893 boundary_sample = td['boundary_sample'] 1894 initial_sample = td['initial_sample'] 1895 collocation_sample = td['collocation_sample'] 1896 1897 #Generate keys 1898 if bolstering: 1899 keys = jax.random.split(jax.random.PRNGKey(key),epochs) 1900 1901 #Initialize loss 1902 train_mse = [] 1903 test_mse = [] 1904 train_L2 = [] 1905 test_L2 = [] 1906 bolstX = [] 1907 bolstXY = [] 1908 loss = [] 1909 time = [] 1910 ep = [] 1911 1912 #Process training 1913 with alive_bar(epochs) as bar: 1914 for e in range(epochs): 1915 if (e % at_each == 0 and at_each != epochs) or e == epochs - 1: 1916 ep = ep + [e] 1917 1918 #Read parameters 1919 params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb')) 1920 1921 #Time 1922 time = time + [params['time']] 1923 1924 #Define learned function 1925 def psi(x): 1926 return forward(x,params['params']['net']) 1927 1928 #Train MSE and L2 1929 if xdata is not None: 1930 train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()] 1931 train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()] 1932 else: 1933 train_mse = train_mse + [None] 1934 train_L2 = train_L2 + [None] 1935 1936 #Test MSE and L2 1937 test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()] 1938 test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()] 1939 1940 #Bolstering 1941 if bolstering: 1942 bX = [] 1943 bXY = [] 1944 for method in ['chi','mm','mpe']: 1945 kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 1946 kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 1947 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 1948 bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()] 1949 for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]: 1950 kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias) 1951 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 1952 bolstX = bolstX + [bX] 1953 bolstXY = bolstXY + [bXY] 1954 else: 1955 bolstX = bolstX + [None] 1956 bolstXY = bolstXY + [None] 1957 1958 #Loss 1959 loss = loss + [params['loss'].tolist()] 1960 1961 #Delete 1962 del params, psi 1963 #Update alive_bar 1964 bar() 1965 1966 #Bolstering results 1967 if bolstering: 1968 bolstX = jnp.array(bolstX) 1969 bolstXY = jnp.array(bolstXY) 1970 1971 #Create data frame 1972 if bolstering: 1973 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 1974 train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]), 1975 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4']) 1976 else: 1977 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 1978 train_mse,test_mse,train_L2,test_L2]), 1979 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2']) 1980 if save: 1981 df.to_csv(file_name_save + '.csv',index = False) 1982 1983 return df
Process the training of a Physics-informed Neural Network
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- at_each (int): Compute results for epochs multiple of at_each. Default 100
- bolstering (logical): Whether to compute bolstering mean square error. Default True
- mc_sample (int): Number of sample for Monte Carlo integration in bolstering. Default 10000
- save (logical): Whether to save the training results. Default False
- file_name_save (str): File prefix to save the plots and the L2 error. Default 'result_pinn'
- key (int): Key for random samples in bolstering. Default 0
- ec (float): Stopping criteria error for EM algorithm in bolstering. Default 1e-6
- lamb (float): Hyperparameter of EM algorithm in bolstering. Default 1
Returns
- pandas data frame with training results
1986def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10): 1987 """ 1988 Demo video with the training of a 1D PINN 1989 ---------- 1990 1991 Parameters 1992 ---------- 1993 test_data : dict 1994 1995 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1996 1997 file_name : str 1998 1999 Name of the files saved during training 2000 2001 at_each : int 2002 2003 Compute results for epochs multiple of at_each. Default 100 2004 2005 times : int 2006 2007 Number of points along the time interval to plot. Default 5 2008 2009 d2 : logical 2010 2011 Whether to make video demo of 2D plot. Default True 2012 2013 file_name_save : str 2014 2015 File prefix to save the plots and videos. Default 'result_pinn_demo' 2016 2017 title : str 2018 2019 Title for plots 2020 2021 framerate : int 2022 2023 Framerate for video. Default 10 2024 2025 Returns 2026 ------- 2027 None 2028 """ 2029 #Config 2030 with open(file_name + '_config.pickle', 'rb') as file: 2031 config = pickle.load(file) 2032 epochs = config['epochs'] 2033 train_data = config['train_data'] 2034 forward = config['forward'] 2035 2036 #Get train data 2037 td = get_train_data(train_data) 2038 xt = td['x'] 2039 u = td['y'] 2040 2041 #Create folder to save plots 2042 os.system('mkdir ' + file_name_save) 2043 2044 #Create images 2045 k = 1 2046 with alive_bar(epochs) as bar: 2047 for e in range(epochs): 2048 if e % at_each == 0 or e == epochs - 1: 2049 #Read parameters 2050 params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 2051 2052 #Define learned function 2053 def psi(x): 2054 return forward(x,params['params']['net']) 2055 2056 #Compute L2 train, L2 test and loss 2057 loss = params['loss'] 2058 L2_train = L2error(psi(xt),u) 2059 L2_test = L2error(psi(test_data['xt']),test_data['u']) 2060 title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6)) 2061 2062 #Save plot 2063 plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch) 2064 k = k + 1 2065 2066 #Delete 2067 del params, psi, loss, L2_train, L2_test, title_epoch 2068 #Update alive_bar 2069 bar() 2070 #Create demo video 2071 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4') 2072 if d2: 2073 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')
Demo video with the training of a 1D PINN
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- at_each (int): Compute results for epochs multiple of at_each. Default 100
- times (int): Number of points along the time interval to plot. Default 5
- d2 (logical): Whether to make video demo of 2D plot. Default True
- file_name_save (str): File prefix to save the plots and videos. Default 'result_pinn_demo'
- title (str): Title for plots
- framerate (int): Framerate for video. Default 10
Returns
- None
2076def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10): 2077 """ 2078 Demo video with the time evolution of a 1D PINN 2079 ---------- 2080 2081 Parameters 2082 ---------- 2083 test_data : dict 2084 2085 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 2086 2087 file_name : str 2088 2089 Name of the files saved during training 2090 2091 epochs : list 2092 2093 Which training epochs to plot 2094 2095 file_name_save : str 2096 2097 File prefix to save the plots and video. Default 'result_pinn_time_demo' 2098 2099 title : str 2100 2101 Title for plots 2102 2103 framerate : int 2104 2105 Framerate for video. Default 10 2106 2107 Returns 2108 ------- 2109 None 2110 """ 2111 #Config 2112 with open(file_name + '_config.pickle', 'rb') as file: 2113 config = pickle.load(file) 2114 train_data = config['train_data'] 2115 forward = config['forward'] 2116 2117 #Create folder to save plots 2118 os.system('mkdir ' + file_name_save) 2119 2120 #Plot parameters 2121 tdom = jnp.unique(test_data['xt'][:,-1]) 2122 ylo = jnp.min(test_data['u']) 2123 ylo = ylo - 0.1*jnp.abs(ylo) 2124 yup = jnp.max(test_data['u']) 2125 yup = yup + 0.1*jnp.abs(yup) 2126 2127 #Open PINN for each epoch 2128 results = [] 2129 upred = [] 2130 for e in epochs: 2131 tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 2132 results = results + [tmp] 2133 upred = upred + [forward(test_data['xt'],tmp['params']['net'])] 2134 2135 #Create images 2136 k = 1 2137 with alive_bar(len(tdom)) as bar: 2138 for t in tdom: 2139 #Test data 2140 xt_step = test_data['xt'][test_data['xt'][:,-1] == t] 2141 u_step = test_data['u'][test_data['xt'][:,-1] == t] 2142 #Initialize plot 2143 if len(epochs) > 1: 2144 fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2)) 2145 else: 2146 fig, ax = plt.subplots(1,1,figsize = (10,5)) 2147 #Create 2148 index = 0 2149 if int(len(epochs)/2) > 1: 2150 for i in range(int(len(epochs)/2)): 2151 for j in range(min(2,len(epochs))): 2152 upred_step = upred[index][test_data['xt'][:,-1] == t] 2153 ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2154 ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2155 ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2156 ax[i,j].set_xlabel(' ') 2157 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2158 index = index + 1 2159 elif len(epochs) > 1: 2160 for j in range(2): 2161 upred_step = upred[index][test_data['xt'][:,-1] == t] 2162 ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2163 ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2164 ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2165 ax[j].set_xlabel(' ') 2166 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2167 index = index + 1 2168 else: 2169 upred_step = upred[index][test_data['xt'][:,-1] == t] 2170 ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 2171 ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 2172 ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10) 2173 ax.set_xlabel(' ') 2174 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 2175 index = index + 1 2176 2177 2178 #Title 2179 fig.suptitle(title + 't = ' + str(round(t,4))) 2180 fig.tight_layout() 2181 2182 #Show and save 2183 fig = plt.gcf() 2184 fig.savefig(file_name_save + '/' + str(k) + '.png') 2185 k = k + 1 2186 plt.close() 2187 bar() 2188 2189 #Create demo video 2190 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
Demo video with the time evolution of a 1D PINN
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- epochs (list): Which training epochs to plot
- file_name_save (str): File prefix to save the plots and video. Default 'result_pinn_time_demo'
- title (str): Title for plots
- framerate (int): Framerate for video. Default 10
Returns
- None