jinnax.nn

   1#Functions to train NN
   2import jax
   3import jax.numpy as jnp
   4import optax
   5from alive_progress import alive_bar
   6import math
   7import time
   8import pandas as pd
   9import numpy as np
  10import matplotlib.pyplot as plt
  11import dill as pickle
  12from genree import bolstering as gb
  13from genree import kernel as gk
  14from jax import random
  15from jinnax import data as jd
  16import os
  17import numpy as np
  18from scipy.fft import dst, idst
  19from itertools import product
  20from functools import partial
  21import orthax
  22from jax import lax
  23from jaxopt import LBFGS
  24
  25__docformat__ = "numpy"
  26
  27#MSE
  28@jax.jit
  29def MSE(pred,true):
  30    """
  31    Squared error
  32    ----------
  33    Parameters
  34    ----------
  35    pred : jax.numpy.array
  36
  37        A JAX numpy array with the predicted values
  38
  39    true : jax.numpy.array
  40
  41        A JAX numpy array with the true values
  42
  43    Returns
  44    -------
  45    squared error
  46    """
  47    return (true - pred) ** 2
  48
  49#MSE self-adaptative
  50@jax.jit
  51def MSE_SA(pred,true,w,q = 2):
  52    """
  53    Self-adaptative squared error
  54    ----------
  55    Parameters
  56    ----------
  57    pred : jax.numpy.array
  58
  59        A JAX numpy array with the predicted values
  60
  61    true : jax.numpy.array
  62
  63        A JAX numpy array with the true values
  64
  65    weight : jax.numpy.array
  66
  67        A JAX numpy array with the weights
  68
  69    q : float
  70
  71        Power for the weights mask
  72
  73    Returns
  74    -------
  75    self-adaptative squared error with polynomial mask
  76    """
  77    return (w ** q) * ((true - pred) ** 2)
  78
  79#L2 error
  80@jax.jit
  81def L2error(pred,true):
  82    """
  83    L2-error in percentage (%)
  84    ----------
  85    Parameters
  86    ----------
  87    pred : jax.numpy.array
  88
  89        A JAX numpy array with the predicted values
  90
  91    true : jax.numpy.array
  92
  93        A JAX numpy array with the true values
  94
  95    Returns
  96    -------
  97    L2-error
  98    """
  99    return 100*jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))
 100
 101#Auxialiry functions to sample singular Matern
 102def idst1(x,axis = -1):
 103    """
 104    Inverse Discrete Sine Transform of type I with orthonormal scaling
 105    ----------
 106    Parameters
 107    ----------
 108    x : jax.numpy.array
 109
 110        Array to apply the transformation
 111
 112    axis : int
 113
 114        Axis to apply the transformation over
 115
 116    Returns
 117    -------
 118    jax.numpy.array
 119    """
 120    return idst(x,type = 1,axis = axis,norm = 'ortho')
 121
 122def dstn(x,axes = None):
 123    """
 124    Discrete Sine Transform of type I with orthonormal scaling over many axes
 125    ----------
 126    Parameters
 127    ----------
 128    x : jax.numpy.array
 129
 130        Array to apply the transformation
 131
 132    axes : int
 133
 134        Axes to apply the transformation over
 135
 136    Returns
 137    -------
 138    jax.numpy.array
 139    """
 140    if axes is None:
 141        axes = tuple(range(x.ndim))
 142    y = x
 143    for ax in axes:
 144        y = dst(x,type = 1,axis = ax,norm = 'ortho')
 145    return y
 146
 147def idstn(x,axes = None):
 148    """
 149    Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes
 150    ----------
 151    Parameters
 152    ----------
 153    x : jax.numpy.array
 154
 155        Array to apply the transformation
 156
 157    axes : int
 158
 159        Axes to apply the transformation over
 160
 161    Returns
 162    -------
 163    jax.numpy.array
 164    """
 165    if axes is None:
 166        axes = tuple(range(x.ndim))
 167    y = x
 168    for ax in axes:
 169        y = idst1(y,axis = ax)
 170    return y
 171
 172def dirichlet_eigs_nd(n,L):
 173    """
 174    Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle
 175    ----------
 176    Parameters
 177    ----------
 178    n : list
 179
 180        List with the number of points in the grid in each dimension
 181
 182    L : list
 183
 184        List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero
 185
 186    Returns
 187    -------
 188    jax.numpy.array
 189    """
 190    #Unidimensional eigenvalues
 191    lam_axes = []
 192    for ni, Li in zip(n,L):
 193        h = Li / (ni + 1.0)
 194        k = jnp.arange(1,ni + 1,dtype = np.float32)
 195        ln = (2.0 / (h*h)) * (1.0 - jnp.cos(jnp.pi * k / (ni + 1.0)))
 196        lam_axes.append(ln)
 197    grids = jnp.meshgrid(*lam_axes, indexing='ij')
 198    Lam = jnp.zeros_like(grids[0])
 199    for g in grids:
 200        Lam += g
 201    return Lam
 202
 203
 204#Sample from d-dimensional Matern process
 205def generate_matern_sample(key,d = 2,N = 128,L = 1.0,kappa = 1,alpha = 1,sigma = 1,periodic = False):
 206    """
 207    Sample d-dimensional Matern process
 208    ----------
 209    Parameters
 210    ----------
 211    key : int
 212
 213        Seed for randomization
 214
 215    d : int
 216
 217        Dimension. Default 2
 218
 219    N : int
 220
 221        Size of grid in each dimension. Default 128
 222
 223    L : list of float
 224
 225        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
 226
 227    kappa,alpha,sigma : float
 228
 229        Parameters of the Matern process
 230
 231    periodic : logical
 232
 233        Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
 234
 235    Returns
 236    -------
 237    jax.numpy.array
 238    """
 239    if periodic:
 240        #Shape and key
 241        key = jax.random.PRNGKey(key)
 242        shape = (N,) * d
 243        if isinstance(L,float) or isinstance(L,int):
 244            L = d*[L]
 245        if isinstance(N,float) or isinstance(N,int):
 246            N = d*[N]
 247
 248        #Setup Frequency Grid (2D)
 249        freq = [jnp.fft.fftfreq(N[j],d = L[j]/N[j]) * 2 * jnp.pi for j in range(d)]
 250        grids = jnp.meshgrid(*freq, indexing='ij')
 251        sq_norm_xi = sum(g**2 for g in grids)
 252
 253        #Generate White Noise in Fourier Space
 254        key_re, key_im = jax.random.split(key)
 255        white_noise_f = (jax.random.normal(key_re, shape) +
 256                         1j * jax.random.normal(key_im, shape))
 257
 258        #Apply the Whittle Filter
 259        amplitude_filter = (kappa ** 2 + sq_norm_xi) ** (-alpha / 2.0)
 260        field_f = white_noise_f * amplitude_filter
 261
 262        #Transform back to Physical Space
 263        sample = jnp.real(jnp.fft.ifftn(field_f))
 264        return sigma*sample
 265    else: #NOT JAX
 266        #Shape and key
 267        rng = np.random.default_rng(seed = key)
 268        if isinstance(L,float) or isinstance(L,int):
 269            L = d*[L]
 270        if isinstance(N,float) or isinstance(N,int):
 271            N = d*[N]
 272        shape = tuple(N)
 273
 274        #White noise in real space
 275        W = rng.standard_normal(size = shape)
 276
 277        #To Dirichlet eigenbasis via separable DST-I (orthonormal)
 278        W_hat = dstn(W)
 279
 280        #Discrete Dirichlet Laplacian eigenvalues
 281        lam = dirichlet_eigs_nd(N, L)
 282
 283        #Spectral filter
 284        filt = ((kappa + lam) ** (-alpha/2.0))
 285        psi_hat = filt * W_hat
 286
 287        #Back to real space
 288        psi = idstn(psi_hat)
 289        return jnp.array(sigma*psi)
 290
 291#Vectorized generate_matern_sample
 292def generate_matern_sample_batch(d = 2,N = 512,L = 1.0,kappa = 10.0,alpha = 1,sigma = 10,periodic = False):
 293    """
 294    Create function to sample d-dimensional Matern process
 295    ----------
 296    Parameters
 297    ----------
 298    d : int
 299
 300        Dimension. Default 2
 301
 302    N : int
 303
 304        Size of grid in each dimension. Default 128
 305
 306    L : list of float
 307
 308        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
 309
 310    kappa,alpha,sigma : float
 311
 312        Parameters of the Matern process
 313
 314    periodic : logical
 315
 316        Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
 317
 318    Returns
 319    -------
 320    function
 321    """
 322    if periodic:
 323        return jax.vmap(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic))
 324    else:
 325        return lambda keys: jnp.array(np.apply_along_axis(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic),1,keys.reshape((keys.shape[0],1))))
 326
 327#Build function to compute the eigenfunctions of Laplacian
 328def eigenf_laplace(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None):
 329    """
 330    Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace.
 331    ----------
 332    Parameters
 333    ----------
 334    L_vec : list of float
 335
 336        The domain of the function in each coordinate is [0,L[1]]
 337
 338    kmax_per_axis : list
 339
 340        List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions
 341
 342    bc : str
 343
 344        Boundary condition. 'dirichlet' or 'neumann'
 345
 346    max_ef : int
 347
 348        Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
 349
 350    Returns
 351    -------
 352    function to compute eigenfunctions,eigenvalues of the eigenfunctions considered
 353    """
 354    #Parameters
 355    L_vec = jnp.asarray(L_vec,dtype = jnp.float32)
 356    d = L_vec.shape[0]
 357    bc = bc.lower()
 358
 359    #Maximum number of functions
 360    if max_ef is None:
 361        if d == 1:
 362            max_ef = jnp.max(jnp.array(kmax_per_axis))
 363        else:
 364            max_ef = jnp.max(d * jnp.array(kmax_per_axis))
 365
 366    #Build the candidate multi-indices per axis
 367    kmax_per_axis = list(map(int, kmax_per_axis))
 368    if bc.startswith("d"):
 369        axis_ranges = [range(1, km + 1) for km in kmax_per_axis]
 370    elif bc.startswith("n"):
 371        axis_ranges = [range(0, km + 1) for km in kmax_per_axis]
 372
 373    #Get all multi-indices
 374    Ks_list = list(product(*axis_ranges))
 375    Ks = jnp.array(Ks_list,dtype = jnp.float32)
 376
 377    #Eigenvalues of the continuous Laplacian
 378    pi_over_L = jnp.pi / L_vec
 379    lambdas_all = jnp.sum((Ks * pi_over_L) ** 2, axis=1)
 380
 381    #Sort by eigenvalue
 382    order = jnp.argsort(lambdas_all)
 383    Ks = Ks[order]
 384    lambdas_all = lambdas_all[order]
 385
 386    #Keep first max_ef
 387    Ks = Ks[:max_ef]
 388    lambdas = lambdas_all[:max_ef]
 389    m = Ks.shape[0]
 390
 391    #Precompute per-feature normalization factor (closed form)
 392    def per_axis_norm_factor(k_i, L_i, is_dirichlet):
 393        if is_dirichlet:
 394            return jnp.sqrt(2.0 / L_i)
 395        else:
 396            return jnp.where(k_i == 0, jnp.sqrt(1.0 / L_i), jnp.sqrt(2.0 / L_i))
 397    if bc.startswith("d"):
 398        nf = jnp.prod(jnp.sqrt(2.0 / L_vec)[None, :],axis = 1)
 399        norm_factors = jnp.ones((m,),dtype = jnp.float32) * nf
 400    else:
 401        # per-mode product across axes
 402        def nf_row(k_row):
 403            return jnp.prod(per_axis_norm_factor(k_row, L_vec, False))
 404        norm_factors = jax.vmap(nf_row)(Ks)
 405
 406    #Build the callable function
 407    Ks_int = Ks  # float array, but only integer values
 408    L_vec_f = L_vec
 409    @jax.jit
 410    def phi(x):
 411        x = jnp.asarray(x,dtype = jnp.float32)
 412        #Initialize with ones
 413        vals = jnp.ones(x.shape[:-1] + (m,), dtype=jnp.float32)
 414        #Compute eigenfunction
 415        for i in range(d):
 416            ang = (jnp.pi / L_vec_f[i]) * x[..., i][..., None] * Ks_int[:, i]
 417            if bc.startswith("d"):
 418                comp = jnp.sin(ang)
 419            else:
 420                comp = jnp.cos(ang)
 421            vals = vals * comp
 422        #Apply L2-normalizing constants
 423        vals = vals * norm_factors[None, ...] if vals.ndim > 1 else vals * norm_factors
 424        return vals
 425    return phi, lambdas
 426
 427#Compute multiple frequences of domain aware fourrier fesatures
 428def multiple_daff(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None):
 429    """
 430    Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain.
 431    ----------
 432    Parameters
 433    ----------
 434    L_vec : list of lists of float
 435
 436        List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]]
 437
 438    kmax_per_axis : list
 439
 440        List with the maximum number of eigenfunctions per dimension.
 441
 442    bc : str
 443
 444        Boundary condition. 'dirichlet' or 'neumann'
 445
 446    max_ef : int
 447
 448        Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
 449
 450    Returns
 451    -------
 452    function to compute daff,eigenvalues of the eigenfunctions considered
 453    """
 454    psi = []
 455    lamb = []
 456    for L in L_vec:
 457        tmp,l = eigenf_laplace(L,kmax_per_axis,bc,max_ef) #Get function
 458        lamb.append(l)
 459        psi.append(tmp)
 460        del tmp
 461    #Create function to compute features
 462    @jax.jit
 463    def mff(x):
 464        y = []
 465        for i in range(len(psi)):
 466            y.append(psi[i](x))
 467        if len(psi) == 1:
 468            return y[0]
 469        else:
 470            return jnp.concatenate(y,1)
 471    return mff,jnp.concatenate(lamb)
 472
 473#Code for chebyshev polynomials writeen by AI (deprecated)
 474def _chebyshev_T_all(t, K: int):
 475    """
 476    Compute T_0..T_K(t) with the standard recurrence.
 477    t shape should be (..., d). We DO NOT squeeze any axis to preserve 'd'
 478    even when d == 1.
 479    Returns: array of shape (K+1, ...) matching t's batch dims, including d.
 480    """
 481    # Expect t to have last axis = d (keep it, even if d == 1)
 482    T0 = jnp.ones_like(t)           # (..., d)
 483    if K == 0:
 484        return T0[None, ...]        # (1, ..., d)
 485
 486    T1 = t                          # (..., d)
 487    if K == 1:
 488        return jnp.stack([T0, T1], axis=0)  # (2, ..., d)
 489
 490    def body(carry, _):
 491        Tkm1, Tk = carry            # each (..., d)
 492        Tkp1 = 2.0 * t * Tk - Tkm1  # (..., d)
 493        return (Tk, Tkp1), Tkp1
 494
 495    # K >= 2: produce T_2..T_K
 496    (_, _), T2_to_TK = lax.scan(body, (T0, T1), jnp.arange(K - 1))  # (K-1, ..., d)
 497    return jnp.concatenate([T0[None, ...], T1[None, ...], T2_to_TK], axis=0)  # (K+1, ..., d)
 498
 499@partial(jax.jit,static_argnums=(2,))  # n is static here; compile once per n
 500def multiple_cheb_fast(x, L_vec, n: int):
 501    """
 502    x: (N, d)
 503    L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension
 504    n: number of k terms (static)
 505    returns: (N, L*n)
 506    """
 507    N, d = x.shape
 508    L = L_vec.shape[0]
 509
 510    a = 0.0
 511    b = L_vec                       # (L, d)
 512    # Map x to t in [-1, 1] for each l, j: shape (L, N, d)
 513    t = (2.0 * x[None, :, :] - (a + b)[:, None, :]) / (b - a)[:, None, :]
 514
 515    # Chebyshev T_0..T_{n+2} for all (L, N, d): shape (n+3, L, N, d)
 516    T = _chebyshev_T_all(t, n + 2)
 517
 518    # phi_k = T_{k+2} - T_k, k = 0..n-1  => shape (n, L, N, d)
 519    ks = jnp.arange(n)
 520    phi = T[ks + 2, ...] - T[ks, ...]
 521
 522    # Multiply across dimensions (over the last axis = d) => (n, L, N)
 523    z = jnp.prod(phi, axis=-1)
 524
 525    # Reorder to (N, L, n) then flatten to (N, L*n)
 526    z = jnp.transpose(z, (2, 1, 0)).reshape(N, L * n)
 527    return z
 528
 529def multiple_cheb(L_vec, n: int):
 530    """
 531    Factory that closes over static n and L_vec (so shapes are constant).
 532    """
 533    L_vec = jnp.asarray(L_vec)
 534    @jax.jit  # optional; multiple_cheb_fast is already jitted
 535    def mcheb(x):
 536        x = jnp.asarray(x)
 537        return multiple_cheb_fast(x, L_vec, n)
 538    return mcheb
 539
 540
 541#Initialize fully connected neyral network Return the initial parameters and the function for the forward pass
 542def fconNN(width,activation = jax.nn.tanh,key = 0,mlp = False,ftype = None,fargs = None,static = None):
 543    """
 544    Initialize fully connected neural network
 545    ----------
 546    Parameters
 547    ----------
 548    width : list
 549
 550        List with the layers width
 551
 552    activation : jax.nn activation
 553
 554        The activation function. Default jax.nn.tanh
 555
 556    key : int
 557
 558        Seed for parameters initialization. Default 0
 559
 560    mlp : logical
 561
 562        Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension.
 563
 564    ftype : str
 565
 566        Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
 567
 568    fargs : list
 569
 570        Arguments for deature transformation:
 571
 572        For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
 573
 574        For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
 575
 576    static : function
 577
 578        A static function to sum to the neural network output.
 579
 580    Returns
 581    -------
 582    dict with initial parameters and the function for the forward pass
 583    """
 584    #Initialize parameters with Glorot initialization
 585    initializer = jax.nn.initializers.glorot_normal()
 586    params = list()
 587    if static is None:
 588        static = lambda x: 0.0
 589
 590    #Feature mapping
 591    if ftype == 'ff': #Fourrier features
 592        for s in range(fargs[0]):
 593            sd = fargs[1] ** ((s + 1)/fargs[0])
 594            if s == 0:
 595                Bff = sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2)))
 596            else:
 597                Bff = jnp.append(Bff,sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))),1)
 598        @jax.jit
 599        def phi(x):
 600            x = x @ Bff
 601            return jnp.concatenate([jnp.sin(2 * jnp.pi * x),jnp.cos(2 * jnp.pi * x)],axis = -1)
 602        width = width[1:]
 603        width[0] = 2*Bff.shape[1]
 604    elif ftype == 'daff' or ftype == 'daff_bias':
 605        if not isinstance(fargs, dict):
 606            fargs = {'L': fargs,'bc': "dirichlet"}
 607        phi,lamb = multiple_daff(list(fargs.values())[0],kmax_per_axis = [width[1]] * width[0],bc = list(fargs.values())[1])
 608        width = width[1:]
 609        width[0] = lamb.shape[0]
 610    elif ftype == 'cheb' or ftype == 'cheb_bias':
 611        phi = multiple_cheb(fargs,n = width[1])
 612        width = width[1:]
 613        width[0] = len(fargs)*width[0]
 614    else:
 615        @jax.jit
 616        def phi(x):
 617            return x
 618
 619    #Initialize parameters
 620    if mlp:
 621        k = jax.random.split(jax.random.PRNGKey(key),4)
 622        WU = initializer(k[0],(width[0],width[1]),jnp.float32)
 623        BU = initializer(k[1],(1,width[1]),jnp.float32)
 624        WV = initializer(k[2],(width[0],width[1]),jnp.float32)
 625        BV = initializer(k[3],(1,width[1]),jnp.float32)
 626        params.append({'WU':WU,'BU':BU,'WV':WV,'BV':BV})
 627    key = jax.random.split(jax.random.PRNGKey(key + 1),len(width)-1) #Seed for initialization
 628    for key,lin,lout in zip(key,width[:-1],width[1:]):
 629        W = initializer(key,(lin,lout),jnp.float32)
 630        B = initializer(key,(1,lout),jnp.float32)
 631        params.append({'W':W,'B':B})
 632
 633    #Define function for forward pass
 634    if mlp:
 635        if ftype != 'daff' and ftype != 'cheb':
 636            @jax.jit
 637            def forward(x,params):
 638                encode,*hidden,output = params
 639                sx = static(x)
 640                x = phi(x)
 641                U = activation(x @ encode['WU'] + encode['BU'])
 642                V = activation(x @ encode['WV'] + encode['BV'])
 643                for layer in hidden:
 644                    x = activation(x @ layer['W'] + layer['B'])
 645                    x = x * U + (1 - x) * V
 646                return x @ output['W'] + output['B'] + sx
 647        else:
 648            @jax.jit
 649            def forward(x,params):
 650                encode,*hidden,output = params
 651                sx = static(x)
 652                x = phi(x)
 653                U = activation(x @ encode['WU'])
 654                V = activation(x @ encode['WV'])
 655                for layer in hidden:
 656                    x = activation(x @ layer['W'])
 657                    x = x * U + (1 - x) * V
 658                return x @ output['W'] + sx
 659    else:
 660        if ftype != 'daff' and ftype != 'cheb':
 661            @jax.jit
 662            def forward(x,params):
 663                *hidden,output = params
 664                sx = static(x)
 665                x = phi(x)
 666                for layer in hidden:
 667                    x = activation(x @ layer['W'] + layer['B'])
 668                return x @ output['W'] + output['B'] + sx
 669        else:
 670            @jax.jit
 671            def forward(x,params):
 672                *hidden,output = params
 673                sx = static(x)
 674                x = phi(x)
 675                for layer in hidden:
 676                    x = activation(x @ layer['W'])
 677                return x @ output['W'] + sx
 678
 679    #Return initial parameters and forward function
 680    return {'params': params,'forward': forward}
 681
 682#Get activation from string
 683def get_activation(act):
 684    """
 685    Return activation function from string
 686    ----------
 687    Parameters
 688    ----------
 689    act : str
 690
 691        Name of the activation function. Default 'tanh'
 692
 693    Returns
 694    -------
 695    jax.nn activation function
 696    """
 697    if act == 'tanh':
 698        return jax.nn.tanh
 699    elif act == 'relu':
 700        return jax.nn.relu
 701    elif act == 'relu6':
 702        return jax.nn.relu6
 703    elif act == 'sigmoid':
 704        return jax.nn.sigmoid
 705    elif act == 'softplus':
 706        return jax.nn.softplus
 707    elif act == 'sparse_plus':
 708        return jx.nn.sparse_plus
 709    elif act == 'soft_sign':
 710        return jax.nn.soft_sign
 711    elif act == 'silu':
 712        return jax.nn.silu
 713    elif act == 'swish':
 714        return jax.nn.swish
 715    elif act == 'log_sigmoid':
 716        return jax.nn.log_sigmoid
 717    elif act == 'leaky_relu':
 718        return jax.nn.leaky_relu
 719    elif act == 'hard_sigmoid':
 720        return jax.nn.hard_sigmoid
 721    elif act == 'hard_silu':
 722        return jax.nn.hard_silu
 723    elif act == 'hard_swish':
 724        return jax.nn.hard_swish
 725    elif act == 'hard_tanh':
 726        return jax.nn.hard_tanh
 727    elif act == 'elu':
 728        return jax.nn.elu
 729    elif act == 'celu':
 730        return jax.nn.celu
 731    elif act == 'selu':
 732        return jax.nn.selu
 733    elif act == 'gelu':
 734        return jax.nn.gelu
 735    elif act == 'glu':
 736        return jax.nn.glu
 737    elif act == 'squareplus':
 738        return  jax.nn.squareplus
 739    elif act == 'mish':
 740        return jax.nn.mish
 741
 742#Training PINN
 743def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,mlp = False):
 744    """
 745    Train a Physics-informed Neural Network
 746    ----------
 747    Parameters
 748    ----------
 749    data : dict
 750
 751        Data generated by the jinnax.data.generate_PINNdata function
 752
 753    width : list
 754
 755        A list with the width of each layer
 756
 757    pde : function
 758
 759        The partial differential operator. Its arguments are u, x and t
 760
 761    test_data : dict, None
 762
 763        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
 764
 765    epochs : int
 766
 767        Number of training epochs. Default 100
 768
 769    at_each : int
 770
 771        Save results for epochs multiple of at_each. Default 10
 772
 773    activation : str
 774
 775        The name of the activation function of the neural network. Default 'tanh'
 776
 777    neumann : logical
 778
 779        Whether to consider Neumann boundary conditions
 780
 781    oper_neumann : function
 782
 783        Penalization of Neumann boundary conditions
 784
 785    sa : logical
 786
 787        Whether to consider self-adaptative PINN
 788
 789    c : dict
 790
 791        Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
 792
 793    inverse : logical
 794
 795        Whether to estimate parameters of the PDE
 796
 797    initial_par : jax.numpy.array
 798
 799        Initial value of the parameters of the PDE in an inverse problem
 800
 801    lr,b1,b2,eps,eps_root: float
 802
 803        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
 804
 805    key : int
 806
 807        Seed for parameters initialization. Default 0
 808
 809    epoch_print : int
 810
 811        Number of epochs to calculate and print test errors. Default 100
 812
 813    save : logical
 814
 815        Whether to save the current parameters. Default False
 816
 817    file_name : str
 818
 819        File prefix to save the current parameters. Default 'result_pinn'
 820
 821    exp_decay : logical
 822
 823        Whether to consider exponential decay of learning rate. Default False
 824
 825    transition_steps : int
 826
 827        Number of steps for exponential decay. Default 1000
 828
 829    decay_rate : float
 830
 831        Rate of exponential decay. Default 0.9
 832
 833    mlp : logical
 834
 835        Whether to consider modifed multi-layer perceptron
 836
 837    Returns
 838    -------
 839    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
 840    """
 841
 842    #Initialize architecture
 843    nnet = fconNN(width,get_activation(activation),key,mlp)
 844    forward = nnet['forward']
 845
 846    #Initialize self adaptative weights
 847    par_sa = {}
 848    if sa:
 849        #Initialize wheights close to zero
 850        ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000)
 851        if data['sensor'] is not None:
 852            par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))})
 853        if data['initial'] is not None:
 854            par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))})
 855        if data['collocation'] is not None:
 856            par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))})
 857        if data['boundary'] is not None:
 858            par_sa.update({'wb': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))})
 859
 860    #Store all parameters
 861    params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa}
 862
 863    #Save config file
 864    if save:
 865        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
 866
 867    #Define loss function
 868    if sa:
 869        #Define loss function
 870        @jax.jit
 871        def lf(params,x):
 872            loss = 0
 873            if x['sensor'] is not None:
 874                #Term that refers to sensor data
 875                loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws']))
 876            if x['boundary'] is not None:
 877                if neumann:
 878                    #Neumann coditions
 879                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
 880                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
 881                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb']))
 882                else:
 883                    #Term that refers to boundary data
 884                    loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb']))
 885            if x['initial'] is not None:
 886                #Term that refers to initial data
 887                loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0']))
 888            if x['collocation'] is not None:
 889                #Term that refers to collocation points
 890                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
 891                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
 892                if inverse:
 893                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr']))
 894                else:
 895                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr']))
 896            return loss
 897    else:
 898        @jax.jit
 899        def lf(params,x):
 900            loss = 0
 901            if x['sensor'] is not None:
 902                #Term that refers to sensor data
 903                loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor']))
 904            if x['boundary'] is not None:
 905                if neumann:
 906                    #Neumann coditions
 907                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
 908                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
 909                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb))
 910                else:
 911                    #Term that refers to boundary data
 912                    loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary']))
 913            if x['initial'] is not None:
 914                #Term that refers to initial data
 915                loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial']))
 916            if x['collocation'] is not None:
 917                #Term that refers to collocation points
 918                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
 919                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
 920                if inverse:
 921                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0))
 922                else:
 923                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0))
 924            return loss
 925
 926    #Initialize Adam Optmizer
 927    if exp_decay:
 928        lr = optax.exponential_decay(lr,transition_steps,decay_rate)
 929    optimizer = optax.adam(lr,b1,b2,eps,eps_root)
 930    opt_state = optimizer.init(params)
 931
 932    #Define the gradient function
 933    grad_loss = jax.jit(jax.grad(lf,0))
 934
 935    #Define update function
 936    @jax.jit
 937    def update(opt_state,params,x):
 938        #Compute gradient
 939        grads = grad_loss(params,x)
 940        #Invert gradient of self-adaptative wheights
 941        if sa:
 942            for w in grads['sa']:
 943                grads['sa'][w] = - grads['sa'][w]
 944        #Calculate parameters updates
 945        updates, opt_state = optimizer.update(grads, opt_state)
 946        #Update parameters
 947        params = optax.apply_updates(params, updates)
 948        #Return state of optmizer and updated parameters
 949        return opt_state,params
 950
 951    ###Training###
 952    t0 = time.time()
 953    #Initialize alive_bar for tracing in terminal
 954    with alive_bar(epochs) as bar:
 955        #For each epoch
 956        for e in range(epochs):
 957            #Update optimizer state and parameters
 958            opt_state,params = update(opt_state,params,data)
 959            #After epoch_print epochs
 960            if e % epoch_print == 0:
 961                #Compute elapsed time and current error
 962                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6))
 963                #If there is test data, compute current L2 error
 964                if test_data is not None:
 965                    #Compute L2 error
 966                    l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist()
 967                    l = l + ' L2 error: ' + str(jnp.round(l2_test,3))
 968                if inverse:
 969                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
 970                #Print
 971                print(l)
 972            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
 973                #Save current parameters
 974                pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
 975            #Update alive_bar
 976            bar()
 977    #Define estimated function
 978    def u(xt):
 979        return forward(xt,params['net'])
 980
 981    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}
 982
 983#Training PINN
 984def train_Matern_PINN(data,width,pde,test_data = None,params = None,d = 2,N = 128,L = 1,alpha = 1,kappa = 1,sigma = 100,bsize = 1024,resample = False,epochs = 100,at_each = 10,activation = 'tanh',
 985    neumann = False,oper_neumann = None,inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 1,save = False,file_name = 'result_pinn',
 986    exp_decay = True,transition_steps = 100,decay_rate = 0.9,mlp = True,ftype = None,fargs = None,q = 4,w = None,periodic = False,static = None,opt = 'LBFGS'):
 987    """
 988    Train a Physics-informed Neural Network
 989    ----------
 990    Parameters
 991    ----------
 992    data : dict
 993
 994        Data generated by the jinnax.data.generate_PINNdata function
 995
 996    width : list
 997
 998        A list with the width of each layer
 999
1000    pde : function
1001
1002        The partial differential operator. Its arguments are u, x and t
1003
1004    test_data : dict, None
1005
1006        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
1007
1008    params : list
1009
1010        Initial parameters for the neural network. Default None to initialize randomly
1011
1012    d : int
1013
1014        Dimension of the problem including the time variable if present. Default 2
1015
1016    N : int
1017
1018        Size of grid in each dimension. Default 128
1019
1020    L : list of float
1021
1022        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
1023
1024    kappa,alpha,sigma : float
1025
1026        Parameters of the Matern process
1027
1028    bsize : int
1029
1030        Batch size for weak norm computation. Default 1024
1031
1032    resample : logical
1033
1034        Whether to resample the test functions at each epoch
1035
1036    epochs : int
1037
1038        Number of training epochs. Default 100
1039
1040    at_each : int
1041
1042        Save results for epochs multiple of at_each. Default 10
1043
1044    activation : str
1045
1046        The name of the activation function of the neural network. Default 'tanh'
1047
1048    neumann : logical
1049
1050        Whether to consider Neumann boundary conditions
1051
1052    oper_neumann : function
1053
1054        Penalization of Neumann boundary conditions
1055
1056    inverse : logical
1057
1058        Whether to estimate parameters of the PDE
1059
1060    initial_par : jax.numpy.array
1061
1062        Initial value of the parameters of the PDE in an inverse problem
1063
1064    lr,b1,b2,eps,eps_root: float
1065
1066        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
1067
1068    key : int
1069
1070        Seed for parameters initialization. Default 0
1071
1072    epoch_print : int
1073
1074        Number of epochs to calculate and print test errors. Default 1
1075
1076    save : logical
1077
1078        Whether to save the current parameters. Default False
1079
1080    file_name : str
1081
1082        File prefix to save the current parameters. Default 'result_pinn'
1083
1084    exp_decay : logical
1085
1086        Whether to consider exponential decay of learning rate. Default True
1087
1088    transition_steps : int
1089
1090        Number of steps for exponential decay. Default 100
1091
1092    decay_rate : float
1093
1094        Rate of exponential decay. Default 0.9
1095
1096    mlp : logical
1097
1098        Whether to consider modifed multilayer perceptron
1099
1100    ftype : str
1101
1102        Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
1103
1104    fargs : list
1105
1106        Arguments for deature transformation:
1107
1108        For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
1109
1110        For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
1111
1112    q : int
1113
1114        Power of weights mask. Default 4
1115
1116    w : dict
1117
1118        Initila weights for self-adaptive scheme.
1119
1120    periodic : logical
1121
1122        Whether to consider periodic test functions. Default False.
1123
1124    static : function
1125
1126        A static function to sum to the neural network output.
1127
1128    opt : str
1129
1130        Optimizer. Default LBFGS.
1131
1132    Returns
1133    -------
1134    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch
1135    """
1136    #Initialize architecture
1137    nnet = fconNN(width,get_activation(activation),key,mlp,ftype,fargs,static)
1138    forward = nnet['forward']
1139    if params is not None:
1140        nnet['params'] = params
1141
1142    #Generate from Matern process
1143    if sigma > 0:
1144        if isinstance(L,float) or isinstance(L,int):
1145            L = d*[L]
1146        #Grid for weak norm
1147        grid = [jnp.linspace(0,L[i],N) for i in range(d)]
1148        grid = jnp.meshgrid(*grid, indexing='ij')
1149        grid = jnp.stack(grid, axis=-1).reshape((-1, d))
1150        #Set sigma
1151        if data['boundary'] is not None:
1152            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma)
1153            tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0])
1154            if neumann:
1155                loss_boundary = oper_neumann(lambda x: forward(x,params['net']),data['boundary'])
1156            else:
1157                loss_boundary = jnp.mean(MSE(forward(data['boundary'],nnet['params']),data['uboundary']))
1158            output_w = pde(lambda x: forward(x,nnet['params']),grid)
1159            integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(tf)
1160            loss_res_weak = jnp.mean(integralOmega ** 2)
1161            sigma = float(jnp.sqrt(loss_boundary/loss_res_weak).tolist())
1162            del gen
1163            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)
1164            tf = sigma*tf
1165        else:
1166            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)
1167            tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0])
1168
1169    #Define loss function
1170    @jax.jit
1171    def lf_each(params,x,k):
1172        if sigma > 0:
1173            #Term that refers to weak loss
1174            if resample:
1175                test_functions = gen(jax.random.split(jax.random.PRNGKey(k[0]),(bsize,))[:,0])
1176            else:
1177                test_functions = tf
1178        loss_sensor = loss_boundary = loss_initial = loss_res = loss_res_weak = 0
1179        if x['sensor'] is not None:
1180            #Term that refers to sensor data
1181            loss_sensor = jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor']))
1182        if x['boundary'] is not None:
1183            if neumann:
1184                #Neumann coditions
1185                loss_boundary = oper_neumann(lambda x: forward(x,params['net']),x['boundary'])
1186            else:
1187                #Term that refers to boundary data
1188                loss_boundary = MSE(forward(x['boundary'],params['net']),x['uboundary'])
1189        if x['initial'] is not None:
1190            #Term that refers to initial data
1191            loss_initial = MSE(forward(x['initial'],params['net']),x['uinitial'])
1192        if x['collocation'] is not None and sigma == 0:
1193            if inverse:
1194                output = pde(lambda x: forward(x,params['net']),x['collocation'],params['inverse'])
1195                loss_res = MSE(output,0)
1196            else:
1197                output = pde(lambda x: forward(x,params['net']),x['collocation'])
1198                loss_res = MSE(output,0)
1199        if sigma > 0:
1200            #Term that refers to weak loss
1201            if inverse:
1202                output_w = pde(lambda x: forward(x,params['net']),grid,params['inverse'])
1203                integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions)
1204                loss_res_weak = jnp.mean(integralOmega ** 2)
1205            else:
1206                output_w = pde(lambda x: forward(x,params['net']),grid)
1207                integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions)
1208                loss_res_weak = jnp.mean(integralOmega ** 2)
1209        return {'ls': loss_sensor,'lb': loss_boundary,'li': loss_initial,'lc': loss_res,'lc_weak': loss_res_weak}
1210
1211    @jax.jit
1212    def lf(params,x,k):
1213        l = lf_each(params,x,k)
1214        w = params['w']
1215        loss = jnp.mean((w['ws'] ** q)*l['ls']) + jnp.mean((w['wb'] ** q)*l['lb']) + jnp.mean((w['wi'] ** q)*l['li']) + jnp.mean((w['wc'] ** q)*l['lc']) + (w['wc_weak'] ** q)*l['lc_weak']
1216        if opt != 'LBFGS':
1217            return loss
1218        else:
1219            l2 = None
1220            if test_data is not None:
1221                l2 = L2error(forward(test_data['sensor'],params['net']),test_data['usensor'])
1222            return loss,{'loss': loss,'l2': l2}
1223
1224    #Initialize self-adaptive weights
1225    if w is None:
1226        w = {'ws': jnp.array(1.0),'wb': jnp.array(1.0),'wi': jnp.array(1.0),'wc': jnp.array(1.0),'wc_weak': jnp.array(1.0)}
1227    if q != 0:
1228        if data['sensor'] is not None:
1229            w['ws'] = w['ws'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+1),(data['sensor'].shape[0],1))
1230        if data['boundary'] is not None:
1231            w['wb'] = w['wb'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+2),(data['boundary'].shape[0],1))
1232        if data['initial'] is not None:
1233            w['wi'] = w['wi'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+3),(data['initial'].shape[0],1))
1234        if data['collocation'] is not None:
1235            w['wc'] = w['wc'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+4),(data['collocation'].shape[0],1))
1236
1237    #Store all parameters
1238    params = {'net': nnet['params'],'inverse': initial_par,'w': w}
1239
1240    #Save config file
1241    if save:
1242        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1243
1244    #Initialize Adam Optmizer
1245    if opt != 'LBFGS':
1246        print('--------- GRADIENT DESCENT OPTIMIZER ---------')
1247        if exp_decay:
1248            lr = optax.exponential_decay(lr,transition_steps,decay_rate)
1249        optimizer = optax.adam(lr,b1,b2,eps,eps_root)
1250        opt_state = optimizer.init(params)
1251
1252        #Define the gradient function
1253        grad_loss = jax.jit(jax.grad(lf,0))
1254
1255        #Define update function
1256        @jax.jit
1257        def update(opt_state,params,x,k):
1258            #Compute gradient
1259            grads = grad_loss(params,x,k)
1260            #Calculate parameters updates
1261            updates, opt_state = optimizer.update(grads, opt_state)
1262            #Update parameters
1263            if q != 0:
1264                updates = {**updates, 'w': jax.tree_util.tree_map(lambda x: -x, updates['w'])} #Change signs of weights
1265            params = optax.apply_updates(params, updates)
1266            #Return state of optmizer and updated parameters
1267            return opt_state,params
1268    else:
1269        print('--------- LBFGS OPTIMIZER ---------')
1270        @jax.jit
1271        def loss_LBFGS(params):
1272            return lf(params,data,key + 234)
1273        solver = LBFGS(fun = loss_LBFGS,has_aux = True,maxiter = epochs,tol = 1e-9,verbose = False,linesearch = 'zoom',history_size = 100)  # linesearch='zoom' by default
1274        state = solver.init_state(params)
1275
1276    ###Training###
1277    t0 = time.time()
1278    k = jax.random.split(jax.random.PRNGKey(key+234),(epochs,))
1279    sloss = []
1280    sL2 = []
1281    stime = []
1282    #Initialize alive_bar for tracing in terminal
1283    with alive_bar(epochs) as bar:
1284        #For each epoch
1285        for e in range(epochs):
1286            if opt != 'LBFGS':
1287                #Update optimizer state and parameters
1288                opt_state,params = update(opt_state,params,data,k[e,:])
1289                sloss.append(lf(params,data,k[e,:]))
1290                if test_data is not None:
1291                    sL2.append(L2error(forward(test_data['sensor'],params['net']),test_data['usensor']))
1292            else:
1293                params, state = solver.update(params, state)
1294                sL2.append(state.aux["l2"])
1295                sloss.append(state.aux["loss"])
1296            stime.append(time.time() - t0)
1297            #After epoch_print epochs
1298            if e % epoch_print == 0:
1299                #Compute elapsed time and current error
1300                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(sloss[-1],6))
1301                #If there is test data, compute current L2 error
1302                if test_data is not None:
1303                    #Compute L2 error
1304                    l = l + ' L2 error: ' + str(jnp.round(sL2[-1],6))
1305                if inverse:
1306                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
1307                #Print
1308                print(l)
1309            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
1310                #Save current parameters
1311                pickle.dump({'params': params,'width': width,'time': stime,'loss': sloss,'L2error': sL2},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1312            #Update alive_bar
1313            bar()
1314    #Define estimated function
1315    def u(xt):
1316        return forward(xt,params['net'])
1317
1318    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0,'loss_each': lf_each(params,data,[key + 100]),'loss': sloss,'L2error': sL2}
1319
1320
1321#Process result
1322def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1):
1323    """
1324    Process the results of a Physics-informed Neural Network
1325    ----------
1326
1327    Parameters
1328    ----------
1329    test_data : dict
1330
1331        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1332
1333    fit : function
1334
1335        The fitted function
1336
1337    train_data : dict
1338
1339        Training data generated by the jinnax.data.generate_PINNdata
1340
1341    plot : logical
1342
1343        Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
1344
1345    plot_test : logical
1346
1347        Whether to plot the test data. Default True
1348
1349    times : int
1350
1351        Number of points along the time interval to plot. Default 5
1352
1353    d2 : logical
1354
1355        Whether to plot 2D plot when the spatial dimension is one. Default True
1356
1357    save : logical
1358
1359        Whether to save the plots. Default False
1360
1361    show : logical
1362
1363        Whether to show the plots. Default True
1364
1365    file_name : str
1366
1367        File prefix to save the plots. Default 'result_pinn'
1368
1369    print_res : logical
1370
1371        Whether to print the L2 error. Default True
1372
1373    p : int
1374
1375        Output dimension. Default 1
1376
1377    Returns
1378    -------
1379    pandas data frame with L2 and MSE errors
1380    """
1381
1382    #Dimension
1383    d = test_data['xt'].shape[1] - 1
1384
1385    #Number of plots multiple of 5
1386    times = 5 * round(times/5.0)
1387
1388    #Data
1389    td = get_train_data(train_data)
1390    xt_train = td['x']
1391    u_train = td['y']
1392    upred_train = fit(xt_train)
1393    upred_test = fit(test_data['xt'])
1394
1395    #Results
1396    l2_error_test = L2error(upred_test,test_data['u']).tolist()
1397    MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist()
1398    l2_error_train = L2error(upred_train,u_train).tolist()
1399    MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist()
1400
1401    df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)),
1402        columns=['l2_error_test','MSE_test','l2_error_train','MSE_train'])
1403    if print_res:
1404        print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) )
1405
1406    #Plots
1407    if d == 1 and p ==1 and plot:
1408        plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name)
1409    elif p == 2 and plot:
1410        plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test)
1411
1412    return df
1413
1414#Plot results for d = 1
1415def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''):
1416    """
1417    Plot the prediction of a 1D PINN
1418    ----------
1419
1420    Parameters
1421    ----------
1422    times : int
1423
1424        Number of points along the time interval to plot. Default 5
1425
1426    xt : jax.numpy.array
1427
1428        Test data xt array
1429
1430    u : jax.numpy.array
1431
1432        Test data u(x,t) array
1433
1434    upred : jax.numpy.array
1435
1436        Predicted upred(x,t) array on test data
1437
1438    d2 : logical
1439
1440        Whether to plot 2D plot. Default True
1441
1442    save : logical
1443
1444        Whether to save the plots. Default False
1445
1446    show : logical
1447
1448        Whether to show the plots. Default True
1449
1450    file_name : str
1451
1452        File prefix to save the plots. Default 'result_pinn'
1453
1454    title_1d : str
1455
1456        Title of 1D plot
1457
1458    title_2d : str
1459
1460        Title of 2D plot
1461
1462    Returns
1463    -------
1464    None
1465    """
1466    #Initialize
1467    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
1468    tlo = jnp.min(xt[:,-1])
1469    tup = jnp.max(xt[:,-1])
1470    ylo = jnp.min(u)
1471    ylo = ylo - 0.1*jnp.abs(ylo)
1472    yup = jnp.max(u)
1473    yup = yup + 0.1*jnp.abs(yup)
1474    k = 0
1475    t_values = np.linspace(tlo,tup,times)
1476
1477    #Create
1478    for i in range(int(times/5)):
1479        for j in range(5):
1480            if k < len(t_values):
1481                t = t_values[k]
1482                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
1483                x_plot = xt[xt[:,-1] == t,:-1]
1484                y_plot = upred[xt[:,-1] == t,:]
1485                u_plot = u[xt[:,-1] == t,:]
1486                if int(times/5) > 1:
1487                    ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
1488                    ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
1489                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
1490                    ax[i,j].set_xlabel(' ')
1491                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1492                else:
1493                    ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
1494                    ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
1495                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
1496                    ax[j].set_xlabel(' ')
1497                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1498                k = k + 1
1499
1500    #Title
1501    fig.suptitle(title_1d)
1502    fig.tight_layout()
1503
1504    #Show and save
1505    fig = plt.gcf()
1506    if show:
1507        plt.show()
1508    if save:
1509        fig.savefig(file_name + '_slices.png')
1510    plt.close()
1511
1512    #2d plot
1513    if d2:
1514        #Initialize
1515        fig, ax = plt.subplots(1,2)
1516        l1 = jnp.unique(xt[:,-1]).shape[0]
1517        l2 = jnp.unique(xt[:,0]).shape[0]
1518
1519        #Create
1520        ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
1521        ax[0].set_title('Exact')
1522        ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
1523        ax[1].set_title('Predicted')
1524
1525        #Title
1526        fig.suptitle(title_2d)
1527        fig.tight_layout()
1528
1529        #Show and save
1530        fig = plt.gcf()
1531        if show:
1532            plt.show()
1533        if save:
1534            fig.savefig(file_name + '_2d.png')
1535        plt.close()
1536
1537#Plot results for d = 1
1538def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True):
1539    """
1540    Plot the prediction of a PINN with 2D output
1541    ----------
1542    Parameters
1543    ----------
1544    times : int
1545
1546        Number of points along the time interval to plot. Default 5
1547
1548    xt : jax.numpy.array
1549
1550        Test data xt array
1551
1552    u : jax.numpy.array
1553
1554        Test data u(x,t) array
1555
1556    upred : jax.numpy.array
1557
1558        Predicted upred(x,t) array on test data
1559
1560    save : logical
1561
1562        Whether to save the plots. Default False
1563
1564    show : logical
1565
1566        Whether to show the plots. Default True
1567
1568    file_name : str
1569
1570        File prefix to save the plots. Default 'result_pinn'
1571
1572    title : str
1573
1574        Title of plot
1575
1576    plot_test : logical
1577
1578        Whether to plot the test data. Default True
1579
1580    Returns
1581    -------
1582    None
1583    """
1584    #Initialize
1585    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
1586    tlo = jnp.min(xt[:,-1])
1587    tup = jnp.max(xt[:,-1])
1588    xlo = jnp.min(u[:,0])
1589    xlo = xlo - 0.1*jnp.abs(xlo)
1590    xup = jnp.max(u[:,0])
1591    xup = xup + 0.1*jnp.abs(xup)
1592    ylo = jnp.min(u[:,1])
1593    ylo = ylo - 0.1*jnp.abs(ylo)
1594    yup = jnp.max(u[:,1])
1595    yup = yup + 0.1*jnp.abs(yup)
1596    k = 0
1597    t_values = np.linspace(tlo,tup,times)
1598
1599    #Create
1600    for i in range(int(times/5)):
1601        for j in range(5):
1602            if k < len(t_values):
1603                t = t_values[k]
1604                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
1605                xpred_plot = upred[xt[:,-1] == t,0]
1606                ypred_plot = upred[xt[:,-1] == t,1]
1607                if plot_test:
1608                    x_plot = u[xt[:,-1] == t,0]
1609                    y_plot = u[xt[:,-1] == t,1]
1610                if int(times/5) > 1:
1611                    if plot_test:
1612                        ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
1613                    ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction')
1614                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
1615                    ax[i,j].set_xlabel(' ')
1616                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1617                else:
1618                    if plot_test:
1619                        ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
1620                    ax[j].plot(xpred_plot,ypred,'r-',linewidth=2,label='Prediction')
1621                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
1622                    ax[j].set_xlabel(' ')
1623                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1624                k = k + 1
1625
1626    #Title
1627    fig.suptitle(title)
1628    fig.tight_layout()
1629
1630    #Show and save
1631    fig = plt.gcf()
1632    if show:
1633        plt.show()
1634    if save:
1635        fig.savefig(file_name + '_slices.png')
1636    plt.close()
1637
1638#Get train data in one array
1639def get_train_data(train_data):
1640    """
1641    Process training sample
1642    ----------
1643
1644    Parameters
1645    ----------
1646    train_data : dict
1647
1648        A dictionay with train data generated by the jinnax.data.generate_PINNdata function
1649
1650    Returns
1651    -------
1652    dict with the processed training data
1653    """
1654    xdata = None
1655    ydata = None
1656    xydata = None
1657    if train_data['sensor'] is not None:
1658        sensor_sample = train_data['sensor'].shape[0]
1659        xdata = train_data['sensor']
1660        ydata = train_data['usensor']
1661        xydata = jnp.column_stack((train_data['sensor'],train_data['usensor']))
1662    else:
1663        sensor_sample = 0
1664    if train_data['boundary'] is not None:
1665        boundary_sample = train_data['boundary'].shape[0]
1666        if xdata is not None:
1667            xdata = jnp.vstack((xdata,train_data['boundary']))
1668            ydata = jnp.vstack((ydata,train_data['uboundary']))
1669            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary']))))
1670        else:
1671            xdata = train_data['boundary']
1672            ydata = train_data['uboundary']
1673            xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary']))
1674    else:
1675        boundary_sample = 0
1676    if train_data['initial'] is not None:
1677        initial_sample = train_data['initial'].shape[0]
1678        if xdata is not None:
1679            xdata = jnp.vstack((xdata,train_data['initial']))
1680            ydata = jnp.vstack((ydata,train_data['uinitial']))
1681            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial']))))
1682        else:
1683            xdata = train_data['initial']
1684            ydata = train_data['uinitial']
1685            xydata = jnp.column_stack((train_data['initial'],train_data['uinitial']))
1686    else:
1687        initial_sample = 0
1688    if train_data['collocation'] is not None:
1689        collocation_sample = train_data['collocation'].shape[0]
1690    else:
1691        collocation_sample = 0
1692
1693    return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
1694
1695#Process training
1696def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1):
1697    """
1698    Process the training of a Physics-informed Neural Network
1699    ----------
1700
1701    Parameters
1702    ----------
1703    test_data : dict
1704
1705        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1706
1707    file_name : str
1708
1709        Name of the files saved during training
1710
1711    at_each : int
1712
1713        Compute results for epochs multiple of at_each. Default 100
1714
1715    bolstering : logical
1716
1717        Whether to compute bolstering mean square error. Default True
1718
1719    mc_sample : int
1720
1721        Number of sample for Monte Carlo integration in bolstering. Default 10000
1722
1723    save : logical
1724
1725        Whether to save the training results. Default False
1726
1727    file_name_save : str
1728
1729        File prefix to save the plots and the L2 error. Default 'result_pinn'
1730
1731    key : int
1732
1733        Key for random samples in bolstering. Default 0
1734
1735    ec : float
1736
1737        Stopping criteria error for EM algorithm in bolstering. Default 1e-6
1738
1739    lamb : float
1740
1741        Hyperparameter of EM algorithm in bolstering. Default 1
1742
1743    Returns
1744    -------
1745    pandas data frame with training results
1746    """
1747    #Config
1748    config = pickle.load(open(file_name + '_config.pickle', 'rb'))
1749    epochs = config['epochs']
1750    train_data = config['train_data']
1751    forward = config['forward']
1752
1753    #Get train data
1754    td = get_train_data(train_data)
1755    xydata = td['xy']
1756    xdata = td['x']
1757    ydata = td['y']
1758    sensor_sample = td['sensor_sample']
1759    boundary_sample = td['boundary_sample']
1760    initial_sample = td['initial_sample']
1761    collocation_sample = td['collocation_sample']
1762
1763    #Generate keys
1764    if bolstering:
1765        keys = jax.random.split(jax.random.PRNGKey(key),epochs)
1766
1767    #Initialize loss
1768    train_mse = []
1769    test_mse = []
1770    train_L2 = []
1771    test_L2 = []
1772    bolstX = []
1773    bolstXY = []
1774    loss = []
1775    time = []
1776    ep = []
1777
1778    #Process training
1779    with alive_bar(epochs) as bar:
1780        for e in range(epochs):
1781            if (e % at_each == 0 and at_each != epochs) or e == epochs - 1:
1782                ep = ep + [e]
1783
1784                #Read parameters
1785                params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb'))
1786
1787                #Time
1788                time = time + [params['time']]
1789
1790                #Define learned function
1791                def psi(x):
1792                    return forward(x,params['params']['net'])
1793
1794                #Train MSE and L2
1795                if xdata is not None:
1796                    train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()]
1797                    train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()]
1798                else:
1799                    train_mse = train_mse + [None]
1800                    train_L2 = train_L2 + [None]
1801
1802                #Test MSE and L2
1803                test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()]
1804                test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()]
1805
1806                #Bolstering
1807                if bolstering:
1808                    bX = []
1809                    bXY = []
1810                    for method in ['chi','mm','mpe']:
1811                        kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
1812                        kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
1813                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
1814                        bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()]
1815                    for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]:
1816                        kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias)
1817                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
1818                    bolstX = bolstX + [bX]
1819                    bolstXY = bolstXY + [bXY]
1820                else:
1821                    bolstX = bolstX + [None]
1822                    bolstXY = bolstXY + [None]
1823
1824                #Loss
1825                loss = loss + [params['loss'].tolist()]
1826
1827                #Delete
1828                del params, psi
1829            #Update alive_bar
1830            bar()
1831
1832    #Bolstering results
1833    if bolstering:
1834        bolstX = jnp.array(bolstX)
1835        bolstXY = jnp.array(bolstXY)
1836
1837    #Create data frame
1838    if bolstering:
1839        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
1840            train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]),
1841            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4'])
1842    else:
1843        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
1844            train_mse,test_mse,train_L2,test_L2]),
1845            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2'])
1846    if save:
1847        df.to_csv(file_name_save + '.csv',index = False)
1848
1849    return df
1850
1851#Demo video for training1D PINN
1852def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10):
1853    """
1854    Demo video with the training of a 1D PINN
1855    ----------
1856
1857    Parameters
1858    ----------
1859    test_data : dict
1860
1861        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1862
1863    file_name : str
1864
1865        Name of the files saved during training
1866
1867    at_each : int
1868
1869        Compute results for epochs multiple of at_each. Default 100
1870
1871    times : int
1872
1873        Number of points along the time interval to plot. Default 5
1874
1875    d2 : logical
1876
1877        Whether to make video demo of 2D plot. Default True
1878
1879    file_name_save : str
1880
1881        File prefix to save the plots and videos. Default 'result_pinn_demo'
1882
1883    title : str
1884
1885        Title for plots
1886
1887    framerate : int
1888
1889        Framerate for video. Default 10
1890
1891    Returns
1892    -------
1893    None
1894    """
1895    #Config
1896    with open(file_name + '_config.pickle', 'rb') as file:
1897        config = pickle.load(file)
1898    epochs = config['epochs']
1899    train_data = config['train_data']
1900    forward = config['forward']
1901
1902    #Get train data
1903    td = get_train_data(train_data)
1904    xt = td['x']
1905    u = td['y']
1906
1907    #Create folder to save plots
1908    os.system('mkdir ' + file_name_save)
1909
1910    #Create images
1911    k = 1
1912    with alive_bar(epochs) as bar:
1913        for e in range(epochs):
1914            if e % at_each == 0 or e == epochs - 1:
1915                #Read parameters
1916                params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1917
1918                #Define learned function
1919                def psi(x):
1920                    return forward(x,params['params']['net'])
1921
1922                #Compute L2 train, L2 test and loss
1923                loss = params['loss']
1924                L2_train = L2error(psi(xt),u)
1925                L2_test = L2error(psi(test_data['xt']),test_data['u'])
1926                title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6))
1927
1928                #Save plot
1929                plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch)
1930                k = k + 1
1931
1932                #Delete
1933                del params, psi, loss, L2_train, L2_test, title_epoch
1934            #Update alive_bar
1935            bar()
1936    #Create demo video
1937    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4')
1938    if d2:
1939        os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')
1940
1941#Demo in time for 1D PINN
1942def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10):
1943    """
1944    Demo video with the time evolution of a 1D PINN
1945    ----------
1946
1947    Parameters
1948    ----------
1949    test_data : dict
1950
1951        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1952
1953    file_name : str
1954
1955        Name of the files saved during training
1956
1957    epochs : list
1958
1959        Which training epochs to plot
1960
1961    file_name_save : str
1962
1963        File prefix to save the plots and video. Default 'result_pinn_time_demo'
1964
1965    title : str
1966
1967        Title for plots
1968
1969    framerate : int
1970
1971        Framerate for video. Default 10
1972
1973    Returns
1974    -------
1975    None
1976    """
1977    #Config
1978    with open(file_name + '_config.pickle', 'rb') as file:
1979        config = pickle.load(file)
1980    train_data = config['train_data']
1981    forward = config['forward']
1982
1983    #Create folder to save plots
1984    os.system('mkdir ' + file_name_save)
1985
1986    #Plot parameters
1987    tdom = jnp.unique(test_data['xt'][:,-1])
1988    ylo = jnp.min(test_data['u'])
1989    ylo = ylo - 0.1*jnp.abs(ylo)
1990    yup = jnp.max(test_data['u'])
1991    yup = yup + 0.1*jnp.abs(yup)
1992
1993    #Open PINN for each epoch
1994    results = []
1995    upred = []
1996    for e in epochs:
1997        tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1998        results = results + [tmp]
1999        upred = upred + [forward(test_data['xt'],tmp['params']['net'])]
2000
2001    #Create images
2002    k = 1
2003    with alive_bar(len(tdom)) as bar:
2004        for t in tdom:
2005            #Test data
2006            xt_step = test_data['xt'][test_data['xt'][:,-1] == t]
2007            u_step = test_data['u'][test_data['xt'][:,-1] == t]
2008            #Initialize plot
2009            if len(epochs) > 1:
2010                fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2))
2011            else:
2012                fig, ax = plt.subplots(1,1,figsize = (10,5))
2013            #Create
2014            index = 0
2015            if int(len(epochs)/2) > 1:
2016                for i in range(int(len(epochs)/2)):
2017                    for j in range(min(2,len(epochs))):
2018                        upred_step = upred[index][test_data['xt'][:,-1] == t]
2019                        ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2020                        ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2021                        ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2022                        ax[i,j].set_xlabel(' ')
2023                        ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2024                        index = index + 1
2025            elif len(epochs) > 1:
2026                for j in range(2):
2027                    upred_step = upred[index][test_data['xt'][:,-1] == t]
2028                    ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2029                    ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2030                    ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2031                    ax[j].set_xlabel(' ')
2032                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2033                    index = index + 1
2034            else:
2035                upred_step = upred[index][test_data['xt'][:,-1] == t]
2036                ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2037                ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2038                ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2039                ax.set_xlabel(' ')
2040                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2041                index = index + 1
2042
2043
2044            #Title
2045            fig.suptitle(title + 't = ' + str(round(t,4)))
2046            fig.tight_layout()
2047
2048            #Show and save
2049            fig = plt.gcf()
2050            fig.savefig(file_name_save + '/' + str(k) + '.png')
2051            k = k + 1
2052            plt.close()
2053            bar()
2054
2055    #Create demo video
2056    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
@jax.jit
def MSE(pred, true):
29@jax.jit
30def MSE(pred,true):
31    """
32    Squared error
33    ----------
34    Parameters
35    ----------
36    pred : jax.numpy.array
37
38        A JAX numpy array with the predicted values
39
40    true : jax.numpy.array
41
42        A JAX numpy array with the true values
43
44    Returns
45    -------
46    squared error
47    """
48    return (true - pred) ** 2
Squared error
Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
Returns
  • squared error
@jax.jit
def MSE_SA(pred, true, w, q=2):
51@jax.jit
52def MSE_SA(pred,true,w,q = 2):
53    """
54    Self-adaptative squared error
55    ----------
56    Parameters
57    ----------
58    pred : jax.numpy.array
59
60        A JAX numpy array with the predicted values
61
62    true : jax.numpy.array
63
64        A JAX numpy array with the true values
65
66    weight : jax.numpy.array
67
68        A JAX numpy array with the weights
69
70    q : float
71
72        Power for the weights mask
73
74    Returns
75    -------
76    self-adaptative squared error with polynomial mask
77    """
78    return (w ** q) * ((true - pred) ** 2)

Self-adaptative squared error

Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
  • weight (jax.numpy.array): A JAX numpy array with the weights
  • q (float): Power for the weights mask
Returns
  • self-adaptative squared error with polynomial mask
@jax.jit
def L2error(pred, true):
 81@jax.jit
 82def L2error(pred,true):
 83    """
 84    L2-error in percentage (%)
 85    ----------
 86    Parameters
 87    ----------
 88    pred : jax.numpy.array
 89
 90        A JAX numpy array with the predicted values
 91
 92    true : jax.numpy.array
 93
 94        A JAX numpy array with the true values
 95
 96    Returns
 97    -------
 98    L2-error
 99    """
100    return 100*jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))

L2-error in percentage (%)

Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
Returns
  • L2-error
def idst1(x, axis=-1):
103def idst1(x,axis = -1):
104    """
105    Inverse Discrete Sine Transform of type I with orthonormal scaling
106    ----------
107    Parameters
108    ----------
109    x : jax.numpy.array
110
111        Array to apply the transformation
112
113    axis : int
114
115        Axis to apply the transformation over
116
117    Returns
118    -------
119    jax.numpy.array
120    """
121    return idst(x,type = 1,axis = axis,norm = 'ortho')
Inverse Discrete Sine Transform of type I with orthonormal scaling
Parameters
  • x (jax.numpy.array): Array to apply the transformation
  • axis (int): Axis to apply the transformation over
Returns
  • jax.numpy.array
def dstn(x, axes=None):
123def dstn(x,axes = None):
124    """
125    Discrete Sine Transform of type I with orthonormal scaling over many axes
126    ----------
127    Parameters
128    ----------
129    x : jax.numpy.array
130
131        Array to apply the transformation
132
133    axes : int
134
135        Axes to apply the transformation over
136
137    Returns
138    -------
139    jax.numpy.array
140    """
141    if axes is None:
142        axes = tuple(range(x.ndim))
143    y = x
144    for ax in axes:
145        y = dst(x,type = 1,axis = ax,norm = 'ortho')
146    return y
Discrete Sine Transform of type I with orthonormal scaling over many axes
Parameters
  • x (jax.numpy.array): Array to apply the transformation
  • axes (int): Axes to apply the transformation over
Returns
  • jax.numpy.array
def idstn(x, axes=None):
148def idstn(x,axes = None):
149    """
150    Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes
151    ----------
152    Parameters
153    ----------
154    x : jax.numpy.array
155
156        Array to apply the transformation
157
158    axes : int
159
160        Axes to apply the transformation over
161
162    Returns
163    -------
164    jax.numpy.array
165    """
166    if axes is None:
167        axes = tuple(range(x.ndim))
168    y = x
169    for ax in axes:
170        y = idst1(y,axis = ax)
171    return y
Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes
Parameters
  • x (jax.numpy.array): Array to apply the transformation
  • axes (int): Axes to apply the transformation over
Returns
  • jax.numpy.array
def dirichlet_eigs_nd(n, L):
173def dirichlet_eigs_nd(n,L):
174    """
175    Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle
176    ----------
177    Parameters
178    ----------
179    n : list
180
181        List with the number of points in the grid in each dimension
182
183    L : list
184
185        List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero
186
187    Returns
188    -------
189    jax.numpy.array
190    """
191    #Unidimensional eigenvalues
192    lam_axes = []
193    for ni, Li in zip(n,L):
194        h = Li / (ni + 1.0)
195        k = jnp.arange(1,ni + 1,dtype = np.float32)
196        ln = (2.0 / (h*h)) * (1.0 - jnp.cos(jnp.pi * k / (ni + 1.0)))
197        lam_axes.append(ln)
198    grids = jnp.meshgrid(*lam_axes, indexing='ij')
199    Lam = jnp.zeros_like(grids[0])
200    for g in grids:
201        Lam += g
202    return Lam

Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle

Parameters
  • n (list): List with the number of points in the grid in each dimension
  • L (list): List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero
Returns
  • jax.numpy.array
def generate_matern_sample(key, d=2, N=128, L=1.0, kappa=1, alpha=1, sigma=1, periodic=False):
206def generate_matern_sample(key,d = 2,N = 128,L = 1.0,kappa = 1,alpha = 1,sigma = 1,periodic = False):
207    """
208    Sample d-dimensional Matern process
209    ----------
210    Parameters
211    ----------
212    key : int
213
214        Seed for randomization
215
216    d : int
217
218        Dimension. Default 2
219
220    N : int
221
222        Size of grid in each dimension. Default 128
223
224    L : list of float
225
226        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
227
228    kappa,alpha,sigma : float
229
230        Parameters of the Matern process
231
232    periodic : logical
233
234        Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
235
236    Returns
237    -------
238    jax.numpy.array
239    """
240    if periodic:
241        #Shape and key
242        key = jax.random.PRNGKey(key)
243        shape = (N,) * d
244        if isinstance(L,float) or isinstance(L,int):
245            L = d*[L]
246        if isinstance(N,float) or isinstance(N,int):
247            N = d*[N]
248
249        #Setup Frequency Grid (2D)
250        freq = [jnp.fft.fftfreq(N[j],d = L[j]/N[j]) * 2 * jnp.pi for j in range(d)]
251        grids = jnp.meshgrid(*freq, indexing='ij')
252        sq_norm_xi = sum(g**2 for g in grids)
253
254        #Generate White Noise in Fourier Space
255        key_re, key_im = jax.random.split(key)
256        white_noise_f = (jax.random.normal(key_re, shape) +
257                         1j * jax.random.normal(key_im, shape))
258
259        #Apply the Whittle Filter
260        amplitude_filter = (kappa ** 2 + sq_norm_xi) ** (-alpha / 2.0)
261        field_f = white_noise_f * amplitude_filter
262
263        #Transform back to Physical Space
264        sample = jnp.real(jnp.fft.ifftn(field_f))
265        return sigma*sample
266    else: #NOT JAX
267        #Shape and key
268        rng = np.random.default_rng(seed = key)
269        if isinstance(L,float) or isinstance(L,int):
270            L = d*[L]
271        if isinstance(N,float) or isinstance(N,int):
272            N = d*[N]
273        shape = tuple(N)
274
275        #White noise in real space
276        W = rng.standard_normal(size = shape)
277
278        #To Dirichlet eigenbasis via separable DST-I (orthonormal)
279        W_hat = dstn(W)
280
281        #Discrete Dirichlet Laplacian eigenvalues
282        lam = dirichlet_eigs_nd(N, L)
283
284        #Spectral filter
285        filt = ((kappa + lam) ** (-alpha/2.0))
286        psi_hat = filt * W_hat
287
288        #Back to real space
289        psi = idstn(psi_hat)
290        return jnp.array(sigma*psi)

Sample d-dimensional Matern process

Parameters
  • key (int): Seed for randomization
  • d (int): Dimension. Default 2
  • N (int): Size of grid in each dimension. Default 128
  • L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
  • kappa,alpha,sigma (float): Parameters of the Matern process
  • periodic (logical): Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
Returns
  • jax.numpy.array
def generate_matern_sample_batch(d=2, N=512, L=1.0, kappa=10.0, alpha=1, sigma=10, periodic=False):
293def generate_matern_sample_batch(d = 2,N = 512,L = 1.0,kappa = 10.0,alpha = 1,sigma = 10,periodic = False):
294    """
295    Create function to sample d-dimensional Matern process
296    ----------
297    Parameters
298    ----------
299    d : int
300
301        Dimension. Default 2
302
303    N : int
304
305        Size of grid in each dimension. Default 128
306
307    L : list of float
308
309        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
310
311    kappa,alpha,sigma : float
312
313        Parameters of the Matern process
314
315    periodic : logical
316
317        Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
318
319    Returns
320    -------
321    function
322    """
323    if periodic:
324        return jax.vmap(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic))
325    else:
326        return lambda keys: jnp.array(np.apply_along_axis(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic),1,keys.reshape((keys.shape[0],1))))

Create function to sample d-dimensional Matern process

Parameters
  • d (int): Dimension. Default 2
  • N (int): Size of grid in each dimension. Default 128
  • L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
  • kappa,alpha,sigma (float): Parameters of the Matern process
  • periodic (logical): Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
Returns
  • function
def eigenf_laplace(L_vec, kmax_per_axis=None, bc='dirichlet', max_ef=None):
329def eigenf_laplace(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None):
330    """
331    Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace.
332    ----------
333    Parameters
334    ----------
335    L_vec : list of float
336
337        The domain of the function in each coordinate is [0,L[1]]
338
339    kmax_per_axis : list
340
341        List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions
342
343    bc : str
344
345        Boundary condition. 'dirichlet' or 'neumann'
346
347    max_ef : int
348
349        Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
350
351    Returns
352    -------
353    function to compute eigenfunctions,eigenvalues of the eigenfunctions considered
354    """
355    #Parameters
356    L_vec = jnp.asarray(L_vec,dtype = jnp.float32)
357    d = L_vec.shape[0]
358    bc = bc.lower()
359
360    #Maximum number of functions
361    if max_ef is None:
362        if d == 1:
363            max_ef = jnp.max(jnp.array(kmax_per_axis))
364        else:
365            max_ef = jnp.max(d * jnp.array(kmax_per_axis))
366
367    #Build the candidate multi-indices per axis
368    kmax_per_axis = list(map(int, kmax_per_axis))
369    if bc.startswith("d"):
370        axis_ranges = [range(1, km + 1) for km in kmax_per_axis]
371    elif bc.startswith("n"):
372        axis_ranges = [range(0, km + 1) for km in kmax_per_axis]
373
374    #Get all multi-indices
375    Ks_list = list(product(*axis_ranges))
376    Ks = jnp.array(Ks_list,dtype = jnp.float32)
377
378    #Eigenvalues of the continuous Laplacian
379    pi_over_L = jnp.pi / L_vec
380    lambdas_all = jnp.sum((Ks * pi_over_L) ** 2, axis=1)
381
382    #Sort by eigenvalue
383    order = jnp.argsort(lambdas_all)
384    Ks = Ks[order]
385    lambdas_all = lambdas_all[order]
386
387    #Keep first max_ef
388    Ks = Ks[:max_ef]
389    lambdas = lambdas_all[:max_ef]
390    m = Ks.shape[0]
391
392    #Precompute per-feature normalization factor (closed form)
393    def per_axis_norm_factor(k_i, L_i, is_dirichlet):
394        if is_dirichlet:
395            return jnp.sqrt(2.0 / L_i)
396        else:
397            return jnp.where(k_i == 0, jnp.sqrt(1.0 / L_i), jnp.sqrt(2.0 / L_i))
398    if bc.startswith("d"):
399        nf = jnp.prod(jnp.sqrt(2.0 / L_vec)[None, :],axis = 1)
400        norm_factors = jnp.ones((m,),dtype = jnp.float32) * nf
401    else:
402        # per-mode product across axes
403        def nf_row(k_row):
404            return jnp.prod(per_axis_norm_factor(k_row, L_vec, False))
405        norm_factors = jax.vmap(nf_row)(Ks)
406
407    #Build the callable function
408    Ks_int = Ks  # float array, but only integer values
409    L_vec_f = L_vec
410    @jax.jit
411    def phi(x):
412        x = jnp.asarray(x,dtype = jnp.float32)
413        #Initialize with ones
414        vals = jnp.ones(x.shape[:-1] + (m,), dtype=jnp.float32)
415        #Compute eigenfunction
416        for i in range(d):
417            ang = (jnp.pi / L_vec_f[i]) * x[..., i][..., None] * Ks_int[:, i]
418            if bc.startswith("d"):
419                comp = jnp.sin(ang)
420            else:
421                comp = jnp.cos(ang)
422            vals = vals * comp
423        #Apply L2-normalizing constants
424        vals = vals * norm_factors[None, ...] if vals.ndim > 1 else vals * norm_factors
425        return vals
426    return phi, lambdas

Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace.

Parameters
  • L_vec (list of float): The domain of the function in each coordinate is [0,L[1]]
  • kmax_per_axis (list): List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions
  • bc (str): Boundary condition. 'dirichlet' or 'neumann'
  • max_ef (int): Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
Returns
  • function to compute eigenfunctions,eigenvalues of the eigenfunctions considered
def multiple_daff(L_vec, kmax_per_axis=None, bc='dirichlet', max_ef=None):
429def multiple_daff(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None):
430    """
431    Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain.
432    ----------
433    Parameters
434    ----------
435    L_vec : list of lists of float
436
437        List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]]
438
439    kmax_per_axis : list
440
441        List with the maximum number of eigenfunctions per dimension.
442
443    bc : str
444
445        Boundary condition. 'dirichlet' or 'neumann'
446
447    max_ef : int
448
449        Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
450
451    Returns
452    -------
453    function to compute daff,eigenvalues of the eigenfunctions considered
454    """
455    psi = []
456    lamb = []
457    for L in L_vec:
458        tmp,l = eigenf_laplace(L,kmax_per_axis,bc,max_ef) #Get function
459        lamb.append(l)
460        psi.append(tmp)
461        del tmp
462    #Create function to compute features
463    @jax.jit
464    def mff(x):
465        y = []
466        for i in range(len(psi)):
467            y.append(psi[i](x))
468        if len(psi) == 1:
469            return y[0]
470        else:
471            return jnp.concatenate(y,1)
472    return mff,jnp.concatenate(lamb)

Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain.

Parameters
  • L_vec (list of lists of float): List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]]
  • kmax_per_axis (list): List with the maximum number of eigenfunctions per dimension.
  • bc (str): Boundary condition. 'dirichlet' or 'neumann'
  • max_ef (int): Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
Returns
  • function to compute daff,eigenvalues of the eigenfunctions considered
@partial(jax.jit, static_argnums=(2,))
def multiple_cheb_fast(x, L_vec, n: int):
500@partial(jax.jit,static_argnums=(2,))  # n is static here; compile once per n
501def multiple_cheb_fast(x, L_vec, n: int):
502    """
503    x: (N, d)
504    L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension
505    n: number of k terms (static)
506    returns: (N, L*n)
507    """
508    N, d = x.shape
509    L = L_vec.shape[0]
510
511    a = 0.0
512    b = L_vec                       # (L, d)
513    # Map x to t in [-1, 1] for each l, j: shape (L, N, d)
514    t = (2.0 * x[None, :, :] - (a + b)[:, None, :]) / (b - a)[:, None, :]
515
516    # Chebyshev T_0..T_{n+2} for all (L, N, d): shape (n+3, L, N, d)
517    T = _chebyshev_T_all(t, n + 2)
518
519    # phi_k = T_{k+2} - T_k, k = 0..n-1  => shape (n, L, N, d)
520    ks = jnp.arange(n)
521    phi = T[ks + 2, ...] - T[ks, ...]
522
523    # Multiply across dimensions (over the last axis = d) => (n, L, N)
524    z = jnp.prod(phi, axis=-1)
525
526    # Reorder to (N, L, n) then flatten to (N, L*n)
527    z = jnp.transpose(z, (2, 1, 0)).reshape(N, L * n)
528    return z

x: (N, d) L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension n: number of k terms (static) returns: (N, L*n)

def multiple_cheb(L_vec, n: int):
530def multiple_cheb(L_vec, n: int):
531    """
532    Factory that closes over static n and L_vec (so shapes are constant).
533    """
534    L_vec = jnp.asarray(L_vec)
535    @jax.jit  # optional; multiple_cheb_fast is already jitted
536    def mcheb(x):
537        x = jnp.asarray(x)
538        return multiple_cheb_fast(x, L_vec, n)
539    return mcheb

Factory that closes over static n and L_vec (so shapes are constant).

def fconNN( width, activation=<PjitFunction of <function tanh>>, key=0, mlp=False, ftype=None, fargs=None, static=None):
543def fconNN(width,activation = jax.nn.tanh,key = 0,mlp = False,ftype = None,fargs = None,static = None):
544    """
545    Initialize fully connected neural network
546    ----------
547    Parameters
548    ----------
549    width : list
550
551        List with the layers width
552
553    activation : jax.nn activation
554
555        The activation function. Default jax.nn.tanh
556
557    key : int
558
559        Seed for parameters initialization. Default 0
560
561    mlp : logical
562
563        Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension.
564
565    ftype : str
566
567        Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
568
569    fargs : list
570
571        Arguments for deature transformation:
572
573        For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
574
575        For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
576
577    static : function
578
579        A static function to sum to the neural network output.
580
581    Returns
582    -------
583    dict with initial parameters and the function for the forward pass
584    """
585    #Initialize parameters with Glorot initialization
586    initializer = jax.nn.initializers.glorot_normal()
587    params = list()
588    if static is None:
589        static = lambda x: 0.0
590
591    #Feature mapping
592    if ftype == 'ff': #Fourrier features
593        for s in range(fargs[0]):
594            sd = fargs[1] ** ((s + 1)/fargs[0])
595            if s == 0:
596                Bff = sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2)))
597            else:
598                Bff = jnp.append(Bff,sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))),1)
599        @jax.jit
600        def phi(x):
601            x = x @ Bff
602            return jnp.concatenate([jnp.sin(2 * jnp.pi * x),jnp.cos(2 * jnp.pi * x)],axis = -1)
603        width = width[1:]
604        width[0] = 2*Bff.shape[1]
605    elif ftype == 'daff' or ftype == 'daff_bias':
606        if not isinstance(fargs, dict):
607            fargs = {'L': fargs,'bc': "dirichlet"}
608        phi,lamb = multiple_daff(list(fargs.values())[0],kmax_per_axis = [width[1]] * width[0],bc = list(fargs.values())[1])
609        width = width[1:]
610        width[0] = lamb.shape[0]
611    elif ftype == 'cheb' or ftype == 'cheb_bias':
612        phi = multiple_cheb(fargs,n = width[1])
613        width = width[1:]
614        width[0] = len(fargs)*width[0]
615    else:
616        @jax.jit
617        def phi(x):
618            return x
619
620    #Initialize parameters
621    if mlp:
622        k = jax.random.split(jax.random.PRNGKey(key),4)
623        WU = initializer(k[0],(width[0],width[1]),jnp.float32)
624        BU = initializer(k[1],(1,width[1]),jnp.float32)
625        WV = initializer(k[2],(width[0],width[1]),jnp.float32)
626        BV = initializer(k[3],(1,width[1]),jnp.float32)
627        params.append({'WU':WU,'BU':BU,'WV':WV,'BV':BV})
628    key = jax.random.split(jax.random.PRNGKey(key + 1),len(width)-1) #Seed for initialization
629    for key,lin,lout in zip(key,width[:-1],width[1:]):
630        W = initializer(key,(lin,lout),jnp.float32)
631        B = initializer(key,(1,lout),jnp.float32)
632        params.append({'W':W,'B':B})
633
634    #Define function for forward pass
635    if mlp:
636        if ftype != 'daff' and ftype != 'cheb':
637            @jax.jit
638            def forward(x,params):
639                encode,*hidden,output = params
640                sx = static(x)
641                x = phi(x)
642                U = activation(x @ encode['WU'] + encode['BU'])
643                V = activation(x @ encode['WV'] + encode['BV'])
644                for layer in hidden:
645                    x = activation(x @ layer['W'] + layer['B'])
646                    x = x * U + (1 - x) * V
647                return x @ output['W'] + output['B'] + sx
648        else:
649            @jax.jit
650            def forward(x,params):
651                encode,*hidden,output = params
652                sx = static(x)
653                x = phi(x)
654                U = activation(x @ encode['WU'])
655                V = activation(x @ encode['WV'])
656                for layer in hidden:
657                    x = activation(x @ layer['W'])
658                    x = x * U + (1 - x) * V
659                return x @ output['W'] + sx
660    else:
661        if ftype != 'daff' and ftype != 'cheb':
662            @jax.jit
663            def forward(x,params):
664                *hidden,output = params
665                sx = static(x)
666                x = phi(x)
667                for layer in hidden:
668                    x = activation(x @ layer['W'] + layer['B'])
669                return x @ output['W'] + output['B'] + sx
670        else:
671            @jax.jit
672            def forward(x,params):
673                *hidden,output = params
674                sx = static(x)
675                x = phi(x)
676                for layer in hidden:
677                    x = activation(x @ layer['W'])
678                return x @ output['W'] + sx
679
680    #Return initial parameters and forward function
681    return {'params': params,'forward': forward}
Initialize fully connected neural network
Parameters
  • width (list): List with the layers width
  • activation (jax.nn activation): The activation function. Default jax.nn.tanh
  • key (int): Seed for parameters initialization. Default 0
  • mlp (logical): Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension.
  • ftype (str): Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
  • fargs (list): Arguments for deature transformation:

    For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.

    For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.

  • static (function): A static function to sum to the neural network output.
Returns
  • dict with initial parameters and the function for the forward pass
def get_activation(act):
684def get_activation(act):
685    """
686    Return activation function from string
687    ----------
688    Parameters
689    ----------
690    act : str
691
692        Name of the activation function. Default 'tanh'
693
694    Returns
695    -------
696    jax.nn activation function
697    """
698    if act == 'tanh':
699        return jax.nn.tanh
700    elif act == 'relu':
701        return jax.nn.relu
702    elif act == 'relu6':
703        return jax.nn.relu6
704    elif act == 'sigmoid':
705        return jax.nn.sigmoid
706    elif act == 'softplus':
707        return jax.nn.softplus
708    elif act == 'sparse_plus':
709        return jx.nn.sparse_plus
710    elif act == 'soft_sign':
711        return jax.nn.soft_sign
712    elif act == 'silu':
713        return jax.nn.silu
714    elif act == 'swish':
715        return jax.nn.swish
716    elif act == 'log_sigmoid':
717        return jax.nn.log_sigmoid
718    elif act == 'leaky_relu':
719        return jax.nn.leaky_relu
720    elif act == 'hard_sigmoid':
721        return jax.nn.hard_sigmoid
722    elif act == 'hard_silu':
723        return jax.nn.hard_silu
724    elif act == 'hard_swish':
725        return jax.nn.hard_swish
726    elif act == 'hard_tanh':
727        return jax.nn.hard_tanh
728    elif act == 'elu':
729        return jax.nn.elu
730    elif act == 'celu':
731        return jax.nn.celu
732    elif act == 'selu':
733        return jax.nn.selu
734    elif act == 'gelu':
735        return jax.nn.gelu
736    elif act == 'glu':
737        return jax.nn.glu
738    elif act == 'squareplus':
739        return  jax.nn.squareplus
740    elif act == 'mish':
741        return jax.nn.mish
Return activation function from string
Parameters
  • act (str): Name of the activation function. Default 'tanh'
Returns
  • jax.nn activation function
def train_PINN( data, width, pde, test_data=None, epochs=100, at_each=10, activation='tanh', neumann=False, oper_neumann=False, sa=False, c={'ws': 1, 'wr': 1, 'w0': 100, 'wb': 1}, inverse=False, initial_par=None, lr=0.001, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, key=0, epoch_print=100, save=False, file_name='result_pinn', exp_decay=False, transition_steps=1000, decay_rate=0.9, mlp=False):
744def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,mlp = False):
745    """
746    Train a Physics-informed Neural Network
747    ----------
748    Parameters
749    ----------
750    data : dict
751
752        Data generated by the jinnax.data.generate_PINNdata function
753
754    width : list
755
756        A list with the width of each layer
757
758    pde : function
759
760        The partial differential operator. Its arguments are u, x and t
761
762    test_data : dict, None
763
764        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
765
766    epochs : int
767
768        Number of training epochs. Default 100
769
770    at_each : int
771
772        Save results for epochs multiple of at_each. Default 10
773
774    activation : str
775
776        The name of the activation function of the neural network. Default 'tanh'
777
778    neumann : logical
779
780        Whether to consider Neumann boundary conditions
781
782    oper_neumann : function
783
784        Penalization of Neumann boundary conditions
785
786    sa : logical
787
788        Whether to consider self-adaptative PINN
789
790    c : dict
791
792        Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
793
794    inverse : logical
795
796        Whether to estimate parameters of the PDE
797
798    initial_par : jax.numpy.array
799
800        Initial value of the parameters of the PDE in an inverse problem
801
802    lr,b1,b2,eps,eps_root: float
803
804        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
805
806    key : int
807
808        Seed for parameters initialization. Default 0
809
810    epoch_print : int
811
812        Number of epochs to calculate and print test errors. Default 100
813
814    save : logical
815
816        Whether to save the current parameters. Default False
817
818    file_name : str
819
820        File prefix to save the current parameters. Default 'result_pinn'
821
822    exp_decay : logical
823
824        Whether to consider exponential decay of learning rate. Default False
825
826    transition_steps : int
827
828        Number of steps for exponential decay. Default 1000
829
830    decay_rate : float
831
832        Rate of exponential decay. Default 0.9
833
834    mlp : logical
835
836        Whether to consider modifed multi-layer perceptron
837
838    Returns
839    -------
840    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
841    """
842
843    #Initialize architecture
844    nnet = fconNN(width,get_activation(activation),key,mlp)
845    forward = nnet['forward']
846
847    #Initialize self adaptative weights
848    par_sa = {}
849    if sa:
850        #Initialize wheights close to zero
851        ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000)
852        if data['sensor'] is not None:
853            par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))})
854        if data['initial'] is not None:
855            par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))})
856        if data['collocation'] is not None:
857            par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))})
858        if data['boundary'] is not None:
859            par_sa.update({'wb': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))})
860
861    #Store all parameters
862    params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa}
863
864    #Save config file
865    if save:
866        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
867
868    #Define loss function
869    if sa:
870        #Define loss function
871        @jax.jit
872        def lf(params,x):
873            loss = 0
874            if x['sensor'] is not None:
875                #Term that refers to sensor data
876                loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws']))
877            if x['boundary'] is not None:
878                if neumann:
879                    #Neumann coditions
880                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
881                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
882                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb']))
883                else:
884                    #Term that refers to boundary data
885                    loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb']))
886            if x['initial'] is not None:
887                #Term that refers to initial data
888                loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0']))
889            if x['collocation'] is not None:
890                #Term that refers to collocation points
891                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
892                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
893                if inverse:
894                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr']))
895                else:
896                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr']))
897            return loss
898    else:
899        @jax.jit
900        def lf(params,x):
901            loss = 0
902            if x['sensor'] is not None:
903                #Term that refers to sensor data
904                loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor']))
905            if x['boundary'] is not None:
906                if neumann:
907                    #Neumann coditions
908                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
909                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
910                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb))
911                else:
912                    #Term that refers to boundary data
913                    loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary']))
914            if x['initial'] is not None:
915                #Term that refers to initial data
916                loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial']))
917            if x['collocation'] is not None:
918                #Term that refers to collocation points
919                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
920                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
921                if inverse:
922                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0))
923                else:
924                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0))
925            return loss
926
927    #Initialize Adam Optmizer
928    if exp_decay:
929        lr = optax.exponential_decay(lr,transition_steps,decay_rate)
930    optimizer = optax.adam(lr,b1,b2,eps,eps_root)
931    opt_state = optimizer.init(params)
932
933    #Define the gradient function
934    grad_loss = jax.jit(jax.grad(lf,0))
935
936    #Define update function
937    @jax.jit
938    def update(opt_state,params,x):
939        #Compute gradient
940        grads = grad_loss(params,x)
941        #Invert gradient of self-adaptative wheights
942        if sa:
943            for w in grads['sa']:
944                grads['sa'][w] = - grads['sa'][w]
945        #Calculate parameters updates
946        updates, opt_state = optimizer.update(grads, opt_state)
947        #Update parameters
948        params = optax.apply_updates(params, updates)
949        #Return state of optmizer and updated parameters
950        return opt_state,params
951
952    ###Training###
953    t0 = time.time()
954    #Initialize alive_bar for tracing in terminal
955    with alive_bar(epochs) as bar:
956        #For each epoch
957        for e in range(epochs):
958            #Update optimizer state and parameters
959            opt_state,params = update(opt_state,params,data)
960            #After epoch_print epochs
961            if e % epoch_print == 0:
962                #Compute elapsed time and current error
963                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6))
964                #If there is test data, compute current L2 error
965                if test_data is not None:
966                    #Compute L2 error
967                    l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist()
968                    l = l + ' L2 error: ' + str(jnp.round(l2_test,3))
969                if inverse:
970                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
971                #Print
972                print(l)
973            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
974                #Save current parameters
975                pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
976            #Update alive_bar
977            bar()
978    #Define estimated function
979    def u(xt):
980        return forward(xt,params['net'])
981
982    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}

Train a Physics-informed Neural Network

Parameters
  • data (dict): Data generated by the jinnax.data.generate_PINNdata function
  • width (list): A list with the width of each layer
  • pde (function): The partial differential operator. Its arguments are u, x and t
  • test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
  • epochs (int): Number of training epochs. Default 100
  • at_each (int): Save results for epochs multiple of at_each. Default 10
  • activation (str): The name of the activation function of the neural network. Default 'tanh'
  • neumann (logical): Whether to consider Neumann boundary conditions
  • oper_neumann (function): Penalization of Neumann boundary conditions
  • sa (logical): Whether to consider self-adaptative PINN
  • c (dict): Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
  • inverse (logical): Whether to estimate parameters of the PDE
  • initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
  • lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
  • key (int): Seed for parameters initialization. Default 0
  • epoch_print (int): Number of epochs to calculate and print test errors. Default 100
  • save (logical): Whether to save the current parameters. Default False
  • file_name (str): File prefix to save the current parameters. Default 'result_pinn'
  • exp_decay (logical): Whether to consider exponential decay of learning rate. Default False
  • transition_steps (int): Number of steps for exponential decay. Default 1000
  • decay_rate (float): Rate of exponential decay. Default 0.9
  • mlp (logical): Whether to consider modifed multi-layer perceptron
Returns
  • dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
def train_Matern_PINN( data, width, pde, test_data=None, params=None, d=2, N=128, L=1, alpha=1, kappa=1, sigma=100, bsize=1024, resample=False, epochs=100, at_each=10, activation='tanh', neumann=False, oper_neumann=None, inverse=False, initial_par=None, lr=0.001, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, key=0, epoch_print=1, save=False, file_name='result_pinn', exp_decay=True, transition_steps=100, decay_rate=0.9, mlp=True, ftype=None, fargs=None, q=4, w=None, periodic=False, static=None, opt='LBFGS'):
 985def train_Matern_PINN(data,width,pde,test_data = None,params = None,d = 2,N = 128,L = 1,alpha = 1,kappa = 1,sigma = 100,bsize = 1024,resample = False,epochs = 100,at_each = 10,activation = 'tanh',
 986    neumann = False,oper_neumann = None,inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 1,save = False,file_name = 'result_pinn',
 987    exp_decay = True,transition_steps = 100,decay_rate = 0.9,mlp = True,ftype = None,fargs = None,q = 4,w = None,periodic = False,static = None,opt = 'LBFGS'):
 988    """
 989    Train a Physics-informed Neural Network
 990    ----------
 991    Parameters
 992    ----------
 993    data : dict
 994
 995        Data generated by the jinnax.data.generate_PINNdata function
 996
 997    width : list
 998
 999        A list with the width of each layer
1000
1001    pde : function
1002
1003        The partial differential operator. Its arguments are u, x and t
1004
1005    test_data : dict, None
1006
1007        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
1008
1009    params : list
1010
1011        Initial parameters for the neural network. Default None to initialize randomly
1012
1013    d : int
1014
1015        Dimension of the problem including the time variable if present. Default 2
1016
1017    N : int
1018
1019        Size of grid in each dimension. Default 128
1020
1021    L : list of float
1022
1023        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
1024
1025    kappa,alpha,sigma : float
1026
1027        Parameters of the Matern process
1028
1029    bsize : int
1030
1031        Batch size for weak norm computation. Default 1024
1032
1033    resample : logical
1034
1035        Whether to resample the test functions at each epoch
1036
1037    epochs : int
1038
1039        Number of training epochs. Default 100
1040
1041    at_each : int
1042
1043        Save results for epochs multiple of at_each. Default 10
1044
1045    activation : str
1046
1047        The name of the activation function of the neural network. Default 'tanh'
1048
1049    neumann : logical
1050
1051        Whether to consider Neumann boundary conditions
1052
1053    oper_neumann : function
1054
1055        Penalization of Neumann boundary conditions
1056
1057    inverse : logical
1058
1059        Whether to estimate parameters of the PDE
1060
1061    initial_par : jax.numpy.array
1062
1063        Initial value of the parameters of the PDE in an inverse problem
1064
1065    lr,b1,b2,eps,eps_root: float
1066
1067        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
1068
1069    key : int
1070
1071        Seed for parameters initialization. Default 0
1072
1073    epoch_print : int
1074
1075        Number of epochs to calculate and print test errors. Default 1
1076
1077    save : logical
1078
1079        Whether to save the current parameters. Default False
1080
1081    file_name : str
1082
1083        File prefix to save the current parameters. Default 'result_pinn'
1084
1085    exp_decay : logical
1086
1087        Whether to consider exponential decay of learning rate. Default True
1088
1089    transition_steps : int
1090
1091        Number of steps for exponential decay. Default 100
1092
1093    decay_rate : float
1094
1095        Rate of exponential decay. Default 0.9
1096
1097    mlp : logical
1098
1099        Whether to consider modifed multilayer perceptron
1100
1101    ftype : str
1102
1103        Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
1104
1105    fargs : list
1106
1107        Arguments for deature transformation:
1108
1109        For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
1110
1111        For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
1112
1113    q : int
1114
1115        Power of weights mask. Default 4
1116
1117    w : dict
1118
1119        Initila weights for self-adaptive scheme.
1120
1121    periodic : logical
1122
1123        Whether to consider periodic test functions. Default False.
1124
1125    static : function
1126
1127        A static function to sum to the neural network output.
1128
1129    opt : str
1130
1131        Optimizer. Default LBFGS.
1132
1133    Returns
1134    -------
1135    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch
1136    """
1137    #Initialize architecture
1138    nnet = fconNN(width,get_activation(activation),key,mlp,ftype,fargs,static)
1139    forward = nnet['forward']
1140    if params is not None:
1141        nnet['params'] = params
1142
1143    #Generate from Matern process
1144    if sigma > 0:
1145        if isinstance(L,float) or isinstance(L,int):
1146            L = d*[L]
1147        #Grid for weak norm
1148        grid = [jnp.linspace(0,L[i],N) for i in range(d)]
1149        grid = jnp.meshgrid(*grid, indexing='ij')
1150        grid = jnp.stack(grid, axis=-1).reshape((-1, d))
1151        #Set sigma
1152        if data['boundary'] is not None:
1153            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma)
1154            tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0])
1155            if neumann:
1156                loss_boundary = oper_neumann(lambda x: forward(x,params['net']),data['boundary'])
1157            else:
1158                loss_boundary = jnp.mean(MSE(forward(data['boundary'],nnet['params']),data['uboundary']))
1159            output_w = pde(lambda x: forward(x,nnet['params']),grid)
1160            integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(tf)
1161            loss_res_weak = jnp.mean(integralOmega ** 2)
1162            sigma = float(jnp.sqrt(loss_boundary/loss_res_weak).tolist())
1163            del gen
1164            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)
1165            tf = sigma*tf
1166        else:
1167            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)
1168            tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0])
1169
1170    #Define loss function
1171    @jax.jit
1172    def lf_each(params,x,k):
1173        if sigma > 0:
1174            #Term that refers to weak loss
1175            if resample:
1176                test_functions = gen(jax.random.split(jax.random.PRNGKey(k[0]),(bsize,))[:,0])
1177            else:
1178                test_functions = tf
1179        loss_sensor = loss_boundary = loss_initial = loss_res = loss_res_weak = 0
1180        if x['sensor'] is not None:
1181            #Term that refers to sensor data
1182            loss_sensor = jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor']))
1183        if x['boundary'] is not None:
1184            if neumann:
1185                #Neumann coditions
1186                loss_boundary = oper_neumann(lambda x: forward(x,params['net']),x['boundary'])
1187            else:
1188                #Term that refers to boundary data
1189                loss_boundary = MSE(forward(x['boundary'],params['net']),x['uboundary'])
1190        if x['initial'] is not None:
1191            #Term that refers to initial data
1192            loss_initial = MSE(forward(x['initial'],params['net']),x['uinitial'])
1193        if x['collocation'] is not None and sigma == 0:
1194            if inverse:
1195                output = pde(lambda x: forward(x,params['net']),x['collocation'],params['inverse'])
1196                loss_res = MSE(output,0)
1197            else:
1198                output = pde(lambda x: forward(x,params['net']),x['collocation'])
1199                loss_res = MSE(output,0)
1200        if sigma > 0:
1201            #Term that refers to weak loss
1202            if inverse:
1203                output_w = pde(lambda x: forward(x,params['net']),grid,params['inverse'])
1204                integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions)
1205                loss_res_weak = jnp.mean(integralOmega ** 2)
1206            else:
1207                output_w = pde(lambda x: forward(x,params['net']),grid)
1208                integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions)
1209                loss_res_weak = jnp.mean(integralOmega ** 2)
1210        return {'ls': loss_sensor,'lb': loss_boundary,'li': loss_initial,'lc': loss_res,'lc_weak': loss_res_weak}
1211
1212    @jax.jit
1213    def lf(params,x,k):
1214        l = lf_each(params,x,k)
1215        w = params['w']
1216        loss = jnp.mean((w['ws'] ** q)*l['ls']) + jnp.mean((w['wb'] ** q)*l['lb']) + jnp.mean((w['wi'] ** q)*l['li']) + jnp.mean((w['wc'] ** q)*l['lc']) + (w['wc_weak'] ** q)*l['lc_weak']
1217        if opt != 'LBFGS':
1218            return loss
1219        else:
1220            l2 = None
1221            if test_data is not None:
1222                l2 = L2error(forward(test_data['sensor'],params['net']),test_data['usensor'])
1223            return loss,{'loss': loss,'l2': l2}
1224
1225    #Initialize self-adaptive weights
1226    if w is None:
1227        w = {'ws': jnp.array(1.0),'wb': jnp.array(1.0),'wi': jnp.array(1.0),'wc': jnp.array(1.0),'wc_weak': jnp.array(1.0)}
1228    if q != 0:
1229        if data['sensor'] is not None:
1230            w['ws'] = w['ws'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+1),(data['sensor'].shape[0],1))
1231        if data['boundary'] is not None:
1232            w['wb'] = w['wb'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+2),(data['boundary'].shape[0],1))
1233        if data['initial'] is not None:
1234            w['wi'] = w['wi'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+3),(data['initial'].shape[0],1))
1235        if data['collocation'] is not None:
1236            w['wc'] = w['wc'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+4),(data['collocation'].shape[0],1))
1237
1238    #Store all parameters
1239    params = {'net': nnet['params'],'inverse': initial_par,'w': w}
1240
1241    #Save config file
1242    if save:
1243        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1244
1245    #Initialize Adam Optmizer
1246    if opt != 'LBFGS':
1247        print('--------- GRADIENT DESCENT OPTIMIZER ---------')
1248        if exp_decay:
1249            lr = optax.exponential_decay(lr,transition_steps,decay_rate)
1250        optimizer = optax.adam(lr,b1,b2,eps,eps_root)
1251        opt_state = optimizer.init(params)
1252
1253        #Define the gradient function
1254        grad_loss = jax.jit(jax.grad(lf,0))
1255
1256        #Define update function
1257        @jax.jit
1258        def update(opt_state,params,x,k):
1259            #Compute gradient
1260            grads = grad_loss(params,x,k)
1261            #Calculate parameters updates
1262            updates, opt_state = optimizer.update(grads, opt_state)
1263            #Update parameters
1264            if q != 0:
1265                updates = {**updates, 'w': jax.tree_util.tree_map(lambda x: -x, updates['w'])} #Change signs of weights
1266            params = optax.apply_updates(params, updates)
1267            #Return state of optmizer and updated parameters
1268            return opt_state,params
1269    else:
1270        print('--------- LBFGS OPTIMIZER ---------')
1271        @jax.jit
1272        def loss_LBFGS(params):
1273            return lf(params,data,key + 234)
1274        solver = LBFGS(fun = loss_LBFGS,has_aux = True,maxiter = epochs,tol = 1e-9,verbose = False,linesearch = 'zoom',history_size = 100)  # linesearch='zoom' by default
1275        state = solver.init_state(params)
1276
1277    ###Training###
1278    t0 = time.time()
1279    k = jax.random.split(jax.random.PRNGKey(key+234),(epochs,))
1280    sloss = []
1281    sL2 = []
1282    stime = []
1283    #Initialize alive_bar for tracing in terminal
1284    with alive_bar(epochs) as bar:
1285        #For each epoch
1286        for e in range(epochs):
1287            if opt != 'LBFGS':
1288                #Update optimizer state and parameters
1289                opt_state,params = update(opt_state,params,data,k[e,:])
1290                sloss.append(lf(params,data,k[e,:]))
1291                if test_data is not None:
1292                    sL2.append(L2error(forward(test_data['sensor'],params['net']),test_data['usensor']))
1293            else:
1294                params, state = solver.update(params, state)
1295                sL2.append(state.aux["l2"])
1296                sloss.append(state.aux["loss"])
1297            stime.append(time.time() - t0)
1298            #After epoch_print epochs
1299            if e % epoch_print == 0:
1300                #Compute elapsed time and current error
1301                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(sloss[-1],6))
1302                #If there is test data, compute current L2 error
1303                if test_data is not None:
1304                    #Compute L2 error
1305                    l = l + ' L2 error: ' + str(jnp.round(sL2[-1],6))
1306                if inverse:
1307                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
1308                #Print
1309                print(l)
1310            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
1311                #Save current parameters
1312                pickle.dump({'params': params,'width': width,'time': stime,'loss': sloss,'L2error': sL2},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1313            #Update alive_bar
1314            bar()
1315    #Define estimated function
1316    def u(xt):
1317        return forward(xt,params['net'])
1318
1319    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0,'loss_each': lf_each(params,data,[key + 100]),'loss': sloss,'L2error': sL2}

Train a Physics-informed Neural Network

Parameters
  • data (dict): Data generated by the jinnax.data.generate_PINNdata function
  • width (list): A list with the width of each layer
  • pde (function): The partial differential operator. Its arguments are u, x and t
  • test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
  • params (list): Initial parameters for the neural network. Default None to initialize randomly
  • d (int): Dimension of the problem including the time variable if present. Default 2
  • N (int): Size of grid in each dimension. Default 128
  • L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
  • kappa,alpha,sigma (float): Parameters of the Matern process
  • bsize (int): Batch size for weak norm computation. Default 1024
  • resample (logical): Whether to resample the test functions at each epoch
  • epochs (int): Number of training epochs. Default 100
  • at_each (int): Save results for epochs multiple of at_each. Default 10
  • activation (str): The name of the activation function of the neural network. Default 'tanh'
  • neumann (logical): Whether to consider Neumann boundary conditions
  • oper_neumann (function): Penalization of Neumann boundary conditions
  • inverse (logical): Whether to estimate parameters of the PDE
  • initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
  • lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
  • key (int): Seed for parameters initialization. Default 0
  • epoch_print (int): Number of epochs to calculate and print test errors. Default 1
  • save (logical): Whether to save the current parameters. Default False
  • file_name (str): File prefix to save the current parameters. Default 'result_pinn'
  • exp_decay (logical): Whether to consider exponential decay of learning rate. Default True
  • transition_steps (int): Number of steps for exponential decay. Default 100
  • decay_rate (float): Rate of exponential decay. Default 0.9
  • mlp (logical): Whether to consider modifed multilayer perceptron
  • ftype (str): Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
  • fargs (list): Arguments for deature transformation:

    For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.

    For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.

  • q (int): Power of weights mask. Default 4
  • w (dict): Initila weights for self-adaptive scheme.
  • periodic (logical): Whether to consider periodic test functions. Default False.
  • static (function): A static function to sum to the neural network output.
  • opt (str): Optimizer. Default LBFGS.
Returns
  • dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch
def process_result( test_data, fit, train_data, plot=True, plot_test=True, times=5, d2=True, save=False, show=True, file_name='result_pinn', print_res=True, p=1):
1323def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1):
1324    """
1325    Process the results of a Physics-informed Neural Network
1326    ----------
1327
1328    Parameters
1329    ----------
1330    test_data : dict
1331
1332        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1333
1334    fit : function
1335
1336        The fitted function
1337
1338    train_data : dict
1339
1340        Training data generated by the jinnax.data.generate_PINNdata
1341
1342    plot : logical
1343
1344        Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
1345
1346    plot_test : logical
1347
1348        Whether to plot the test data. Default True
1349
1350    times : int
1351
1352        Number of points along the time interval to plot. Default 5
1353
1354    d2 : logical
1355
1356        Whether to plot 2D plot when the spatial dimension is one. Default True
1357
1358    save : logical
1359
1360        Whether to save the plots. Default False
1361
1362    show : logical
1363
1364        Whether to show the plots. Default True
1365
1366    file_name : str
1367
1368        File prefix to save the plots. Default 'result_pinn'
1369
1370    print_res : logical
1371
1372        Whether to print the L2 error. Default True
1373
1374    p : int
1375
1376        Output dimension. Default 1
1377
1378    Returns
1379    -------
1380    pandas data frame with L2 and MSE errors
1381    """
1382
1383    #Dimension
1384    d = test_data['xt'].shape[1] - 1
1385
1386    #Number of plots multiple of 5
1387    times = 5 * round(times/5.0)
1388
1389    #Data
1390    td = get_train_data(train_data)
1391    xt_train = td['x']
1392    u_train = td['y']
1393    upred_train = fit(xt_train)
1394    upred_test = fit(test_data['xt'])
1395
1396    #Results
1397    l2_error_test = L2error(upred_test,test_data['u']).tolist()
1398    MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist()
1399    l2_error_train = L2error(upred_train,u_train).tolist()
1400    MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist()
1401
1402    df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)),
1403        columns=['l2_error_test','MSE_test','l2_error_train','MSE_train'])
1404    if print_res:
1405        print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) )
1406
1407    #Plots
1408    if d == 1 and p ==1 and plot:
1409        plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name)
1410    elif p == 2 and plot:
1411        plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test)
1412
1413    return df

Process the results of a Physics-informed Neural Network

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • fit (function): The fitted function
  • train_data (dict): Training data generated by the jinnax.data.generate_PINNdata
  • plot (logical): Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
  • plot_test (logical): Whether to plot the test data. Default True
  • times (int): Number of points along the time interval to plot. Default 5
  • d2 (logical): Whether to plot 2D plot when the spatial dimension is one. Default True
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • print_res (logical): Whether to print the L2 error. Default True
  • p (int): Output dimension. Default 1
Returns
  • pandas data frame with L2 and MSE errors
def plot_pinn1D( times, xt, u, upred, d2=True, save=False, show=True, file_name='result_pinn', title_1d='', title_2d=''):
1416def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''):
1417    """
1418    Plot the prediction of a 1D PINN
1419    ----------
1420
1421    Parameters
1422    ----------
1423    times : int
1424
1425        Number of points along the time interval to plot. Default 5
1426
1427    xt : jax.numpy.array
1428
1429        Test data xt array
1430
1431    u : jax.numpy.array
1432
1433        Test data u(x,t) array
1434
1435    upred : jax.numpy.array
1436
1437        Predicted upred(x,t) array on test data
1438
1439    d2 : logical
1440
1441        Whether to plot 2D plot. Default True
1442
1443    save : logical
1444
1445        Whether to save the plots. Default False
1446
1447    show : logical
1448
1449        Whether to show the plots. Default True
1450
1451    file_name : str
1452
1453        File prefix to save the plots. Default 'result_pinn'
1454
1455    title_1d : str
1456
1457        Title of 1D plot
1458
1459    title_2d : str
1460
1461        Title of 2D plot
1462
1463    Returns
1464    -------
1465    None
1466    """
1467    #Initialize
1468    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
1469    tlo = jnp.min(xt[:,-1])
1470    tup = jnp.max(xt[:,-1])
1471    ylo = jnp.min(u)
1472    ylo = ylo - 0.1*jnp.abs(ylo)
1473    yup = jnp.max(u)
1474    yup = yup + 0.1*jnp.abs(yup)
1475    k = 0
1476    t_values = np.linspace(tlo,tup,times)
1477
1478    #Create
1479    for i in range(int(times/5)):
1480        for j in range(5):
1481            if k < len(t_values):
1482                t = t_values[k]
1483                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
1484                x_plot = xt[xt[:,-1] == t,:-1]
1485                y_plot = upred[xt[:,-1] == t,:]
1486                u_plot = u[xt[:,-1] == t,:]
1487                if int(times/5) > 1:
1488                    ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
1489                    ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
1490                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
1491                    ax[i,j].set_xlabel(' ')
1492                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1493                else:
1494                    ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
1495                    ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
1496                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
1497                    ax[j].set_xlabel(' ')
1498                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1499                k = k + 1
1500
1501    #Title
1502    fig.suptitle(title_1d)
1503    fig.tight_layout()
1504
1505    #Show and save
1506    fig = plt.gcf()
1507    if show:
1508        plt.show()
1509    if save:
1510        fig.savefig(file_name + '_slices.png')
1511    plt.close()
1512
1513    #2d plot
1514    if d2:
1515        #Initialize
1516        fig, ax = plt.subplots(1,2)
1517        l1 = jnp.unique(xt[:,-1]).shape[0]
1518        l2 = jnp.unique(xt[:,0]).shape[0]
1519
1520        #Create
1521        ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
1522        ax[0].set_title('Exact')
1523        ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
1524        ax[1].set_title('Predicted')
1525
1526        #Title
1527        fig.suptitle(title_2d)
1528        fig.tight_layout()
1529
1530        #Show and save
1531        fig = plt.gcf()
1532        if show:
1533            plt.show()
1534        if save:
1535            fig.savefig(file_name + '_2d.png')
1536        plt.close()

Plot the prediction of a 1D PINN

Parameters
  • times (int): Number of points along the time interval to plot. Default 5
  • xt (jax.numpy.array): Test data xt array
  • u (jax.numpy.array): Test data u(x,t) array
  • upred (jax.numpy.array): Predicted upred(x,t) array on test data
  • d2 (logical): Whether to plot 2D plot. Default True
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • title_1d (str): Title of 1D plot
  • title_2d (str): Title of 2D plot
Returns
  • None
def plot_pinn_out2D( times, xt, u, upred, save=False, show=True, file_name='result_pinn', title='', plot_test=True):
1539def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True):
1540    """
1541    Plot the prediction of a PINN with 2D output
1542    ----------
1543    Parameters
1544    ----------
1545    times : int
1546
1547        Number of points along the time interval to plot. Default 5
1548
1549    xt : jax.numpy.array
1550
1551        Test data xt array
1552
1553    u : jax.numpy.array
1554
1555        Test data u(x,t) array
1556
1557    upred : jax.numpy.array
1558
1559        Predicted upred(x,t) array on test data
1560
1561    save : logical
1562
1563        Whether to save the plots. Default False
1564
1565    show : logical
1566
1567        Whether to show the plots. Default True
1568
1569    file_name : str
1570
1571        File prefix to save the plots. Default 'result_pinn'
1572
1573    title : str
1574
1575        Title of plot
1576
1577    plot_test : logical
1578
1579        Whether to plot the test data. Default True
1580
1581    Returns
1582    -------
1583    None
1584    """
1585    #Initialize
1586    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
1587    tlo = jnp.min(xt[:,-1])
1588    tup = jnp.max(xt[:,-1])
1589    xlo = jnp.min(u[:,0])
1590    xlo = xlo - 0.1*jnp.abs(xlo)
1591    xup = jnp.max(u[:,0])
1592    xup = xup + 0.1*jnp.abs(xup)
1593    ylo = jnp.min(u[:,1])
1594    ylo = ylo - 0.1*jnp.abs(ylo)
1595    yup = jnp.max(u[:,1])
1596    yup = yup + 0.1*jnp.abs(yup)
1597    k = 0
1598    t_values = np.linspace(tlo,tup,times)
1599
1600    #Create
1601    for i in range(int(times/5)):
1602        for j in range(5):
1603            if k < len(t_values):
1604                t = t_values[k]
1605                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
1606                xpred_plot = upred[xt[:,-1] == t,0]
1607                ypred_plot = upred[xt[:,-1] == t,1]
1608                if plot_test:
1609                    x_plot = u[xt[:,-1] == t,0]
1610                    y_plot = u[xt[:,-1] == t,1]
1611                if int(times/5) > 1:
1612                    if plot_test:
1613                        ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
1614                    ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction')
1615                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
1616                    ax[i,j].set_xlabel(' ')
1617                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1618                else:
1619                    if plot_test:
1620                        ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
1621                    ax[j].plot(xpred_plot,ypred,'r-',linewidth=2,label='Prediction')
1622                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
1623                    ax[j].set_xlabel(' ')
1624                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1625                k = k + 1
1626
1627    #Title
1628    fig.suptitle(title)
1629    fig.tight_layout()
1630
1631    #Show and save
1632    fig = plt.gcf()
1633    if show:
1634        plt.show()
1635    if save:
1636        fig.savefig(file_name + '_slices.png')
1637    plt.close()

Plot the prediction of a PINN with 2D output

Parameters
  • times (int): Number of points along the time interval to plot. Default 5
  • xt (jax.numpy.array): Test data xt array
  • u (jax.numpy.array): Test data u(x,t) array
  • upred (jax.numpy.array): Predicted upred(x,t) array on test data
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • title (str): Title of plot
  • plot_test (logical): Whether to plot the test data. Default True
Returns
  • None
def get_train_data(train_data):
1640def get_train_data(train_data):
1641    """
1642    Process training sample
1643    ----------
1644
1645    Parameters
1646    ----------
1647    train_data : dict
1648
1649        A dictionay with train data generated by the jinnax.data.generate_PINNdata function
1650
1651    Returns
1652    -------
1653    dict with the processed training data
1654    """
1655    xdata = None
1656    ydata = None
1657    xydata = None
1658    if train_data['sensor'] is not None:
1659        sensor_sample = train_data['sensor'].shape[0]
1660        xdata = train_data['sensor']
1661        ydata = train_data['usensor']
1662        xydata = jnp.column_stack((train_data['sensor'],train_data['usensor']))
1663    else:
1664        sensor_sample = 0
1665    if train_data['boundary'] is not None:
1666        boundary_sample = train_data['boundary'].shape[0]
1667        if xdata is not None:
1668            xdata = jnp.vstack((xdata,train_data['boundary']))
1669            ydata = jnp.vstack((ydata,train_data['uboundary']))
1670            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary']))))
1671        else:
1672            xdata = train_data['boundary']
1673            ydata = train_data['uboundary']
1674            xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary']))
1675    else:
1676        boundary_sample = 0
1677    if train_data['initial'] is not None:
1678        initial_sample = train_data['initial'].shape[0]
1679        if xdata is not None:
1680            xdata = jnp.vstack((xdata,train_data['initial']))
1681            ydata = jnp.vstack((ydata,train_data['uinitial']))
1682            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial']))))
1683        else:
1684            xdata = train_data['initial']
1685            ydata = train_data['uinitial']
1686            xydata = jnp.column_stack((train_data['initial'],train_data['uinitial']))
1687    else:
1688        initial_sample = 0
1689    if train_data['collocation'] is not None:
1690        collocation_sample = train_data['collocation'].shape[0]
1691    else:
1692        collocation_sample = 0
1693
1694    return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
Process training sample
Parameters
Returns
  • dict with the processed training data
def process_training( test_data, file_name, at_each=100, bolstering=True, mc_sample=10000, save=False, file_name_save='result_pinn', key=0, ec=1e-06, lamb=1):
1697def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1):
1698    """
1699    Process the training of a Physics-informed Neural Network
1700    ----------
1701
1702    Parameters
1703    ----------
1704    test_data : dict
1705
1706        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1707
1708    file_name : str
1709
1710        Name of the files saved during training
1711
1712    at_each : int
1713
1714        Compute results for epochs multiple of at_each. Default 100
1715
1716    bolstering : logical
1717
1718        Whether to compute bolstering mean square error. Default True
1719
1720    mc_sample : int
1721
1722        Number of sample for Monte Carlo integration in bolstering. Default 10000
1723
1724    save : logical
1725
1726        Whether to save the training results. Default False
1727
1728    file_name_save : str
1729
1730        File prefix to save the plots and the L2 error. Default 'result_pinn'
1731
1732    key : int
1733
1734        Key for random samples in bolstering. Default 0
1735
1736    ec : float
1737
1738        Stopping criteria error for EM algorithm in bolstering. Default 1e-6
1739
1740    lamb : float
1741
1742        Hyperparameter of EM algorithm in bolstering. Default 1
1743
1744    Returns
1745    -------
1746    pandas data frame with training results
1747    """
1748    #Config
1749    config = pickle.load(open(file_name + '_config.pickle', 'rb'))
1750    epochs = config['epochs']
1751    train_data = config['train_data']
1752    forward = config['forward']
1753
1754    #Get train data
1755    td = get_train_data(train_data)
1756    xydata = td['xy']
1757    xdata = td['x']
1758    ydata = td['y']
1759    sensor_sample = td['sensor_sample']
1760    boundary_sample = td['boundary_sample']
1761    initial_sample = td['initial_sample']
1762    collocation_sample = td['collocation_sample']
1763
1764    #Generate keys
1765    if bolstering:
1766        keys = jax.random.split(jax.random.PRNGKey(key),epochs)
1767
1768    #Initialize loss
1769    train_mse = []
1770    test_mse = []
1771    train_L2 = []
1772    test_L2 = []
1773    bolstX = []
1774    bolstXY = []
1775    loss = []
1776    time = []
1777    ep = []
1778
1779    #Process training
1780    with alive_bar(epochs) as bar:
1781        for e in range(epochs):
1782            if (e % at_each == 0 and at_each != epochs) or e == epochs - 1:
1783                ep = ep + [e]
1784
1785                #Read parameters
1786                params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb'))
1787
1788                #Time
1789                time = time + [params['time']]
1790
1791                #Define learned function
1792                def psi(x):
1793                    return forward(x,params['params']['net'])
1794
1795                #Train MSE and L2
1796                if xdata is not None:
1797                    train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()]
1798                    train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()]
1799                else:
1800                    train_mse = train_mse + [None]
1801                    train_L2 = train_L2 + [None]
1802
1803                #Test MSE and L2
1804                test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()]
1805                test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()]
1806
1807                #Bolstering
1808                if bolstering:
1809                    bX = []
1810                    bXY = []
1811                    for method in ['chi','mm','mpe']:
1812                        kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
1813                        kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
1814                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
1815                        bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()]
1816                    for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]:
1817                        kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias)
1818                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
1819                    bolstX = bolstX + [bX]
1820                    bolstXY = bolstXY + [bXY]
1821                else:
1822                    bolstX = bolstX + [None]
1823                    bolstXY = bolstXY + [None]
1824
1825                #Loss
1826                loss = loss + [params['loss'].tolist()]
1827
1828                #Delete
1829                del params, psi
1830            #Update alive_bar
1831            bar()
1832
1833    #Bolstering results
1834    if bolstering:
1835        bolstX = jnp.array(bolstX)
1836        bolstXY = jnp.array(bolstXY)
1837
1838    #Create data frame
1839    if bolstering:
1840        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
1841            train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]),
1842            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4'])
1843    else:
1844        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
1845            train_mse,test_mse,train_L2,test_L2]),
1846            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2'])
1847    if save:
1848        df.to_csv(file_name_save + '.csv',index = False)
1849
1850    return df

Process the training of a Physics-informed Neural Network

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • at_each (int): Compute results for epochs multiple of at_each. Default 100
  • bolstering (logical): Whether to compute bolstering mean square error. Default True
  • mc_sample (int): Number of sample for Monte Carlo integration in bolstering. Default 10000
  • save (logical): Whether to save the training results. Default False
  • file_name_save (str): File prefix to save the plots and the L2 error. Default 'result_pinn'
  • key (int): Key for random samples in bolstering. Default 0
  • ec (float): Stopping criteria error for EM algorithm in bolstering. Default 1e-6
  • lamb (float): Hyperparameter of EM algorithm in bolstering. Default 1
Returns
  • pandas data frame with training results
def demo_train_pinn1D( test_data, file_name, at_each=100, times=5, d2=True, file_name_save='result_pinn_demo', title='', framerate=10):
1853def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10):
1854    """
1855    Demo video with the training of a 1D PINN
1856    ----------
1857
1858    Parameters
1859    ----------
1860    test_data : dict
1861
1862        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1863
1864    file_name : str
1865
1866        Name of the files saved during training
1867
1868    at_each : int
1869
1870        Compute results for epochs multiple of at_each. Default 100
1871
1872    times : int
1873
1874        Number of points along the time interval to plot. Default 5
1875
1876    d2 : logical
1877
1878        Whether to make video demo of 2D plot. Default True
1879
1880    file_name_save : str
1881
1882        File prefix to save the plots and videos. Default 'result_pinn_demo'
1883
1884    title : str
1885
1886        Title for plots
1887
1888    framerate : int
1889
1890        Framerate for video. Default 10
1891
1892    Returns
1893    -------
1894    None
1895    """
1896    #Config
1897    with open(file_name + '_config.pickle', 'rb') as file:
1898        config = pickle.load(file)
1899    epochs = config['epochs']
1900    train_data = config['train_data']
1901    forward = config['forward']
1902
1903    #Get train data
1904    td = get_train_data(train_data)
1905    xt = td['x']
1906    u = td['y']
1907
1908    #Create folder to save plots
1909    os.system('mkdir ' + file_name_save)
1910
1911    #Create images
1912    k = 1
1913    with alive_bar(epochs) as bar:
1914        for e in range(epochs):
1915            if e % at_each == 0 or e == epochs - 1:
1916                #Read parameters
1917                params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1918
1919                #Define learned function
1920                def psi(x):
1921                    return forward(x,params['params']['net'])
1922
1923                #Compute L2 train, L2 test and loss
1924                loss = params['loss']
1925                L2_train = L2error(psi(xt),u)
1926                L2_test = L2error(psi(test_data['xt']),test_data['u'])
1927                title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6))
1928
1929                #Save plot
1930                plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch)
1931                k = k + 1
1932
1933                #Delete
1934                del params, psi, loss, L2_train, L2_test, title_epoch
1935            #Update alive_bar
1936            bar()
1937    #Create demo video
1938    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4')
1939    if d2:
1940        os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')

Demo video with the training of a 1D PINN

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • at_each (int): Compute results for epochs multiple of at_each. Default 100
  • times (int): Number of points along the time interval to plot. Default 5
  • d2 (logical): Whether to make video demo of 2D plot. Default True
  • file_name_save (str): File prefix to save the plots and videos. Default 'result_pinn_demo'
  • title (str): Title for plots
  • framerate (int): Framerate for video. Default 10
Returns
  • None
def demo_time_pinn1D( test_data, file_name, epochs, file_name_save='result_pinn_time_demo', title='', framerate=10):
1943def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10):
1944    """
1945    Demo video with the time evolution of a 1D PINN
1946    ----------
1947
1948    Parameters
1949    ----------
1950    test_data : dict
1951
1952        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1953
1954    file_name : str
1955
1956        Name of the files saved during training
1957
1958    epochs : list
1959
1960        Which training epochs to plot
1961
1962    file_name_save : str
1963
1964        File prefix to save the plots and video. Default 'result_pinn_time_demo'
1965
1966    title : str
1967
1968        Title for plots
1969
1970    framerate : int
1971
1972        Framerate for video. Default 10
1973
1974    Returns
1975    -------
1976    None
1977    """
1978    #Config
1979    with open(file_name + '_config.pickle', 'rb') as file:
1980        config = pickle.load(file)
1981    train_data = config['train_data']
1982    forward = config['forward']
1983
1984    #Create folder to save plots
1985    os.system('mkdir ' + file_name_save)
1986
1987    #Plot parameters
1988    tdom = jnp.unique(test_data['xt'][:,-1])
1989    ylo = jnp.min(test_data['u'])
1990    ylo = ylo - 0.1*jnp.abs(ylo)
1991    yup = jnp.max(test_data['u'])
1992    yup = yup + 0.1*jnp.abs(yup)
1993
1994    #Open PINN for each epoch
1995    results = []
1996    upred = []
1997    for e in epochs:
1998        tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1999        results = results + [tmp]
2000        upred = upred + [forward(test_data['xt'],tmp['params']['net'])]
2001
2002    #Create images
2003    k = 1
2004    with alive_bar(len(tdom)) as bar:
2005        for t in tdom:
2006            #Test data
2007            xt_step = test_data['xt'][test_data['xt'][:,-1] == t]
2008            u_step = test_data['u'][test_data['xt'][:,-1] == t]
2009            #Initialize plot
2010            if len(epochs) > 1:
2011                fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2))
2012            else:
2013                fig, ax = plt.subplots(1,1,figsize = (10,5))
2014            #Create
2015            index = 0
2016            if int(len(epochs)/2) > 1:
2017                for i in range(int(len(epochs)/2)):
2018                    for j in range(min(2,len(epochs))):
2019                        upred_step = upred[index][test_data['xt'][:,-1] == t]
2020                        ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2021                        ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2022                        ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2023                        ax[i,j].set_xlabel(' ')
2024                        ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2025                        index = index + 1
2026            elif len(epochs) > 1:
2027                for j in range(2):
2028                    upred_step = upred[index][test_data['xt'][:,-1] == t]
2029                    ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2030                    ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2031                    ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2032                    ax[j].set_xlabel(' ')
2033                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2034                    index = index + 1
2035            else:
2036                upred_step = upred[index][test_data['xt'][:,-1] == t]
2037                ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2038                ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2039                ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2040                ax.set_xlabel(' ')
2041                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2042                index = index + 1
2043
2044
2045            #Title
2046            fig.suptitle(title + 't = ' + str(round(t,4)))
2047            fig.tight_layout()
2048
2049            #Show and save
2050            fig = plt.gcf()
2051            fig.savefig(file_name_save + '/' + str(k) + '.png')
2052            k = k + 1
2053            plt.close()
2054            bar()
2055
2056    #Create demo video
2057    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')

Demo video with the time evolution of a 1D PINN

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • epochs (list): Which training epochs to plot
  • file_name_save (str): File prefix to save the plots and video. Default 'result_pinn_time_demo'
  • title (str): Title for plots
  • framerate (int): Framerate for video. Default 10
Returns
  • None