jinnax.nn

   1#Functions to train NN
   2import jax
   3import jax.numpy as jnp
   4import optax
   5from alive_progress import alive_bar
   6import math
   7import time
   8import pandas as pd
   9import numpy as np
  10import matplotlib.pyplot as plt
  11import dill as pickle
  12from genree import bolstering as gb
  13from genree import kernel as gk
  14from jax import random
  15from jinnax import data as jd
  16import os
  17
  18__docformat__ = "numpy"
  19
  20#MSE
  21@jax.jit
  22def MSE(pred,true):
  23    """
  24    Mean square error
  25    ----------
  26
  27    Parameters
  28    ----------
  29    pred : jax.numpy.array
  30
  31        A JAX numpy array with the predicted values
  32
  33    true : jax.numpy.array
  34
  35        A JAX numpy array with the true values
  36
  37    Returns
  38    -------
  39    mean square error
  40    """
  41    return (true - pred) ** 2
  42
  43#MSE self-adaptative
  44@jax.jit
  45def MSE_SA(pred,true,w):
  46    """
  47    Selft-adaptative mean square error
  48    ----------
  49
  50    Parameters
  51    ----------
  52    pred : jax.numpy.array
  53
  54        A JAX numpy array with the predicted values
  55
  56    true : jax.numpy.array
  57
  58        A JAX numpy array with the true values
  59
  60    wheight : jax.numpy.array
  61
  62        A JAX numpy array with the weights
  63
  64    c : float
  65
  66        Hyperparameter
  67
  68    Returns
  69    -------
  70    self-adaptative mean square error with sigmoid mask
  71    """
  72    return (w * (true - pred)) ** 2
  73
  74#L2 error
  75@jax.jit
  76def L2error(pred,true):
  77    """
  78    L2-error
  79    ----------
  80
  81    Parameters
  82    ----------
  83    pred : jax.numpy.array
  84
  85        A JAX numpy array with the predicted values
  86
  87    true : jax.numpy.array
  88
  89        A JAX numpy array with the true values
  90
  91    Returns
  92    -------
  93    L2-error
  94    """
  95    return jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))
  96
  97#Simple fully connected architecture. Return the initial parameters and the function for the forward pass
  98def fconNN(width,activation = jax.nn.tanh,key = 0):
  99    """
 100    Initialize fully connected neural network
 101    ----------
 102
 103    Parameters
 104    ----------
 105    width : list
 106
 107        List with the layers width
 108
 109    activation : jax.nn activation
 110
 111        The activation function. Default jax.nn.tanh
 112
 113    key : int
 114
 115        Seed for parameters initialization. Default 0
 116
 117    Returns
 118    -------
 119    dict with initial parameters and the function for the forward pass
 120    """
 121    #Initialize parameters with Glorot initialization
 122    initializer = jax.nn.initializers.glorot_normal()
 123    key = jax.random.split(jax.random.PRNGKey(key),len(width)-1) #Seed for initialization
 124    params = list()
 125    for key,lin,lout in zip(key,width[:-1],width[1:]):
 126        W = initializer(key,(lin,lout),jnp.float32)
 127        B = initializer(key,(1,lout),jnp.float32)
 128        params.append({'W':W,'B':B})
 129
 130    #Define function for forward pass
 131    @jax.jit
 132    def forward(x,params):
 133      *hidden,output = params
 134      for layer in hidden:
 135        x = activation(x @ layer['W'] + layer['B'])
 136      return x @ output['W'] + output['B']
 137
 138    #Return initial parameters and forward function
 139    return {'params': params,'forward': forward}
 140
 141#Get activation from string
 142def get_activation(act):
 143    """
 144    Return activation function from string
 145    ----------
 146
 147    Parameters
 148    ----------
 149    act : str
 150
 151        Name of the activation function. Default 'tanh'
 152
 153    Returns
 154    -------
 155    jax.nn activation function
 156    """
 157    if act == 'tanh':
 158        return jax.nn.tanh
 159    elif act == 'relu':
 160        return jax.nn.relu
 161    elif act == 'relu6':
 162        return jax.nn.relu6
 163    elif act == 'sigmoid':
 164        return jax.nn.sigmoid
 165    elif act == 'softplus':
 166        return jax.nn.softplus
 167    elif act == 'sparse_plus':
 168        return jx.nn.sparse_plus
 169    elif act == 'soft_sign':
 170        return jax.nn.soft_sign
 171    elif act == 'silu':
 172        return jax.nn.silu
 173    elif act == 'swish':
 174        return jax.nn.swish
 175    elif act == 'log_sigmoid':
 176        return jax.nn.log_sigmoid
 177    elif act == 'leaky_relu':
 178        return jax.xx.leaky_relu
 179    elif act == 'hard_sigmoid':
 180        return jax.nn.hard_sigmoid
 181    elif act == 'hard_silu':
 182        return jax.nn.hard_silu
 183    elif act == 'hard_swish':
 184        return jax.nn.hard_swish
 185    elif act == 'hard_tanh':
 186        return jax.nn.hard_tanh
 187    elif act == 'elu':
 188        return jax.nn.elu
 189    elif act == 'celu':
 190        return jax.nn.celu
 191    elif act == 'selu':
 192        return jax.nn.selu
 193    elif act == 'gelu':
 194        return jax.nn.gelu
 195    elif act == 'glu':
 196        return jax.nn.glu
 197    elif act == 'squareplus':
 198        return  jax.nn.squareplus
 199    elif act == 'mish':
 200        return jax.nn.mish
 201
 202#Training PINN
 203def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9):
 204    """
 205    Train a Physics-informed Neural Network
 206    ----------
 207
 208    Parameters
 209    ----------
 210    data : dict
 211
 212        Data generated by the jinnax.data.generate_PINNdata function
 213
 214    width : list
 215
 216        A list with the width of each layer
 217
 218    pde : function
 219
 220        The partial differential operator. Its arguments are u, x and t
 221
 222    test_data : dict, None
 223
 224        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
 225
 226    epochs : int
 227
 228        Number of training epochs. Default 100
 229
 230    at_each : int
 231
 232        Save results for epochs multiple of at_each. Default 10
 233
 234    activation : str
 235
 236        The name of the activation function of the neural network. Default 'tanh'
 237
 238    neumann : logical
 239
 240        Whether to consider Neumann boundary conditions
 241
 242    oper_neumann : function
 243
 244        Penalization of Neumann boundary conditions
 245
 246    sa : logical
 247
 248        Whether to consider self-adaptative PINN
 249
 250    c : dict
 251
 252        Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
 253
 254    inverse : logical
 255
 256        Whether to estimate parameters of the PDE
 257
 258    initial_par : jax.numpy.array
 259
 260        Initial value of the parameters of the PDE in an inverse problem
 261
 262    lr,b1,b2,eps,eps_root: float
 263
 264        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
 265
 266    key : int
 267
 268        Seed for parameters initialization. Default 0
 269
 270    epoch_print : int
 271
 272        Number of epochs to calculate and print test errors. Default 100
 273
 274    save : logical
 275
 276        Whether to save the current parameters. Default False
 277
 278    file_name : str
 279
 280        File prefix to save the current parameters. Default 'result_pinn'
 281
 282    exp_decay : logical
 283
 284        Whether to consider exponential decay of learning rate. Default False
 285
 286    transition_steps : int
 287
 288        Number of steps for exponential decay. Default 1000
 289
 290    decay_rate : float
 291
 292        Rate of exponential decay. Default 0.9
 293
 294    Returns
 295    -------
 296    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
 297    """
 298
 299    #Initialize architecture
 300    nnet = fconNN(width,get_activation(activation),key)
 301    forward = nnet['forward']
 302
 303    #Initialize self adaptative weights
 304    par_sa = {}
 305    if sa:
 306        #Initialize wheights close to zero
 307        ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000)
 308        if data['sensor'] is not None:
 309            par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))})
 310        if data['initial'] is not None:
 311            par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))})
 312        if data['collocation'] is not None:
 313            par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))})
 314        if data['boundary'] is not None:
 315            par_sa.update({'wb': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))})
 316
 317    #Store all parameters
 318    params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa}
 319
 320    #Save config file
 321    if save:
 322        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
 323
 324    #Define loss function
 325    if sa:
 326        #Define loss function
 327        @jax.jit
 328        def lf(params,x):
 329            loss = 0
 330            if x['sensor'] is not None:
 331                #Term that refers to sensor data
 332                loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws']))
 333            if x['boundary'] is not None:
 334                if neumann:
 335                    #Neumann coditions
 336                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
 337                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
 338                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb']))
 339                else:
 340                    #Term that refers to boundary data
 341                    loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb']))
 342            if x['initial'] is not None:
 343                #Term that refers to initial data
 344                loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0']))
 345            if x['collocation'] is not None:
 346                #Term that refers to collocation points
 347                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
 348                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
 349                if inverse:
 350                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr']))
 351                else:
 352                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr']))
 353            return loss
 354    else:
 355        @jax.jit
 356        def lf(params,x):
 357            loss = 0
 358            if x['sensor'] is not None:
 359                #Term that refers to sensor data
 360                loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor']))
 361            if x['boundary'] is not None:
 362                if neumann:
 363                    #Neumann coditions
 364                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
 365                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
 366                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb))
 367                else:
 368                    #Term that refers to boundary data
 369                    loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary']))
 370            if x['initial'] is not None:
 371                #Term that refers to initial data
 372                loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial']))
 373            if x['collocation'] is not None:
 374                #Term that refers to collocation points
 375                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
 376                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
 377                if inverse:
 378                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0))
 379                else:
 380                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0))
 381            return loss
 382
 383    #Initialize Adam Optmizer
 384    if exp_decay:
 385        lr = optax.exponential_decay(lr,transition_steps,decay_rate)
 386    optimizer = optax.adam(lr,b1,b2,eps,eps_root)
 387    opt_state = optimizer.init(params)
 388
 389    #Define the gradient function
 390    grad_loss = jax.jit(jax.grad(lf,0))
 391
 392    #Define update function
 393    @jax.jit
 394    def update(opt_state,params,x):
 395        #Compute gradient
 396        grads = grad_loss(params,x)
 397        #Invert gradient of self-adaptative wheights
 398        if sa:
 399            for w in grads['sa']:
 400                grads['sa'][w] = - grads['sa'][w]
 401        #Calculate parameters updates
 402        updates, opt_state = optimizer.update(grads, opt_state)
 403        #Update parameters
 404        params = optax.apply_updates(params, updates)
 405        #Return state of optmizer and updated parameters
 406        return opt_state,params
 407
 408    ###Training###
 409    t0 = time.time()
 410    #Initialize alive_bar for tracing in terminal
 411    with alive_bar(epochs) as bar:
 412        #For each epoch
 413        for e in range(epochs):
 414            #Update optimizer state and parameters
 415            opt_state,params = update(opt_state,params,data)
 416            #After epoch_print epochs
 417            if e % epoch_print == 0:
 418                #Compute elapsed time and current error
 419                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6))
 420                #If there is test data, compute current L2 error
 421                if test_data is not None:
 422                    #Compute L2 error
 423                    l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist()
 424                    l = l + ' L2 error: ' + str(jnp.round(l2_test,6))
 425                if inverse:
 426                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
 427                #Print
 428                print(l)
 429            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
 430                #Save current parameters
 431                pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
 432            #Update alive_bar
 433            bar()
 434    #Define estimated function
 435    def u(xt):
 436        return forward(xt,params['net'])
 437
 438    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}
 439
 440#Process result
 441def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1):
 442    """
 443    Process the results of a Physics-informed Neural Network
 444    ----------
 445
 446    Parameters
 447    ----------
 448    test_data : dict
 449
 450        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
 451
 452    fit : function
 453
 454        The fitted function
 455
 456    train_data : dict
 457
 458        Training data generated by the jinnax.data.generate_PINNdata
 459
 460    plot : logical
 461
 462        Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
 463
 464    plot_test : logical
 465
 466        Whether to plot the test data. Default True
 467
 468    times : int
 469
 470        Number of points along the time interval to plot. Default 5
 471
 472    d2 : logical
 473
 474        Whether to plot 2D plot when the spatial dimension is one. Default True
 475
 476    save : logical
 477
 478        Whether to save the plots. Default False
 479
 480    show : logical
 481
 482        Whether to show the plots. Default True
 483
 484    file_name : str
 485
 486        File prefix to save the plots. Default 'result_pinn'
 487
 488    print_res : logical
 489
 490        Whether to print the L2 error. Default True
 491
 492    p : int
 493
 494        Output dimension. Default 1
 495
 496    Returns
 497    -------
 498    pandas data frame with L2 and MSE errors
 499    """
 500
 501    #Dimension
 502    d = test_data['xt'].shape[1] - 1
 503
 504    #Number of plots multiple of 5
 505    times = 5 * round(times/5.0)
 506
 507    #Data
 508    td = get_train_data(train_data)
 509    xt_train = td['x']
 510    u_train = td['y']
 511    upred_train = fit(xt_train)
 512    upred_test = fit(test_data['xt'])
 513
 514    #Results
 515    l2_error_test = L2error(upred_test,test_data['u']).tolist()
 516    MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist()
 517    l2_error_train = L2error(upred_train,u_train).tolist()
 518    MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist()
 519
 520    df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)),
 521        columns=['l2_error_test','MSE_test','l2_error_train','MSE_train'])
 522    if print_res:
 523        print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) )
 524
 525    #Plots
 526    if d == 1 and p ==1 and plot:
 527        plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name)
 528    elif p == 2 and plot:
 529        plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test)
 530
 531    return df
 532
 533#Plot results for d = 1
 534def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''):
 535    """
 536    Plot the prediction of a 1D PINN
 537    ----------
 538
 539    Parameters
 540    ----------
 541    times : int
 542
 543        Number of points along the time interval to plot. Default 5
 544
 545    xt : jax.numpy.array
 546
 547        Test data xt array
 548
 549    u : jax.numpy.array
 550
 551        Test data u(x,t) array
 552
 553    upred : jax.numpy.array
 554
 555        Predicted upred(x,t) array on test data
 556
 557    d2 : logical
 558
 559        Whether to plot 2D plot. Default True
 560
 561    save : logical
 562
 563        Whether to save the plots. Default False
 564
 565    show : logical
 566
 567        Whether to show the plots. Default True
 568
 569    file_name : str
 570
 571        File prefix to save the plots. Default 'result_pinn'
 572
 573    title_1d : str
 574
 575        Title of 1D plot
 576
 577    title_2d : str
 578
 579        Title of 2D plot
 580
 581    Returns
 582    -------
 583    None
 584    """
 585    #Initialize
 586    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
 587    tlo = jnp.min(xt[:,-1])
 588    tup = jnp.max(xt[:,-1])
 589    ylo = jnp.min(u)
 590    ylo = ylo - 0.1*jnp.abs(ylo)
 591    yup = jnp.max(u)
 592    yup = yup + 0.1*jnp.abs(yup)
 593    k = 0
 594    t_values = np.linspace(tlo,tup,times)
 595
 596    #Create
 597    for i in range(int(times/5)):
 598        for j in range(5):
 599            if k < len(t_values):
 600                t = t_values[k]
 601                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
 602                x_plot = xt[xt[:,-1] == t,:-1]
 603                y_plot = upred[xt[:,-1] == t,:]
 604                u_plot = u[xt[:,-1] == t,:]
 605                if int(times/5) > 1:
 606                    ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
 607                    ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
 608                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
 609                    ax[i,j].set_xlabel(' ')
 610                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
 611                else:
 612                    ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
 613                    ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
 614                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
 615                    ax[j].set_xlabel(' ')
 616                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
 617                k = k + 1
 618
 619    #Title
 620    fig.suptitle(title_1d)
 621    fig.tight_layout()
 622
 623    #Show and save
 624    fig = plt.gcf()
 625    if show:
 626        plt.show()
 627    if save:
 628        fig.savefig(file_name + '_slices.png')
 629    plt.close()
 630
 631    #2d plot
 632    if d2:
 633        #Initialize
 634        fig, ax = plt.subplots(1,2)
 635        l1 = jnp.unique(xt[:,-1]).shape[0]
 636        l2 = jnp.unique(xt[:,0]).shape[0]
 637
 638        #Create
 639        ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
 640        ax[0].set_title('Exact')
 641        ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
 642        ax[1].set_title('Predicted')
 643
 644        #Title
 645        fig.suptitle(title_2d)
 646        fig.tight_layout()
 647
 648        #Show and save
 649        fig = plt.gcf()
 650        if show:
 651            plt.show()
 652        if save:
 653            fig.savefig(file_name + '_2d.png')
 654        plt.close()
 655
 656#Plot results for d = 1
 657def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True):
 658    """
 659    Plot the prediction of a PINN with 2D output
 660    ----------
 661    Parameters
 662    ----------
 663    times : int
 664
 665        Number of points along the time interval to plot. Default 5
 666
 667    xt : jax.numpy.array
 668
 669        Test data xt array
 670
 671    u : jax.numpy.array
 672
 673        Test data u(x,t) array
 674
 675    upred : jax.numpy.array
 676
 677        Predicted upred(x,t) array on test data
 678
 679    save : logical
 680
 681        Whether to save the plots. Default False
 682
 683    show : logical
 684
 685        Whether to show the plots. Default True
 686
 687    file_name : str
 688
 689        File prefix to save the plots. Default 'result_pinn'
 690
 691    title : str
 692
 693        Title of plot
 694
 695    plot_test : logical
 696
 697        Whether to plot the test data. Default True
 698
 699    Returns
 700    -------
 701    None
 702    """
 703    #Initialize
 704    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
 705    tlo = jnp.min(xt[:,-1])
 706    tup = jnp.max(xt[:,-1])
 707    xlo = jnp.min(u[:,0])
 708    xlo = xlo - 0.1*jnp.abs(xlo)
 709    xup = jnp.max(u[:,0])
 710    xup = xup + 0.1*jnp.abs(xup)
 711    ylo = jnp.min(u[:,1])
 712    ylo = ylo - 0.1*jnp.abs(ylo)
 713    yup = jnp.max(u[:,1])
 714    yup = yup + 0.1*jnp.abs(yup)
 715    k = 0
 716    t_values = np.linspace(tlo,tup,times)
 717
 718    #Create
 719    for i in range(int(times/5)):
 720        for j in range(5):
 721            if k < len(t_values):
 722                t = t_values[k]
 723                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
 724                xpred_plot = upred[xt[:,-1] == t,0]
 725                ypred_plot = upred[xt[:,-1] == t,1]
 726                if plot_test:
 727                    x_plot = u[xt[:,-1] == t,0]
 728                    y_plot = u[xt[:,-1] == t,1]
 729                if int(times/5) > 1:
 730                    if plot_test:
 731                        ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
 732                    ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction')
 733                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
 734                    ax[i,j].set_xlabel(' ')
 735                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
 736                else:
 737                    if plot_test:
 738                        ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
 739                    ax[j].plot(xpred_plot,ypred,'r-',linewidth=2,label='Prediction')
 740                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
 741                    ax[j].set_xlabel(' ')
 742                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
 743                k = k + 1
 744
 745    #Title
 746    fig.suptitle(title)
 747    fig.tight_layout()
 748
 749    #Show and save
 750    fig = plt.gcf()
 751    if show:
 752        plt.show()
 753    if save:
 754        fig.savefig(file_name + '_slices.png')
 755    plt.close()
 756
 757#Get train data in one array
 758def get_train_data(train_data):
 759    """
 760    Process training sample
 761    ----------
 762
 763    Parameters
 764    ----------
 765    train_data : dict
 766
 767        A dictionay with train data generated by the jinnax.data.generate_PINNdata function
 768
 769    Returns
 770    -------
 771    dict with the processed training data
 772    """
 773    xdata = None
 774    ydata = None
 775    xydata = None
 776    if train_data['sensor'] is not None:
 777        sensor_sample = train_data['sensor'].shape[0]
 778        xdata = train_data['sensor']
 779        ydata = train_data['usensor']
 780        xydata = jnp.column_stack((train_data['sensor'],train_data['usensor']))
 781    else:
 782        sensor_sample = 0
 783    if train_data['boundary'] is not None:
 784        boundary_sample = train_data['boundary'].shape[0]
 785        if xdata is not None:
 786            xdata = jnp.vstack((xdata,train_data['boundary']))
 787            ydata = jnp.vstack((ydata,train_data['uboundary']))
 788            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary']))))
 789        else:
 790            xdata = train_data['boundary']
 791            ydata = train_data['uboundary']
 792            xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary']))
 793    else:
 794        boundary_sample = 0
 795    if train_data['initial'] is not None:
 796        initial_sample = train_data['initial'].shape[0]
 797        if xdata is not None:
 798            xdata = jnp.vstack((xdata,train_data['initial']))
 799            ydata = jnp.vstack((ydata,train_data['uinitial']))
 800            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial']))))
 801        else:
 802            xdata = train_data['initial']
 803            ydata = train_data['uinitial']
 804            xydata = jnp.column_stack((train_data['initial'],train_data['uinitial']))
 805    else:
 806        initial_sample = 0
 807    if train_data['collocation'] is not None:
 808        collocation_sample = train_data['collocation'].shape[0]
 809    else:
 810        collocation_sample = 0
 811
 812    return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
 813
 814#Process training
 815def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1):
 816    """
 817    Process the training of a Physics-informed Neural Network
 818    ----------
 819
 820    Parameters
 821    ----------
 822    test_data : dict
 823
 824        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
 825
 826    file_name : str
 827
 828        Name of the files saved during training
 829
 830    at_each : int
 831
 832        Compute results for epochs multiple of at_each. Default 100
 833
 834    bolstering : logical
 835
 836        Whether to compute bolstering mean square error. Default True
 837
 838    mc_sample : int
 839
 840        Number of sample for Monte Carlo integration in bolstering. Default 10000
 841
 842    save : logical
 843
 844        Whether to save the training results. Default False
 845
 846    file_name_save : str
 847
 848        File prefix to save the plots and the L2 error. Default 'result_pinn'
 849
 850    key : int
 851
 852        Key for random samples in bolstering. Default 0
 853
 854    ec : float
 855
 856        Stopping criteria error for EM algorithm in bolstering. Default 1e-6
 857
 858    lamb : float
 859
 860        Hyperparameter of EM algorithm in bolstering. Default 1
 861
 862    Returns
 863    -------
 864    pandas data frame with training results
 865    """
 866    #Config
 867    config = pickle.load(open(file_name + '_config.pickle', 'rb'))
 868    epochs = config['epochs']
 869    train_data = config['train_data']
 870    forward = config['forward']
 871
 872    #Get train data
 873    td = get_train_data(train_data)
 874    xydata = td['xy']
 875    xdata = td['x']
 876    ydata = td['y']
 877    sensor_sample = td['sensor_sample']
 878    boundary_sample = td['boundary_sample']
 879    initial_sample = td['initial_sample']
 880    collocation_sample = td['collocation_sample']
 881
 882    #Generate keys
 883    if bolstering:
 884        keys = jax.random.split(jax.random.PRNGKey(key),epochs)
 885
 886    #Initialize loss
 887    train_mse = []
 888    test_mse = []
 889    train_L2 = []
 890    test_L2 = []
 891    bolstX = []
 892    bolstXY = []
 893    loss = []
 894    time = []
 895    ep = []
 896
 897    #Process training
 898    with alive_bar(epochs) as bar:
 899        for e in range(epochs):
 900            if (e % at_each == 0 and at_each != epochs) or e == epochs - 1:
 901                ep = ep + [e]
 902
 903                #Read parameters
 904                params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb'))
 905
 906                #Time
 907                time = time + [params['time']]
 908
 909                #Define learned function
 910                def psi(x):
 911                    return forward(x,params['params']['net'])
 912
 913                #Train MSE and L2
 914                if xdata is not None:
 915                    train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()]
 916                    train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()]
 917                else:
 918                    train_mse = train_mse + [None]
 919                    train_L2 = train_L2 + [None]
 920
 921                #Test MSE and L2
 922                test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()]
 923                test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()]
 924
 925                #Bolstering
 926                if bolstering:
 927                    bX = []
 928                    bXY = []
 929                    for method in ['chi','mm','mpe']:
 930                        kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
 931                        kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
 932                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
 933                        bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()]
 934                    for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]:
 935                        kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias)
 936                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
 937                    bolstX = bolstX + [bX]
 938                    bolstXY = bolstXY + [bXY]
 939                else:
 940                    bolstX = bolstX + [None]
 941                    bolstXY = bolstXY + [None]
 942
 943                #Loss
 944                loss = loss + [params['loss'].tolist()]
 945
 946                #Delete
 947                del params, psi
 948            #Update alive_bar
 949            bar()
 950
 951    #Bolstering results
 952    if bolstering:
 953        bolstX = jnp.array(bolstX)
 954        bolstXY = jnp.array(bolstXY)
 955
 956    #Create data frame
 957    if bolstering:
 958        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
 959            train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]),
 960            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4'])
 961    else:
 962        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
 963            train_mse,test_mse,train_L2,test_L2]),
 964            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2'])
 965    if save:
 966        df.to_csv(file_name_save + '.csv',index = False)
 967
 968    return df
 969
 970#Demo video for training1D PINN
 971def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10):
 972    """
 973    Demo video with the training of a 1D PINN
 974    ----------
 975
 976    Parameters
 977    ----------
 978    test_data : dict
 979
 980        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
 981
 982    file_name : str
 983
 984        Name of the files saved during training
 985
 986    at_each : int
 987
 988        Compute results for epochs multiple of at_each. Default 100
 989
 990    times : int
 991
 992        Number of points along the time interval to plot. Default 5
 993
 994    d2 : logical
 995
 996        Whether to make video demo of 2D plot. Default True
 997
 998    file_name_save : str
 999
1000        File prefix to save the plots and videos. Default 'result_pinn_demo'
1001
1002    title : str
1003
1004        Title for plots
1005
1006    framerate : int
1007
1008        Framerate for video. Default 10
1009
1010    Returns
1011    -------
1012    None
1013    """
1014    #Config
1015    with open(file_name + '_config.pickle', 'rb') as file:
1016        config = pickle.load(file)
1017    epochs = config['epochs']
1018    train_data = config['train_data']
1019    forward = config['forward']
1020
1021    #Get train data
1022    td = get_train_data(train_data)
1023    xt = td['x']
1024    u = td['y']
1025
1026    #Create folder to save plots
1027    os.system('mkdir ' + file_name_save)
1028
1029    #Create images
1030    k = 1
1031    with alive_bar(epochs) as bar:
1032        for e in range(epochs):
1033            if e % at_each == 0 or e == epochs - 1:
1034                #Read parameters
1035                params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1036
1037                #Define learned function
1038                def psi(x):
1039                    return forward(x,params['params']['net'])
1040
1041                #Compute L2 train, L2 test and loss
1042                loss = params['loss']
1043                L2_train = L2error(psi(xt),u)
1044                L2_test = L2error(psi(test_data['xt']),test_data['u'])
1045                title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6))
1046
1047                #Save plot
1048                plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch)
1049                k = k + 1
1050
1051                #Delete
1052                del params, psi, loss, L2_train, L2_test, title_epoch
1053            #Update alive_bar
1054            bar()
1055    #Create demo video
1056    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4')
1057    if d2:
1058        os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')
1059
1060#Demo in time for 1D PINN
1061def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10):
1062    """
1063    Demo video with the time evolution of a 1D PINN
1064    ----------
1065
1066    Parameters
1067    ----------
1068    test_data : dict
1069
1070        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1071
1072    file_name : str
1073
1074        Name of the files saved during training
1075
1076    epochs : list
1077
1078        Which training epochs to plot
1079
1080    file_name_save : str
1081
1082        File prefix to save the plots and video. Default 'result_pinn_time_demo'
1083
1084    title : str
1085
1086        Title for plots
1087
1088    framerate : int
1089
1090        Framerate for video. Default 10
1091
1092    Returns
1093    -------
1094    None
1095    """
1096    #Config
1097    with open(file_name + '_config.pickle', 'rb') as file:
1098        config = pickle.load(file)
1099    train_data = config['train_data']
1100    forward = config['forward']
1101
1102    #Create folder to save plots
1103    os.system('mkdir ' + file_name_save)
1104
1105    #Plot parameters
1106    tdom = jnp.unique(test_data['xt'][:,-1])
1107    ylo = jnp.min(test_data['u'])
1108    ylo = ylo - 0.1*jnp.abs(ylo)
1109    yup = jnp.max(test_data['u'])
1110    yup = yup + 0.1*jnp.abs(yup)
1111
1112    #Open PINN for each epoch
1113    results = []
1114    upred = []
1115    for e in epochs:
1116        tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1117        results = results + [tmp]
1118        upred = upred + [forward(test_data['xt'],tmp['params']['net'])]
1119
1120    #Create images
1121    k = 1
1122    with alive_bar(len(tdom)) as bar:
1123        for t in tdom:
1124            #Test data
1125            xt_step = test_data['xt'][test_data['xt'][:,-1] == t]
1126            u_step = test_data['u'][test_data['xt'][:,-1] == t]
1127            #Initialize plot
1128            if len(epochs) > 1:
1129                fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2))
1130            else:
1131                fig, ax = plt.subplots(1,1,figsize = (10,5))
1132            #Create
1133            index = 0
1134            if int(len(epochs)/2) > 1:
1135                for i in range(int(len(epochs)/2)):
1136                    for j in range(min(2,len(epochs))):
1137                        upred_step = upred[index][test_data['xt'][:,-1] == t]
1138                        ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
1139                        ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
1140                        ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1141                        ax[i,j].set_xlabel(' ')
1142                        ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1143                        index = index + 1
1144            elif len(epochs) > 1:
1145                for j in range(2):
1146                    upred_step = upred[index][test_data['xt'][:,-1] == t]
1147                    ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
1148                    ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
1149                    ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1150                    ax[j].set_xlabel(' ')
1151                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1152                    index = index + 1
1153            else:
1154                upred_step = upred[index][test_data['xt'][:,-1] == t]
1155                ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
1156                ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
1157                ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1158                ax.set_xlabel(' ')
1159                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1160                index = index + 1
1161
1162
1163            #Title
1164            fig.suptitle(title + 't = ' + str(round(t,4)))
1165            fig.tight_layout()
1166
1167            #Show and save
1168            fig = plt.gcf()
1169            fig.savefig(file_name_save + '/' + str(k) + '.png')
1170            k = k + 1
1171            plt.close()
1172            bar()
1173
1174    #Create demo video
1175    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
1176
1177#Demo in time for 1D PINN
1178def demo_time_pinn2D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10,ffmpeg = 'ffmpeg'):
1179    """
1180    Demo video with the time evolution of a 2D PINN
1181    ----------
1182    Parameters
1183    ----------
1184    test_data : dict
1185
1186        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1187
1188    file_name : str
1189
1190        Name of the files saved during training
1191
1192    epochs : list
1193
1194        Which training epochs to plot
1195
1196    file_name_save : str
1197
1198        File prefix to save the plots and video. Default 'result_pinn_time_demo'
1199
1200    title : str
1201
1202        Title for plots
1203
1204    framerate : int
1205
1206        Framerate for video. Default 10
1207
1208    ffmpeg : str
1209
1210        Path to ffmpeg
1211
1212    Returns
1213    -------
1214    None
1215    """
1216    #Config
1217    with open(file_name + '_config.pickle', 'rb') as file:
1218        config = pickle.load(file)
1219    train_data = config['train_data']
1220    forward = config['forward']
1221
1222    #Create folder to save plots
1223    os.system('mkdir ' + file_name_save)
1224
1225    #Plot parameters
1226    tdom = jnp.unique(test_data['xt'][:,-1])
1227    ylo = jnp.min(test_data['u'])
1228    ylo = ylo - 0.1*jnp.abs(ylo)
1229    yup = jnp.max(test_data['u'])
1230    yup = yup + 0.1*jnp.abs(yup)
1231
1232    #Open PINN for each epoch
1233    results = []
1234    upred = []
1235    for e in epochs:
1236        tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1237        results = results + [tmp]
1238        upred = upred + [forward(test_data['xt'],tmp['params']['net'])]
1239
1240    #Create images
1241    k = 1
1242    with alive_bar(len(tdom)) as bar:
1243        for t in tdom:
1244            #Test data
1245            xt_step = test_data['xt'][test_data['xt'][:,-1] == t]
1246            ux_step = test_data['u'][test_data['xt'][:,-1] == t,0]
1247            uy_step = test_data['u'][test_data['xt'][:,-1] == t,1]
1248            #Initialize plot
1249            if len(epochs) > 1:
1250                fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2))
1251            else:
1252                fig, ax = plt.subplots(1,1,figsize = (10,5))
1253            #Create
1254            index = 0
1255            if int(len(epochs)/2) > 1:
1256                for i in range(int(len(epochs)/2)):
1257                    for j in range(min(2,len(epochs))):
1258                        upredx_step = upred[index][test_data['xt'][:,-1] == t,0]
1259                        upredy_step = upred[index][test_data['xt'][:,-1] == t,1]
1260                        ax[i,j].plot(ux_step,uy_step,'b-',linewidth=2,label='Exact')
1261                        ax[i,j].plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction')
1262                        ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1263                        ax[i,j].set_xlabel(' ')
1264                        ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1265                        index = index + 1
1266            elif len(epochs) > 1:
1267                for j in range(2):
1268                    upredx_step = upred[index][test_data['xt'][:,-1] == t,0]
1269                    upredy_step = upred[index][test_data['xt'][:,-1] == t,1]
1270                    ax[j].plot(ux_step,uy_step,'b-',linewidth=2,label='Exact')
1271                    ax[j].plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction')
1272                    ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1273                    ax[j].set_xlabel(' ')
1274                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1275                    index = index + 1
1276            else:
1277                upredx_step = upred[index][test_data['xt'][:,-1] == t,0]
1278                upredy_step = upred[index][test_data['xt'][:,-1] == t,1]
1279                ax.plot(ux_step,uy_step,'b-',linewidth=2,label='Exact')
1280                ax.plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction')
1281                ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1282                ax.set_xlabel(' ')
1283                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1284                index = index + 1
1285
1286
1287            #Title
1288            fig.suptitle(title + 't = ' + str(round(t,4)))
1289            fig.tight_layout()
1290
1291            #Show and save
1292            fig = plt.gcf()
1293            fig.savefig(file_name_save + '/' + str(k) + '.png')
1294            k = k + 1
1295            plt.close()
1296            bar()
1297
1298    #Create demo video
1299    os.system(ffmpeg + ' -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
1300
1301def DN_CSF_circle(uinitial,xl,xu,tl,tu,width,radius,Ntb = 100,N0 = 100,Nc = 50,Ntc = 50,Ns = 100,Nts = 100,epochs = 100,at_each = 10,activation = 'tanh',sa = True,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,demo = True,framerate = 2,ffmpeg = 'ffmpeg',c = 1e-6):
1302    #If demo, then save
1303    if demo:
1304        save = True
1305
1306    #Define initial function and function to evaluate at boundary
1307    def uinit(x,t):
1308        u = uinitial(x,t)
1309        return jnp.append(u['u1'],u['u2'])
1310
1311    def ubound(x,t):
1312        u = uinitial(x,t)
1313        return jnp.append(u['u1'],u['u2'],1)
1314
1315
1316    #PDE operator
1317    def pde(u,x,t):
1318        #One function for each coordinate (assuming that x and t has dimension 1 x 1 and u(x,t) has dimension 1 x 2)
1319        u1 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,0][0]
1320        u2 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,1][0]
1321        #First derivatives of each coordinate
1322        ux1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t))
1323        ux2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t))
1324        ut1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),1)(x,t))
1325        ut2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),1)(x,t))
1326        #Second derivative of each coordinate
1327        ux1_tmp = lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t)
1328        ux2_tmp = lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t)
1329        uxx1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : ux1_tmp(x,t)[0],0)(x,t))
1330        uxx2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : ux2_tmp(x,t)[0],0)(x,t))
1331        #Return
1332        return jnp.sqrt((ut1(x,t) - uxx1(x,t)/(ux1(x,t) ** 2 + ux2(x,t) ** 2 + c)) ** 2 + (ut2(x,t) - uxx2(x,t)/(ux1(x,t) ** 2 + ux2(x,t) ** 2 + c)) ** 2)
1333
1334    #Operator to evaluate boundary conditions
1335    def oper_boundary(u,x,t,w = 1,Ntb = Ntb):
1336      #Enforce Dirichlet at the right boundary (fixed at point a, as the initial condition)
1337      res_right_dir = jnp.sum(jnp.where(x == xu,(u(x,t) - ubound(x,t)) ** 2,0),1).reshape(x.shape[0],1)
1338      #Enforce Dirichlet at the left boundary (is in the circle of radius fixed)
1339      res_left_dir = jnp.sum(jnp.where(x == xl,((jnp.sum(u(x,t) ** 2,1) - radius ** 2) ** 2).reshape(x.shape),0),1).reshape(x.shape[0],1)
1340      #One function for each coordinate (assuming that x and t has dimension 1 x 1 and u(x,t) has dimension 1 x 2)
1341      u1 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,0][0]
1342      u2 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,1][0]
1343      #Take the derivatives in x
1344      ux1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t))(x,t)
1345      ux2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t))(x,t)
1346      #Enforce Neumann at the left boundary
1347      nS = u(x,t)/jnp.sqrt(jnp.sum(u(x,t) ** 2,0)) #Assuming that u(x,y) \in S, compute the vector normal to S at u(x,t)
1348      nu = jnp.append(ux2,(-1)*ux1,1)/jnp.sqrt(ux1 ** 2 + ux2 ** 2)
1349      ip = jnp.sum(nS * nu,1).reshape(x.shape[0],1) ** 2
1350      res_left_neu = jnp.where(x == xl,ip,0)
1351      #Rearrange
1352      res = jnp.append(jnp.append(res_right_dir[:Ntb,:],res_left_dir[Ntb:2*Ntb,:],0),res_left_neu[2*Ntb:,:],0)
1353      return w*res
1354
1355    #Generate Data
1356    train_data = jd.generate_PINNdata(u = uinit,xl = xl,xu = xu,tl = tl,tu = tu,Ns = None,Nts = None,Nb = 2,Ntb = Ntb,N0 = N0,Nc = Nc,Ntc = Ntc,p = 2,poss = 'random',posts = 'random',pos0 = 'random',postb = 'random',posc = 'random',postc = 'random')
1357
1358    #Rearange boundary data
1359    train_data['boundary'] = jnp.append(jnp.append(train_data['boundary'][Ntb:,:],train_data['boundary'][:Ntb,:],0),train_data['boundary'][:Ntb,:],0)
1360    train_data['uboundary'] = jnp.append(jnp.append(train_data['uboundary'][Ntb:,:],train_data['uboundary'][:Ntb,:],0),train_data['uboundary'][:Ntb,:],0)
1361
1362    #Train PINN
1363    fit = train_PINN(train_data,width,pde,c = {'ws': 1,'wr': 1,'w0': 1,'wb': 1},test_data = None,epochs = epochs,at_each = at_each,activation = activation,neumann = True,oper_neumann = oper_boundary,sa = sa,lr = lr,b1 = b1,b2 = b2,eps = eps,eps_root = eps_root,key = key,epoch_print = epoch_print,save = save,file_name = file_name,exp_decay = exp_decay,transition_steps = transition_steps,decay_rate = decay_rate)
1364
1365    #Test data
1366    test_data = jd.generate_PINNdata(u = uinit,xl = xl,xu = xu,tl = tl,tu = tu,Ns = None,Nts = None,Nb = 2,Ntb = 2*Ntb,N0 = 2*N0,Nc = 2*Nc,Ntc = 2*Ntc,p = 2,poss = 'random',posts = 'random',pos0 = 'random',postb = 'random',posc = 'random',postc = 'random')
1367    Ntb = 2*Ntb
1368    test_data['boundary'] = jnp.append(jnp.append(test_data['boundary'][Ntb:,:],test_data['boundary'][:Ntb,:],0),test_data['boundary'][:Ntb,:],0)
1369    test_data['uboundary'] = jnp.append(jnp.append(test_data['uboundary'][Ntb:,:],test_data['uboundary'][:Ntb,:],0),test_data['uboundary'][:Ntb,:],0)
1370
1371    #Evaluate residuals
1372    def u(x,t):
1373        return fit['u'](jnp.append(x,t,1))
1374
1375    res_pde = jnp.mean(pde(u,test_data['collocation'][:,0].reshape((test_data['collocation'].shape[0],1)),test_data['collocation'][:,1].reshape((test_data['collocation'].shape[0],1))) ** 2)
1376    res_DN = oper_boundary(u,test_data['boundary'][:,0].reshape((test_data['boundary'].shape[0],1)),test_data['boundary'][:,1].reshape((test_data['boundary'].shape[0],1)),Ntb = Ntb)
1377    res_dir_right = jnp.mean(res_DN[:Ntb,:] ** 2)
1378    res_neu = jnp.mean(res_DN[2*Ntb:,:] ** 2)
1379    res_dir_left = jnp.mean(res_DN[Ntb:2*Ntb,:] ** 2)
1380    res_initial = jnp.mean((u(test_data['initial'][:,0].reshape((test_data['initial'].shape[0],1)),test_data['initial'][:,1].reshape((test_data['initial'].shape[0],1))) - test_data['uinitial']) ** 2)
1381
1382    #Save file
1383    res_data = pd.DataFrame({'PDE': [res_pde.tolist()],
1384                             'Dirichlet_Right': [res_dir_right.tolist()],
1385                             'Dirichlet_Left': [res_dir_left.tolist()],
1386                             'Neumann': [res_neu.tolist()],
1387                             'initial': res_initial,
1388                             'time': fit['time'],
1389                             'epochs': epochs})
1390    res_data.to_csv(file_name + '_residuals.csv')
1391
1392    if demo:
1393        def ucircle(x,t):
1394          y = 2*jnp.pi*(x - xl)/(xu - xl)
1395          return jnp.append(radius*jnp.sin(y),radius*jnp.cos(y),0)
1396        test_data = jd.generate_PINNdata(u = ucircle,xl = xl,xu = xu,tl = tl,tu = tu,Ns = Ns,Nts = Nts,Nb = 0,Ntb = 0,N0 = 0,Nc = 0,Ntc = 0,p = 2,train = False)
1397        demo_time_pinn2D(test_data,file_name,[epochs-1],file_name_save = file_name + '_demo',title = '',framerate = framerate,ffmpeg = ffmpeg)
1398
1399    return fit,res_data
@jax.jit
def MSE(pred, true):
22@jax.jit
23def MSE(pred,true):
24    """
25    Mean square error
26    ----------
27
28    Parameters
29    ----------
30    pred : jax.numpy.array
31
32        A JAX numpy array with the predicted values
33
34    true : jax.numpy.array
35
36        A JAX numpy array with the true values
37
38    Returns
39    -------
40    mean square error
41    """
42    return (true - pred) ** 2
Mean square error
Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
Returns
  • mean square error
@jax.jit
def MSE_SA(pred, true, w):
45@jax.jit
46def MSE_SA(pred,true,w):
47    """
48    Selft-adaptative mean square error
49    ----------
50
51    Parameters
52    ----------
53    pred : jax.numpy.array
54
55        A JAX numpy array with the predicted values
56
57    true : jax.numpy.array
58
59        A JAX numpy array with the true values
60
61    wheight : jax.numpy.array
62
63        A JAX numpy array with the weights
64
65    c : float
66
67        Hyperparameter
68
69    Returns
70    -------
71    self-adaptative mean square error with sigmoid mask
72    """
73    return (w * (true - pred)) ** 2

Selft-adaptative mean square error

Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
  • wheight (jax.numpy.array): A JAX numpy array with the weights
  • c (float): Hyperparameter
Returns
  • self-adaptative mean square error with sigmoid mask
@jax.jit
def L2error(pred, true):
76@jax.jit
77def L2error(pred,true):
78    """
79    L2-error
80    ----------
81
82    Parameters
83    ----------
84    pred : jax.numpy.array
85
86        A JAX numpy array with the predicted values
87
88    true : jax.numpy.array
89
90        A JAX numpy array with the true values
91
92    Returns
93    -------
94    L2-error
95    """
96    return jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))

L2-error

Parameters
  • pred (jax.numpy.array): A JAX numpy array with the predicted values
  • true (jax.numpy.array): A JAX numpy array with the true values
Returns
  • L2-error
def fconNN(width, activation=<PjitFunction of <function tanh>>, key=0):
 99def fconNN(width,activation = jax.nn.tanh,key = 0):
100    """
101    Initialize fully connected neural network
102    ----------
103
104    Parameters
105    ----------
106    width : list
107
108        List with the layers width
109
110    activation : jax.nn activation
111
112        The activation function. Default jax.nn.tanh
113
114    key : int
115
116        Seed for parameters initialization. Default 0
117
118    Returns
119    -------
120    dict with initial parameters and the function for the forward pass
121    """
122    #Initialize parameters with Glorot initialization
123    initializer = jax.nn.initializers.glorot_normal()
124    key = jax.random.split(jax.random.PRNGKey(key),len(width)-1) #Seed for initialization
125    params = list()
126    for key,lin,lout in zip(key,width[:-1],width[1:]):
127        W = initializer(key,(lin,lout),jnp.float32)
128        B = initializer(key,(1,lout),jnp.float32)
129        params.append({'W':W,'B':B})
130
131    #Define function for forward pass
132    @jax.jit
133    def forward(x,params):
134      *hidden,output = params
135      for layer in hidden:
136        x = activation(x @ layer['W'] + layer['B'])
137      return x @ output['W'] + output['B']
138
139    #Return initial parameters and forward function
140    return {'params': params,'forward': forward}
Initialize fully connected neural network
Parameters
  • width (list): List with the layers width
  • activation (jax.nn activation): The activation function. Default jax.nn.tanh
  • key (int): Seed for parameters initialization. Default 0
Returns
  • dict with initial parameters and the function for the forward pass
def get_activation(act):
143def get_activation(act):
144    """
145    Return activation function from string
146    ----------
147
148    Parameters
149    ----------
150    act : str
151
152        Name of the activation function. Default 'tanh'
153
154    Returns
155    -------
156    jax.nn activation function
157    """
158    if act == 'tanh':
159        return jax.nn.tanh
160    elif act == 'relu':
161        return jax.nn.relu
162    elif act == 'relu6':
163        return jax.nn.relu6
164    elif act == 'sigmoid':
165        return jax.nn.sigmoid
166    elif act == 'softplus':
167        return jax.nn.softplus
168    elif act == 'sparse_plus':
169        return jx.nn.sparse_plus
170    elif act == 'soft_sign':
171        return jax.nn.soft_sign
172    elif act == 'silu':
173        return jax.nn.silu
174    elif act == 'swish':
175        return jax.nn.swish
176    elif act == 'log_sigmoid':
177        return jax.nn.log_sigmoid
178    elif act == 'leaky_relu':
179        return jax.xx.leaky_relu
180    elif act == 'hard_sigmoid':
181        return jax.nn.hard_sigmoid
182    elif act == 'hard_silu':
183        return jax.nn.hard_silu
184    elif act == 'hard_swish':
185        return jax.nn.hard_swish
186    elif act == 'hard_tanh':
187        return jax.nn.hard_tanh
188    elif act == 'elu':
189        return jax.nn.elu
190    elif act == 'celu':
191        return jax.nn.celu
192    elif act == 'selu':
193        return jax.nn.selu
194    elif act == 'gelu':
195        return jax.nn.gelu
196    elif act == 'glu':
197        return jax.nn.glu
198    elif act == 'squareplus':
199        return  jax.nn.squareplus
200    elif act == 'mish':
201        return jax.nn.mish
Return activation function from string
Parameters
  • act (str): Name of the activation function. Default 'tanh'
Returns
  • jax.nn activation function
def train_PINN( data, width, pde, test_data=None, epochs=100, at_each=10, activation='tanh', neumann=False, oper_neumann=False, sa=False, c={'ws': 1, 'wr': 1, 'w0': 100, 'wb': 1}, inverse=False, initial_par=None, lr=0.001, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, key=0, epoch_print=100, save=False, file_name='result_pinn', exp_decay=False, transition_steps=1000, decay_rate=0.9):
204def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9):
205    """
206    Train a Physics-informed Neural Network
207    ----------
208
209    Parameters
210    ----------
211    data : dict
212
213        Data generated by the jinnax.data.generate_PINNdata function
214
215    width : list
216
217        A list with the width of each layer
218
219    pde : function
220
221        The partial differential operator. Its arguments are u, x and t
222
223    test_data : dict, None
224
225        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
226
227    epochs : int
228
229        Number of training epochs. Default 100
230
231    at_each : int
232
233        Save results for epochs multiple of at_each. Default 10
234
235    activation : str
236
237        The name of the activation function of the neural network. Default 'tanh'
238
239    neumann : logical
240
241        Whether to consider Neumann boundary conditions
242
243    oper_neumann : function
244
245        Penalization of Neumann boundary conditions
246
247    sa : logical
248
249        Whether to consider self-adaptative PINN
250
251    c : dict
252
253        Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
254
255    inverse : logical
256
257        Whether to estimate parameters of the PDE
258
259    initial_par : jax.numpy.array
260
261        Initial value of the parameters of the PDE in an inverse problem
262
263    lr,b1,b2,eps,eps_root: float
264
265        Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
266
267    key : int
268
269        Seed for parameters initialization. Default 0
270
271    epoch_print : int
272
273        Number of epochs to calculate and print test errors. Default 100
274
275    save : logical
276
277        Whether to save the current parameters. Default False
278
279    file_name : str
280
281        File prefix to save the current parameters. Default 'result_pinn'
282
283    exp_decay : logical
284
285        Whether to consider exponential decay of learning rate. Default False
286
287    transition_steps : int
288
289        Number of steps for exponential decay. Default 1000
290
291    decay_rate : float
292
293        Rate of exponential decay. Default 0.9
294
295    Returns
296    -------
297    dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
298    """
299
300    #Initialize architecture
301    nnet = fconNN(width,get_activation(activation),key)
302    forward = nnet['forward']
303
304    #Initialize self adaptative weights
305    par_sa = {}
306    if sa:
307        #Initialize wheights close to zero
308        ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000)
309        if data['sensor'] is not None:
310            par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))})
311        if data['initial'] is not None:
312            par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))})
313        if data['collocation'] is not None:
314            par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))})
315        if data['boundary'] is not None:
316            par_sa.update({'wb': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))})
317
318    #Store all parameters
319    params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa}
320
321    #Save config file
322    if save:
323        pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
324
325    #Define loss function
326    if sa:
327        #Define loss function
328        @jax.jit
329        def lf(params,x):
330            loss = 0
331            if x['sensor'] is not None:
332                #Term that refers to sensor data
333                loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws']))
334            if x['boundary'] is not None:
335                if neumann:
336                    #Neumann coditions
337                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
338                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
339                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb']))
340                else:
341                    #Term that refers to boundary data
342                    loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb']))
343            if x['initial'] is not None:
344                #Term that refers to initial data
345                loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0']))
346            if x['collocation'] is not None:
347                #Term that refers to collocation points
348                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
349                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
350                if inverse:
351                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr']))
352                else:
353                    loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr']))
354            return loss
355    else:
356        @jax.jit
357        def lf(params,x):
358            loss = 0
359            if x['sensor'] is not None:
360                #Term that refers to sensor data
361                loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor']))
362            if x['boundary'] is not None:
363                if neumann:
364                    #Neumann coditions
365                    xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1))
366                    tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1))
367                    loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb))
368                else:
369                    #Term that refers to boundary data
370                    loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary']))
371            if x['initial'] is not None:
372                #Term that refers to initial data
373                loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial']))
374            if x['collocation'] is not None:
375                #Term that refers to collocation points
376                x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1))
377                t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1))
378                if inverse:
379                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0))
380                else:
381                    loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0))
382            return loss
383
384    #Initialize Adam Optmizer
385    if exp_decay:
386        lr = optax.exponential_decay(lr,transition_steps,decay_rate)
387    optimizer = optax.adam(lr,b1,b2,eps,eps_root)
388    opt_state = optimizer.init(params)
389
390    #Define the gradient function
391    grad_loss = jax.jit(jax.grad(lf,0))
392
393    #Define update function
394    @jax.jit
395    def update(opt_state,params,x):
396        #Compute gradient
397        grads = grad_loss(params,x)
398        #Invert gradient of self-adaptative wheights
399        if sa:
400            for w in grads['sa']:
401                grads['sa'][w] = - grads['sa'][w]
402        #Calculate parameters updates
403        updates, opt_state = optimizer.update(grads, opt_state)
404        #Update parameters
405        params = optax.apply_updates(params, updates)
406        #Return state of optmizer and updated parameters
407        return opt_state,params
408
409    ###Training###
410    t0 = time.time()
411    #Initialize alive_bar for tracing in terminal
412    with alive_bar(epochs) as bar:
413        #For each epoch
414        for e in range(epochs):
415            #Update optimizer state and parameters
416            opt_state,params = update(opt_state,params,data)
417            #After epoch_print epochs
418            if e % epoch_print == 0:
419                #Compute elapsed time and current error
420                l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6))
421                #If there is test data, compute current L2 error
422                if test_data is not None:
423                    #Compute L2 error
424                    l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist()
425                    l = l + ' L2 error: ' + str(jnp.round(l2_test,6))
426                if inverse:
427                    l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6))
428                #Print
429                print(l)
430            if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save:
431                #Save current parameters
432                pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL)
433            #Update alive_bar
434            bar()
435    #Define estimated function
436    def u(xt):
437        return forward(xt,params['net'])
438
439    return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}

Train a Physics-informed Neural Network

Parameters
  • data (dict): Data generated by the jinnax.data.generate_PINNdata function
  • width (list): A list with the width of each layer
  • pde (function): The partial differential operator. Its arguments are u, x and t
  • test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
  • epochs (int): Number of training epochs. Default 100
  • at_each (int): Save results for epochs multiple of at_each. Default 10
  • activation (str): The name of the activation function of the neural network. Default 'tanh'
  • neumann (logical): Whether to consider Neumann boundary conditions
  • oper_neumann (function): Penalization of Neumann boundary conditions
  • sa (logical): Whether to consider self-adaptative PINN
  • c (dict): Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
  • inverse (logical): Whether to estimate parameters of the PDE
  • initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
  • lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
  • key (int): Seed for parameters initialization. Default 0
  • epoch_print (int): Number of epochs to calculate and print test errors. Default 100
  • save (logical): Whether to save the current parameters. Default False
  • file_name (str): File prefix to save the current parameters. Default 'result_pinn'
  • exp_decay (logical): Whether to consider exponential decay of learning rate. Default False
  • transition_steps (int): Number of steps for exponential decay. Default 1000
  • decay_rate (float): Rate of exponential decay. Default 0.9
Returns
  • dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
def process_result( test_data, fit, train_data, plot=True, plot_test=True, times=5, d2=True, save=False, show=True, file_name='result_pinn', print_res=True, p=1):
442def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1):
443    """
444    Process the results of a Physics-informed Neural Network
445    ----------
446
447    Parameters
448    ----------
449    test_data : dict
450
451        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
452
453    fit : function
454
455        The fitted function
456
457    train_data : dict
458
459        Training data generated by the jinnax.data.generate_PINNdata
460
461    plot : logical
462
463        Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
464
465    plot_test : logical
466
467        Whether to plot the test data. Default True
468
469    times : int
470
471        Number of points along the time interval to plot. Default 5
472
473    d2 : logical
474
475        Whether to plot 2D plot when the spatial dimension is one. Default True
476
477    save : logical
478
479        Whether to save the plots. Default False
480
481    show : logical
482
483        Whether to show the plots. Default True
484
485    file_name : str
486
487        File prefix to save the plots. Default 'result_pinn'
488
489    print_res : logical
490
491        Whether to print the L2 error. Default True
492
493    p : int
494
495        Output dimension. Default 1
496
497    Returns
498    -------
499    pandas data frame with L2 and MSE errors
500    """
501
502    #Dimension
503    d = test_data['xt'].shape[1] - 1
504
505    #Number of plots multiple of 5
506    times = 5 * round(times/5.0)
507
508    #Data
509    td = get_train_data(train_data)
510    xt_train = td['x']
511    u_train = td['y']
512    upred_train = fit(xt_train)
513    upred_test = fit(test_data['xt'])
514
515    #Results
516    l2_error_test = L2error(upred_test,test_data['u']).tolist()
517    MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist()
518    l2_error_train = L2error(upred_train,u_train).tolist()
519    MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist()
520
521    df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)),
522        columns=['l2_error_test','MSE_test','l2_error_train','MSE_train'])
523    if print_res:
524        print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) )
525
526    #Plots
527    if d == 1 and p ==1 and plot:
528        plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name)
529    elif p == 2 and plot:
530        plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test)
531
532    return df

Process the results of a Physics-informed Neural Network

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • fit (function): The fitted function
  • train_data (dict): Training data generated by the jinnax.data.generate_PINNdata
  • plot (logical): Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
  • plot_test (logical): Whether to plot the test data. Default True
  • times (int): Number of points along the time interval to plot. Default 5
  • d2 (logical): Whether to plot 2D plot when the spatial dimension is one. Default True
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • print_res (logical): Whether to print the L2 error. Default True
  • p (int): Output dimension. Default 1
Returns
  • pandas data frame with L2 and MSE errors
def plot_pinn1D( times, xt, u, upred, d2=True, save=False, show=True, file_name='result_pinn', title_1d='', title_2d=''):
535def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''):
536    """
537    Plot the prediction of a 1D PINN
538    ----------
539
540    Parameters
541    ----------
542    times : int
543
544        Number of points along the time interval to plot. Default 5
545
546    xt : jax.numpy.array
547
548        Test data xt array
549
550    u : jax.numpy.array
551
552        Test data u(x,t) array
553
554    upred : jax.numpy.array
555
556        Predicted upred(x,t) array on test data
557
558    d2 : logical
559
560        Whether to plot 2D plot. Default True
561
562    save : logical
563
564        Whether to save the plots. Default False
565
566    show : logical
567
568        Whether to show the plots. Default True
569
570    file_name : str
571
572        File prefix to save the plots. Default 'result_pinn'
573
574    title_1d : str
575
576        Title of 1D plot
577
578    title_2d : str
579
580        Title of 2D plot
581
582    Returns
583    -------
584    None
585    """
586    #Initialize
587    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
588    tlo = jnp.min(xt[:,-1])
589    tup = jnp.max(xt[:,-1])
590    ylo = jnp.min(u)
591    ylo = ylo - 0.1*jnp.abs(ylo)
592    yup = jnp.max(u)
593    yup = yup + 0.1*jnp.abs(yup)
594    k = 0
595    t_values = np.linspace(tlo,tup,times)
596
597    #Create
598    for i in range(int(times/5)):
599        for j in range(5):
600            if k < len(t_values):
601                t = t_values[k]
602                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
603                x_plot = xt[xt[:,-1] == t,:-1]
604                y_plot = upred[xt[:,-1] == t,:]
605                u_plot = u[xt[:,-1] == t,:]
606                if int(times/5) > 1:
607                    ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
608                    ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
609                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
610                    ax[i,j].set_xlabel(' ')
611                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
612                else:
613                    ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact')
614                    ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction')
615                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
616                    ax[j].set_xlabel(' ')
617                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
618                k = k + 1
619
620    #Title
621    fig.suptitle(title_1d)
622    fig.tight_layout()
623
624    #Show and save
625    fig = plt.gcf()
626    if show:
627        plt.show()
628    if save:
629        fig.savefig(file_name + '_slices.png')
630    plt.close()
631
632    #2d plot
633    if d2:
634        #Initialize
635        fig, ax = plt.subplots(1,2)
636        l1 = jnp.unique(xt[:,-1]).shape[0]
637        l2 = jnp.unique(xt[:,0]).shape[0]
638
639        #Create
640        ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
641        ax[0].set_title('Exact')
642        ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist())
643        ax[1].set_title('Predicted')
644
645        #Title
646        fig.suptitle(title_2d)
647        fig.tight_layout()
648
649        #Show and save
650        fig = plt.gcf()
651        if show:
652            plt.show()
653        if save:
654            fig.savefig(file_name + '_2d.png')
655        plt.close()

Plot the prediction of a 1D PINN

Parameters
  • times (int): Number of points along the time interval to plot. Default 5
  • xt (jax.numpy.array): Test data xt array
  • u (jax.numpy.array): Test data u(x,t) array
  • upred (jax.numpy.array): Predicted upred(x,t) array on test data
  • d2 (logical): Whether to plot 2D plot. Default True
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • title_1d (str): Title of 1D plot
  • title_2d (str): Title of 2D plot
Returns
  • None
def plot_pinn_out2D( times, xt, u, upred, save=False, show=True, file_name='result_pinn', title='', plot_test=True):
658def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True):
659    """
660    Plot the prediction of a PINN with 2D output
661    ----------
662    Parameters
663    ----------
664    times : int
665
666        Number of points along the time interval to plot. Default 5
667
668    xt : jax.numpy.array
669
670        Test data xt array
671
672    u : jax.numpy.array
673
674        Test data u(x,t) array
675
676    upred : jax.numpy.array
677
678        Predicted upred(x,t) array on test data
679
680    save : logical
681
682        Whether to save the plots. Default False
683
684    show : logical
685
686        Whether to show the plots. Default True
687
688    file_name : str
689
690        File prefix to save the plots. Default 'result_pinn'
691
692    title : str
693
694        Title of plot
695
696    plot_test : logical
697
698        Whether to plot the test data. Default True
699
700    Returns
701    -------
702    None
703    """
704    #Initialize
705    fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5)))
706    tlo = jnp.min(xt[:,-1])
707    tup = jnp.max(xt[:,-1])
708    xlo = jnp.min(u[:,0])
709    xlo = xlo - 0.1*jnp.abs(xlo)
710    xup = jnp.max(u[:,0])
711    xup = xup + 0.1*jnp.abs(xup)
712    ylo = jnp.min(u[:,1])
713    ylo = ylo - 0.1*jnp.abs(ylo)
714    yup = jnp.max(u[:,1])
715    yup = yup + 0.1*jnp.abs(yup)
716    k = 0
717    t_values = np.linspace(tlo,tup,times)
718
719    #Create
720    for i in range(int(times/5)):
721        for j in range(5):
722            if k < len(t_values):
723                t = t_values[k]
724                t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist()
725                xpred_plot = upred[xt[:,-1] == t,0]
726                ypred_plot = upred[xt[:,-1] == t,1]
727                if plot_test:
728                    x_plot = u[xt[:,-1] == t,0]
729                    y_plot = u[xt[:,-1] == t,1]
730                if int(times/5) > 1:
731                    if plot_test:
732                        ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
733                    ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction')
734                    ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10)
735                    ax[i,j].set_xlabel(' ')
736                    ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
737                else:
738                    if plot_test:
739                        ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact')
740                    ax[j].plot(xpred_plot,ypred,'r-',linewidth=2,label='Prediction')
741                    ax[j].set_title('$t = %.2f$' % (t),fontsize=10)
742                    ax[j].set_xlabel(' ')
743                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
744                k = k + 1
745
746    #Title
747    fig.suptitle(title)
748    fig.tight_layout()
749
750    #Show and save
751    fig = plt.gcf()
752    if show:
753        plt.show()
754    if save:
755        fig.savefig(file_name + '_slices.png')
756    plt.close()

Plot the prediction of a PINN with 2D output

Parameters
  • times (int): Number of points along the time interval to plot. Default 5
  • xt (jax.numpy.array): Test data xt array
  • u (jax.numpy.array): Test data u(x,t) array
  • upred (jax.numpy.array): Predicted upred(x,t) array on test data
  • save (logical): Whether to save the plots. Default False
  • show (logical): Whether to show the plots. Default True
  • file_name (str): File prefix to save the plots. Default 'result_pinn'
  • title (str): Title of plot
  • plot_test (logical): Whether to plot the test data. Default True
Returns
  • None
def get_train_data(train_data):
759def get_train_data(train_data):
760    """
761    Process training sample
762    ----------
763
764    Parameters
765    ----------
766    train_data : dict
767
768        A dictionay with train data generated by the jinnax.data.generate_PINNdata function
769
770    Returns
771    -------
772    dict with the processed training data
773    """
774    xdata = None
775    ydata = None
776    xydata = None
777    if train_data['sensor'] is not None:
778        sensor_sample = train_data['sensor'].shape[0]
779        xdata = train_data['sensor']
780        ydata = train_data['usensor']
781        xydata = jnp.column_stack((train_data['sensor'],train_data['usensor']))
782    else:
783        sensor_sample = 0
784    if train_data['boundary'] is not None:
785        boundary_sample = train_data['boundary'].shape[0]
786        if xdata is not None:
787            xdata = jnp.vstack((xdata,train_data['boundary']))
788            ydata = jnp.vstack((ydata,train_data['uboundary']))
789            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary']))))
790        else:
791            xdata = train_data['boundary']
792            ydata = train_data['uboundary']
793            xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary']))
794    else:
795        boundary_sample = 0
796    if train_data['initial'] is not None:
797        initial_sample = train_data['initial'].shape[0]
798        if xdata is not None:
799            xdata = jnp.vstack((xdata,train_data['initial']))
800            ydata = jnp.vstack((ydata,train_data['uinitial']))
801            xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial']))))
802        else:
803            xdata = train_data['initial']
804            ydata = train_data['uinitial']
805            xydata = jnp.column_stack((train_data['initial'],train_data['uinitial']))
806    else:
807        initial_sample = 0
808    if train_data['collocation'] is not None:
809        collocation_sample = train_data['collocation'].shape[0]
810    else:
811        collocation_sample = 0
812
813    return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
Process training sample
Parameters
Returns
  • dict with the processed training data
def process_training( test_data, file_name, at_each=100, bolstering=True, mc_sample=10000, save=False, file_name_save='result_pinn', key=0, ec=1e-06, lamb=1):
816def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1):
817    """
818    Process the training of a Physics-informed Neural Network
819    ----------
820
821    Parameters
822    ----------
823    test_data : dict
824
825        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
826
827    file_name : str
828
829        Name of the files saved during training
830
831    at_each : int
832
833        Compute results for epochs multiple of at_each. Default 100
834
835    bolstering : logical
836
837        Whether to compute bolstering mean square error. Default True
838
839    mc_sample : int
840
841        Number of sample for Monte Carlo integration in bolstering. Default 10000
842
843    save : logical
844
845        Whether to save the training results. Default False
846
847    file_name_save : str
848
849        File prefix to save the plots and the L2 error. Default 'result_pinn'
850
851    key : int
852
853        Key for random samples in bolstering. Default 0
854
855    ec : float
856
857        Stopping criteria error for EM algorithm in bolstering. Default 1e-6
858
859    lamb : float
860
861        Hyperparameter of EM algorithm in bolstering. Default 1
862
863    Returns
864    -------
865    pandas data frame with training results
866    """
867    #Config
868    config = pickle.load(open(file_name + '_config.pickle', 'rb'))
869    epochs = config['epochs']
870    train_data = config['train_data']
871    forward = config['forward']
872
873    #Get train data
874    td = get_train_data(train_data)
875    xydata = td['xy']
876    xdata = td['x']
877    ydata = td['y']
878    sensor_sample = td['sensor_sample']
879    boundary_sample = td['boundary_sample']
880    initial_sample = td['initial_sample']
881    collocation_sample = td['collocation_sample']
882
883    #Generate keys
884    if bolstering:
885        keys = jax.random.split(jax.random.PRNGKey(key),epochs)
886
887    #Initialize loss
888    train_mse = []
889    test_mse = []
890    train_L2 = []
891    test_L2 = []
892    bolstX = []
893    bolstXY = []
894    loss = []
895    time = []
896    ep = []
897
898    #Process training
899    with alive_bar(epochs) as bar:
900        for e in range(epochs):
901            if (e % at_each == 0 and at_each != epochs) or e == epochs - 1:
902                ep = ep + [e]
903
904                #Read parameters
905                params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb'))
906
907                #Time
908                time = time + [params['time']]
909
910                #Define learned function
911                def psi(x):
912                    return forward(x,params['params']['net'])
913
914                #Train MSE and L2
915                if xdata is not None:
916                    train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()]
917                    train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()]
918                else:
919                    train_mse = train_mse + [None]
920                    train_L2 = train_L2 + [None]
921
922                #Test MSE and L2
923                test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()]
924                test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()]
925
926                #Bolstering
927                if bolstering:
928                    bX = []
929                    bXY = []
930                    for method in ['chi','mm','mpe']:
931                        kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
932                        kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi)
933                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
934                        bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()]
935                    for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]:
936                        kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias)
937                        bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()]
938                    bolstX = bolstX + [bX]
939                    bolstXY = bolstXY + [bXY]
940                else:
941                    bolstX = bolstX + [None]
942                    bolstXY = bolstXY + [None]
943
944                #Loss
945                loss = loss + [params['loss'].tolist()]
946
947                #Delete
948                del params, psi
949            #Update alive_bar
950            bar()
951
952    #Bolstering results
953    if bolstering:
954        bolstX = jnp.array(bolstX)
955        bolstXY = jnp.array(bolstXY)
956
957    #Create data frame
958    if bolstering:
959        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
960            train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]),
961            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4'])
962    else:
963        df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss,
964            train_mse,test_mse,train_L2,test_L2]),
965            columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2'])
966    if save:
967        df.to_csv(file_name_save + '.csv',index = False)
968
969    return df

Process the training of a Physics-informed Neural Network

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • at_each (int): Compute results for epochs multiple of at_each. Default 100
  • bolstering (logical): Whether to compute bolstering mean square error. Default True
  • mc_sample (int): Number of sample for Monte Carlo integration in bolstering. Default 10000
  • save (logical): Whether to save the training results. Default False
  • file_name_save (str): File prefix to save the plots and the L2 error. Default 'result_pinn'
  • key (int): Key for random samples in bolstering. Default 0
  • ec (float): Stopping criteria error for EM algorithm in bolstering. Default 1e-6
  • lamb (float): Hyperparameter of EM algorithm in bolstering. Default 1
Returns
  • pandas data frame with training results
def demo_train_pinn1D( test_data, file_name, at_each=100, times=5, d2=True, file_name_save='result_pinn_demo', title='', framerate=10):
 972def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10):
 973    """
 974    Demo video with the training of a 1D PINN
 975    ----------
 976
 977    Parameters
 978    ----------
 979    test_data : dict
 980
 981        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
 982
 983    file_name : str
 984
 985        Name of the files saved during training
 986
 987    at_each : int
 988
 989        Compute results for epochs multiple of at_each. Default 100
 990
 991    times : int
 992
 993        Number of points along the time interval to plot. Default 5
 994
 995    d2 : logical
 996
 997        Whether to make video demo of 2D plot. Default True
 998
 999    file_name_save : str
1000
1001        File prefix to save the plots and videos. Default 'result_pinn_demo'
1002
1003    title : str
1004
1005        Title for plots
1006
1007    framerate : int
1008
1009        Framerate for video. Default 10
1010
1011    Returns
1012    -------
1013    None
1014    """
1015    #Config
1016    with open(file_name + '_config.pickle', 'rb') as file:
1017        config = pickle.load(file)
1018    epochs = config['epochs']
1019    train_data = config['train_data']
1020    forward = config['forward']
1021
1022    #Get train data
1023    td = get_train_data(train_data)
1024    xt = td['x']
1025    u = td['y']
1026
1027    #Create folder to save plots
1028    os.system('mkdir ' + file_name_save)
1029
1030    #Create images
1031    k = 1
1032    with alive_bar(epochs) as bar:
1033        for e in range(epochs):
1034            if e % at_each == 0 or e == epochs - 1:
1035                #Read parameters
1036                params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1037
1038                #Define learned function
1039                def psi(x):
1040                    return forward(x,params['params']['net'])
1041
1042                #Compute L2 train, L2 test and loss
1043                loss = params['loss']
1044                L2_train = L2error(psi(xt),u)
1045                L2_test = L2error(psi(test_data['xt']),test_data['u'])
1046                title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6))
1047
1048                #Save plot
1049                plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch)
1050                k = k + 1
1051
1052                #Delete
1053                del params, psi, loss, L2_train, L2_test, title_epoch
1054            #Update alive_bar
1055            bar()
1056    #Create demo video
1057    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4')
1058    if d2:
1059        os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')

Demo video with the training of a 1D PINN

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • at_each (int): Compute results for epochs multiple of at_each. Default 100
  • times (int): Number of points along the time interval to plot. Default 5
  • d2 (logical): Whether to make video demo of 2D plot. Default True
  • file_name_save (str): File prefix to save the plots and videos. Default 'result_pinn_demo'
  • title (str): Title for plots
  • framerate (int): Framerate for video. Default 10
Returns
  • None
def demo_time_pinn1D( test_data, file_name, epochs, file_name_save='result_pinn_time_demo', title='', framerate=10):
1062def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10):
1063    """
1064    Demo video with the time evolution of a 1D PINN
1065    ----------
1066
1067    Parameters
1068    ----------
1069    test_data : dict
1070
1071        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1072
1073    file_name : str
1074
1075        Name of the files saved during training
1076
1077    epochs : list
1078
1079        Which training epochs to plot
1080
1081    file_name_save : str
1082
1083        File prefix to save the plots and video. Default 'result_pinn_time_demo'
1084
1085    title : str
1086
1087        Title for plots
1088
1089    framerate : int
1090
1091        Framerate for video. Default 10
1092
1093    Returns
1094    -------
1095    None
1096    """
1097    #Config
1098    with open(file_name + '_config.pickle', 'rb') as file:
1099        config = pickle.load(file)
1100    train_data = config['train_data']
1101    forward = config['forward']
1102
1103    #Create folder to save plots
1104    os.system('mkdir ' + file_name_save)
1105
1106    #Plot parameters
1107    tdom = jnp.unique(test_data['xt'][:,-1])
1108    ylo = jnp.min(test_data['u'])
1109    ylo = ylo - 0.1*jnp.abs(ylo)
1110    yup = jnp.max(test_data['u'])
1111    yup = yup + 0.1*jnp.abs(yup)
1112
1113    #Open PINN for each epoch
1114    results = []
1115    upred = []
1116    for e in epochs:
1117        tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1118        results = results + [tmp]
1119        upred = upred + [forward(test_data['xt'],tmp['params']['net'])]
1120
1121    #Create images
1122    k = 1
1123    with alive_bar(len(tdom)) as bar:
1124        for t in tdom:
1125            #Test data
1126            xt_step = test_data['xt'][test_data['xt'][:,-1] == t]
1127            u_step = test_data['u'][test_data['xt'][:,-1] == t]
1128            #Initialize plot
1129            if len(epochs) > 1:
1130                fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2))
1131            else:
1132                fig, ax = plt.subplots(1,1,figsize = (10,5))
1133            #Create
1134            index = 0
1135            if int(len(epochs)/2) > 1:
1136                for i in range(int(len(epochs)/2)):
1137                    for j in range(min(2,len(epochs))):
1138                        upred_step = upred[index][test_data['xt'][:,-1] == t]
1139                        ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
1140                        ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
1141                        ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1142                        ax[i,j].set_xlabel(' ')
1143                        ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1144                        index = index + 1
1145            elif len(epochs) > 1:
1146                for j in range(2):
1147                    upred_step = upred[index][test_data['xt'][:,-1] == t]
1148                    ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
1149                    ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
1150                    ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1151                    ax[j].set_xlabel(' ')
1152                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1153                    index = index + 1
1154            else:
1155                upred_step = upred[index][test_data['xt'][:,-1] == t]
1156                ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact')
1157                ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction')
1158                ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1159                ax.set_xlabel(' ')
1160                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1161                index = index + 1
1162
1163
1164            #Title
1165            fig.suptitle(title + 't = ' + str(round(t,4)))
1166            fig.tight_layout()
1167
1168            #Show and save
1169            fig = plt.gcf()
1170            fig.savefig(file_name_save + '/' + str(k) + '.png')
1171            k = k + 1
1172            plt.close()
1173            bar()
1174
1175    #Create demo video
1176    os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')

Demo video with the time evolution of a 1D PINN

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • epochs (list): Which training epochs to plot
  • file_name_save (str): File prefix to save the plots and video. Default 'result_pinn_time_demo'
  • title (str): Title for plots
  • framerate (int): Framerate for video. Default 10
Returns
  • None
def demo_time_pinn2D( test_data, file_name, epochs, file_name_save='result_pinn_time_demo', title='', framerate=10, ffmpeg='ffmpeg'):
1179def demo_time_pinn2D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10,ffmpeg = 'ffmpeg'):
1180    """
1181    Demo video with the time evolution of a 2D PINN
1182    ----------
1183    Parameters
1184    ----------
1185    test_data : dict
1186
1187        A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
1188
1189    file_name : str
1190
1191        Name of the files saved during training
1192
1193    epochs : list
1194
1195        Which training epochs to plot
1196
1197    file_name_save : str
1198
1199        File prefix to save the plots and video. Default 'result_pinn_time_demo'
1200
1201    title : str
1202
1203        Title for plots
1204
1205    framerate : int
1206
1207        Framerate for video. Default 10
1208
1209    ffmpeg : str
1210
1211        Path to ffmpeg
1212
1213    Returns
1214    -------
1215    None
1216    """
1217    #Config
1218    with open(file_name + '_config.pickle', 'rb') as file:
1219        config = pickle.load(file)
1220    train_data = config['train_data']
1221    forward = config['forward']
1222
1223    #Create folder to save plots
1224    os.system('mkdir ' + file_name_save)
1225
1226    #Plot parameters
1227    tdom = jnp.unique(test_data['xt'][:,-1])
1228    ylo = jnp.min(test_data['u'])
1229    ylo = ylo - 0.1*jnp.abs(ylo)
1230    yup = jnp.max(test_data['u'])
1231    yup = yup + 0.1*jnp.abs(yup)
1232
1233    #Open PINN for each epoch
1234    results = []
1235    upred = []
1236    for e in epochs:
1237        tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle')
1238        results = results + [tmp]
1239        upred = upred + [forward(test_data['xt'],tmp['params']['net'])]
1240
1241    #Create images
1242    k = 1
1243    with alive_bar(len(tdom)) as bar:
1244        for t in tdom:
1245            #Test data
1246            xt_step = test_data['xt'][test_data['xt'][:,-1] == t]
1247            ux_step = test_data['u'][test_data['xt'][:,-1] == t,0]
1248            uy_step = test_data['u'][test_data['xt'][:,-1] == t,1]
1249            #Initialize plot
1250            if len(epochs) > 1:
1251                fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2))
1252            else:
1253                fig, ax = plt.subplots(1,1,figsize = (10,5))
1254            #Create
1255            index = 0
1256            if int(len(epochs)/2) > 1:
1257                for i in range(int(len(epochs)/2)):
1258                    for j in range(min(2,len(epochs))):
1259                        upredx_step = upred[index][test_data['xt'][:,-1] == t,0]
1260                        upredy_step = upred[index][test_data['xt'][:,-1] == t,1]
1261                        ax[i,j].plot(ux_step,uy_step,'b-',linewidth=2,label='Exact')
1262                        ax[i,j].plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction')
1263                        ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1264                        ax[i,j].set_xlabel(' ')
1265                        ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1266                        index = index + 1
1267            elif len(epochs) > 1:
1268                for j in range(2):
1269                    upredx_step = upred[index][test_data['xt'][:,-1] == t,0]
1270                    upredy_step = upred[index][test_data['xt'][:,-1] == t,1]
1271                    ax[j].plot(ux_step,uy_step,'b-',linewidth=2,label='Exact')
1272                    ax[j].plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction')
1273                    ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1274                    ax[j].set_xlabel(' ')
1275                    ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1276                    index = index + 1
1277            else:
1278                upredx_step = upred[index][test_data['xt'][:,-1] == t,0]
1279                upredy_step = upred[index][test_data['xt'][:,-1] == t,1]
1280                ax.plot(ux_step,uy_step,'b-',linewidth=2,label='Exact')
1281                ax.plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction')
1282                ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10)
1283                ax.set_xlabel(' ')
1284                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
1285                index = index + 1
1286
1287
1288            #Title
1289            fig.suptitle(title + 't = ' + str(round(t,4)))
1290            fig.tight_layout()
1291
1292            #Show and save
1293            fig = plt.gcf()
1294            fig.savefig(file_name_save + '/' + str(k) + '.png')
1295            k = k + 1
1296            plt.close()
1297            bar()
1298
1299    #Create demo video
1300    os.system(ffmpeg + ' -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')

Demo video with the time evolution of a 2D PINN

Parameters
  • test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
  • file_name (str): Name of the files saved during training
  • epochs (list): Which training epochs to plot
  • file_name_save (str): File prefix to save the plots and video. Default 'result_pinn_time_demo'
  • title (str): Title for plots
  • framerate (int): Framerate for video. Default 10
  • ffmpeg (str): Path to ffmpeg
Returns
  • None
def DN_CSF_circle( uinitial, xl, xu, tl, tu, width, radius, Ntb=100, N0=100, Nc=50, Ntc=50, Ns=100, Nts=100, epochs=100, at_each=10, activation='tanh', sa=True, lr=0.001, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, key=0, epoch_print=100, save=False, file_name='result_pinn', exp_decay=False, transition_steps=1000, decay_rate=0.9, demo=True, framerate=2, ffmpeg='ffmpeg', c=1e-06):
1302def DN_CSF_circle(uinitial,xl,xu,tl,tu,width,radius,Ntb = 100,N0 = 100,Nc = 50,Ntc = 50,Ns = 100,Nts = 100,epochs = 100,at_each = 10,activation = 'tanh',sa = True,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,demo = True,framerate = 2,ffmpeg = 'ffmpeg',c = 1e-6):
1303    #If demo, then save
1304    if demo:
1305        save = True
1306
1307    #Define initial function and function to evaluate at boundary
1308    def uinit(x,t):
1309        u = uinitial(x,t)
1310        return jnp.append(u['u1'],u['u2'])
1311
1312    def ubound(x,t):
1313        u = uinitial(x,t)
1314        return jnp.append(u['u1'],u['u2'],1)
1315
1316
1317    #PDE operator
1318    def pde(u,x,t):
1319        #One function for each coordinate (assuming that x and t has dimension 1 x 1 and u(x,t) has dimension 1 x 2)
1320        u1 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,0][0]
1321        u2 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,1][0]
1322        #First derivatives of each coordinate
1323        ux1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t))
1324        ux2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t))
1325        ut1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),1)(x,t))
1326        ut2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),1)(x,t))
1327        #Second derivative of each coordinate
1328        ux1_tmp = lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t)
1329        ux2_tmp = lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t)
1330        uxx1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : ux1_tmp(x,t)[0],0)(x,t))
1331        uxx2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : ux2_tmp(x,t)[0],0)(x,t))
1332        #Return
1333        return jnp.sqrt((ut1(x,t) - uxx1(x,t)/(ux1(x,t) ** 2 + ux2(x,t) ** 2 + c)) ** 2 + (ut2(x,t) - uxx2(x,t)/(ux1(x,t) ** 2 + ux2(x,t) ** 2 + c)) ** 2)
1334
1335    #Operator to evaluate boundary conditions
1336    def oper_boundary(u,x,t,w = 1,Ntb = Ntb):
1337      #Enforce Dirichlet at the right boundary (fixed at point a, as the initial condition)
1338      res_right_dir = jnp.sum(jnp.where(x == xu,(u(x,t) - ubound(x,t)) ** 2,0),1).reshape(x.shape[0],1)
1339      #Enforce Dirichlet at the left boundary (is in the circle of radius fixed)
1340      res_left_dir = jnp.sum(jnp.where(x == xl,((jnp.sum(u(x,t) ** 2,1) - radius ** 2) ** 2).reshape(x.shape),0),1).reshape(x.shape[0],1)
1341      #One function for each coordinate (assuming that x and t has dimension 1 x 1 and u(x,t) has dimension 1 x 2)
1342      u1 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,0][0]
1343      u2 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,1][0]
1344      #Take the derivatives in x
1345      ux1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t))(x,t)
1346      ux2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t))(x,t)
1347      #Enforce Neumann at the left boundary
1348      nS = u(x,t)/jnp.sqrt(jnp.sum(u(x,t) ** 2,0)) #Assuming that u(x,y) \in S, compute the vector normal to S at u(x,t)
1349      nu = jnp.append(ux2,(-1)*ux1,1)/jnp.sqrt(ux1 ** 2 + ux2 ** 2)
1350      ip = jnp.sum(nS * nu,1).reshape(x.shape[0],1) ** 2
1351      res_left_neu = jnp.where(x == xl,ip,0)
1352      #Rearrange
1353      res = jnp.append(jnp.append(res_right_dir[:Ntb,:],res_left_dir[Ntb:2*Ntb,:],0),res_left_neu[2*Ntb:,:],0)
1354      return w*res
1355
1356    #Generate Data
1357    train_data = jd.generate_PINNdata(u = uinit,xl = xl,xu = xu,tl = tl,tu = tu,Ns = None,Nts = None,Nb = 2,Ntb = Ntb,N0 = N0,Nc = Nc,Ntc = Ntc,p = 2,poss = 'random',posts = 'random',pos0 = 'random',postb = 'random',posc = 'random',postc = 'random')
1358
1359    #Rearange boundary data
1360    train_data['boundary'] = jnp.append(jnp.append(train_data['boundary'][Ntb:,:],train_data['boundary'][:Ntb,:],0),train_data['boundary'][:Ntb,:],0)
1361    train_data['uboundary'] = jnp.append(jnp.append(train_data['uboundary'][Ntb:,:],train_data['uboundary'][:Ntb,:],0),train_data['uboundary'][:Ntb,:],0)
1362
1363    #Train PINN
1364    fit = train_PINN(train_data,width,pde,c = {'ws': 1,'wr': 1,'w0': 1,'wb': 1},test_data = None,epochs = epochs,at_each = at_each,activation = activation,neumann = True,oper_neumann = oper_boundary,sa = sa,lr = lr,b1 = b1,b2 = b2,eps = eps,eps_root = eps_root,key = key,epoch_print = epoch_print,save = save,file_name = file_name,exp_decay = exp_decay,transition_steps = transition_steps,decay_rate = decay_rate)
1365
1366    #Test data
1367    test_data = jd.generate_PINNdata(u = uinit,xl = xl,xu = xu,tl = tl,tu = tu,Ns = None,Nts = None,Nb = 2,Ntb = 2*Ntb,N0 = 2*N0,Nc = 2*Nc,Ntc = 2*Ntc,p = 2,poss = 'random',posts = 'random',pos0 = 'random',postb = 'random',posc = 'random',postc = 'random')
1368    Ntb = 2*Ntb
1369    test_data['boundary'] = jnp.append(jnp.append(test_data['boundary'][Ntb:,:],test_data['boundary'][:Ntb,:],0),test_data['boundary'][:Ntb,:],0)
1370    test_data['uboundary'] = jnp.append(jnp.append(test_data['uboundary'][Ntb:,:],test_data['uboundary'][:Ntb,:],0),test_data['uboundary'][:Ntb,:],0)
1371
1372    #Evaluate residuals
1373    def u(x,t):
1374        return fit['u'](jnp.append(x,t,1))
1375
1376    res_pde = jnp.mean(pde(u,test_data['collocation'][:,0].reshape((test_data['collocation'].shape[0],1)),test_data['collocation'][:,1].reshape((test_data['collocation'].shape[0],1))) ** 2)
1377    res_DN = oper_boundary(u,test_data['boundary'][:,0].reshape((test_data['boundary'].shape[0],1)),test_data['boundary'][:,1].reshape((test_data['boundary'].shape[0],1)),Ntb = Ntb)
1378    res_dir_right = jnp.mean(res_DN[:Ntb,:] ** 2)
1379    res_neu = jnp.mean(res_DN[2*Ntb:,:] ** 2)
1380    res_dir_left = jnp.mean(res_DN[Ntb:2*Ntb,:] ** 2)
1381    res_initial = jnp.mean((u(test_data['initial'][:,0].reshape((test_data['initial'].shape[0],1)),test_data['initial'][:,1].reshape((test_data['initial'].shape[0],1))) - test_data['uinitial']) ** 2)
1382
1383    #Save file
1384    res_data = pd.DataFrame({'PDE': [res_pde.tolist()],
1385                             'Dirichlet_Right': [res_dir_right.tolist()],
1386                             'Dirichlet_Left': [res_dir_left.tolist()],
1387                             'Neumann': [res_neu.tolist()],
1388                             'initial': res_initial,
1389                             'time': fit['time'],
1390                             'epochs': epochs})
1391    res_data.to_csv(file_name + '_residuals.csv')
1392
1393    if demo:
1394        def ucircle(x,t):
1395          y = 2*jnp.pi*(x - xl)/(xu - xl)
1396          return jnp.append(radius*jnp.sin(y),radius*jnp.cos(y),0)
1397        test_data = jd.generate_PINNdata(u = ucircle,xl = xl,xu = xu,tl = tl,tu = tu,Ns = Ns,Nts = Nts,Nb = 0,Ntb = 0,N0 = 0,Nc = 0,Ntc = 0,p = 2,train = False)
1398        demo_time_pinn2D(test_data,file_name,[epochs-1],file_name_save = file_name + '_demo',title = '',framerate = framerate,ffmpeg = ffmpeg)
1399
1400    return fit,res_data