jinnax.nn

   1#Functions to train NN
   2import jax
   3import jax.numpy as jnp
   4import optax
   5from alive_progress import alive_bar
   6import math
   7import time
   8import pandas as pd
   9import numpy as np
  10import matplotlib.pyplot as plt
  11import dill as pickle
  12from genree import bolstering as gb
  13from genree import kernel as gk
  14from jax import random
  15from jinnax import data as jd
  16import os
  17import numpy as np
  18from scipy.fft import dst, idst
  19from itertools import product
  20from functools import partial
  21import orthax
  22from jax import lax
  23from jaxopt import LBFGS
  24from jax.tree_util import tree_flatten
  25
  26__docformat__ = "numpy"
  27
  28def assert_tree_float64(tree, name="object"):
  29    """
  30    Raises an error if any leaf in a pytree is not float64.
  31    """
  32    leaves, _ = tree_flatten(tree)
  33    bad = [
  34        (i, x.dtype)
  35        for i, x in enumerate(leaves)
  36        if hasattr(x, "dtype") and x.dtype != jnp.float64
  37    ]
  38    if bad:
  39        raise RuntimeError(
  40            f"[FLOAT64 CHECK FAILED] {name} contains non-float64 dtypes:\n"
  41            + "\n".join([f"  leaf {i}: {dtype}" for i, dtype in bad])
  42        )
  43
  44
  45def warn_tree_float64(tree, name="object"):
  46    """
  47    Same as assert_tree_float64, but prints a warning instead of raising.
  48    """
  49    leaves, _ = tree_flatten(tree)
  50    dtypes = {x.dtype for x in leaves if hasattr(x, "dtype")}
  51    if dtypes != {jnp.float64}:
  52        print(f"[WARNING] {name} dtypes = {dtypes}")
  53
  54
  55def check_grads_float64(grads):
  56    leaves, _ = tree_flatten(grads)
  57    bad = [
  58        (i, x.dtype)
  59        for i, x in enumerate(leaves)
  60        if hasattr(x, "dtype") and x.dtype != jnp.float64
  61    ]
  62    if bad:
  63        raise RuntimeError(
  64            "[FLOAT64 CHECK FAILED] Gradients are not float64:\n"
  65            + "\n".join([f"  grad {i}: {dtype}" for i, dtype in bad])
  66        )
  67
  68
  69def assert_lbfgs_state_float64(state):
  70    """
  71    Ensures that all *floating-point* values inside the LBFGS state
  72    are float64. Integers and booleans are allowed.
  73    """
  74    leaves, _ = tree_flatten(state)
  75    bad = []
  76    for i, x in enumerate(leaves):
  77        if isinstance(x, jnp.ndarray):
  78            if jnp.issubdtype(x.dtype, jnp.floating):
  79                if x.dtype != jnp.float64:
  80                    bad.append((i, x.dtype))
  81    if bad:
  82        raise RuntimeError(
  83            "[FLOAT64 CHECK FAILED] LBFGS floating-point buffers are not float64:\n"
  84            + "\n".join([f"  leaf {i}: {dtype}" for i, dtype in bad])
  85        )
  86
  87
  88#Change to float64
  89def to_float64(tree):
  90    def cast(x):
  91        if isinstance(x, (jax.Array, np.ndarray, float, int)):
  92            return x.astype(jnp.float64)
  93        return x
  94    return jax.tree_util.tree_map(cast, tree)
  95
  96
  97#MSE
  98@jax.jit
  99def MSE(pred,true):
 100    """
 101    Squared error
 102    ----------
 103    Parameters
 104    ----------
 105    pred : jax.numpy.array
 106
 107        A JAX numpy array with the predicted values
 108
 109    true : jax.numpy.array
 110
 111        A JAX numpy array with the true values
 112
 113    Returns
 114    -------
 115    squared error
 116    """
 117    return (true - pred) ** 2
 118
 119#MSE self-adaptative
 120@jax.jit
 121def MSE_SA(pred,true,w,q = 2):
 122    """
 123    Self-adaptative squared error
 124    ----------
 125    Parameters
 126    ----------
 127    pred : jax.numpy.array
 128
 129        A JAX numpy array with the predicted values
 130
 131    true : jax.numpy.array
 132
 133        A JAX numpy array with the true values
 134
 135    weight : jax.numpy.array
 136
 137        A JAX numpy array with the weights
 138
 139    q : float
 140
 141        Power for the weights mask
 142
 143    Returns
 144    -------
 145    self-adaptative squared error with polynomial mask
 146    """
 147    return (w ** q) * ((true - pred) ** 2)
 148
 149#L2 error
 150@jax.jit
 151def L2error(pred,true):
 152    """
 153    L2-error in percentage (%)
 154    ----------
 155    Parameters
 156    ----------
 157    pred : jax.numpy.array
 158
 159        A JAX numpy array with the predicted values
 160
 161    true : jax.numpy.array
 162
 163        A JAX numpy array with the true values
 164
 165    Returns
 166    -------
 167    L2-error
 168    """
 169    return 100*jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))
 170
 171#Auxialiry functions to sample singular Matern
 172def idst1(x,axis = -1):
 173    """
 174    Inverse Discrete Sine Transform of type I with orthonormal scaling
 175    ----------
 176    Parameters
 177    ----------
 178    x : jax.numpy.array
 179
 180        Array to apply the transformation
 181
 182    axis : int
 183
 184        Axis to apply the transformation over
 185
 186    Returns
 187    -------
 188    jax.numpy.array
 189    """
 190    return idst(x,type = 1,axis = axis,norm = 'ortho')
 191
 192def dstn(x,axes = None):
 193    """
 194    Discrete Sine Transform of type I with orthonormal scaling over many axes
 195    ----------
 196    Parameters
 197    ----------
 198    x : jax.numpy.array
 199
 200        Array to apply the transformation
 201
 202    axes : int
 203
 204        Axes to apply the transformation over
 205
 206    Returns
 207    -------
 208    jax.numpy.array
 209    """
 210    if axes is None:
 211        axes = tuple(range(x.ndim))
 212    y = x
 213    for ax in axes:
 214        y = dst(y,type = 1,axis = ax,norm = 'ortho')
 215    return y
 216
 217def idstn(x,axes = None):
 218    """
 219    Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes
 220    ----------
 221    Parameters
 222    ----------
 223    x : jax.numpy.array
 224
 225        Array to apply the transformation
 226
 227    axes : tuple
 228
 229        Axes to apply the transformation over
 230
 231    Returns
 232    -------
 233    jax.numpy.array
 234    """
 235    if axes is None:
 236        axes = tuple(range(x.ndim))
 237    y = x
 238    for ax in axes:
 239        y = idst1(y,axis = ax)
 240    return y
 241
 242def dirichlet_eigs_nd(n,L):
 243    """
 244    Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle
 245    ----------
 246    Parameters
 247    ----------
 248    n : list
 249
 250        List with the number of points in the grid in each dimension
 251
 252    L : list
 253
 254        List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero
 255
 256    Returns
 257    -------
 258    jax.numpy.array
 259    """
 260    #Unidimensional eigenvalues
 261    lam_axes = []
 262    for ni, Li in zip(n,L):
 263        h = Li / (ni + 1)
 264        k = jnp.arange(1,ni + 1)
 265        ln = (2 / (h*h)) * (1 - jnp.cos(jnp.pi * k / (ni + 1)))
 266        lam_axes.append(ln)
 267    grids = jnp.meshgrid(*lam_axes, indexing='ij')
 268    Lam = jnp.zeros_like(grids[0])
 269    for g in grids:
 270        Lam += g
 271    return Lam
 272
 273
 274#Sample from d-dimensional Matern process
 275def generate_matern_sample(key,d = 2,N = 128,L = 1.0,kappa = 1,alpha = 1,sigma = 1,periodic = False):
 276    """
 277    Sample d-dimensional Matern process
 278    ----------
 279    Parameters
 280    ----------
 281    key : int
 282
 283        Seed for randomization
 284
 285    d : int
 286
 287        Dimension. Default 2
 288
 289    N : int
 290
 291        Size of grid in each dimension. Default 128
 292
 293    L : list of float
 294
 295        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
 296
 297    kappa,alpha,sigma : float
 298
 299        Parameters of the Matern process
 300
 301    periodic : logical
 302
 303        Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
 304
 305    Returns
 306    -------
 307    jax.numpy.array
 308    """
 309    if periodic:
 310        #Shape and key
 311        key = jax.random.PRNGKey(key)
 312        shape = (N,) * d
 313        if isinstance(L,float) or isinstance(L,int):
 314            L = d*[L]
 315        if isinstance(N,float) or isinstance(N,int):
 316            N = d*[N]
 317
 318        #Setup Frequency Grid (2D)
 319        freq = [jnp.fft.fftfreq(N[j],d = L[j]/N[j]) * 2 * jnp.pi for j in range(d)]
 320        grids = jnp.meshgrid(*freq, indexing='ij')
 321        sq_norm_xi = sum(g**2 for g in grids)
 322
 323        #Generate White Noise in Fourier Space
 324        key_re, key_im = jax.random.split(key)
 325        white_noise_f = (jax.random.normal(key_re, shape) +
 326                         1j * jax.random.normal(key_im, shape))
 327
 328        #Apply the Whittle Filter
 329        amplitude_filter = (kappa ** 2 + sq_norm_xi) ** (-alpha / 2)
 330        field_f = white_noise_f * amplitude_filter
 331
 332        #Transform back to Physical Space
 333        sample = jnp.real(jnp.fft.ifftn(field_f))
 334        return sigma*sample
 335    else: #NOT JAX
 336        #Shape and key
 337        rng = np.random.default_rng(seed = key)
 338        if isinstance(L,float) or isinstance(L,int):
 339            L = d*[L]
 340        if isinstance(N,float) or isinstance(N,int):
 341            N = d*[N]
 342        shape = tuple(N)
 343
 344        #White noise in real space
 345        W = rng.standard_normal(size = shape)
 346
 347        #To Dirichlet eigenbasis via separable DST-I (orthonormal)
 348        W_hat = dstn(W)
 349
 350        #Discrete Dirichlet Laplacian eigenvalues
 351        lam = dirichlet_eigs_nd(N, L)
 352
 353        #Spectral filter
 354        filt = ((kappa + lam) ** (-alpha/2))
 355        psi_hat = filt * W_hat
 356
 357        #Back to real space
 358        psi = idstn(psi_hat)
 359        return jnp.array(sigma*psi)
 360
 361#Vectorized generate_matern_sample
 362def generate_matern_sample_batch(d = 2,N = 512,L = 1.0,kappa = 10.0,alpha = 1,sigma = 10,periodic = False):
 363    """
 364    Create function to sample d-dimensional Matern process
 365    ----------
 366    Parameters
 367    ----------
 368    d : int
 369
 370        Dimension. Default 2
 371
 372    N : int
 373
 374        Size of grid in each dimension. Default 128
 375
 376    L : list of float
 377
 378        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
 379
 380    kappa,alpha,sigma : float
 381
 382        Parameters of the Matern process
 383
 384    periodic : logical
 385
 386        Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
 387
 388    Returns
 389    -------
 390    function
 391    """
 392    if periodic:
 393        return jax.vmap(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic))
 394    else:
 395        return lambda keys: jnp.array(np.apply_along_axis(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic),1,keys.reshape((keys.shape[0],1))))
 396
 397#Build function to compute the eigenfunctions of Laplacian
 398def eigenf_laplace(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None):
 399    """
 400    Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace.
 401    ----------
 402    Parameters
 403    ----------
 404    L_vec : list of float
 405
 406        The domain of the function in each coordinate is [0,L[1]]
 407
 408    kmax_per_axis : list
 409
 410        List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions
 411
 412    bc : str
 413
 414        Boundary condition. 'dirichlet' or 'neumann'
 415
 416    max_ef : int
 417
 418        Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
 419
 420    Returns
 421    -------
 422    function to compute eigenfunctions,eigenvalues of the eigenfunctions considered
 423    """
 424    #Parameters
 425    L_vec = jnp.asarray(L_vec)
 426    d = L_vec.shape[0]
 427    bc = bc.lower()
 428
 429    #Maximum number of functions
 430    if max_ef is None:
 431        if d == 1:
 432            max_ef = jnp.max(jnp.array(kmax_per_axis))
 433        else:
 434            max_ef = jnp.max(d * jnp.array(kmax_per_axis))
 435
 436    #Build the candidate multi-indices per axis
 437    kmax_per_axis = list(map(int, kmax_per_axis))
 438    if bc.startswith("d"):
 439        axis_ranges = [range(1, km + 1) for km in kmax_per_axis]
 440    elif bc.startswith("n"):
 441        axis_ranges = [range(0, km + 1) for km in kmax_per_axis]
 442
 443    #Get all multi-indices
 444    Ks_list = list(product(*axis_ranges))
 445    Ks = jnp.array(Ks_list)
 446
 447    #Eigenvalues of the continuous Laplacian
 448    pi_over_L = jnp.pi / L_vec
 449    lambdas_all = jnp.sum((Ks * pi_over_L) ** 2, axis=1)
 450
 451    #Sort by eigenvalue
 452    order = jnp.argsort(lambdas_all)
 453    Ks = Ks[order]
 454    lambdas_all = lambdas_all[order]
 455
 456    #Keep first max_ef
 457    Ks = Ks[:max_ef]
 458    lambdas = lambdas_all[:max_ef]
 459    m = Ks.shape[0]
 460
 461    #Precompute per-feature normalization factor (closed form)
 462    def per_axis_norm_factor(k_i, L_i, is_dirichlet):
 463        if is_dirichlet:
 464            return jnp.sqrt(2 / L_i)
 465        else:
 466            return jnp.where(k_i == 0, jnp.sqrt(1 / L_i), jnp.sqrt(2 / L_i))
 467    if bc.startswith("d"):
 468        nf = jnp.prod(jnp.sqrt(2 / L_vec)[None, :],axis = 1)
 469        norm_factors = jnp.ones((m,)) * nf
 470    else:
 471        # per-mode product across axes
 472        def nf_row(k_row):
 473            return jnp.prod(per_axis_norm_factor(k_row, L_vec, False))
 474        norm_factors = jax.vmap(nf_row)(Ks)
 475
 476    #Build the callable function
 477    Ks_int = Ks  # float array, but only integer values
 478    L_vec_f = L_vec
 479    @jax.jit
 480    def phi(x):
 481        x = jnp.asarray(x)
 482        #Initialize with ones
 483        vals = jnp.ones(x.shape[:-1] + (m,))
 484        #Compute eigenfunction
 485        for i in range(d):
 486            ang = (jnp.pi / L_vec_f[i]) * x[..., i][..., None] * Ks_int[:, i]
 487            if bc.startswith("d"):
 488                comp = jnp.sin(ang)
 489            else:
 490                comp = jnp.cos(ang)
 491            vals = vals * comp
 492        #Apply L2-normalizing constants
 493        vals = vals * norm_factors[None, ...] if vals.ndim > 1 else vals * norm_factors
 494        return vals
 495    return phi, lambdas
 496
 497#Compute multiple frequences of domain aware fourrier fesatures
 498def multiple_daff(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None):
 499    """
 500    Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain.
 501    ----------
 502    Parameters
 503    ----------
 504    L_vec : list of lists of float
 505
 506        List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]]
 507
 508    kmax_per_axis : list
 509
 510        List with the maximum number of eigenfunctions per dimension.
 511
 512    bc : str
 513
 514        Boundary condition. 'dirichlet' or 'neumann'
 515
 516    max_ef : int
 517
 518        Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
 519
 520    Returns
 521    -------
 522    function to compute daff,eigenvalues of the eigenfunctions considered
 523    """
 524    psi = []
 525    lamb = []
 526    for L in L_vec:
 527        tmp,l = eigenf_laplace(L,kmax_per_axis,bc,max_ef) #Get function
 528        lamb.append(l)
 529        psi.append(tmp)
 530        del tmp
 531    #Create function to compute features
 532    @jax.jit
 533    def mff(x):
 534        y = []
 535        for i in range(len(psi)):
 536            y.append(psi[i](x))
 537        if len(psi) == 1:
 538            return y[0]
 539        else:
 540            return jnp.concatenate(y,1)
 541    return mff,jnp.concatenate(lamb)
 542
 543#Code for chebyshev polynomials writeen by AI (deprecated)
 544def _chebyshev_T_all(t, K: int):
 545    """
 546    Compute T_0..T_K(t) with the standard recurrence.
 547    t shape should be (..., d). We DO NOT squeeze any axis to preserve 'd'
 548    even when d == 1.
 549    Returns: array of shape (K+1, ...) matching t's batch dims, including d.
 550    """
 551    # Expect t to have last axis = d (keep it, even if d == 1)
 552    T0 = jnp.ones_like(t)           # (..., d)
 553    if K == 0:
 554        return T0[None, ...]        # (1, ..., d)
 555
 556    T1 = t                          # (..., d)
 557    if K == 1:
 558        return jnp.stack([T0, T1], axis=0)  # (2, ..., d)
 559
 560    def body(carry, _):
 561        Tkm1, Tk = carry            # each (..., d)
 562        Tkp1 = 2 * t * Tk - Tkm1  # (..., d)
 563        return (Tk, Tkp1), Tkp1
 564
 565    # K >= 2: produce T_2..T_K
 566    (_, _), T2_to_TK = lax.scan(body, (T0, T1), jnp.arange(K - 1))  # (K-1, ..., d)
 567    return jnp.concatenate([T0[None, ...], T1[None, ...], T2_to_TK], axis=0)  # (K+1, ..., d)
 568
 569@partial(jax.jit,static_argnums=(2,))  # n is static here; compile once per n
 570def multiple_cheb_fast(x, L_vec, n: int):
 571    """
 572    x: (N, d)
 573    L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension
 574    n: number of k terms (static)
 575    returns: (N, L*n)
 576    """
 577    N, d = x.shape
 578    L = L_vec.shape[0]
 579
 580    a = 0
 581    b = L_vec                       # (L, d)
 582    # Map x to t in [-1, 1] for each l, j: shape (L, N, d)
 583    t = (2 * x[None, :, :] - (a + b)[:, None, :]) / (b - a)[:, None, :]
 584
 585    # Chebyshev T_0..T_{n+2} for all (L, N, d): shape (n+3, L, N, d)
 586    T = _chebyshev_T_all(t, n + 2)
 587
 588    # phi_k = T_{k+2} - T_k, k = 0..n-1  => shape (n, L, N, d)
 589    ks = jnp.arange(n)
 590    phi = T[ks + 2, ...] - T[ks, ...]
 591
 592    # Multiply across dimensions (over the last axis = d) => (n, L, N)
 593    z = jnp.prod(phi, axis=-1)
 594
 595    # Reorder to (N, L, n) then flatten to (N, L*n)
 596    z = jnp.transpose(z, (2, 1, 0)).reshape(N, L * n)
 597    return z
 598
 599def multiple_cheb(L_vec, n: int):
 600    """
 601    Factory that closes over static n and L_vec (so shapes are constant).
 602    """
 603    L_vec = jnp.asarray(L_vec)
 604    @jax.jit  # optional; multiple_cheb_fast is already jitted
 605    def mcheb(x):
 606        x = jnp.asarray(x)
 607        return multiple_cheb_fast(x, L_vec, n)
 608    return mcheb
 609
 610
 611#Initialize fully connected neyral network Return the initial parameters and the function for the forward pass
 612def fconNN(width,activation = jax.nn.tanh,key = 0,mlp = False,ftype = None,fargs = None,static = None,daff = None):
 613    """
 614    Initialize fully connected neural network
 615    ----------
 616    Parameters
 617    ----------
 618    width : list
 619
 620        List with the layers width
 621
 622    activation : jax.nn activation
 623
 624        The activation function. Default jax.nn.tanh
 625
 626    key : int
 627
 628        Seed for parameters initialization. Default 0
 629
 630    mlp : logical
 631
 632        Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension.
 633
 634    ftype : str
 635
 636        Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
 637
 638    fargs : list
 639
 640        Arguments for deature transformation:
 641
 642        For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
 643
 644        For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
 645
 646    static : function
 647
 648        A static function to sum to the neural network output.
 649
 650    daff : list
 651
 652    List with function to compute daff and the number of daff. If None computes assuming rectangular domain.
 653
 654    Returns
 655    -------
 656    dict with initial parameters and the function for the forward pass
 657    """
 658    #Initialize parameters with Glorot initialization
 659    initializer = jax.nn.initializers.glorot_normal()
 660    params = list()
 661    if static is None:
 662        static = lambda x: 0
 663
 664    #Feature mapping
 665    if ftype == 'ff': #Fourrier features
 666        for s in range(fargs[0]):
 667            sd = fargs[1] ** ((s + 1)/fargs[0])
 668            if s == 0:
 669                Bff = sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2)))
 670            else:
 671                Bff = jnp.append(Bff,sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))),1)
 672        @jax.jit
 673        def phi(x):
 674            x = x @ Bff
 675            return jnp.concatenate([jnp.sin(2 * jnp.pi * x),jnp.cos(2 * jnp.pi * x)],axis = -1)
 676        width = width[1:]
 677        width[0] = 2*Bff.shape[1]
 678    elif ftype == 'daff' or ftype == 'daff_bias':
 679        if not isinstance(fargs, dict):
 680            fargs = {'L': fargs,'bc': "dirichlet"}
 681        if daff is None:
 682            phi,lamb = multiple_daff(list(fargs.values())[0],kmax_per_axis = [width[1]] * width[0],bc = list(fargs.values())[1])
 683            width = width[1:]
 684            width[0] = lamb.shape[0]
 685        else:
 686            phi = daff[0]
 687            width = width[1:]
 688            width[0] = daff[1]
 689    elif ftype == 'cheb' or ftype == 'cheb_bias':
 690        phi = multiple_cheb(fargs,n = width[1])
 691        width = width[1:]
 692        width[0] = len(fargs)*width[0]
 693    else:
 694        @jax.jit
 695        def phi(x):
 696            return x
 697
 698    #Initialize parameters
 699    if mlp:
 700        k = jax.random.split(jax.random.PRNGKey(key),4)
 701        WU = initializer(k[0],(width[0],width[1]))
 702        BU = initializer(k[1],(1,width[1]))
 703        WV = initializer(k[2],(width[0],width[1]))
 704        BV = initializer(k[3],(1,width[1]))
 705        params.append({'WU':WU,'BU':BU,'WV':WV,'BV':BV})
 706    key = jax.random.split(jax.random.PRNGKey(key + 1),len(width)-1) #Seed for initialization
 707    for key,lin,lout in zip(key,width[:-1],width[1:]):
 708        W = initializer(key,(lin,lout))
 709        B = initializer(key,(1,lout))
 710        params.append({'W':W,'B':B})
 711
 712    #Define function for forward pass
 713    if mlp:
 714        if ftype != 'daff' and ftype != 'cheb':
 715            @jax.jit
 716            def forward(x,params):
 717                encode,*hidden,output = params
 718                sx = static(x)
 719                x = phi(x)
 720                U = activation(x @ encode['WU'] + encode['BU'])
 721                V = activation(x @ encode['WV'] + encode['BV'])
 722                for layer in hidden:
 723                    x = activation(x @ layer['W'] + layer['B'])
 724                    x = x * U + (1 - x) * V
 725                return x @ output['W'] + output['B'] + sx
 726        else:
 727            @jax.jit
 728            def forward(x,params):
 729                encode,*hidden,output = params
 730                sx = static(x)
 731                x = phi(x)
 732                U = activation(x @ encode['WU'])
 733                V = activation(x @ encode['WV'])
 734                for layer in hidden:
 735                    x = activation(x @ layer['W'])
 736                    x = x * U + (1 - x) * V
 737                return x @ output['W'] + sx
 738    else:
 739        if ftype != 'daff' and ftype != 'cheb':
 740            @jax.jit
 741            def forward(x,params):
 742                *hidden,output = params
 743                sx = static(x)
 744                x = phi(x)
 745                for layer in hidden:
 746                    x = activation(x @ layer['W'] + layer['B'])
 747                return x @ output['W'] + output['B'] + sx
 748        else:
 749            @jax.jit
 750            def forward(x,params):
 751                *hidden,output = params
 752                sx = static(x)
 753                x = phi(x)
 754                for layer in hidden:
 755                    x = activation(x @ layer['W'])
 756                return x @ output['W'] + sx
 757
 758    #Return initial parameters and forward function
 759    return {'params': params,'forward': forward}
 760
 761#Get activation from string
 762def get_activation(act):
 763    """
 764    Return activation function from string
 765    ----------
 766    Parameters
 767    ----------
 768    act : str
 769
 770        Name of the activation function. Default 'tanh'
 771
 772    Returns
 773    -------
 774    jax.nn activation function
 775    """
 776    if act == 'tanh':
 777        return jax.nn.tanh
 778    elif act == 'relu':
 779        return jax.nn.relu
 780    elif act == 'relu6':
 781        return jax.nn.relu6
 782    elif act == 'sigmoid':
 783        return jax.nn.sigmoid
 784    elif act == 'softplus':
 785        return jax.nn.softplus
 786    elif act == 'sparse_plus':
 787        return jax.nn.sparse_plus
 788    elif act == 'soft_sign':
 789        return jax.nn.soft_sign
 790    elif act == 'silu':
 791        return jax.nn.silu
 792    elif act == 'swish':
 793        return jax.nn.swish
 794    elif act == 'log_sigmoid':
 795        return jax.nn.log_sigmoid
 796    elif act == 'leaky_relu':
 797        return jax.nn.leaky_relu
 798    elif act == 'hard_sigmoid':
 799        return jax.nn.hard_sigmoid
 800    elif act == 'hard_silu':
 801        return jax.nn.hard_silu
 802    elif act == 'hard_swish':
 803        return jax.nn.hard_swish
 804    elif act == 'hard_tanh':
 805        return jax.nn.hard_tanh
 806    elif act == 'elu':
 807        return jax.nn.elu
 808    elif act == 'celu':
 809        return jax.nn.celu
 810    elif act == 'selu':
 811        return jax.nn.selu
 812    elif act == 'gelu':
 813        return jax.nn.gelu
 814    elif act == 'glu':
 815        return jax.nn.glu
 816    elif act == 'squareplus':
 817        return  jax.nn.squareplus
 818    elif act == 'mish':
 819        return jax.nn.mish
 820
 821#Training PINN
 822def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,mlp = False):
 823    """
 824    Train a Physics-informed Neural Network
 825    ----------
 826    Parameters
 827    ----------
 828    data : dict
 829
 830        Data generated by the jinnax.data.generate_PINNdata function
 831
 832    width : list
 833
 834        A list with the width of each layer
 835
 836    pde : function
 837
 838        The partial differential operator. Its arguments are u, x and t
 839
 840    test_data : dict, None
 841
 842        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
 843
 844    epochs : int
 845
 846        Number of training epochs. Default 100
 847
 848    at_each : int
 849
 850        Save results for epochs multiple of at_each. Default 10
 851
 852    activation : str
 853
 854        The name of the activation function of the neural network. Default 'tanh'
 855
 856    neumann : logical
 857
 858        Whether to consider Neumann boundary conditions
 859
 860    oper_neumann : function
 861
 862        Penalization of Neumann boundary conditions
 863
 864    sa : logical
 865
 866        Whether to consider self-adaptative PINN
 867
 868    c : dict
 869
 870        Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
 871
 872    inverse : logical
 873
 874        Whether to estimate parameters of the PDE
 875
 876    initial_par : jax.numpy.array
 877
 878        Initial value of the parameters of the PDE in an inverse problem
 879
 880    lr,b1,b2,eps,eps_root: float
 881
 882        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
 883
 884    key : int
 885
 886        Seed for parameters initialization. Default 0
 887
 888    epoch_print : int
 889
 890        Number of epochs to calculate and print test errors. Default 100
 891
 892    save : logical
 893
 894        Whether to save the current parameters. Default False
 895
 896    file_name : str
 897
 898        File prefix to save the current parameters. Default 'result_pinn'
 899
 900    exp_decay : logical
 901
 902        Whether to consider exponential decay of learning rate. Default False
 903
 904    transition_steps : int
 905
 906        Number of steps for exponential decay. Default 1000
 907
 908    decay_rate : float
 909
 910        Rate of exponential decay. Default 0.9
 911
 912    mlp : logical
 913
 914        Whether to consider modifed multi-layer perceptron
 915
 916    Returns
 917    -------
 918    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
 919    """
 920
 921    #Initialize architecture
 922    nnet = fconNN(width,get_activation(activation),key,mlp)
 923    forward = nnet['forward']
 924
 925    #Initialize self adaptative weights
 926    par_sa = {}
 927    if sa:
 928        #Initialize wheights close to zero
 929        ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000)
 930        if data['sensor'] is not None:
 931            par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))})
 932        if data['initial'] is not None:
 933            par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))})
 934        if data['collocation'] is not None:
 935            par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))})
 936        if data['boundary'] is not None:
 937            par_sa.update({'wb': c['wb'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))})
 938
 939    #Store all parameters
 940    params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa}
 941
 942    #Save config file
 943    if save:
 944        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
 945
 946    #Define loss function
 947    if sa:
 948        #Define loss function
 949        @jax.jit
 950        def lf(params,x):
 951            loss = 0
 952            if x['sensor'] is not None:
 953                #Term that refers to sensor data
 954                loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws']))
 955            if x['boundary'] is not None:
 956                if neumann:
 957                    #Neumann coditions
 958                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
 959                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
 960                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb']))
 961                else:
 962                    #Term that refers to boundary data
 963                    loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb']))
 964            if x['initial'] is not None:
 965                #Term that refers to initial data
 966                loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0']))
 967            if x['collocation'] is not None:
 968                #Term that refers to collocation points
 969                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
 970                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
 971                if inverse:
 972                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr']))
 973                else:
 974                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr']))
 975            return loss
 976    else:
 977        @jax.jit
 978        def lf(params,x):
 979            loss = 0
 980            if x['sensor'] is not None:
 981                #Term that refers to sensor data
 982                loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor']))
 983            if x['boundary'] is not None:
 984                if neumann:
 985                    #Neumann coditions
 986                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
 987                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
 988                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb))
 989                else:
 990                    #Term that refers to boundary data
 991                    loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary']))
 992            if x['initial'] is not None:
 993                #Term that refers to initial data
 994                loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial']))
 995            if x['collocation'] is not None:
 996                #Term that refers to collocation points
 997                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
 998                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
 999                if inverse:
1000                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0))
1001                else:
1002                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0))
1003            return loss
1004
1005    #Initialize Adam Optmizer
1006    if exp_decay:
1007        lr = optax.exponential_decay(lr,transition_steps,decay_rate)
1008    optimizer = optax.adam(lr,b1,b2,eps,eps_root)
1009    opt_state = optimizer.init(params)
1010
1011    #Define the gradient function
1012    grad_loss = jax.jit(jax.grad(lf,0))
1013
1014    #Define update function
1015    @jax.jit
1016    def update(opt_state,params,x):
1017        #Compute gradient
1018        grads = grad_loss(params,x)
1019        #Invert gradient of self-adaptative wheights
1020        if sa:
1021            for w in grads['sa']:
1022                grads['sa'][w] = - grads['sa'][w]
1023        #Calculate parameters updates
1024        updates, opt_state = optimizer.update(grads, opt_state)
1025        #Update parameters
1026        params = optax.apply_updates(params, updates)
1027        #Return state of optmizer and updated parameters
1028        return opt_state,params
1029
1030    ###Training###
1031    t0 = time.time()
1032    #Initialize alive_bar for tracing in terminal
1033    with alive_bar(epochs) as bar:
1034        #For each epoch
1035        for e in range(epochs):
1036            #Update optimizer state and parameters
1037            opt_state,params = update(opt_state,params,data)
1038            #After epoch_print epochs
1039            if e % epoch_print == 0:
1040                #Compute elapsed time and current error
1041                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6))
1042                #If there is test data, compute current L2 error
1043                if test_data is not None:
1044                    #Compute L2 error
1045                    l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist()
1046                    l = l + ' L2 error: ' + str(jnp.round(l2_test,3))
1047                if inverse:
1048                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
1049                #Print
1050                print(l)
1051            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
1052                #Save current parameters
1053                pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1054            #Update alive_bar
1055            bar()
1056    #Define estimated function
1057    def u(xt):
1058        return forward(xt,params['net'])
1059
1060    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}
1061
1062#Training PINN
1063def train_Matern_PINN(data,width,pde,test_data = None,params = None,d = 2,N = 128,L = 1,alpha = 1,kappa = 1,sigma = 100,bsize = 1024,resample = False,epochs = 100,at_each = 10,activation = 'tanh',
1064    neumann = False,oper_neumann = None,inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 1,save = False,file_name = 'result_pinn',
1065    exp_decay = True,transition_steps = 100,decay_rate = 0.9,mlp = True,ftype = None,fargs = None,q = 4,w = None,periodic = False,static = None,opt = 'LBFGS',float64 = False,restart = None):
1066    """
1067    Train a Physics-informed Neural Network
1068    ----------
1069    Parameters
1070    ----------
1071    data : dict
1072
1073        Data generated by the jinnax.data.generate_PINNdata function
1074
1075    width : list
1076
1077        A list with the width of each layer
1078
1079    pde : function
1080
1081        The partial differential operator. Its arguments are u, x and t
1082
1083    test_data : dict, None
1084
1085        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
1086
1087    params : list
1088
1089        Initial parameters for the neural network. Default None to initialize randomly
1090
1091    d : int
1092
1093        Dimension of the problem including the time variable if present. Default 2
1094
1095    N : int
1096
1097        Size of grid in each dimension. Default 128
1098
1099    L : list of float
1100
1101        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
1102
1103    kappa,alpha,sigma : float
1104
1105        Parameters of the Matern process
1106
1107    bsize : int
1108
1109        Batch size for weak norm computation. Default 1024
1110
1111    resample : logical
1112
1113        Whether to resample the test functions at each epoch
1114
1115    epochs : int
1116
1117        Number of training epochs. Default 100
1118
1119    at_each : int
1120
1121        Save results for epochs multiple of at_each. Default 10
1122
1123    activation : str
1124
1125        The name of the activation function of the neural network. Default 'tanh'
1126
1127    neumann : logical
1128
1129        Whether to consider Neumann boundary conditions
1130
1131    oper_neumann : function
1132
1133        Penalization of Neumann boundary conditions
1134
1135    inverse : logical
1136
1137        Whether to estimate parameters of the PDE
1138
1139    initial_par : jax.numpy.array
1140
1141        Initial value of the parameters of the PDE in an inverse problem
1142
1143    lr,b1,b2,eps,eps_root: float
1144
1145        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
1146
1147    key : int
1148
1149        Seed for parameters initialization. Default 0
1150
1151    epoch_print : int
1152
1153        Number of epochs to calculate and print test errors. Default 1
1154
1155    save : logical
1156
1157        Whether to save the current parameters. Default False
1158
1159    file_name : str
1160
1161        File prefix to save the current parameters. Default 'result_pinn'
1162
1163    exp_decay : logical
1164
1165        Whether to consider exponential decay of learning rate. Default True
1166
1167    transition_steps : int
1168
1169        Number of steps for exponential decay. Default 100
1170
1171    decay_rate : float
1172
1173        Rate of exponential decay. Default 0.9
1174
1175    mlp : logical
1176
1177        Whether to consider modifed multilayer perceptron
1178
1179    ftype : str
1180
1181        Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
1182
1183    fargs : list
1184
1185        Arguments for deature transformation:
1186
1187        For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
1188
1189        For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
1190
1191    q : int
1192
1193        Power of weights mask. Default 4
1194
1195    w : dict
1196
1197        Initila weights for self-adaptive scheme.
1198
1199    periodic : logical
1200
1201        Whether to consider periodic test functions. Default False.
1202
1203    static : function
1204
1205        A static function to sum to the neural network output.
1206
1207    opt : str
1208
1209        Optimizer. Default LBFGS.
1210
1211    float64 : logical
1212
1213        Whether to train with float64
1214
1215    restart : int
1216
1217        Epochs to restart L-BFGS
1218
1219    Returns
1220    -------
1221    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch
1222    """
1223    #Initialize architecture
1224    nnet = fconNN(width,get_activation(activation),key,mlp,ftype,fargs,static)
1225    if float64:
1226        forward = lambda x,params: nnet['forward'](x,params).astype(jnp.float64)
1227        assert jax.config.jax_enable_x64, "JAX is NOT running in float64 mode!"
1228    else:
1229        forward = nnet['forward']
1230    if params is not None:
1231        nnet['params'] = params
1232    if float64:
1233        data = to_float64(data)
1234        if test_data is not None:
1235            test_data = to_float64(test_data)
1236
1237    #Generate from Matern process
1238    if sigma > 0:
1239        if isinstance(L,float) or isinstance(L,int):
1240            L = d*[L]
1241        #Grid for weak norm
1242        if float64:
1243            grid = [jnp.linspace(0,L[i],N,dtype = jnp.float64) for i in range(d)]
1244        else:
1245            grid = [jnp.linspace(0,L[i],N) for i in range(d)]
1246        grid = jnp.meshgrid(*grid, indexing='ij')
1247        grid = jnp.stack(grid, axis=-1).reshape((-1, d))
1248        #Set sigma
1249        if data['boundary'] is not None:
1250            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma)
1251            tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0])
1252            if neumann:
1253                loss_boundary = oper_neumann(lambda x: forward(x,nnet['params']),data['boundary'])
1254            else:
1255                loss_boundary = jnp.mean(MSE(forward(data['boundary'],nnet['params']),data['uboundary']))
1256            output_w = pde(lambda x: forward(x,nnet['params']),grid)
1257            integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(tf)
1258            loss_res_weak = jnp.mean(integralOmega ** 2)
1259            sigma = float(jnp.sqrt(loss_boundary/loss_res_weak).tolist())
1260            del gen
1261            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)
1262            tf = sigma*tf
1263        else:
1264            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)
1265            tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0])
1266        if float64 and tf is not None:
1267            tf = to_float64(tf)
1268            grid = to_float64(grid)
1269    else:
1270        tf = None
1271        grid = None
1272
1273    #Define loss function
1274    @jax.jit
1275    def lf_each(params,x,k,tf,grid):
1276        if sigma > 0:
1277            #Term that refers to weak loss
1278            if resample:
1279                test_functions = to_float64(gen(jax.random.split(jax.random.PRNGKey(k[0]),(bsize,))[:,0]))
1280            else:
1281                test_functions = tf
1282        loss_sensor = loss_boundary = loss_initial = loss_res = loss_res_weak = 0
1283        if x['sensor'] is not None:
1284            #Term that refers to sensor data
1285            loss_sensor = MSE(forward(x['sensor'],params['net']),x['usensor'])
1286        if x['boundary'] is not None:
1287            if neumann:
1288                #Neumann coditions
1289                loss_boundary = oper_neumann(lambda x: forward(x,nnet['params']),x['boundary'])
1290            else:
1291                #Term that refers to boundary data
1292                loss_boundary = MSE(forward(x['boundary'],params['net']),x['uboundary'])
1293        if x['initial'] is not None:
1294            #Term that refers to initial data
1295            loss_initial = MSE(forward(x['initial'],params['net']),x['uinitial'])
1296        if x['collocation'] is not None and sigma == 0:
1297            if inverse:
1298                output = pde(lambda x: forward(x,params['net']),x['collocation'],params['inverse'])
1299                loss_res = MSE(output,0)
1300            else:
1301                output = pde(lambda x: forward(x,params['net']),x['collocation'])
1302                loss_res = MSE(output,0)
1303        if sigma > 0:
1304            #Term that refers to weak loss
1305            if inverse:
1306                output_w = pde(lambda x: forward(x,params['net']),grid,params['inverse'])
1307                integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions)
1308                loss_res_weak = jnp.mean(integralOmega ** 2)
1309            else:
1310                output_w = pde(lambda x: forward(x,params['net']),grid)
1311                integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions)
1312                loss_res_weak = jnp.mean(integralOmega ** 2)
1313        return {'ls': loss_sensor,'lb': loss_boundary,'li': loss_initial,'lc': loss_res,'lc_weak': loss_res_weak}
1314
1315    @jax.jit
1316    def lf(params,x,k,tf,grid):
1317        l = lf_each(params,x,k,tf,grid)
1318        if opt != 'LBFGS':
1319            loss = jnp.mean((params['w']['ws'] ** q)*l['ls']) + jnp.mean((params['w']['wb'] ** q)*l['lb']) + jnp.mean((params['w']['wi'] ** q)*l['li']) + jnp.mean((params['w']['wc'] ** q)*l['lc']) + (params['w']['wc_weak'] ** q)*l['lc_weak']
1320            return loss
1321        else:
1322            loss = jnp.mean((w['ws'] ** q)*l['ls']) + jnp.mean((w['wb'] ** q)*l['lb']) + jnp.mean((w['wi'] ** q)*l['li']) + jnp.mean((w['wc'] ** q)*l['lc']) + (w['wc_weak'] ** q)*l['lc_weak']
1323            l2 = None
1324            if test_data is not None:
1325                l2 = L2error(forward(test_data['sensor'],params['net']),test_data['usensor'])
1326            return loss,{'loss': loss,'l2': l2}
1327
1328    #Initialize self-adaptive weights
1329    if float64:
1330        typ = jnp.float64
1331    else:
1332        typ = jnp.float32
1333    if w is None:
1334        w = {'ws': jnp.array(1.0),'wb': jnp.array(1.0),'wi': jnp.array(1.0),'wc': jnp.array(1.0),'wc_weak': jnp.array(1.0)}
1335    if q != 0 and opt != 'LBFGS':
1336        if data['sensor'] is not None:
1337            w['ws'] = w['ws'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+1),(data['sensor'].shape[0],1),dtype = typ)
1338        if data['boundary'] is not None:
1339            w['wb'] = w['wb'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+2),(data['boundary'].shape[0],1),dtype = typ)
1340        if data['initial'] is not None:
1341            w['wi'] = w['wi'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+3),(data['initial'].shape[0],1),dtype = typ)
1342        if data['collocation'] is not None:
1343            w['wc'] = w['wc'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+4),(data['collocation'].shape[0],1),dtype = typ)
1344
1345    #Store all parameters
1346    if opt != 'LBFGS':
1347        params = {'net': nnet['params'],'inverse': initial_par,'w': w}
1348    else:
1349        params = {'net': nnet['params'],'inverse': initial_par}
1350    if float64:
1351        params = to_float64(params)
1352
1353    #Save config file
1354    if save:
1355        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1356
1357    #Initialize Adam Optmizer
1358    if opt != 'LBFGS':
1359        print('--------- GRADIENT DESCENT OPTIMIZER ---------')
1360        if exp_decay:
1361            lr = optax.exponential_decay(lr,transition_steps,decay_rate)
1362        optimizer = optax.adam(lr,b1,b2,eps,eps_root)
1363        opt_state = optimizer.init(params)
1364
1365        #Define the gradient function
1366        grad_loss = jax.jit(jax.grad(lf,0))
1367
1368        #Define update function
1369        @jax.jit
1370        def update(opt_state,params,x,k,tf,grid):
1371            #Compute gradient
1372            grads = grad_loss(params,x,k,tf,grid)
1373            #Calculate parameters updates
1374            updates, opt_state = optimizer.update(grads, opt_state)
1375            #Update parameters
1376            if q != 0:
1377                updates = {**updates, 'w': jax.tree_util.tree_map(lambda x: -x, updates['w'])} #Change signs of weights
1378            params = optax.apply_updates(params, updates)
1379            #Return state of optmizer and updated parameters
1380            return opt_state,params
1381    else:
1382        print('--------- LBFGS OPTIMIZER ---------')
1383        @jax.jit
1384        def loss_LBFGS(params):
1385            return lf(params,data,key + 234,tf,grid)
1386        solver = LBFGS(fun = loss_LBFGS,has_aux = True,maxiter = epochs,tol = 1e-9,verbose = False,linesearch = 'zoom',history_size = 200)  # linesearch='zoom' by default
1387        state = solver.init_state(params)
1388        if float64:
1389            assert_lbfgs_state_float64(state)
1390
1391    ###Training###
1392    t0 = time.time()
1393    k = jax.random.split(jax.random.PRNGKey(key+234),(epochs,))
1394    sloss = []
1395    sL2 = []
1396    stime = []
1397    #Initialize alive_bar for tracing in terminal
1398    with alive_bar(epochs) as bar:
1399        #For each epoch
1400        for e in range(epochs):
1401            if opt != 'LBFGS':
1402                if float64 and e < 10:
1403                    assert_tree_float64(params, name="params (before update)")
1404                    grads = grad_loss(params, data, k[e,:], tf, grid)
1405                    check_grads_float64(grads)
1406                #Update optimizer state and parameters
1407                opt_state,params = update(opt_state,params,data,k[e,:],tf,grid)
1408                sloss.append(lf(params,data,k[e,:],tf,grid))
1409                if test_data is not None:
1410                    sL2.append(L2error(forward(test_data['sensor'],params['net']),test_data['usensor']))
1411                if float64 and e < 10:
1412                    assert_tree_float64(params, name="params (before update)")
1413                    grads = grad_loss(params, data, k[e,:], tf, grid)
1414                    check_grads_float64(grads)
1415            else:
1416                if float64 and e < 10:
1417                    assert_tree_float64(params, name="params (before LBFGS step)")
1418                params, state = solver.update(params, state)
1419                if float64 and e < 10:
1420                    assert_tree_float64(params, name="params (after LBFGS step)")
1421                    assert_lbfgs_state_float64(state)
1422                sL2.append(state.aux["l2"])
1423                sloss.append(state.aux["loss"])
1424                if float64 and e < 10:
1425                    assert_tree_float64(params, name="params (before update)")
1426                if restart is not None:
1427                    if (e + 1) % restart == 0:
1428                        state = solver.init_state(params)
1429            stime.append(time.time() - t0)
1430            #After epoch_print epochs
1431            if e % epoch_print == 0:
1432                #Compute elapsed time and current error
1433                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(sloss[-1],6))
1434                #If there is test data, compute current L2 error
1435                if test_data is not None:
1436                    #Compute L2 error
1437                    l = l + ' L2 error: ' + str(jnp.round(sL2[-1],6))
1438                if inverse:
1439                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
1440                #Print
1441                print(l)
1442            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
1443                #Save current parameters
1444                pickle.dump({'params': params,'width': width,'time': stime,'loss': sloss,'L2error': sL2},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1445            #Update alive_bar
1446            bar()
1447    #Define estimated function
1448    def u(xt):
1449        return forward(xt,params['net'])
1450
1451    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0,'loss_each': lf_each(params,data,[key + 100],tf,grid),'loss': sloss,'L2error': sL2}
1452
1453
1454#Process result
1455def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1):
1456    """
1457    Process the results of a Physics-informed Neural Network
1458    ----------
1459
1460    Parameters
1461    ----------
1462    test_data : dict
1463
1464        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1465
1466    fit : function
1467
1468        The fitted function
1469
1470    train_data : dict
1471
1472        Training data generated by the jinnax.data.generate_PINNdata
1473
1474    plot : logical
1475
1476        Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
1477
1478    plot_test : logical
1479
1480        Whether to plot the test data. Default True
1481
1482    times : int
1483
1484        Number of points along the time interval to plot. Default 5
1485
1486    d2 : logical
1487
1488        Whether to plot 2D plot when the spatial dimension is one. Default True
1489
1490    save : logical
1491
1492        Whether to save the plots. Default False
1493
1494    show : logical
1495
1496        Whether to show the plots. Default True
1497
1498    file_name : str
1499
1500        File prefix to save the plots. Default 'result_pinn'
1501
1502    print_res : logical
1503
1504        Whether to print the L2 error. Default True
1505
1506    p : int
1507
1508        Output dimension. Default 1
1509
1510    Returns
1511    -------
1512    pandas data frame with L2 and MSE errors
1513    """
1514
1515    #Dimension
1516    d = test_data['xt'].shape[1] - 1
1517
1518    #Number of plots multiple of 5
1519    times = 5 * round(times/5)
1520
1521    #Data
1522    td = get_train_data(train_data)
1523    xt_train = td['x']
1524    u_train = td['y']
1525    upred_train = fit(xt_train)
1526    upred_test = fit(test_data['xt'])
1527
1528    #Results
1529    l2_error_test = L2error(upred_test,test_data['u']).tolist()
1530    MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist()
1531    l2_error_train = L2error(upred_train,u_train).tolist()
1532    MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist()
1533
1534    df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)),
1535        columns=['l2_error_test','MSE_test','l2_error_train','MSE_train'])
1536    if print_res:
1537        print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) )
1538
1539    #Plots
1540    if d == 1 and p ==1 and plot:
1541        plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name)
1542    elif p == 2 and plot:
1543        plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test)
1544
1545    return df
1546
1547#Plot results for d = 1
1548def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''):
1549    """
1550    Plot the prediction of a 1D PINN
1551    ----------
1552
1553    Parameters
1554    ----------
1555    times : int
1556
1557        Number of points along the time interval to plot. Default 5
1558
1559    xt : jax.numpy.array
1560
1561        Test data xt array
1562
1563    u : jax.numpy.array
1564
1565        Test data u(x,t) array
1566
1567    upred : jax.numpy.array
1568
1569        Predicted upred(x,t) array on test data
1570
1571    d2 : logical
1572
1573        Whether to plot 2D plot. Default True
1574
1575    save : logical
1576
1577        Whether to save the plots. Default False
1578
1579    show : logical
1580
1581        Whether to show the plots. Default True
1582
1583    file_name : str
1584
1585        File prefix to save the plots. Default 'result_pinn'
1586
1587    title_1d : str
1588
1589        Title of 1D plot
1590
1591    title_2d : str
1592
1593        Title of 2D plot
1594
1595    Returns
1596    -------
1597    None
1598    """
1599    #Initialize
1600    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
1601    tlo = jnp.min(xt[:,-1])
1602    tup = jnp.max(xt[:,-1])
1603    ylo = jnp.min(u)
1604    ylo = ylo - 0.1*jnp.abs(ylo)
1605    yup = jnp.max(u)
1606    yup = yup + 0.1*jnp.abs(yup)
1607    k = 0
1608    t_values = np.linspace(tlo,tup,times)
1609
1610    #Create
1611    for i in range(int(times/5)):
1612        for j in range(5):
1613            if k < len(t_values):
1614                t = t_values[k]
1615                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
1616                x_plot = xt[xt[:,-1] == t,:-1]
1617                y_plot = upred[xt[:,-1] == t,:]
1618                u_plot = u[xt[:,-1] == t,:]
1619                if int(times/5) > 1:
1620                    ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
1621                    ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
1622                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
1623                    ax[i,j].set_xlabel(' ')
1624                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1625                else:
1626                    ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
1627                    ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
1628                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
1629                    ax[j].set_xlabel(' ')
1630                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1631                k = k + 1
1632
1633    #Title
1634    fig.suptitle(title_1d)
1635    fig.tight_layout()
1636
1637    #Show and save
1638    fig = plt.gcf()
1639    if show:
1640        plt.show()
1641    if save:
1642        fig.savefig(file_name + '_slices.png')
1643    plt.close()
1644
1645    #2d plot
1646    if d2:
1647        #Initialize
1648        fig, ax = plt.subplots(1,2)
1649        l1 = jnp.unique(xt[:,-1]).shape[0]
1650        l2 = jnp.unique(xt[:,0]).shape[0]
1651
1652        #Create
1653        ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
1654        ax[0].set_title('Exact')
1655        ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
1656        ax[1].set_title('Predicted')
1657
1658        #Title
1659        fig.suptitle(title_2d)
1660        fig.tight_layout()
1661
1662        #Show and save
1663        fig = plt.gcf()
1664        if show:
1665            plt.show()
1666        if save:
1667            fig.savefig(file_name + '_2d.png')
1668        plt.close()
1669
1670#Plot results for d = 1
1671def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True):
1672    """
1673    Plot the prediction of a PINN with 2D output
1674    ----------
1675    Parameters
1676    ----------
1677    times : int
1678
1679        Number of points along the time interval to plot. Default 5
1680
1681    xt : jax.numpy.array
1682
1683        Test data xt array
1684
1685    u : jax.numpy.array
1686
1687        Test data u(x,t) array
1688
1689    upred : jax.numpy.array
1690
1691        Predicted upred(x,t) array on test data
1692
1693    save : logical
1694
1695        Whether to save the plots. Default False
1696
1697    show : logical
1698
1699        Whether to show the plots. Default True
1700
1701    file_name : str
1702
1703        File prefix to save the plots. Default 'result_pinn'
1704
1705    title : str
1706
1707        Title of plot
1708
1709    plot_test : logical
1710
1711        Whether to plot the test data. Default True
1712
1713    Returns
1714    -------
1715    None
1716    """
1717    #Initialize
1718    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
1719    tlo = jnp.min(xt[:,-1])
1720    tup = jnp.max(xt[:,-1])
1721    xlo = jnp.min(u[:,0])
1722    xlo = xlo - 0.1*jnp.abs(xlo)
1723    xup = jnp.max(u[:,0])
1724    xup = xup + 0.1*jnp.abs(xup)
1725    ylo = jnp.min(u[:,1])
1726    ylo = ylo - 0.1*jnp.abs(ylo)
1727    yup = jnp.max(u[:,1])
1728    yup = yup + 0.1*jnp.abs(yup)
1729    k = 0
1730    t_values = np.linspace(tlo,tup,times)
1731
1732    #Create
1733    for i in range(int(times/5)):
1734        for j in range(5):
1735            if k < len(t_values):
1736                t = t_values[k]
1737                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
1738                xpred_plot = upred[xt[:,-1] == t,0]
1739                ypred_plot = upred[xt[:,-1] == t,1]
1740                if plot_test:
1741                    x_plot = u[xt[:,-1] == t,0]
1742                    y_plot = u[xt[:,-1] == t,1]
1743                if int(times/5) > 1:
1744                    if plot_test:
1745                        ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
1746                    ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction')
1747                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
1748                    ax[i,j].set_xlabel(' ')
1749                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1750                else:
1751                    if plot_test:
1752                        ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
1753                    ax[j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction')
1754                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
1755                    ax[j].set_xlabel(' ')
1756                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1757                k = k + 1
1758
1759    #Title
1760    fig.suptitle(title)
1761    fig.tight_layout()
1762
1763    #Show and save
1764    fig = plt.gcf()
1765    if show:
1766        plt.show()
1767    if save:
1768        fig.savefig(file_name + '_slices.png')
1769    plt.close()
1770
1771#Get train data in one array
1772def get_train_data(train_data):
1773    """
1774    Process training sample
1775    ----------
1776
1777    Parameters
1778    ----------
1779    train_data : dict
1780
1781        A dictionay with train data generated by the jinnax.data.generate_PINNdata function
1782
1783    Returns
1784    -------
1785    dict with the processed training data
1786    """
1787    xdata = None
1788    ydata = None
1789    xydata = None
1790    if train_data['sensor'] is not None:
1791        sensor_sample = train_data['sensor'].shape[0]
1792        xdata = train_data['sensor']
1793        ydata = train_data['usensor']
1794        xydata = jnp.column_stack((train_data['sensor'],train_data['usensor']))
1795    else:
1796        sensor_sample = 0
1797    if train_data['boundary'] is not None:
1798        boundary_sample = train_data['boundary'].shape[0]
1799        if xdata is not None:
1800            xdata = jnp.vstack((xdata,train_data['boundary']))
1801            ydata = jnp.vstack((ydata,train_data['uboundary']))
1802            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary']))))
1803        else:
1804            xdata = train_data['boundary']
1805            ydata = train_data['uboundary']
1806            xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary']))
1807    else:
1808        boundary_sample = 0
1809    if train_data['initial'] is not None:
1810        initial_sample = train_data['initial'].shape[0]
1811        if xdata is not None:
1812            xdata = jnp.vstack((xdata,train_data['initial']))
1813            ydata = jnp.vstack((ydata,train_data['uinitial']))
1814            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial']))))
1815        else:
1816            xdata = train_data['initial']
1817            ydata = train_data['uinitial']
1818            xydata = jnp.column_stack((train_data['initial'],train_data['uinitial']))
1819    else:
1820        initial_sample = 0
1821    if train_data['collocation'] is not None:
1822        collocation_sample = train_data['collocation'].shape[0]
1823    else:
1824        collocation_sample = 0
1825
1826    return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
1827
1828#Process training
1829def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1):
1830    """
1831    Process the training of a Physics-informed Neural Network
1832    ----------
1833
1834    Parameters
1835    ----------
1836    test_data : dict
1837
1838        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1839
1840    file_name : str
1841
1842        Name of the files saved during training
1843
1844    at_each : int
1845
1846        Compute results for epochs multiple of at_each. Default 100
1847
1848    bolstering : logical
1849
1850        Whether to compute bolstering mean square error. Default True
1851
1852    mc_sample : int
1853
1854        Number of sample for Monte Carlo integration in bolstering. Default 10000
1855
1856    save : logical
1857
1858        Whether to save the training results. Default False
1859
1860    file_name_save : str
1861
1862        File prefix to save the plots and the L2 error. Default 'result_pinn'
1863
1864    key : int
1865
1866        Key for random samples in bolstering. Default 0
1867
1868    ec : float
1869
1870        Stopping criteria error for EM algorithm in bolstering. Default 1e-6
1871
1872    lamb : float
1873
1874        Hyperparameter of EM algorithm in bolstering. Default 1
1875
1876    Returns
1877    -------
1878    pandas data frame with training results
1879    """
1880    #Config
1881    config = pickle.load(open(file_name + '_config.pickle', 'rb'))
1882    epochs = config['epochs']
1883    train_data = config['train_data']
1884    forward = config['forward']
1885
1886    #Get train data
1887    td = get_train_data(train_data)
1888    xydata = td['xy']
1889    xdata = td['x']
1890    ydata = td['y']
1891    sensor_sample = td['sensor_sample']
1892    boundary_sample = td['boundary_sample']
1893    initial_sample = td['initial_sample']
1894    collocation_sample = td['collocation_sample']
1895
1896    #Generate keys
1897    if bolstering:
1898        keys = jax.random.split(jax.random.PRNGKey(key),epochs)
1899
1900    #Initialize loss
1901    train_mse = []
1902    test_mse = []
1903    train_L2 = []
1904    test_L2 = []
1905    bolstX = []
1906    bolstXY = []
1907    loss = []
1908    time = []
1909    ep = []
1910
1911    #Process training
1912    with alive_bar(epochs) as bar:
1913        for e in range(epochs):
1914            if (e % at_each == 0 and at_each != epochs) or e == epochs - 1:
1915                ep = ep + [e]
1916
1917                #Read parameters
1918                params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb'))
1919
1920                #Time
1921                time = time + [params['time']]
1922
1923                #Define learned function
1924                def psi(x):
1925                    return forward(x,params['params']['net'])
1926
1927                #Train MSE and L2
1928                if xdata is not None:
1929                    train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()]
1930                    train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()]
1931                else:
1932                    train_mse = train_mse + [None]
1933                    train_L2 = train_L2 + [None]
1934
1935                #Test MSE and L2
1936                test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()]
1937                test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()]
1938
1939                #Bolstering
1940                if bolstering:
1941                    bX = []
1942                    bXY = []
1943                    for method in ['chi','mm','mpe']:
1944                        kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
1945                        kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
1946                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
1947                        bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()]
1948                    for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]:
1949                        kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias)
1950                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
1951                    bolstX = bolstX + [bX]
1952                    bolstXY = bolstXY + [bXY]
1953                else:
1954                    bolstX = bolstX + [None]
1955                    bolstXY = bolstXY + [None]
1956
1957                #Loss
1958                loss = loss + [params['loss'].tolist()]
1959
1960                #Delete
1961                del params, psi
1962            #Update alive_bar
1963            bar()
1964
1965    #Bolstering results
1966    if bolstering:
1967        bolstX = jnp.array(bolstX)
1968        bolstXY = jnp.array(bolstXY)
1969
1970    #Create data frame
1971    if bolstering:
1972        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
1973            train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]),
1974            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4'])
1975    else:
1976        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
1977            train_mse,test_mse,train_L2,test_L2]),
1978            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2'])
1979    if save:
1980        df.to_csv(file_name_save + '.csv',index = False)
1981
1982    return df
1983
1984#Demo video for training1D PINN
1985def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10):
1986    """
1987    Demo video with the training of a 1D PINN
1988    ----------
1989
1990    Parameters
1991    ----------
1992    test_data : dict
1993
1994        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1995
1996    file_name : str
1997
1998        Name of the files saved during training
1999
2000    at_each : int
2001
2002        Compute results for epochs multiple of at_each. Default 100
2003
2004    times : int
2005
2006        Number of points along the time interval to plot. Default 5
2007
2008    d2 : logical
2009
2010        Whether to make video demo of 2D plot. Default True
2011
2012    file_name_save : str
2013
2014        File prefix to save the plots and videos. Default 'result_pinn_demo'
2015
2016    title : str
2017
2018        Title for plots
2019
2020    framerate : int
2021
2022        Framerate for video. Default 10
2023
2024    Returns
2025    -------
2026    None
2027    """
2028    #Config
2029    with open(file_name + '_config.pickle', 'rb') as file:
2030        config = pickle.load(file)
2031    epochs = config['epochs']
2032    train_data = config['train_data']
2033    forward = config['forward']
2034
2035    #Get train data
2036    td = get_train_data(train_data)
2037    xt = td['x']
2038    u = td['y']
2039
2040    #Create folder to save plots
2041    os.system('mkdir ' + file_name_save)
2042
2043    #Create images
2044    k = 1
2045    with alive_bar(epochs) as bar:
2046        for e in range(epochs):
2047            if e % at_each == 0 or e == epochs - 1:
2048                #Read parameters
2049                params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
2050
2051                #Define learned function
2052                def psi(x):
2053                    return forward(x,params['params']['net'])
2054
2055                #Compute L2 train, L2 test and loss
2056                loss = params['loss']
2057                L2_train = L2error(psi(xt),u)
2058                L2_test = L2error(psi(test_data['xt']),test_data['u'])
2059                title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6))
2060
2061                #Save plot
2062                plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch)
2063                k = k + 1
2064
2065                #Delete
2066                del params, psi, loss, L2_train, L2_test, title_epoch
2067            #Update alive_bar
2068            bar()
2069    #Create demo video
2070    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4')
2071    if d2:
2072        os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')
2073
2074#Demo in time for 1D PINN
2075def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10):
2076    """
2077    Demo video with the time evolution of a 1D PINN
2078    ----------
2079
2080    Parameters
2081    ----------
2082    test_data : dict
2083
2084        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
2085
2086    file_name : str
2087
2088        Name of the files saved during training
2089
2090    epochs : list
2091
2092        Which training epochs to plot
2093
2094    file_name_save : str
2095
2096        File prefix to save the plots and video. Default 'result_pinn_time_demo'
2097
2098    title : str
2099
2100        Title for plots
2101
2102    framerate : int
2103
2104        Framerate for video. Default 10
2105
2106    Returns
2107    -------
2108    None
2109    """
2110    #Config
2111    with open(file_name + '_config.pickle', 'rb') as file:
2112        config = pickle.load(file)
2113    train_data = config['train_data']
2114    forward = config['forward']
2115
2116    #Create folder to save plots
2117    os.system('mkdir ' + file_name_save)
2118
2119    #Plot parameters
2120    tdom = jnp.unique(test_data['xt'][:,-1])
2121    ylo = jnp.min(test_data['u'])
2122    ylo = ylo - 0.1*jnp.abs(ylo)
2123    yup = jnp.max(test_data['u'])
2124    yup = yup + 0.1*jnp.abs(yup)
2125
2126    #Open PINN for each epoch
2127    results = []
2128    upred = []
2129    for e in epochs:
2130        tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
2131        results = results + [tmp]
2132        upred = upred + [forward(test_data['xt'],tmp['params']['net'])]
2133
2134    #Create images
2135    k = 1
2136    with alive_bar(len(tdom)) as bar:
2137        for t in tdom:
2138            #Test data
2139            xt_step = test_data['xt'][test_data['xt'][:,-1] == t]
2140            u_step = test_data['u'][test_data['xt'][:,-1] == t]
2141            #Initialize plot
2142            if len(epochs) > 1:
2143                fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2))
2144            else:
2145                fig, ax = plt.subplots(1,1,figsize = (10,5))
2146            #Create
2147            index = 0
2148            if int(len(epochs)/2) > 1:
2149                for i in range(int(len(epochs)/2)):
2150                    for j in range(min(2,len(epochs))):
2151                        upred_step = upred[index][test_data['xt'][:,-1] == t]
2152                        ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2153                        ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2154                        ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2155                        ax[i,j].set_xlabel(' ')
2156                        ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2157                        index = index + 1
2158            elif len(epochs) > 1:
2159                for j in range(2):
2160                    upred_step = upred[index][test_data['xt'][:,-1] == t]
2161                    ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2162                    ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2163                    ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2164                    ax[j].set_xlabel(' ')
2165                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2166                    index = index + 1
2167            else:
2168                upred_step = upred[index][test_data['xt'][:,-1] == t]
2169                ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2170                ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2171                ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2172                ax.set_xlabel(' ')
2173                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2174                index = index + 1
2175
2176
2177            #Title
2178            fig.suptitle(title + 't = ' + str(round(t,4)))
2179            fig.tight_layout()
2180
2181            #Show and save
2182            fig = plt.gcf()
2183            fig.savefig(file_name_save + '/' + str(k) + '.png')
2184            k = k + 1
2185            plt.close()
2186            bar()
2187
2188    #Create demo video
2189    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
def assert_tree_float64(tree, name='object'):
29def assert_tree_float64(tree, name="object"):
30    """
31    Raises an error if any leaf in a pytree is not float64.
32    """
33    leaves, _ = tree_flatten(tree)
34    bad = [
35        (i, x.dtype)
36        for i, x in enumerate(leaves)
37        if hasattr(x, "dtype") and x.dtype != jnp.float64
38    ]
39    if bad:
40        raise RuntimeError(
41            f"[FLOAT64 CHECK FAILED] {name} contains non-float64 dtypes:\n"
42            + "\n".join([f"  leaf {i}: {dtype}" for i, dtype in bad])
43        )

Raises an error if any leaf in a pytree is not float64.

def warn_tree_float64(tree, name='object'):
46def warn_tree_float64(tree, name="object"):
47    """
48    Same as assert_tree_float64, but prints a warning instead of raising.
49    """
50    leaves, _ = tree_flatten(tree)
51    dtypes = {x.dtype for x in leaves if hasattr(x, "dtype")}
52    if dtypes != {jnp.float64}:
53        print(f"[WARNING] {name} dtypes = {dtypes}")

Same as assert_tree_float64, but prints a warning instead of raising.

def check_grads_float64(grads):
56def check_grads_float64(grads):
57    leaves, _ = tree_flatten(grads)
58    bad = [
59        (i, x.dtype)
60        for i, x in enumerate(leaves)
61        if hasattr(x, "dtype") and x.dtype != jnp.float64
62    ]
63    if bad:
64        raise RuntimeError(
65            "[FLOAT64 CHECK FAILED] Gradients are not float64:\n"
66            + "\n".join([f"  grad {i}: {dtype}" for i, dtype in bad])
67        )
def assert_lbfgs_state_float64(state):
70def assert_lbfgs_state_float64(state):
71    """
72    Ensures that all *floating-point* values inside the LBFGS state
73    are float64. Integers and booleans are allowed.
74    """
75    leaves, _ = tree_flatten(state)
76    bad = []
77    for i, x in enumerate(leaves):
78        if isinstance(x, jnp.ndarray):
79            if jnp.issubdtype(x.dtype, jnp.floating):
80                if x.dtype != jnp.float64:
81                    bad.append((i, x.dtype))
82    if bad:
83        raise RuntimeError(
84            "[FLOAT64 CHECK FAILED] LBFGS floating-point buffers are not float64:\n"
85            + "\n".join([f"  leaf {i}: {dtype}" for i, dtype in bad])
86        )

Ensures that all floating-point values inside the LBFGS state are float64. Integers and booleans are allowed.

def to_float64(tree):
90def to_float64(tree):
91    def cast(x):
92        if isinstance(x, (jax.Array, np.ndarray, float, int)):
93            return x.astype(jnp.float64)
94        return x
95    return jax.tree_util.tree_map(cast, tree)
@jax.jit
def MSE(pred, true):
 99@jax.jit
100def MSE(pred,true):
101    """
102    Squared error
103    ----------
104    Parameters
105    ----------
106    pred : jax.numpy.array
107
108        A JAX numpy array with the predicted values
109
110    true : jax.numpy.array
111
112        A JAX numpy array with the true values
113
114    Returns
115    -------
116    squared error
117    """
118    return (true - pred) ** 2
Squared error
Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
Returns
  • squared error
@jax.jit
def MSE_SA(pred, true, w, q=2):
121@jax.jit
122def MSE_SA(pred,true,w,q = 2):
123    """
124    Self-adaptative squared error
125    ----------
126    Parameters
127    ----------
128    pred : jax.numpy.array
129
130        A JAX numpy array with the predicted values
131
132    true : jax.numpy.array
133
134        A JAX numpy array with the true values
135
136    weight : jax.numpy.array
137
138        A JAX numpy array with the weights
139
140    q : float
141
142        Power for the weights mask
143
144    Returns
145    -------
146    self-adaptative squared error with polynomial mask
147    """
148    return (w ** q) * ((true - pred) ** 2)

Self-adaptative squared error

Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
  • weight (jax.numpy.array): A JAX numpy array with the weights
  • q (float): Power for the weights mask
Returns
  • self-adaptative squared error with polynomial mask
@jax.jit
def L2error(pred, true):
151@jax.jit
152def L2error(pred,true):
153    """
154    L2-error in percentage (%)
155    ----------
156    Parameters
157    ----------
158    pred : jax.numpy.array
159
160        A JAX numpy array with the predicted values
161
162    true : jax.numpy.array
163
164        A JAX numpy array with the true values
165
166    Returns
167    -------
168    L2-error
169    """
170    return 100*jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))

L2-error in percentage (%)

Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
Returns
  • L2-error
def idst1(x, axis=-1):
173def idst1(x,axis = -1):
174    """
175    Inverse Discrete Sine Transform of type I with orthonormal scaling
176    ----------
177    Parameters
178    ----------
179    x : jax.numpy.array
180
181        Array to apply the transformation
182
183    axis : int
184
185        Axis to apply the transformation over
186
187    Returns
188    -------
189    jax.numpy.array
190    """
191    return idst(x,type = 1,axis = axis,norm = 'ortho')
Inverse Discrete Sine Transform of type I with orthonormal scaling
Parameters
  • x (jax.numpy.array): Array to apply the transformation
  • axis (int): Axis to apply the transformation over
Returns
  • jax.numpy.array
def dstn(x, axes=None):
193def dstn(x,axes = None):
194    """
195    Discrete Sine Transform of type I with orthonormal scaling over many axes
196    ----------
197    Parameters
198    ----------
199    x : jax.numpy.array
200
201        Array to apply the transformation
202
203    axes : int
204
205        Axes to apply the transformation over
206
207    Returns
208    -------
209    jax.numpy.array
210    """
211    if axes is None:
212        axes = tuple(range(x.ndim))
213    y = x
214    for ax in axes:
215        y = dst(y,type = 1,axis = ax,norm = 'ortho')
216    return y
Discrete Sine Transform of type I with orthonormal scaling over many axes
Parameters
  • x (jax.numpy.array): Array to apply the transformation
  • axes (int): Axes to apply the transformation over
Returns
  • jax.numpy.array
def idstn(x, axes=None):
218def idstn(x,axes = None):
219    """
220    Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes
221    ----------
222    Parameters
223    ----------
224    x : jax.numpy.array
225
226        Array to apply the transformation
227
228    axes : tuple
229
230        Axes to apply the transformation over
231
232    Returns
233    -------
234    jax.numpy.array
235    """
236    if axes is None:
237        axes = tuple(range(x.ndim))
238    y = x
239    for ax in axes:
240        y = idst1(y,axis = ax)
241    return y
Inverse Discrete Sine Transform of type I with orthonormal scaling over many axes
Parameters
  • x (jax.numpy.array): Array to apply the transformation
  • axes (tuple): Axes to apply the transformation over
Returns
  • jax.numpy.array
def dirichlet_eigs_nd(n, L):
243def dirichlet_eigs_nd(n,L):
244    """
245    Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle
246    ----------
247    Parameters
248    ----------
249    n : list
250
251        List with the number of points in the grid in each dimension
252
253    L : list
254
255        List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero
256
257    Returns
258    -------
259    jax.numpy.array
260    """
261    #Unidimensional eigenvalues
262    lam_axes = []
263    for ni, Li in zip(n,L):
264        h = Li / (ni + 1)
265        k = jnp.arange(1,ni + 1)
266        ln = (2 / (h*h)) * (1 - jnp.cos(jnp.pi * k / (ni + 1)))
267        lam_axes.append(ln)
268    grids = jnp.meshgrid(*lam_axes, indexing='ij')
269    Lam = jnp.zeros_like(grids[0])
270    for g in grids:
271        Lam += g
272    return Lam

Eigenvalues of the discrete Dirichlet-Laplace operator in a rectangle

Parameters
  • n (list): List with the number of points in the grid in each dimension
  • L (list): List with the upper limit of the interval of the domain in each dimension. Assumed the lower limit is zero
Returns
  • jax.numpy.array
def generate_matern_sample(key, d=2, N=128, L=1.0, kappa=1, alpha=1, sigma=1, periodic=False):
276def generate_matern_sample(key,d = 2,N = 128,L = 1.0,kappa = 1,alpha = 1,sigma = 1,periodic = False):
277    """
278    Sample d-dimensional Matern process
279    ----------
280    Parameters
281    ----------
282    key : int
283
284        Seed for randomization
285
286    d : int
287
288        Dimension. Default 2
289
290    N : int
291
292        Size of grid in each dimension. Default 128
293
294    L : list of float
295
296        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
297
298    kappa,alpha,sigma : float
299
300        Parameters of the Matern process
301
302    periodic : logical
303
304        Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
305
306    Returns
307    -------
308    jax.numpy.array
309    """
310    if periodic:
311        #Shape and key
312        key = jax.random.PRNGKey(key)
313        shape = (N,) * d
314        if isinstance(L,float) or isinstance(L,int):
315            L = d*[L]
316        if isinstance(N,float) or isinstance(N,int):
317            N = d*[N]
318
319        #Setup Frequency Grid (2D)
320        freq = [jnp.fft.fftfreq(N[j],d = L[j]/N[j]) * 2 * jnp.pi for j in range(d)]
321        grids = jnp.meshgrid(*freq, indexing='ij')
322        sq_norm_xi = sum(g**2 for g in grids)
323
324        #Generate White Noise in Fourier Space
325        key_re, key_im = jax.random.split(key)
326        white_noise_f = (jax.random.normal(key_re, shape) +
327                         1j * jax.random.normal(key_im, shape))
328
329        #Apply the Whittle Filter
330        amplitude_filter = (kappa ** 2 + sq_norm_xi) ** (-alpha / 2)
331        field_f = white_noise_f * amplitude_filter
332
333        #Transform back to Physical Space
334        sample = jnp.real(jnp.fft.ifftn(field_f))
335        return sigma*sample
336    else: #NOT JAX
337        #Shape and key
338        rng = np.random.default_rng(seed = key)
339        if isinstance(L,float) or isinstance(L,int):
340            L = d*[L]
341        if isinstance(N,float) or isinstance(N,int):
342            N = d*[N]
343        shape = tuple(N)
344
345        #White noise in real space
346        W = rng.standard_normal(size = shape)
347
348        #To Dirichlet eigenbasis via separable DST-I (orthonormal)
349        W_hat = dstn(W)
350
351        #Discrete Dirichlet Laplacian eigenvalues
352        lam = dirichlet_eigs_nd(N, L)
353
354        #Spectral filter
355        filt = ((kappa + lam) ** (-alpha/2))
356        psi_hat = filt * W_hat
357
358        #Back to real space
359        psi = idstn(psi_hat)
360        return jnp.array(sigma*psi)

Sample d-dimensional Matern process

Parameters
  • key (int): Seed for randomization
  • d (int): Dimension. Default 2
  • N (int): Size of grid in each dimension. Default 128
  • L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
  • kappa,alpha,sigma (float): Parameters of the Matern process
  • periodic (logical): Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
Returns
  • jax.numpy.array
def generate_matern_sample_batch(d=2, N=512, L=1.0, kappa=10.0, alpha=1, sigma=10, periodic=False):
363def generate_matern_sample_batch(d = 2,N = 512,L = 1.0,kappa = 10.0,alpha = 1,sigma = 10,periodic = False):
364    """
365    Create function to sample d-dimensional Matern process
366    ----------
367    Parameters
368    ----------
369    d : int
370
371        Dimension. Default 2
372
373    N : int
374
375        Size of grid in each dimension. Default 128
376
377    L : list of float
378
379        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
380
381    kappa,alpha,sigma : float
382
383        Parameters of the Matern process
384
385    periodic : logical
386
387        Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
388
389    Returns
390    -------
391    function
392    """
393    if periodic:
394        return jax.vmap(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic))
395    else:
396        return lambda keys: jnp.array(np.apply_along_axis(lambda k: generate_matern_sample(k,d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic),1,keys.reshape((keys.shape[0],1))))

Create function to sample d-dimensional Matern process

Parameters
  • d (int): Dimension. Default 2
  • N (int): Size of grid in each dimension. Default 128
  • L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
  • kappa,alpha,sigma (float): Parameters of the Matern process
  • periodic (logical): Whether to sample with periodic boundary conditions. Periodic = False is not JAX native and does not work with JIT
Returns
  • function
def eigenf_laplace(L_vec, kmax_per_axis=None, bc='dirichlet', max_ef=None):
399def eigenf_laplace(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None):
400    """
401    Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace.
402    ----------
403    Parameters
404    ----------
405    L_vec : list of float
406
407        The domain of the function in each coordinate is [0,L[1]]
408
409    kmax_per_axis : list
410
411        List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions
412
413    bc : str
414
415        Boundary condition. 'dirichlet' or 'neumann'
416
417    max_ef : int
418
419        Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
420
421    Returns
422    -------
423    function to compute eigenfunctions,eigenvalues of the eigenfunctions considered
424    """
425    #Parameters
426    L_vec = jnp.asarray(L_vec)
427    d = L_vec.shape[0]
428    bc = bc.lower()
429
430    #Maximum number of functions
431    if max_ef is None:
432        if d == 1:
433            max_ef = jnp.max(jnp.array(kmax_per_axis))
434        else:
435            max_ef = jnp.max(d * jnp.array(kmax_per_axis))
436
437    #Build the candidate multi-indices per axis
438    kmax_per_axis = list(map(int, kmax_per_axis))
439    if bc.startswith("d"):
440        axis_ranges = [range(1, km + 1) for km in kmax_per_axis]
441    elif bc.startswith("n"):
442        axis_ranges = [range(0, km + 1) for km in kmax_per_axis]
443
444    #Get all multi-indices
445    Ks_list = list(product(*axis_ranges))
446    Ks = jnp.array(Ks_list)
447
448    #Eigenvalues of the continuous Laplacian
449    pi_over_L = jnp.pi / L_vec
450    lambdas_all = jnp.sum((Ks * pi_over_L) ** 2, axis=1)
451
452    #Sort by eigenvalue
453    order = jnp.argsort(lambdas_all)
454    Ks = Ks[order]
455    lambdas_all = lambdas_all[order]
456
457    #Keep first max_ef
458    Ks = Ks[:max_ef]
459    lambdas = lambdas_all[:max_ef]
460    m = Ks.shape[0]
461
462    #Precompute per-feature normalization factor (closed form)
463    def per_axis_norm_factor(k_i, L_i, is_dirichlet):
464        if is_dirichlet:
465            return jnp.sqrt(2 / L_i)
466        else:
467            return jnp.where(k_i == 0, jnp.sqrt(1 / L_i), jnp.sqrt(2 / L_i))
468    if bc.startswith("d"):
469        nf = jnp.prod(jnp.sqrt(2 / L_vec)[None, :],axis = 1)
470        norm_factors = jnp.ones((m,)) * nf
471    else:
472        # per-mode product across axes
473        def nf_row(k_row):
474            return jnp.prod(per_axis_norm_factor(k_row, L_vec, False))
475        norm_factors = jax.vmap(nf_row)(Ks)
476
477    #Build the callable function
478    Ks_int = Ks  # float array, but only integer values
479    L_vec_f = L_vec
480    @jax.jit
481    def phi(x):
482        x = jnp.asarray(x)
483        #Initialize with ones
484        vals = jnp.ones(x.shape[:-1] + (m,))
485        #Compute eigenfunction
486        for i in range(d):
487            ang = (jnp.pi / L_vec_f[i]) * x[..., i][..., None] * Ks_int[:, i]
488            if bc.startswith("d"):
489                comp = jnp.sin(ang)
490            else:
491                comp = jnp.cos(ang)
492            vals = vals * comp
493        #Apply L2-normalizing constants
494        vals = vals * norm_factors[None, ...] if vals.ndim > 1 else vals * norm_factors
495        return vals
496    return phi, lambdas

Create function to compute in batches the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace.

Parameters
  • L_vec (list of float): The domain of the function in each coordinate is [0,L[1]]
  • kmax_per_axis (list): List with the maximum number of eigenfunctions per dimension. Consider d * max(kmax_per_axis) eigenfunctions
  • bc (str): Boundary condition. 'dirichlet' or 'neumann'
  • max_ef (int): Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
Returns
  • function to compute eigenfunctions,eigenvalues of the eigenfunctions considered
def multiple_daff(L_vec, kmax_per_axis=None, bc='dirichlet', max_ef=None):
499def multiple_daff(L_vec,kmax_per_axis = None,bc = "dirichlet",max_ef = None):
500    """
501    Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain.
502    ----------
503    Parameters
504    ----------
505    L_vec : list of lists of float
506
507        List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]]
508
509    kmax_per_axis : list
510
511        List with the maximum number of eigenfunctions per dimension.
512
513    bc : str
514
515        Boundary condition. 'dirichlet' or 'neumann'
516
517    max_ef : int
518
519        Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
520
521    Returns
522    -------
523    function to compute daff,eigenvalues of the eigenfunctions considered
524    """
525    psi = []
526    lamb = []
527    for L in L_vec:
528        tmp,l = eigenf_laplace(L,kmax_per_axis,bc,max_ef) #Get function
529        lamb.append(l)
530        psi.append(tmp)
531        del tmp
532    #Create function to compute features
533    @jax.jit
534    def mff(x):
535        y = []
536        for i in range(len(psi)):
537            y.append(psi[i](x))
538        if len(psi) == 1:
539            return y[0]
540        else:
541            return jnp.concatenate(y,1)
542    return mff,jnp.concatenate(lamb)

Create function to compute multiple frequences of the eigenfunctions of the Dirichlet-Laplace or Neumann-Laplace. Each frequences is a different domain.

Parameters
  • L_vec (list of lists of float): List with the domain of each frequence of the eigenfunctions in the form [0,L[i][1]]
  • kmax_per_axis (list): List with the maximum number of eigenfunctions per dimension.
  • bc (str): Boundary condition. 'dirichlet' or 'neumann'
  • max_ef (int): Maximum number of eigenfunctions to consider among the ones with greatest eigenvalues. If None, considers d * max(kmax_per_axis) eigenfunctions
Returns
  • function to compute daff,eigenvalues of the eigenfunctions considered
@partial(jax.jit, static_argnums=(2,))
def multiple_cheb_fast(x, L_vec, n: int):
570@partial(jax.jit,static_argnums=(2,))  # n is static here; compile once per n
571def multiple_cheb_fast(x, L_vec, n: int):
572    """
573    x: (N, d)
574    L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension
575    n: number of k terms (static)
576    returns: (N, L*n)
577    """
578    N, d = x.shape
579    L = L_vec.shape[0]
580
581    a = 0
582    b = L_vec                       # (L, d)
583    # Map x to t in [-1, 1] for each l, j: shape (L, N, d)
584    t = (2 * x[None, :, :] - (a + b)[:, None, :]) / (b - a)[:, None, :]
585
586    # Chebyshev T_0..T_{n+2} for all (L, N, d): shape (n+3, L, N, d)
587    T = _chebyshev_T_all(t, n + 2)
588
589    # phi_k = T_{k+2} - T_k, k = 0..n-1  => shape (n, L, N, d)
590    ks = jnp.arange(n)
591    phi = T[ks + 2, ...] - T[ks, ...]
592
593    # Multiply across dimensions (over the last axis = d) => (n, L, N)
594    z = jnp.prod(phi, axis=-1)
595
596    # Reorder to (N, L, n) then flatten to (N, L*n)
597    z = jnp.transpose(z, (2, 1, 0)).reshape(N, L * n)
598    return z

x: (N, d) L_vec: (L, d) containing 'b' endpoints (a is 0) for each dimension n: number of k terms (static) returns: (N, L*n)

def multiple_cheb(L_vec, n: int):
600def multiple_cheb(L_vec, n: int):
601    """
602    Factory that closes over static n and L_vec (so shapes are constant).
603    """
604    L_vec = jnp.asarray(L_vec)
605    @jax.jit  # optional; multiple_cheb_fast is already jitted
606    def mcheb(x):
607        x = jnp.asarray(x)
608        return multiple_cheb_fast(x, L_vec, n)
609    return mcheb

Factory that closes over static n and L_vec (so shapes are constant).

def fconNN( width, activation=<PjitFunction of <function tanh>>, key=0, mlp=False, ftype=None, fargs=None, static=None, daff=None):
613def fconNN(width,activation = jax.nn.tanh,key = 0,mlp = False,ftype = None,fargs = None,static = None,daff = None):
614    """
615    Initialize fully connected neural network
616    ----------
617    Parameters
618    ----------
619    width : list
620
621        List with the layers width
622
623    activation : jax.nn activation
624
625        The activation function. Default jax.nn.tanh
626
627    key : int
628
629        Seed for parameters initialization. Default 0
630
631    mlp : logical
632
633        Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension.
634
635    ftype : str
636
637        Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
638
639    fargs : list
640
641        Arguments for deature transformation:
642
643        For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
644
645        For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
646
647    static : function
648
649        A static function to sum to the neural network output.
650
651    daff : list
652
653    List with function to compute daff and the number of daff. If None computes assuming rectangular domain.
654
655    Returns
656    -------
657    dict with initial parameters and the function for the forward pass
658    """
659    #Initialize parameters with Glorot initialization
660    initializer = jax.nn.initializers.glorot_normal()
661    params = list()
662    if static is None:
663        static = lambda x: 0
664
665    #Feature mapping
666    if ftype == 'ff': #Fourrier features
667        for s in range(fargs[0]):
668            sd = fargs[1] ** ((s + 1)/fargs[0])
669            if s == 0:
670                Bff = sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2)))
671            else:
672                Bff = jnp.append(Bff,sd*jax.random.normal(jax.random.PRNGKey(key + s + 1),(width[0],int(width[1]/2))),1)
673        @jax.jit
674        def phi(x):
675            x = x @ Bff
676            return jnp.concatenate([jnp.sin(2 * jnp.pi * x),jnp.cos(2 * jnp.pi * x)],axis = -1)
677        width = width[1:]
678        width[0] = 2*Bff.shape[1]
679    elif ftype == 'daff' or ftype == 'daff_bias':
680        if not isinstance(fargs, dict):
681            fargs = {'L': fargs,'bc': "dirichlet"}
682        if daff is None:
683            phi,lamb = multiple_daff(list(fargs.values())[0],kmax_per_axis = [width[1]] * width[0],bc = list(fargs.values())[1])
684            width = width[1:]
685            width[0] = lamb.shape[0]
686        else:
687            phi = daff[0]
688            width = width[1:]
689            width[0] = daff[1]
690    elif ftype == 'cheb' or ftype == 'cheb_bias':
691        phi = multiple_cheb(fargs,n = width[1])
692        width = width[1:]
693        width[0] = len(fargs)*width[0]
694    else:
695        @jax.jit
696        def phi(x):
697            return x
698
699    #Initialize parameters
700    if mlp:
701        k = jax.random.split(jax.random.PRNGKey(key),4)
702        WU = initializer(k[0],(width[0],width[1]))
703        BU = initializer(k[1],(1,width[1]))
704        WV = initializer(k[2],(width[0],width[1]))
705        BV = initializer(k[3],(1,width[1]))
706        params.append({'WU':WU,'BU':BU,'WV':WV,'BV':BV})
707    key = jax.random.split(jax.random.PRNGKey(key + 1),len(width)-1) #Seed for initialization
708    for key,lin,lout in zip(key,width[:-1],width[1:]):
709        W = initializer(key,(lin,lout))
710        B = initializer(key,(1,lout))
711        params.append({'W':W,'B':B})
712
713    #Define function for forward pass
714    if mlp:
715        if ftype != 'daff' and ftype != 'cheb':
716            @jax.jit
717            def forward(x,params):
718                encode,*hidden,output = params
719                sx = static(x)
720                x = phi(x)
721                U = activation(x @ encode['WU'] + encode['BU'])
722                V = activation(x @ encode['WV'] + encode['BV'])
723                for layer in hidden:
724                    x = activation(x @ layer['W'] + layer['B'])
725                    x = x * U + (1 - x) * V
726                return x @ output['W'] + output['B'] + sx
727        else:
728            @jax.jit
729            def forward(x,params):
730                encode,*hidden,output = params
731                sx = static(x)
732                x = phi(x)
733                U = activation(x @ encode['WU'])
734                V = activation(x @ encode['WV'])
735                for layer in hidden:
736                    x = activation(x @ layer['W'])
737                    x = x * U + (1 - x) * V
738                return x @ output['W'] + sx
739    else:
740        if ftype != 'daff' and ftype != 'cheb':
741            @jax.jit
742            def forward(x,params):
743                *hidden,output = params
744                sx = static(x)
745                x = phi(x)
746                for layer in hidden:
747                    x = activation(x @ layer['W'] + layer['B'])
748                return x @ output['W'] + output['B'] + sx
749        else:
750            @jax.jit
751            def forward(x,params):
752                *hidden,output = params
753                sx = static(x)
754                x = phi(x)
755                for layer in hidden:
756                    x = activation(x @ layer['W'])
757                return x @ output['W'] + sx
758
759    #Return initial parameters and forward function
760    return {'params': params,'forward': forward}
Initialize fully connected neural network
Parameters
  • width (list): List with the layers width
  • activation (jax.nn activation): The activation function. Default jax.nn.tanh
  • key (int): Seed for parameters initialization. Default 0
  • mlp (logical): Whether to consider a modified multilayer perceptron. Assumes all hidden layers have the same dimension.
  • ftype (str): Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
  • fargs (list): Arguments for deature transformation:

    For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.

    For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.

  • static (function): A static function to sum to the neural network output.
  • daff (list):

  • List with function to compute daff and the number of daff. If None computes assuming rectangular domain.

Returns
  • dict with initial parameters and the function for the forward pass
def get_activation(act):
763def get_activation(act):
764    """
765    Return activation function from string
766    ----------
767    Parameters
768    ----------
769    act : str
770
771        Name of the activation function. Default 'tanh'
772
773    Returns
774    -------
775    jax.nn activation function
776    """
777    if act == 'tanh':
778        return jax.nn.tanh
779    elif act == 'relu':
780        return jax.nn.relu
781    elif act == 'relu6':
782        return jax.nn.relu6
783    elif act == 'sigmoid':
784        return jax.nn.sigmoid
785    elif act == 'softplus':
786        return jax.nn.softplus
787    elif act == 'sparse_plus':
788        return jax.nn.sparse_plus
789    elif act == 'soft_sign':
790        return jax.nn.soft_sign
791    elif act == 'silu':
792        return jax.nn.silu
793    elif act == 'swish':
794        return jax.nn.swish
795    elif act == 'log_sigmoid':
796        return jax.nn.log_sigmoid
797    elif act == 'leaky_relu':
798        return jax.nn.leaky_relu
799    elif act == 'hard_sigmoid':
800        return jax.nn.hard_sigmoid
801    elif act == 'hard_silu':
802        return jax.nn.hard_silu
803    elif act == 'hard_swish':
804        return jax.nn.hard_swish
805    elif act == 'hard_tanh':
806        return jax.nn.hard_tanh
807    elif act == 'elu':
808        return jax.nn.elu
809    elif act == 'celu':
810        return jax.nn.celu
811    elif act == 'selu':
812        return jax.nn.selu
813    elif act == 'gelu':
814        return jax.nn.gelu
815    elif act == 'glu':
816        return jax.nn.glu
817    elif act == 'squareplus':
818        return  jax.nn.squareplus
819    elif act == 'mish':
820        return jax.nn.mish
Return activation function from string
Parameters
  • act (str): Name of the activation function. Default 'tanh'
Returns
  • jax.nn activation function
def train_PINN( data, width, pde, test_data=None, epochs=100, at_each=10, activation='tanh', neumann=False, oper_neumann=False, sa=False, c={'ws': 1, 'wr': 1, 'w0': 100, 'wb': 1}, inverse=False, initial_par=None, lr=0.001, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, key=0, epoch_print=100, save=False, file_name='result_pinn', exp_decay=False, transition_steps=1000, decay_rate=0.9, mlp=False):
 823def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,mlp = False):
 824    """
 825    Train a Physics-informed Neural Network
 826    ----------
 827    Parameters
 828    ----------
 829    data : dict
 830
 831        Data generated by the jinnax.data.generate_PINNdata function
 832
 833    width : list
 834
 835        A list with the width of each layer
 836
 837    pde : function
 838
 839        The partial differential operator. Its arguments are u, x and t
 840
 841    test_data : dict, None
 842
 843        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
 844
 845    epochs : int
 846
 847        Number of training epochs. Default 100
 848
 849    at_each : int
 850
 851        Save results for epochs multiple of at_each. Default 10
 852
 853    activation : str
 854
 855        The name of the activation function of the neural network. Default 'tanh'
 856
 857    neumann : logical
 858
 859        Whether to consider Neumann boundary conditions
 860
 861    oper_neumann : function
 862
 863        Penalization of Neumann boundary conditions
 864
 865    sa : logical
 866
 867        Whether to consider self-adaptative PINN
 868
 869    c : dict
 870
 871        Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
 872
 873    inverse : logical
 874
 875        Whether to estimate parameters of the PDE
 876
 877    initial_par : jax.numpy.array
 878
 879        Initial value of the parameters of the PDE in an inverse problem
 880
 881    lr,b1,b2,eps,eps_root: float
 882
 883        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
 884
 885    key : int
 886
 887        Seed for parameters initialization. Default 0
 888
 889    epoch_print : int
 890
 891        Number of epochs to calculate and print test errors. Default 100
 892
 893    save : logical
 894
 895        Whether to save the current parameters. Default False
 896
 897    file_name : str
 898
 899        File prefix to save the current parameters. Default 'result_pinn'
 900
 901    exp_decay : logical
 902
 903        Whether to consider exponential decay of learning rate. Default False
 904
 905    transition_steps : int
 906
 907        Number of steps for exponential decay. Default 1000
 908
 909    decay_rate : float
 910
 911        Rate of exponential decay. Default 0.9
 912
 913    mlp : logical
 914
 915        Whether to consider modifed multi-layer perceptron
 916
 917    Returns
 918    -------
 919    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
 920    """
 921
 922    #Initialize architecture
 923    nnet = fconNN(width,get_activation(activation),key,mlp)
 924    forward = nnet['forward']
 925
 926    #Initialize self adaptative weights
 927    par_sa = {}
 928    if sa:
 929        #Initialize wheights close to zero
 930        ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000)
 931        if data['sensor'] is not None:
 932            par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))})
 933        if data['initial'] is not None:
 934            par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))})
 935        if data['collocation'] is not None:
 936            par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))})
 937        if data['boundary'] is not None:
 938            par_sa.update({'wb': c['wb'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))})
 939
 940    #Store all parameters
 941    params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa}
 942
 943    #Save config file
 944    if save:
 945        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
 946
 947    #Define loss function
 948    if sa:
 949        #Define loss function
 950        @jax.jit
 951        def lf(params,x):
 952            loss = 0
 953            if x['sensor'] is not None:
 954                #Term that refers to sensor data
 955                loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws']))
 956            if x['boundary'] is not None:
 957                if neumann:
 958                    #Neumann coditions
 959                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
 960                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
 961                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb']))
 962                else:
 963                    #Term that refers to boundary data
 964                    loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb']))
 965            if x['initial'] is not None:
 966                #Term that refers to initial data
 967                loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0']))
 968            if x['collocation'] is not None:
 969                #Term that refers to collocation points
 970                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
 971                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
 972                if inverse:
 973                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr']))
 974                else:
 975                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr']))
 976            return loss
 977    else:
 978        @jax.jit
 979        def lf(params,x):
 980            loss = 0
 981            if x['sensor'] is not None:
 982                #Term that refers to sensor data
 983                loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor']))
 984            if x['boundary'] is not None:
 985                if neumann:
 986                    #Neumann coditions
 987                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
 988                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
 989                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb))
 990                else:
 991                    #Term that refers to boundary data
 992                    loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary']))
 993            if x['initial'] is not None:
 994                #Term that refers to initial data
 995                loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial']))
 996            if x['collocation'] is not None:
 997                #Term that refers to collocation points
 998                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
 999                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
1000                if inverse:
1001                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0))
1002                else:
1003                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0))
1004            return loss
1005
1006    #Initialize Adam Optmizer
1007    if exp_decay:
1008        lr = optax.exponential_decay(lr,transition_steps,decay_rate)
1009    optimizer = optax.adam(lr,b1,b2,eps,eps_root)
1010    opt_state = optimizer.init(params)
1011
1012    #Define the gradient function
1013    grad_loss = jax.jit(jax.grad(lf,0))
1014
1015    #Define update function
1016    @jax.jit
1017    def update(opt_state,params,x):
1018        #Compute gradient
1019        grads = grad_loss(params,x)
1020        #Invert gradient of self-adaptative wheights
1021        if sa:
1022            for w in grads['sa']:
1023                grads['sa'][w] = - grads['sa'][w]
1024        #Calculate parameters updates
1025        updates, opt_state = optimizer.update(grads, opt_state)
1026        #Update parameters
1027        params = optax.apply_updates(params, updates)
1028        #Return state of optmizer and updated parameters
1029        return opt_state,params
1030
1031    ###Training###
1032    t0 = time.time()
1033    #Initialize alive_bar for tracing in terminal
1034    with alive_bar(epochs) as bar:
1035        #For each epoch
1036        for e in range(epochs):
1037            #Update optimizer state and parameters
1038            opt_state,params = update(opt_state,params,data)
1039            #After epoch_print epochs
1040            if e % epoch_print == 0:
1041                #Compute elapsed time and current error
1042                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6))
1043                #If there is test data, compute current L2 error
1044                if test_data is not None:
1045                    #Compute L2 error
1046                    l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist()
1047                    l = l + ' L2 error: ' + str(jnp.round(l2_test,3))
1048                if inverse:
1049                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
1050                #Print
1051                print(l)
1052            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
1053                #Save current parameters
1054                pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1055            #Update alive_bar
1056            bar()
1057    #Define estimated function
1058    def u(xt):
1059        return forward(xt,params['net'])
1060
1061    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}

Train a Physics-informed Neural Network

Parameters
  • data (dict): Data generated by the jinnax.data.generate_PINNdata function
  • width (list): A list with the width of each layer
  • pde (function): The partial differential operator. Its arguments are u, x and t
  • test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
  • epochs (int): Number of training epochs. Default 100
  • at_each (int): Save results for epochs multiple of at_each. Default 10
  • activation (str): The name of the activation function of the neural network. Default 'tanh'
  • neumann (logical): Whether to consider Neumann boundary conditions
  • oper_neumann (function): Penalization of Neumann boundary conditions
  • sa (logical): Whether to consider self-adaptative PINN
  • c (dict): Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
  • inverse (logical): Whether to estimate parameters of the PDE
  • initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
  • lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
  • key (int): Seed for parameters initialization. Default 0
  • epoch_print (int): Number of epochs to calculate and print test errors. Default 100
  • save (logical): Whether to save the current parameters. Default False
  • file_name (str): File prefix to save the current parameters. Default 'result_pinn'
  • exp_decay (logical): Whether to consider exponential decay of learning rate. Default False
  • transition_steps (int): Number of steps for exponential decay. Default 1000
  • decay_rate (float): Rate of exponential decay. Default 0.9
  • mlp (logical): Whether to consider modifed multi-layer perceptron
Returns
  • dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
def train_Matern_PINN( data, width, pde, test_data=None, params=None, d=2, N=128, L=1, alpha=1, kappa=1, sigma=100, bsize=1024, resample=False, epochs=100, at_each=10, activation='tanh', neumann=False, oper_neumann=None, inverse=False, initial_par=None, lr=0.001, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, key=0, epoch_print=1, save=False, file_name='result_pinn', exp_decay=True, transition_steps=100, decay_rate=0.9, mlp=True, ftype=None, fargs=None, q=4, w=None, periodic=False, static=None, opt='LBFGS', float64=False, restart=None):
1064def train_Matern_PINN(data,width,pde,test_data = None,params = None,d = 2,N = 128,L = 1,alpha = 1,kappa = 1,sigma = 100,bsize = 1024,resample = False,epochs = 100,at_each = 10,activation = 'tanh',
1065    neumann = False,oper_neumann = None,inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 1,save = False,file_name = 'result_pinn',
1066    exp_decay = True,transition_steps = 100,decay_rate = 0.9,mlp = True,ftype = None,fargs = None,q = 4,w = None,periodic = False,static = None,opt = 'LBFGS',float64 = False,restart = None):
1067    """
1068    Train a Physics-informed Neural Network
1069    ----------
1070    Parameters
1071    ----------
1072    data : dict
1073
1074        Data generated by the jinnax.data.generate_PINNdata function
1075
1076    width : list
1077
1078        A list with the width of each layer
1079
1080    pde : function
1081
1082        The partial differential operator. Its arguments are u, x and t
1083
1084    test_data : dict, None
1085
1086        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
1087
1088    params : list
1089
1090        Initial parameters for the neural network. Default None to initialize randomly
1091
1092    d : int
1093
1094        Dimension of the problem including the time variable if present. Default 2
1095
1096    N : int
1097
1098        Size of grid in each dimension. Default 128
1099
1100    L : list of float
1101
1102        The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
1103
1104    kappa,alpha,sigma : float
1105
1106        Parameters of the Matern process
1107
1108    bsize : int
1109
1110        Batch size for weak norm computation. Default 1024
1111
1112    resample : logical
1113
1114        Whether to resample the test functions at each epoch
1115
1116    epochs : int
1117
1118        Number of training epochs. Default 100
1119
1120    at_each : int
1121
1122        Save results for epochs multiple of at_each. Default 10
1123
1124    activation : str
1125
1126        The name of the activation function of the neural network. Default 'tanh'
1127
1128    neumann : logical
1129
1130        Whether to consider Neumann boundary conditions
1131
1132    oper_neumann : function
1133
1134        Penalization of Neumann boundary conditions
1135
1136    inverse : logical
1137
1138        Whether to estimate parameters of the PDE
1139
1140    initial_par : jax.numpy.array
1141
1142        Initial value of the parameters of the PDE in an inverse problem
1143
1144    lr,b1,b2,eps,eps_root: float
1145
1146        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
1147
1148    key : int
1149
1150        Seed for parameters initialization. Default 0
1151
1152    epoch_print : int
1153
1154        Number of epochs to calculate and print test errors. Default 1
1155
1156    save : logical
1157
1158        Whether to save the current parameters. Default False
1159
1160    file_name : str
1161
1162        File prefix to save the current parameters. Default 'result_pinn'
1163
1164    exp_decay : logical
1165
1166        Whether to consider exponential decay of learning rate. Default True
1167
1168    transition_steps : int
1169
1170        Number of steps for exponential decay. Default 100
1171
1172    decay_rate : float
1173
1174        Rate of exponential decay. Default 0.9
1175
1176    mlp : logical
1177
1178        Whether to consider modifed multilayer perceptron
1179
1180    ftype : str
1181
1182        Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
1183
1184    fargs : list
1185
1186        Arguments for deature transformation:
1187
1188        For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.
1189
1190        For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.
1191
1192    q : int
1193
1194        Power of weights mask. Default 4
1195
1196    w : dict
1197
1198        Initila weights for self-adaptive scheme.
1199
1200    periodic : logical
1201
1202        Whether to consider periodic test functions. Default False.
1203
1204    static : function
1205
1206        A static function to sum to the neural network output.
1207
1208    opt : str
1209
1210        Optimizer. Default LBFGS.
1211
1212    float64 : logical
1213
1214        Whether to train with float64
1215
1216    restart : int
1217
1218        Epochs to restart L-BFGS
1219
1220    Returns
1221    -------
1222    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch
1223    """
1224    #Initialize architecture
1225    nnet = fconNN(width,get_activation(activation),key,mlp,ftype,fargs,static)
1226    if float64:
1227        forward = lambda x,params: nnet['forward'](x,params).astype(jnp.float64)
1228        assert jax.config.jax_enable_x64, "JAX is NOT running in float64 mode!"
1229    else:
1230        forward = nnet['forward']
1231    if params is not None:
1232        nnet['params'] = params
1233    if float64:
1234        data = to_float64(data)
1235        if test_data is not None:
1236            test_data = to_float64(test_data)
1237
1238    #Generate from Matern process
1239    if sigma > 0:
1240        if isinstance(L,float) or isinstance(L,int):
1241            L = d*[L]
1242        #Grid for weak norm
1243        if float64:
1244            grid = [jnp.linspace(0,L[i],N,dtype = jnp.float64) for i in range(d)]
1245        else:
1246            grid = [jnp.linspace(0,L[i],N) for i in range(d)]
1247        grid = jnp.meshgrid(*grid, indexing='ij')
1248        grid = jnp.stack(grid, axis=-1).reshape((-1, d))
1249        #Set sigma
1250        if data['boundary'] is not None:
1251            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma)
1252            tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0])
1253            if neumann:
1254                loss_boundary = oper_neumann(lambda x: forward(x,nnet['params']),data['boundary'])
1255            else:
1256                loss_boundary = jnp.mean(MSE(forward(data['boundary'],nnet['params']),data['uboundary']))
1257            output_w = pde(lambda x: forward(x,nnet['params']),grid)
1258            integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(tf)
1259            loss_res_weak = jnp.mean(integralOmega ** 2)
1260            sigma = float(jnp.sqrt(loss_boundary/loss_res_weak).tolist())
1261            del gen
1262            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)
1263            tf = sigma*tf
1264        else:
1265            gen = generate_matern_sample_batch(d = d,N = N,L = L,kappa = kappa,alpha = alpha,sigma = sigma,periodic = periodic)
1266            tf = gen(jax.random.split(jax.random.PRNGKey(key + 1),(bsize,))[:,0])
1267        if float64 and tf is not None:
1268            tf = to_float64(tf)
1269            grid = to_float64(grid)
1270    else:
1271        tf = None
1272        grid = None
1273
1274    #Define loss function
1275    @jax.jit
1276    def lf_each(params,x,k,tf,grid):
1277        if sigma > 0:
1278            #Term that refers to weak loss
1279            if resample:
1280                test_functions = to_float64(gen(jax.random.split(jax.random.PRNGKey(k[0]),(bsize,))[:,0]))
1281            else:
1282                test_functions = tf
1283        loss_sensor = loss_boundary = loss_initial = loss_res = loss_res_weak = 0
1284        if x['sensor'] is not None:
1285            #Term that refers to sensor data
1286            loss_sensor = MSE(forward(x['sensor'],params['net']),x['usensor'])
1287        if x['boundary'] is not None:
1288            if neumann:
1289                #Neumann coditions
1290                loss_boundary = oper_neumann(lambda x: forward(x,nnet['params']),x['boundary'])
1291            else:
1292                #Term that refers to boundary data
1293                loss_boundary = MSE(forward(x['boundary'],params['net']),x['uboundary'])
1294        if x['initial'] is not None:
1295            #Term that refers to initial data
1296            loss_initial = MSE(forward(x['initial'],params['net']),x['uinitial'])
1297        if x['collocation'] is not None and sigma == 0:
1298            if inverse:
1299                output = pde(lambda x: forward(x,params['net']),x['collocation'],params['inverse'])
1300                loss_res = MSE(output,0)
1301            else:
1302                output = pde(lambda x: forward(x,params['net']),x['collocation'])
1303                loss_res = MSE(output,0)
1304        if sigma > 0:
1305            #Term that refers to weak loss
1306            if inverse:
1307                output_w = pde(lambda x: forward(x,params['net']),grid,params['inverse'])
1308                integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions)
1309                loss_res_weak = jnp.mean(integralOmega ** 2)
1310            else:
1311                output_w = pde(lambda x: forward(x,params['net']),grid)
1312                integralOmega = jax.vmap(lambda psi: jnp.mean(psi*output_w.reshape((N,) * d)))(test_functions)
1313                loss_res_weak = jnp.mean(integralOmega ** 2)
1314        return {'ls': loss_sensor,'lb': loss_boundary,'li': loss_initial,'lc': loss_res,'lc_weak': loss_res_weak}
1315
1316    @jax.jit
1317    def lf(params,x,k,tf,grid):
1318        l = lf_each(params,x,k,tf,grid)
1319        if opt != 'LBFGS':
1320            loss = jnp.mean((params['w']['ws'] ** q)*l['ls']) + jnp.mean((params['w']['wb'] ** q)*l['lb']) + jnp.mean((params['w']['wi'] ** q)*l['li']) + jnp.mean((params['w']['wc'] ** q)*l['lc']) + (params['w']['wc_weak'] ** q)*l['lc_weak']
1321            return loss
1322        else:
1323            loss = jnp.mean((w['ws'] ** q)*l['ls']) + jnp.mean((w['wb'] ** q)*l['lb']) + jnp.mean((w['wi'] ** q)*l['li']) + jnp.mean((w['wc'] ** q)*l['lc']) + (w['wc_weak'] ** q)*l['lc_weak']
1324            l2 = None
1325            if test_data is not None:
1326                l2 = L2error(forward(test_data['sensor'],params['net']),test_data['usensor'])
1327            return loss,{'loss': loss,'l2': l2}
1328
1329    #Initialize self-adaptive weights
1330    if float64:
1331        typ = jnp.float64
1332    else:
1333        typ = jnp.float32
1334    if w is None:
1335        w = {'ws': jnp.array(1.0),'wb': jnp.array(1.0),'wi': jnp.array(1.0),'wc': jnp.array(1.0),'wc_weak': jnp.array(1.0)}
1336    if q != 0 and opt != 'LBFGS':
1337        if data['sensor'] is not None:
1338            w['ws'] = w['ws'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+1),(data['sensor'].shape[0],1),dtype = typ)
1339        if data['boundary'] is not None:
1340            w['wb'] = w['wb'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+2),(data['boundary'].shape[0],1),dtype = typ)
1341        if data['initial'] is not None:
1342            w['wi'] = w['wi'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+3),(data['initial'].shape[0],1),dtype = typ)
1343        if data['collocation'] is not None:
1344            w['wc'] = w['wc'] + 0.05*jax.random.normal(jax.random.PRNGKey(key+4),(data['collocation'].shape[0],1),dtype = typ)
1345
1346    #Store all parameters
1347    if opt != 'LBFGS':
1348        params = {'net': nnet['params'],'inverse': initial_par,'w': w}
1349    else:
1350        params = {'net': nnet['params'],'inverse': initial_par}
1351    if float64:
1352        params = to_float64(params)
1353
1354    #Save config file
1355    if save:
1356        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1357
1358    #Initialize Adam Optmizer
1359    if opt != 'LBFGS':
1360        print('--------- GRADIENT DESCENT OPTIMIZER ---------')
1361        if exp_decay:
1362            lr = optax.exponential_decay(lr,transition_steps,decay_rate)
1363        optimizer = optax.adam(lr,b1,b2,eps,eps_root)
1364        opt_state = optimizer.init(params)
1365
1366        #Define the gradient function
1367        grad_loss = jax.jit(jax.grad(lf,0))
1368
1369        #Define update function
1370        @jax.jit
1371        def update(opt_state,params,x,k,tf,grid):
1372            #Compute gradient
1373            grads = grad_loss(params,x,k,tf,grid)
1374            #Calculate parameters updates
1375            updates, opt_state = optimizer.update(grads, opt_state)
1376            #Update parameters
1377            if q != 0:
1378                updates = {**updates, 'w': jax.tree_util.tree_map(lambda x: -x, updates['w'])} #Change signs of weights
1379            params = optax.apply_updates(params, updates)
1380            #Return state of optmizer and updated parameters
1381            return opt_state,params
1382    else:
1383        print('--------- LBFGS OPTIMIZER ---------')
1384        @jax.jit
1385        def loss_LBFGS(params):
1386            return lf(params,data,key + 234,tf,grid)
1387        solver = LBFGS(fun = loss_LBFGS,has_aux = True,maxiter = epochs,tol = 1e-9,verbose = False,linesearch = 'zoom',history_size = 200)  # linesearch='zoom' by default
1388        state = solver.init_state(params)
1389        if float64:
1390            assert_lbfgs_state_float64(state)
1391
1392    ###Training###
1393    t0 = time.time()
1394    k = jax.random.split(jax.random.PRNGKey(key+234),(epochs,))
1395    sloss = []
1396    sL2 = []
1397    stime = []
1398    #Initialize alive_bar for tracing in terminal
1399    with alive_bar(epochs) as bar:
1400        #For each epoch
1401        for e in range(epochs):
1402            if opt != 'LBFGS':
1403                if float64 and e < 10:
1404                    assert_tree_float64(params, name="params (before update)")
1405                    grads = grad_loss(params, data, k[e,:], tf, grid)
1406                    check_grads_float64(grads)
1407                #Update optimizer state and parameters
1408                opt_state,params = update(opt_state,params,data,k[e,:],tf,grid)
1409                sloss.append(lf(params,data,k[e,:],tf,grid))
1410                if test_data is not None:
1411                    sL2.append(L2error(forward(test_data['sensor'],params['net']),test_data['usensor']))
1412                if float64 and e < 10:
1413                    assert_tree_float64(params, name="params (before update)")
1414                    grads = grad_loss(params, data, k[e,:], tf, grid)
1415                    check_grads_float64(grads)
1416            else:
1417                if float64 and e < 10:
1418                    assert_tree_float64(params, name="params (before LBFGS step)")
1419                params, state = solver.update(params, state)
1420                if float64 and e < 10:
1421                    assert_tree_float64(params, name="params (after LBFGS step)")
1422                    assert_lbfgs_state_float64(state)
1423                sL2.append(state.aux["l2"])
1424                sloss.append(state.aux["loss"])
1425                if float64 and e < 10:
1426                    assert_tree_float64(params, name="params (before update)")
1427                if restart is not None:
1428                    if (e + 1) % restart == 0:
1429                        state = solver.init_state(params)
1430            stime.append(time.time() - t0)
1431            #After epoch_print epochs
1432            if e % epoch_print == 0:
1433                #Compute elapsed time and current error
1434                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(sloss[-1],6))
1435                #If there is test data, compute current L2 error
1436                if test_data is not None:
1437                    #Compute L2 error
1438                    l = l + ' L2 error: ' + str(jnp.round(sL2[-1],6))
1439                if inverse:
1440                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
1441                #Print
1442                print(l)
1443            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
1444                #Save current parameters
1445                pickle.dump({'params': params,'width': width,'time': stime,'loss': sloss,'L2error': sL2},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
1446            #Update alive_bar
1447            bar()
1448    #Define estimated function
1449    def u(xt):
1450        return forward(xt,params['net'])
1451
1452    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0,'loss_each': lf_each(params,data,[key + 100],tf,grid),'loss': sloss,'L2error': sL2}

Train a Physics-informed Neural Network

Parameters
  • data (dict): Data generated by the jinnax.data.generate_PINNdata function
  • width (list): A list with the width of each layer
  • pde (function): The partial differential operator. Its arguments are u, x and t
  • test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
  • params (list): Initial parameters for the neural network. Default None to initialize randomly
  • d (int): Dimension of the problem including the time variable if present. Default 2
  • N (int): Size of grid in each dimension. Default 128
  • L (list of float): The domain of the function in each coordinate is [0,L[1]]. If a float, repeat the same interval for all coordinates. Default 1
  • kappa,alpha,sigma (float): Parameters of the Matern process
  • bsize (int): Batch size for weak norm computation. Default 1024
  • resample (logical): Whether to resample the test functions at each epoch
  • epochs (int): Number of training epochs. Default 100
  • at_each (int): Save results for epochs multiple of at_each. Default 10
  • activation (str): The name of the activation function of the neural network. Default 'tanh'
  • neumann (logical): Whether to consider Neumann boundary conditions
  • oper_neumann (function): Penalization of Neumann boundary conditions
  • inverse (logical): Whether to estimate parameters of the PDE
  • initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
  • lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
  • key (int): Seed for parameters initialization. Default 0
  • epoch_print (int): Number of epochs to calculate and print test errors. Default 1
  • save (logical): Whether to save the current parameters. Default False
  • file_name (str): File prefix to save the current parameters. Default 'result_pinn'
  • exp_decay (logical): Whether to consider exponential decay of learning rate. Default True
  • transition_steps (int): Number of steps for exponential decay. Default 100
  • decay_rate (float): Rate of exponential decay. Default 0.9
  • mlp (logical): Whether to consider modifed multilayer perceptron
  • ftype (str): Type of feature transformation to use: None, 'ff', 'daff','daff_bias', 'cheb', 'cheb_bias'.
  • fargs (list): Arguments for deature transformation:

    For 'ff': A list with the number of frequences and value of greatest frequence standard deviation.

    For 'daff' and 'daff' bias: A dicitionary with a list with the size of rectangles and the type of boundary condition. If its a list, than boundary conditions is dirichlet.

  • q (int): Power of weights mask. Default 4
  • w (dict): Initila weights for self-adaptive scheme.
  • periodic (logical): Whether to consider periodic test functions. Default False.
  • static (function): A static function to sum to the neural network output.
  • opt (str): Optimizer. Default LBFGS.
  • float64 (logical): Whether to train with float64
  • restart (int): Epochs to restart L-BFGS
Returns
  • dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the loss, L2error and training time at each epoch
def process_result( test_data, fit, train_data, plot=True, plot_test=True, times=5, d2=True, save=False, show=True, file_name='result_pinn', print_res=True, p=1):
1456def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1):
1457    """
1458    Process the results of a Physics-informed Neural Network
1459    ----------
1460
1461    Parameters
1462    ----------
1463    test_data : dict
1464
1465        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1466
1467    fit : function
1468
1469        The fitted function
1470
1471    train_data : dict
1472
1473        Training data generated by the jinnax.data.generate_PINNdata
1474
1475    plot : logical
1476
1477        Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
1478
1479    plot_test : logical
1480
1481        Whether to plot the test data. Default True
1482
1483    times : int
1484
1485        Number of points along the time interval to plot. Default 5
1486
1487    d2 : logical
1488
1489        Whether to plot 2D plot when the spatial dimension is one. Default True
1490
1491    save : logical
1492
1493        Whether to save the plots. Default False
1494
1495    show : logical
1496
1497        Whether to show the plots. Default True
1498
1499    file_name : str
1500
1501        File prefix to save the plots. Default 'result_pinn'
1502
1503    print_res : logical
1504
1505        Whether to print the L2 error. Default True
1506
1507    p : int
1508
1509        Output dimension. Default 1
1510
1511    Returns
1512    -------
1513    pandas data frame with L2 and MSE errors
1514    """
1515
1516    #Dimension
1517    d = test_data['xt'].shape[1] - 1
1518
1519    #Number of plots multiple of 5
1520    times = 5 * round(times/5)
1521
1522    #Data
1523    td = get_train_data(train_data)
1524    xt_train = td['x']
1525    u_train = td['y']
1526    upred_train = fit(xt_train)
1527    upred_test = fit(test_data['xt'])
1528
1529    #Results
1530    l2_error_test = L2error(upred_test,test_data['u']).tolist()
1531    MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist()
1532    l2_error_train = L2error(upred_train,u_train).tolist()
1533    MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist()
1534
1535    df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)),
1536        columns=['l2_error_test','MSE_test','l2_error_train','MSE_train'])
1537    if print_res:
1538        print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) )
1539
1540    #Plots
1541    if d == 1 and p ==1 and plot:
1542        plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name)
1543    elif p == 2 and plot:
1544        plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test)
1545
1546    return df

Process the results of a Physics-informed Neural Network

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • fit (function): The fitted function
  • train_data (dict): Training data generated by the jinnax.data.generate_PINNdata
  • plot (logical): Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
  • plot_test (logical): Whether to plot the test data. Default True
  • times (int): Number of points along the time interval to plot. Default 5
  • d2 (logical): Whether to plot 2D plot when the spatial dimension is one. Default True
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • print_res (logical): Whether to print the L2 error. Default True
  • p (int): Output dimension. Default 1
Returns
  • pandas data frame with L2 and MSE errors
def plot_pinn1D( times, xt, u, upred, d2=True, save=False, show=True, file_name='result_pinn', title_1d='', title_2d=''):
1549def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''):
1550    """
1551    Plot the prediction of a 1D PINN
1552    ----------
1553
1554    Parameters
1555    ----------
1556    times : int
1557
1558        Number of points along the time interval to plot. Default 5
1559
1560    xt : jax.numpy.array
1561
1562        Test data xt array
1563
1564    u : jax.numpy.array
1565
1566        Test data u(x,t) array
1567
1568    upred : jax.numpy.array
1569
1570        Predicted upred(x,t) array on test data
1571
1572    d2 : logical
1573
1574        Whether to plot 2D plot. Default True
1575
1576    save : logical
1577
1578        Whether to save the plots. Default False
1579
1580    show : logical
1581
1582        Whether to show the plots. Default True
1583
1584    file_name : str
1585
1586        File prefix to save the plots. Default 'result_pinn'
1587
1588    title_1d : str
1589
1590        Title of 1D plot
1591
1592    title_2d : str
1593
1594        Title of 2D plot
1595
1596    Returns
1597    -------
1598    None
1599    """
1600    #Initialize
1601    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
1602    tlo = jnp.min(xt[:,-1])
1603    tup = jnp.max(xt[:,-1])
1604    ylo = jnp.min(u)
1605    ylo = ylo - 0.1*jnp.abs(ylo)
1606    yup = jnp.max(u)
1607    yup = yup + 0.1*jnp.abs(yup)
1608    k = 0
1609    t_values = np.linspace(tlo,tup,times)
1610
1611    #Create
1612    for i in range(int(times/5)):
1613        for j in range(5):
1614            if k < len(t_values):
1615                t = t_values[k]
1616                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
1617                x_plot = xt[xt[:,-1] == t,:-1]
1618                y_plot = upred[xt[:,-1] == t,:]
1619                u_plot = u[xt[:,-1] == t,:]
1620                if int(times/5) > 1:
1621                    ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
1622                    ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
1623                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
1624                    ax[i,j].set_xlabel(' ')
1625                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1626                else:
1627                    ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
1628                    ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
1629                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
1630                    ax[j].set_xlabel(' ')
1631                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1632                k = k + 1
1633
1634    #Title
1635    fig.suptitle(title_1d)
1636    fig.tight_layout()
1637
1638    #Show and save
1639    fig = plt.gcf()
1640    if show:
1641        plt.show()
1642    if save:
1643        fig.savefig(file_name + '_slices.png')
1644    plt.close()
1645
1646    #2d plot
1647    if d2:
1648        #Initialize
1649        fig, ax = plt.subplots(1,2)
1650        l1 = jnp.unique(xt[:,-1]).shape[0]
1651        l2 = jnp.unique(xt[:,0]).shape[0]
1652
1653        #Create
1654        ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
1655        ax[0].set_title('Exact')
1656        ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
1657        ax[1].set_title('Predicted')
1658
1659        #Title
1660        fig.suptitle(title_2d)
1661        fig.tight_layout()
1662
1663        #Show and save
1664        fig = plt.gcf()
1665        if show:
1666            plt.show()
1667        if save:
1668            fig.savefig(file_name + '_2d.png')
1669        plt.close()

Plot the prediction of a 1D PINN

Parameters
  • times (int): Number of points along the time interval to plot. Default 5
  • xt (jax.numpy.array): Test data xt array
  • u (jax.numpy.array): Test data u(x,t) array
  • upred (jax.numpy.array): Predicted upred(x,t) array on test data
  • d2 (logical): Whether to plot 2D plot. Default True
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • title_1d (str): Title of 1D plot
  • title_2d (str): Title of 2D plot
Returns
  • None
def plot_pinn_out2D( times, xt, u, upred, save=False, show=True, file_name='result_pinn', title='', plot_test=True):
1672def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True):
1673    """
1674    Plot the prediction of a PINN with 2D output
1675    ----------
1676    Parameters
1677    ----------
1678    times : int
1679
1680        Number of points along the time interval to plot. Default 5
1681
1682    xt : jax.numpy.array
1683
1684        Test data xt array
1685
1686    u : jax.numpy.array
1687
1688        Test data u(x,t) array
1689
1690    upred : jax.numpy.array
1691
1692        Predicted upred(x,t) array on test data
1693
1694    save : logical
1695
1696        Whether to save the plots. Default False
1697
1698    show : logical
1699
1700        Whether to show the plots. Default True
1701
1702    file_name : str
1703
1704        File prefix to save the plots. Default 'result_pinn'
1705
1706    title : str
1707
1708        Title of plot
1709
1710    plot_test : logical
1711
1712        Whether to plot the test data. Default True
1713
1714    Returns
1715    -------
1716    None
1717    """
1718    #Initialize
1719    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
1720    tlo = jnp.min(xt[:,-1])
1721    tup = jnp.max(xt[:,-1])
1722    xlo = jnp.min(u[:,0])
1723    xlo = xlo - 0.1*jnp.abs(xlo)
1724    xup = jnp.max(u[:,0])
1725    xup = xup + 0.1*jnp.abs(xup)
1726    ylo = jnp.min(u[:,1])
1727    ylo = ylo - 0.1*jnp.abs(ylo)
1728    yup = jnp.max(u[:,1])
1729    yup = yup + 0.1*jnp.abs(yup)
1730    k = 0
1731    t_values = np.linspace(tlo,tup,times)
1732
1733    #Create
1734    for i in range(int(times/5)):
1735        for j in range(5):
1736            if k < len(t_values):
1737                t = t_values[k]
1738                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
1739                xpred_plot = upred[xt[:,-1] == t,0]
1740                ypred_plot = upred[xt[:,-1] == t,1]
1741                if plot_test:
1742                    x_plot = u[xt[:,-1] == t,0]
1743                    y_plot = u[xt[:,-1] == t,1]
1744                if int(times/5) > 1:
1745                    if plot_test:
1746                        ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
1747                    ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction')
1748                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
1749                    ax[i,j].set_xlabel(' ')
1750                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1751                else:
1752                    if plot_test:
1753                        ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
1754                    ax[j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction')
1755                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
1756                    ax[j].set_xlabel(' ')
1757                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1758                k = k + 1
1759
1760    #Title
1761    fig.suptitle(title)
1762    fig.tight_layout()
1763
1764    #Show and save
1765    fig = plt.gcf()
1766    if show:
1767        plt.show()
1768    if save:
1769        fig.savefig(file_name + '_slices.png')
1770    plt.close()

Plot the prediction of a PINN with 2D output

Parameters
  • times (int): Number of points along the time interval to plot. Default 5
  • xt (jax.numpy.array): Test data xt array
  • u (jax.numpy.array): Test data u(x,t) array
  • upred (jax.numpy.array): Predicted upred(x,t) array on test data
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • title (str): Title of plot
  • plot_test (logical): Whether to plot the test data. Default True
Returns
  • None
def get_train_data(train_data):
1773def get_train_data(train_data):
1774    """
1775    Process training sample
1776    ----------
1777
1778    Parameters
1779    ----------
1780    train_data : dict
1781
1782        A dictionay with train data generated by the jinnax.data.generate_PINNdata function
1783
1784    Returns
1785    -------
1786    dict with the processed training data
1787    """
1788    xdata = None
1789    ydata = None
1790    xydata = None
1791    if train_data['sensor'] is not None:
1792        sensor_sample = train_data['sensor'].shape[0]
1793        xdata = train_data['sensor']
1794        ydata = train_data['usensor']
1795        xydata = jnp.column_stack((train_data['sensor'],train_data['usensor']))
1796    else:
1797        sensor_sample = 0
1798    if train_data['boundary'] is not None:
1799        boundary_sample = train_data['boundary'].shape[0]
1800        if xdata is not None:
1801            xdata = jnp.vstack((xdata,train_data['boundary']))
1802            ydata = jnp.vstack((ydata,train_data['uboundary']))
1803            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary']))))
1804        else:
1805            xdata = train_data['boundary']
1806            ydata = train_data['uboundary']
1807            xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary']))
1808    else:
1809        boundary_sample = 0
1810    if train_data['initial'] is not None:
1811        initial_sample = train_data['initial'].shape[0]
1812        if xdata is not None:
1813            xdata = jnp.vstack((xdata,train_data['initial']))
1814            ydata = jnp.vstack((ydata,train_data['uinitial']))
1815            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial']))))
1816        else:
1817            xdata = train_data['initial']
1818            ydata = train_data['uinitial']
1819            xydata = jnp.column_stack((train_data['initial'],train_data['uinitial']))
1820    else:
1821        initial_sample = 0
1822    if train_data['collocation'] is not None:
1823        collocation_sample = train_data['collocation'].shape[0]
1824    else:
1825        collocation_sample = 0
1826
1827    return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
Process training sample
Parameters
Returns
  • dict with the processed training data
def process_training( test_data, file_name, at_each=100, bolstering=True, mc_sample=10000, save=False, file_name_save='result_pinn', key=0, ec=1e-06, lamb=1):
1830def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1):
1831    """
1832    Process the training of a Physics-informed Neural Network
1833    ----------
1834
1835    Parameters
1836    ----------
1837    test_data : dict
1838
1839        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1840
1841    file_name : str
1842
1843        Name of the files saved during training
1844
1845    at_each : int
1846
1847        Compute results for epochs multiple of at_each. Default 100
1848
1849    bolstering : logical
1850
1851        Whether to compute bolstering mean square error. Default True
1852
1853    mc_sample : int
1854
1855        Number of sample for Monte Carlo integration in bolstering. Default 10000
1856
1857    save : logical
1858
1859        Whether to save the training results. Default False
1860
1861    file_name_save : str
1862
1863        File prefix to save the plots and the L2 error. Default 'result_pinn'
1864
1865    key : int
1866
1867        Key for random samples in bolstering. Default 0
1868
1869    ec : float
1870
1871        Stopping criteria error for EM algorithm in bolstering. Default 1e-6
1872
1873    lamb : float
1874
1875        Hyperparameter of EM algorithm in bolstering. Default 1
1876
1877    Returns
1878    -------
1879    pandas data frame with training results
1880    """
1881    #Config
1882    config = pickle.load(open(file_name + '_config.pickle', 'rb'))
1883    epochs = config['epochs']
1884    train_data = config['train_data']
1885    forward = config['forward']
1886
1887    #Get train data
1888    td = get_train_data(train_data)
1889    xydata = td['xy']
1890    xdata = td['x']
1891    ydata = td['y']
1892    sensor_sample = td['sensor_sample']
1893    boundary_sample = td['boundary_sample']
1894    initial_sample = td['initial_sample']
1895    collocation_sample = td['collocation_sample']
1896
1897    #Generate keys
1898    if bolstering:
1899        keys = jax.random.split(jax.random.PRNGKey(key),epochs)
1900
1901    #Initialize loss
1902    train_mse = []
1903    test_mse = []
1904    train_L2 = []
1905    test_L2 = []
1906    bolstX = []
1907    bolstXY = []
1908    loss = []
1909    time = []
1910    ep = []
1911
1912    #Process training
1913    with alive_bar(epochs) as bar:
1914        for e in range(epochs):
1915            if (e % at_each == 0 and at_each != epochs) or e == epochs - 1:
1916                ep = ep + [e]
1917
1918                #Read parameters
1919                params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb'))
1920
1921                #Time
1922                time = time + [params['time']]
1923
1924                #Define learned function
1925                def psi(x):
1926                    return forward(x,params['params']['net'])
1927
1928                #Train MSE and L2
1929                if xdata is not None:
1930                    train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()]
1931                    train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()]
1932                else:
1933                    train_mse = train_mse + [None]
1934                    train_L2 = train_L2 + [None]
1935
1936                #Test MSE and L2
1937                test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()]
1938                test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()]
1939
1940                #Bolstering
1941                if bolstering:
1942                    bX = []
1943                    bXY = []
1944                    for method in ['chi','mm','mpe']:
1945                        kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
1946                        kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
1947                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
1948                        bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()]
1949                    for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]:
1950                        kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias)
1951                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
1952                    bolstX = bolstX + [bX]
1953                    bolstXY = bolstXY + [bXY]
1954                else:
1955                    bolstX = bolstX + [None]
1956                    bolstXY = bolstXY + [None]
1957
1958                #Loss
1959                loss = loss + [params['loss'].tolist()]
1960
1961                #Delete
1962                del params, psi
1963            #Update alive_bar
1964            bar()
1965
1966    #Bolstering results
1967    if bolstering:
1968        bolstX = jnp.array(bolstX)
1969        bolstXY = jnp.array(bolstXY)
1970
1971    #Create data frame
1972    if bolstering:
1973        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
1974            train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]),
1975            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4'])
1976    else:
1977        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
1978            train_mse,test_mse,train_L2,test_L2]),
1979            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2'])
1980    if save:
1981        df.to_csv(file_name_save + '.csv',index = False)
1982
1983    return df

Process the training of a Physics-informed Neural Network

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • at_each (int): Compute results for epochs multiple of at_each. Default 100
  • bolstering (logical): Whether to compute bolstering mean square error. Default True
  • mc_sample (int): Number of sample for Monte Carlo integration in bolstering. Default 10000
  • save (logical): Whether to save the training results. Default False
  • file_name_save (str): File prefix to save the plots and the L2 error. Default 'result_pinn'
  • key (int): Key for random samples in bolstering. Default 0
  • ec (float): Stopping criteria error for EM algorithm in bolstering. Default 1e-6
  • lamb (float): Hyperparameter of EM algorithm in bolstering. Default 1
Returns
  • pandas data frame with training results
def demo_train_pinn1D( test_data, file_name, at_each=100, times=5, d2=True, file_name_save='result_pinn_demo', title='', framerate=10):
1986def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10):
1987    """
1988    Demo video with the training of a 1D PINN
1989    ----------
1990
1991    Parameters
1992    ----------
1993    test_data : dict
1994
1995        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1996
1997    file_name : str
1998
1999        Name of the files saved during training
2000
2001    at_each : int
2002
2003        Compute results for epochs multiple of at_each. Default 100
2004
2005    times : int
2006
2007        Number of points along the time interval to plot. Default 5
2008
2009    d2 : logical
2010
2011        Whether to make video demo of 2D plot. Default True
2012
2013    file_name_save : str
2014
2015        File prefix to save the plots and videos. Default 'result_pinn_demo'
2016
2017    title : str
2018
2019        Title for plots
2020
2021    framerate : int
2022
2023        Framerate for video. Default 10
2024
2025    Returns
2026    -------
2027    None
2028    """
2029    #Config
2030    with open(file_name + '_config.pickle', 'rb') as file:
2031        config = pickle.load(file)
2032    epochs = config['epochs']
2033    train_data = config['train_data']
2034    forward = config['forward']
2035
2036    #Get train data
2037    td = get_train_data(train_data)
2038    xt = td['x']
2039    u = td['y']
2040
2041    #Create folder to save plots
2042    os.system('mkdir ' + file_name_save)
2043
2044    #Create images
2045    k = 1
2046    with alive_bar(epochs) as bar:
2047        for e in range(epochs):
2048            if e % at_each == 0 or e == epochs - 1:
2049                #Read parameters
2050                params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
2051
2052                #Define learned function
2053                def psi(x):
2054                    return forward(x,params['params']['net'])
2055
2056                #Compute L2 train, L2 test and loss
2057                loss = params['loss']
2058                L2_train = L2error(psi(xt),u)
2059                L2_test = L2error(psi(test_data['xt']),test_data['u'])
2060                title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6))
2061
2062                #Save plot
2063                plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch)
2064                k = k + 1
2065
2066                #Delete
2067                del params, psi, loss, L2_train, L2_test, title_epoch
2068            #Update alive_bar
2069            bar()
2070    #Create demo video
2071    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4')
2072    if d2:
2073        os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')

Demo video with the training of a 1D PINN

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • at_each (int): Compute results for epochs multiple of at_each. Default 100
  • times (int): Number of points along the time interval to plot. Default 5
  • d2 (logical): Whether to make video demo of 2D plot. Default True
  • file_name_save (str): File prefix to save the plots and videos. Default 'result_pinn_demo'
  • title (str): Title for plots
  • framerate (int): Framerate for video. Default 10
Returns
  • None
def demo_time_pinn1D( test_data, file_name, epochs, file_name_save='result_pinn_time_demo', title='', framerate=10):
2076def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10):
2077    """
2078    Demo video with the time evolution of a 1D PINN
2079    ----------
2080
2081    Parameters
2082    ----------
2083    test_data : dict
2084
2085        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
2086
2087    file_name : str
2088
2089        Name of the files saved during training
2090
2091    epochs : list
2092
2093        Which training epochs to plot
2094
2095    file_name_save : str
2096
2097        File prefix to save the plots and video. Default 'result_pinn_time_demo'
2098
2099    title : str
2100
2101        Title for plots
2102
2103    framerate : int
2104
2105        Framerate for video. Default 10
2106
2107    Returns
2108    -------
2109    None
2110    """
2111    #Config
2112    with open(file_name + '_config.pickle', 'rb') as file:
2113        config = pickle.load(file)
2114    train_data = config['train_data']
2115    forward = config['forward']
2116
2117    #Create folder to save plots
2118    os.system('mkdir ' + file_name_save)
2119
2120    #Plot parameters
2121    tdom = jnp.unique(test_data['xt'][:,-1])
2122    ylo = jnp.min(test_data['u'])
2123    ylo = ylo - 0.1*jnp.abs(ylo)
2124    yup = jnp.max(test_data['u'])
2125    yup = yup + 0.1*jnp.abs(yup)
2126
2127    #Open PINN for each epoch
2128    results = []
2129    upred = []
2130    for e in epochs:
2131        tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
2132        results = results + [tmp]
2133        upred = upred + [forward(test_data['xt'],tmp['params']['net'])]
2134
2135    #Create images
2136    k = 1
2137    with alive_bar(len(tdom)) as bar:
2138        for t in tdom:
2139            #Test data
2140            xt_step = test_data['xt'][test_data['xt'][:,-1] == t]
2141            u_step = test_data['u'][test_data['xt'][:,-1] == t]
2142            #Initialize plot
2143            if len(epochs) > 1:
2144                fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2))
2145            else:
2146                fig, ax = plt.subplots(1,1,figsize = (10,5))
2147            #Create
2148            index = 0
2149            if int(len(epochs)/2) > 1:
2150                for i in range(int(len(epochs)/2)):
2151                    for j in range(min(2,len(epochs))):
2152                        upred_step = upred[index][test_data['xt'][:,-1] == t]
2153                        ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2154                        ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2155                        ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2156                        ax[i,j].set_xlabel(' ')
2157                        ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2158                        index = index + 1
2159            elif len(epochs) > 1:
2160                for j in range(2):
2161                    upred_step = upred[index][test_data['xt'][:,-1] == t]
2162                    ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2163                    ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2164                    ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2165                    ax[j].set_xlabel(' ')
2166                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2167                    index = index + 1
2168            else:
2169                upred_step = upred[index][test_data['xt'][:,-1] == t]
2170                ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
2171                ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
2172                ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10)
2173                ax.set_xlabel(' ')
2174                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
2175                index = index + 1
2176
2177
2178            #Title
2179            fig.suptitle(title + 't = ' + str(round(t,4)))
2180            fig.tight_layout()
2181
2182            #Show and save
2183            fig = plt.gcf()
2184            fig.savefig(file_name_save + '/' + str(k) + '.png')
2185            k = k + 1
2186            plt.close()
2187            bar()
2188
2189    #Create demo video
2190    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')

Demo video with the time evolution of a 1D PINN

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • epochs (list): Which training epochs to plot
  • file_name_save (str): File prefix to save the plots and video. Default 'result_pinn_time_demo'
  • title (str): Title for plots
  • framerate (int): Framerate for video. Default 10
Returns
  • None