jinnax.nn
1#Functions to train NN 2import jax 3import jax.numpy as jnp 4import optax 5from alive_progress import alive_bar 6import math 7import time 8import pandas as pd 9import numpy as np 10import matplotlib.pyplot as plt 11import dill as pickle 12from genree import bolstering as gb 13from genree import kernel as gk 14from jax import random 15from jinnax import data as jd 16import os 17 18__docformat__ = "numpy" 19 20#MSE 21@jax.jit 22def MSE(pred,true): 23 """ 24 Mean square error 25 ---------- 26 27 Parameters 28 ---------- 29 pred : jax.numpy.array 30 31 A JAX numpy array with the predicted values 32 33 true : jax.numpy.array 34 35 A JAX numpy array with the true values 36 37 Returns 38 ------- 39 mean square error 40 """ 41 return (true - pred) ** 2 42 43#MSE self-adaptative 44@jax.jit 45def MSE_SA(pred,true,w): 46 """ 47 Selft-adaptative mean square error 48 ---------- 49 50 Parameters 51 ---------- 52 pred : jax.numpy.array 53 54 A JAX numpy array with the predicted values 55 56 true : jax.numpy.array 57 58 A JAX numpy array with the true values 59 60 wheight : jax.numpy.array 61 62 A JAX numpy array with the weights 63 64 c : float 65 66 Hyperparameter 67 68 Returns 69 ------- 70 self-adaptative mean square error with sigmoid mask 71 """ 72 return (w * (true - pred)) ** 2 73 74#L2 error 75@jax.jit 76def L2error(pred,true): 77 """ 78 L2-error 79 ---------- 80 81 Parameters 82 ---------- 83 pred : jax.numpy.array 84 85 A JAX numpy array with the predicted values 86 87 true : jax.numpy.array 88 89 A JAX numpy array with the true values 90 91 Returns 92 ------- 93 L2-error 94 """ 95 return jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2)) 96 97#Simple fully connected architecture. Return the initial parameters and the function for the forward pass 98def fconNN(width,activation = jax.nn.tanh,key = 0): 99 """ 100 Initialize fully connected neural network 101 ---------- 102 103 Parameters 104 ---------- 105 width : list 106 107 List with the layers width 108 109 activation : jax.nn activation 110 111 The activation function. Default jax.nn.tanh 112 113 key : int 114 115 Seed for parameters initialization. Default 0 116 117 Returns 118 ------- 119 dict with initial parameters and the function for the forward pass 120 """ 121 #Initialize parameters with Glorot initialization 122 initializer = jax.nn.initializers.glorot_normal() 123 key = jax.random.split(jax.random.PRNGKey(key),len(width)-1) #Seed for initialization 124 params = list() 125 for key,lin,lout in zip(key,width[:-1],width[1:]): 126 W = initializer(key,(lin,lout),jnp.float32) 127 B = initializer(key,(1,lout),jnp.float32) 128 params.append({'W':W,'B':B}) 129 130 #Define function for forward pass 131 @jax.jit 132 def forward(x,params): 133 *hidden,output = params 134 for layer in hidden: 135 x = activation(x @ layer['W'] + layer['B']) 136 return x @ output['W'] + output['B'] 137 138 #Return initial parameters and forward function 139 return {'params': params,'forward': forward} 140 141#Get activation from string 142def get_activation(act): 143 """ 144 Return activation function from string 145 ---------- 146 147 Parameters 148 ---------- 149 act : str 150 151 Name of the activation function. Default 'tanh' 152 153 Returns 154 ------- 155 jax.nn activation function 156 """ 157 if act == 'tanh': 158 return jax.nn.tanh 159 elif act == 'relu': 160 return jax.nn.relu 161 elif act == 'relu6': 162 return jax.nn.relu6 163 elif act == 'sigmoid': 164 return jax.nn.sigmoid 165 elif act == 'softplus': 166 return jax.nn.softplus 167 elif act == 'sparse_plus': 168 return jx.nn.sparse_plus 169 elif act == 'soft_sign': 170 return jax.nn.soft_sign 171 elif act == 'silu': 172 return jax.nn.silu 173 elif act == 'swish': 174 return jax.nn.swish 175 elif act == 'log_sigmoid': 176 return jax.nn.log_sigmoid 177 elif act == 'leaky_relu': 178 return jax.xx.leaky_relu 179 elif act == 'hard_sigmoid': 180 return jax.nn.hard_sigmoid 181 elif act == 'hard_silu': 182 return jax.nn.hard_silu 183 elif act == 'hard_swish': 184 return jax.nn.hard_swish 185 elif act == 'hard_tanh': 186 return jax.nn.hard_tanh 187 elif act == 'elu': 188 return jax.nn.elu 189 elif act == 'celu': 190 return jax.nn.celu 191 elif act == 'selu': 192 return jax.nn.selu 193 elif act == 'gelu': 194 return jax.nn.gelu 195 elif act == 'glu': 196 return jax.nn.glu 197 elif act == 'squareplus': 198 return jax.nn.squareplus 199 elif act == 'mish': 200 return jax.nn.mish 201 202#Training PINN 203def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9): 204 """ 205 Train a Physics-informed Neural Network 206 ---------- 207 208 Parameters 209 ---------- 210 data : dict 211 212 Data generated by the jinnax.data.generate_PINNdata function 213 214 width : list 215 216 A list with the width of each layer 217 218 pde : function 219 220 The partial differential operator. Its arguments are u, x and t 221 222 test_data : dict, None 223 224 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 225 226 epochs : int 227 228 Number of training epochs. Default 100 229 230 at_each : int 231 232 Save results for epochs multiple of at_each. Default 10 233 234 activation : str 235 236 The name of the activation function of the neural network. Default 'tanh' 237 238 neumann : logical 239 240 Whether to consider Neumann boundary conditions 241 242 oper_neumann : function 243 244 Penalization of Neumann boundary conditions 245 246 sa : logical 247 248 Whether to consider self-adaptative PINN 249 250 c : dict 251 252 Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1 253 254 inverse : logical 255 256 Whether to estimate parameters of the PDE 257 258 initial_par : jax.numpy.array 259 260 Initial value of the parameters of the PDE in an inverse problem 261 262 lr,b1,b2,eps,eps_root: float 263 264 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 265 266 key : int 267 268 Seed for parameters initialization. Default 0 269 270 epoch_print : int 271 272 Number of epochs to calculate and print test errors. Default 100 273 274 save : logical 275 276 Whether to save the current parameters. Default False 277 278 file_name : str 279 280 File prefix to save the current parameters. Default 'result_pinn' 281 282 exp_decay : logical 283 284 Whether to consider exponential decay of learning rate. Default False 285 286 transition_steps : int 287 288 Number of steps for exponential decay. Default 1000 289 290 decay_rate : float 291 292 Rate of exponential decay. Default 0.9 293 294 Returns 295 ------- 296 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time 297 """ 298 299 #Initialize architecture 300 nnet = fconNN(width,get_activation(activation),key) 301 forward = nnet['forward'] 302 303 #Initialize self adaptative weights 304 par_sa = {} 305 if sa: 306 #Initialize wheights close to zero 307 ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000) 308 if data['sensor'] is not None: 309 par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))}) 310 if data['initial'] is not None: 311 par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))}) 312 if data['collocation'] is not None: 313 par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))}) 314 if data['boundary'] is not None: 315 par_sa.update({'wb': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))}) 316 317 #Store all parameters 318 params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa} 319 320 #Save config file 321 if save: 322 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 323 324 #Define loss function 325 if sa: 326 #Define loss function 327 @jax.jit 328 def lf(params,x): 329 loss = 0 330 if x['sensor'] is not None: 331 #Term that refers to sensor data 332 loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws'])) 333 if x['boundary'] is not None: 334 if neumann: 335 #Neumann coditions 336 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 337 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 338 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb'])) 339 else: 340 #Term that refers to boundary data 341 loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb'])) 342 if x['initial'] is not None: 343 #Term that refers to initial data 344 loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0'])) 345 if x['collocation'] is not None: 346 #Term that refers to collocation points 347 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 348 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 349 if inverse: 350 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr'])) 351 else: 352 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr'])) 353 return loss 354 else: 355 @jax.jit 356 def lf(params,x): 357 loss = 0 358 if x['sensor'] is not None: 359 #Term that refers to sensor data 360 loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor'])) 361 if x['boundary'] is not None: 362 if neumann: 363 #Neumann coditions 364 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 365 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 366 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb)) 367 else: 368 #Term that refers to boundary data 369 loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary'])) 370 if x['initial'] is not None: 371 #Term that refers to initial data 372 loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial'])) 373 if x['collocation'] is not None: 374 #Term that refers to collocation points 375 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 376 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 377 if inverse: 378 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0)) 379 else: 380 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0)) 381 return loss 382 383 #Initialize Adam Optmizer 384 if exp_decay: 385 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 386 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 387 opt_state = optimizer.init(params) 388 389 #Define the gradient function 390 grad_loss = jax.jit(jax.grad(lf,0)) 391 392 #Define update function 393 @jax.jit 394 def update(opt_state,params,x): 395 #Compute gradient 396 grads = grad_loss(params,x) 397 #Invert gradient of self-adaptative wheights 398 if sa: 399 for w in grads['sa']: 400 grads['sa'][w] = - grads['sa'][w] 401 #Calculate parameters updates 402 updates, opt_state = optimizer.update(grads, opt_state) 403 #Update parameters 404 params = optax.apply_updates(params, updates) 405 #Return state of optmizer and updated parameters 406 return opt_state,params 407 408 ###Training### 409 t0 = time.time() 410 #Initialize alive_bar for tracing in terminal 411 with alive_bar(epochs) as bar: 412 #For each epoch 413 for e in range(epochs): 414 #Update optimizer state and parameters 415 opt_state,params = update(opt_state,params,data) 416 #After epoch_print epochs 417 if e % epoch_print == 0: 418 #Compute elapsed time and current error 419 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6)) 420 #If there is test data, compute current L2 error 421 if test_data is not None: 422 #Compute L2 error 423 l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist() 424 l = l + ' L2 error: ' + str(jnp.round(l2_test,6)) 425 if inverse: 426 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 427 #Print 428 print(l) 429 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 430 #Save current parameters 431 pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 432 #Update alive_bar 433 bar() 434 #Define estimated function 435 def u(xt): 436 return forward(xt,params['net']) 437 438 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0} 439 440#Process result 441def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1): 442 """ 443 Process the results of a Physics-informed Neural Network 444 ---------- 445 446 Parameters 447 ---------- 448 test_data : dict 449 450 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 451 452 fit : function 453 454 The fitted function 455 456 train_data : dict 457 458 Training data generated by the jinnax.data.generate_PINNdata 459 460 plot : logical 461 462 Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True 463 464 plot_test : logical 465 466 Whether to plot the test data. Default True 467 468 times : int 469 470 Number of points along the time interval to plot. Default 5 471 472 d2 : logical 473 474 Whether to plot 2D plot when the spatial dimension is one. Default True 475 476 save : logical 477 478 Whether to save the plots. Default False 479 480 show : logical 481 482 Whether to show the plots. Default True 483 484 file_name : str 485 486 File prefix to save the plots. Default 'result_pinn' 487 488 print_res : logical 489 490 Whether to print the L2 error. Default True 491 492 p : int 493 494 Output dimension. Default 1 495 496 Returns 497 ------- 498 pandas data frame with L2 and MSE errors 499 """ 500 501 #Dimension 502 d = test_data['xt'].shape[1] - 1 503 504 #Number of plots multiple of 5 505 times = 5 * round(times/5.0) 506 507 #Data 508 td = get_train_data(train_data) 509 xt_train = td['x'] 510 u_train = td['y'] 511 upred_train = fit(xt_train) 512 upred_test = fit(test_data['xt']) 513 514 #Results 515 l2_error_test = L2error(upred_test,test_data['u']).tolist() 516 MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist() 517 l2_error_train = L2error(upred_train,u_train).tolist() 518 MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist() 519 520 df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)), 521 columns=['l2_error_test','MSE_test','l2_error_train','MSE_train']) 522 if print_res: 523 print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) ) 524 525 #Plots 526 if d == 1 and p ==1 and plot: 527 plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name) 528 elif p == 2 and plot: 529 plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test) 530 531 return df 532 533#Plot results for d = 1 534def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''): 535 """ 536 Plot the prediction of a 1D PINN 537 ---------- 538 539 Parameters 540 ---------- 541 times : int 542 543 Number of points along the time interval to plot. Default 5 544 545 xt : jax.numpy.array 546 547 Test data xt array 548 549 u : jax.numpy.array 550 551 Test data u(x,t) array 552 553 upred : jax.numpy.array 554 555 Predicted upred(x,t) array on test data 556 557 d2 : logical 558 559 Whether to plot 2D plot. Default True 560 561 save : logical 562 563 Whether to save the plots. Default False 564 565 show : logical 566 567 Whether to show the plots. Default True 568 569 file_name : str 570 571 File prefix to save the plots. Default 'result_pinn' 572 573 title_1d : str 574 575 Title of 1D plot 576 577 title_2d : str 578 579 Title of 2D plot 580 581 Returns 582 ------- 583 None 584 """ 585 #Initialize 586 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 587 tlo = jnp.min(xt[:,-1]) 588 tup = jnp.max(xt[:,-1]) 589 ylo = jnp.min(u) 590 ylo = ylo - 0.1*jnp.abs(ylo) 591 yup = jnp.max(u) 592 yup = yup + 0.1*jnp.abs(yup) 593 k = 0 594 t_values = np.linspace(tlo,tup,times) 595 596 #Create 597 for i in range(int(times/5)): 598 for j in range(5): 599 if k < len(t_values): 600 t = t_values[k] 601 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 602 x_plot = xt[xt[:,-1] == t,:-1] 603 y_plot = upred[xt[:,-1] == t,:] 604 u_plot = u[xt[:,-1] == t,:] 605 if int(times/5) > 1: 606 ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 607 ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 608 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 609 ax[i,j].set_xlabel(' ') 610 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 611 else: 612 ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 613 ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 614 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 615 ax[j].set_xlabel(' ') 616 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 617 k = k + 1 618 619 #Title 620 fig.suptitle(title_1d) 621 fig.tight_layout() 622 623 #Show and save 624 fig = plt.gcf() 625 if show: 626 plt.show() 627 if save: 628 fig.savefig(file_name + '_slices.png') 629 plt.close() 630 631 #2d plot 632 if d2: 633 #Initialize 634 fig, ax = plt.subplots(1,2) 635 l1 = jnp.unique(xt[:,-1]).shape[0] 636 l2 = jnp.unique(xt[:,0]).shape[0] 637 638 #Create 639 ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 640 ax[0].set_title('Exact') 641 ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 642 ax[1].set_title('Predicted') 643 644 #Title 645 fig.suptitle(title_2d) 646 fig.tight_layout() 647 648 #Show and save 649 fig = plt.gcf() 650 if show: 651 plt.show() 652 if save: 653 fig.savefig(file_name + '_2d.png') 654 plt.close() 655 656#Plot results for d = 1 657def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True): 658 """ 659 Plot the prediction of a PINN with 2D output 660 ---------- 661 Parameters 662 ---------- 663 times : int 664 665 Number of points along the time interval to plot. Default 5 666 667 xt : jax.numpy.array 668 669 Test data xt array 670 671 u : jax.numpy.array 672 673 Test data u(x,t) array 674 675 upred : jax.numpy.array 676 677 Predicted upred(x,t) array on test data 678 679 save : logical 680 681 Whether to save the plots. Default False 682 683 show : logical 684 685 Whether to show the plots. Default True 686 687 file_name : str 688 689 File prefix to save the plots. Default 'result_pinn' 690 691 title : str 692 693 Title of plot 694 695 plot_test : logical 696 697 Whether to plot the test data. Default True 698 699 Returns 700 ------- 701 None 702 """ 703 #Initialize 704 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 705 tlo = jnp.min(xt[:,-1]) 706 tup = jnp.max(xt[:,-1]) 707 xlo = jnp.min(u[:,0]) 708 xlo = xlo - 0.1*jnp.abs(xlo) 709 xup = jnp.max(u[:,0]) 710 xup = xup + 0.1*jnp.abs(xup) 711 ylo = jnp.min(u[:,1]) 712 ylo = ylo - 0.1*jnp.abs(ylo) 713 yup = jnp.max(u[:,1]) 714 yup = yup + 0.1*jnp.abs(yup) 715 k = 0 716 t_values = np.linspace(tlo,tup,times) 717 718 #Create 719 for i in range(int(times/5)): 720 for j in range(5): 721 if k < len(t_values): 722 t = t_values[k] 723 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 724 xpred_plot = upred[xt[:,-1] == t,0] 725 ypred_plot = upred[xt[:,-1] == t,1] 726 if plot_test: 727 x_plot = u[xt[:,-1] == t,0] 728 y_plot = u[xt[:,-1] == t,1] 729 if int(times/5) > 1: 730 if plot_test: 731 ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 732 ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction') 733 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 734 ax[i,j].set_xlabel(' ') 735 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 736 else: 737 if plot_test: 738 ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 739 ax[j].plot(xpred_plot,ypred,'r-',linewidth=2,label='Prediction') 740 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 741 ax[j].set_xlabel(' ') 742 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 743 k = k + 1 744 745 #Title 746 fig.suptitle(title) 747 fig.tight_layout() 748 749 #Show and save 750 fig = plt.gcf() 751 if show: 752 plt.show() 753 if save: 754 fig.savefig(file_name + '_slices.png') 755 plt.close() 756 757#Get train data in one array 758def get_train_data(train_data): 759 """ 760 Process training sample 761 ---------- 762 763 Parameters 764 ---------- 765 train_data : dict 766 767 A dictionay with train data generated by the jinnax.data.generate_PINNdata function 768 769 Returns 770 ------- 771 dict with the processed training data 772 """ 773 xdata = None 774 ydata = None 775 xydata = None 776 if train_data['sensor'] is not None: 777 sensor_sample = train_data['sensor'].shape[0] 778 xdata = train_data['sensor'] 779 ydata = train_data['usensor'] 780 xydata = jnp.column_stack((train_data['sensor'],train_data['usensor'])) 781 else: 782 sensor_sample = 0 783 if train_data['boundary'] is not None: 784 boundary_sample = train_data['boundary'].shape[0] 785 if xdata is not None: 786 xdata = jnp.vstack((xdata,train_data['boundary'])) 787 ydata = jnp.vstack((ydata,train_data['uboundary'])) 788 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary'])))) 789 else: 790 xdata = train_data['boundary'] 791 ydata = train_data['uboundary'] 792 xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary'])) 793 else: 794 boundary_sample = 0 795 if train_data['initial'] is not None: 796 initial_sample = train_data['initial'].shape[0] 797 if xdata is not None: 798 xdata = jnp.vstack((xdata,train_data['initial'])) 799 ydata = jnp.vstack((ydata,train_data['uinitial'])) 800 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial'])))) 801 else: 802 xdata = train_data['initial'] 803 ydata = train_data['uinitial'] 804 xydata = jnp.column_stack((train_data['initial'],train_data['uinitial'])) 805 else: 806 initial_sample = 0 807 if train_data['collocation'] is not None: 808 collocation_sample = train_data['collocation'].shape[0] 809 else: 810 collocation_sample = 0 811 812 return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample} 813 814#Process training 815def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1): 816 """ 817 Process the training of a Physics-informed Neural Network 818 ---------- 819 820 Parameters 821 ---------- 822 test_data : dict 823 824 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 825 826 file_name : str 827 828 Name of the files saved during training 829 830 at_each : int 831 832 Compute results for epochs multiple of at_each. Default 100 833 834 bolstering : logical 835 836 Whether to compute bolstering mean square error. Default True 837 838 mc_sample : int 839 840 Number of sample for Monte Carlo integration in bolstering. Default 10000 841 842 save : logical 843 844 Whether to save the training results. Default False 845 846 file_name_save : str 847 848 File prefix to save the plots and the L2 error. Default 'result_pinn' 849 850 key : int 851 852 Key for random samples in bolstering. Default 0 853 854 ec : float 855 856 Stopping criteria error for EM algorithm in bolstering. Default 1e-6 857 858 lamb : float 859 860 Hyperparameter of EM algorithm in bolstering. Default 1 861 862 Returns 863 ------- 864 pandas data frame with training results 865 """ 866 #Config 867 config = pickle.load(open(file_name + '_config.pickle', 'rb')) 868 epochs = config['epochs'] 869 train_data = config['train_data'] 870 forward = config['forward'] 871 872 #Get train data 873 td = get_train_data(train_data) 874 xydata = td['xy'] 875 xdata = td['x'] 876 ydata = td['y'] 877 sensor_sample = td['sensor_sample'] 878 boundary_sample = td['boundary_sample'] 879 initial_sample = td['initial_sample'] 880 collocation_sample = td['collocation_sample'] 881 882 #Generate keys 883 if bolstering: 884 keys = jax.random.split(jax.random.PRNGKey(key),epochs) 885 886 #Initialize loss 887 train_mse = [] 888 test_mse = [] 889 train_L2 = [] 890 test_L2 = [] 891 bolstX = [] 892 bolstXY = [] 893 loss = [] 894 time = [] 895 ep = [] 896 897 #Process training 898 with alive_bar(epochs) as bar: 899 for e in range(epochs): 900 if (e % at_each == 0 and at_each != epochs) or e == epochs - 1: 901 ep = ep + [e] 902 903 #Read parameters 904 params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb')) 905 906 #Time 907 time = time + [params['time']] 908 909 #Define learned function 910 def psi(x): 911 return forward(x,params['params']['net']) 912 913 #Train MSE and L2 914 if xdata is not None: 915 train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()] 916 train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()] 917 else: 918 train_mse = train_mse + [None] 919 train_L2 = train_L2 + [None] 920 921 #Test MSE and L2 922 test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()] 923 test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()] 924 925 #Bolstering 926 if bolstering: 927 bX = [] 928 bXY = [] 929 for method in ['chi','mm','mpe']: 930 kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 931 kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 932 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 933 bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()] 934 for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]: 935 kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias) 936 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 937 bolstX = bolstX + [bX] 938 bolstXY = bolstXY + [bXY] 939 else: 940 bolstX = bolstX + [None] 941 bolstXY = bolstXY + [None] 942 943 #Loss 944 loss = loss + [params['loss'].tolist()] 945 946 #Delete 947 del params, psi 948 #Update alive_bar 949 bar() 950 951 #Bolstering results 952 if bolstering: 953 bolstX = jnp.array(bolstX) 954 bolstXY = jnp.array(bolstXY) 955 956 #Create data frame 957 if bolstering: 958 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 959 train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]), 960 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4']) 961 else: 962 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 963 train_mse,test_mse,train_L2,test_L2]), 964 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2']) 965 if save: 966 df.to_csv(file_name_save + '.csv',index = False) 967 968 return df 969 970#Demo video for training1D PINN 971def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10): 972 """ 973 Demo video with the training of a 1D PINN 974 ---------- 975 976 Parameters 977 ---------- 978 test_data : dict 979 980 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 981 982 file_name : str 983 984 Name of the files saved during training 985 986 at_each : int 987 988 Compute results for epochs multiple of at_each. Default 100 989 990 times : int 991 992 Number of points along the time interval to plot. Default 5 993 994 d2 : logical 995 996 Whether to make video demo of 2D plot. Default True 997 998 file_name_save : str 999 1000 File prefix to save the plots and videos. Default 'result_pinn_demo' 1001 1002 title : str 1003 1004 Title for plots 1005 1006 framerate : int 1007 1008 Framerate for video. Default 10 1009 1010 Returns 1011 ------- 1012 None 1013 """ 1014 #Config 1015 with open(file_name + '_config.pickle', 'rb') as file: 1016 config = pickle.load(file) 1017 epochs = config['epochs'] 1018 train_data = config['train_data'] 1019 forward = config['forward'] 1020 1021 #Get train data 1022 td = get_train_data(train_data) 1023 xt = td['x'] 1024 u = td['y'] 1025 1026 #Create folder to save plots 1027 os.system('mkdir ' + file_name_save) 1028 1029 #Create images 1030 k = 1 1031 with alive_bar(epochs) as bar: 1032 for e in range(epochs): 1033 if e % at_each == 0 or e == epochs - 1: 1034 #Read parameters 1035 params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1036 1037 #Define learned function 1038 def psi(x): 1039 return forward(x,params['params']['net']) 1040 1041 #Compute L2 train, L2 test and loss 1042 loss = params['loss'] 1043 L2_train = L2error(psi(xt),u) 1044 L2_test = L2error(psi(test_data['xt']),test_data['u']) 1045 title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6)) 1046 1047 #Save plot 1048 plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch) 1049 k = k + 1 1050 1051 #Delete 1052 del params, psi, loss, L2_train, L2_test, title_epoch 1053 #Update alive_bar 1054 bar() 1055 #Create demo video 1056 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4') 1057 if d2: 1058 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4') 1059 1060#Demo in time for 1D PINN 1061def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10): 1062 """ 1063 Demo video with the time evolution of a 1D PINN 1064 ---------- 1065 1066 Parameters 1067 ---------- 1068 test_data : dict 1069 1070 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1071 1072 file_name : str 1073 1074 Name of the files saved during training 1075 1076 epochs : list 1077 1078 Which training epochs to plot 1079 1080 file_name_save : str 1081 1082 File prefix to save the plots and video. Default 'result_pinn_time_demo' 1083 1084 title : str 1085 1086 Title for plots 1087 1088 framerate : int 1089 1090 Framerate for video. Default 10 1091 1092 Returns 1093 ------- 1094 None 1095 """ 1096 #Config 1097 with open(file_name + '_config.pickle', 'rb') as file: 1098 config = pickle.load(file) 1099 train_data = config['train_data'] 1100 forward = config['forward'] 1101 1102 #Create folder to save plots 1103 os.system('mkdir ' + file_name_save) 1104 1105 #Plot parameters 1106 tdom = jnp.unique(test_data['xt'][:,-1]) 1107 ylo = jnp.min(test_data['u']) 1108 ylo = ylo - 0.1*jnp.abs(ylo) 1109 yup = jnp.max(test_data['u']) 1110 yup = yup + 0.1*jnp.abs(yup) 1111 1112 #Open PINN for each epoch 1113 results = [] 1114 upred = [] 1115 for e in epochs: 1116 tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1117 results = results + [tmp] 1118 upred = upred + [forward(test_data['xt'],tmp['params']['net'])] 1119 1120 #Create images 1121 k = 1 1122 with alive_bar(len(tdom)) as bar: 1123 for t in tdom: 1124 #Test data 1125 xt_step = test_data['xt'][test_data['xt'][:,-1] == t] 1126 u_step = test_data['u'][test_data['xt'][:,-1] == t] 1127 #Initialize plot 1128 if len(epochs) > 1: 1129 fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2)) 1130 else: 1131 fig, ax = plt.subplots(1,1,figsize = (10,5)) 1132 #Create 1133 index = 0 1134 if int(len(epochs)/2) > 1: 1135 for i in range(int(len(epochs)/2)): 1136 for j in range(min(2,len(epochs))): 1137 upred_step = upred[index][test_data['xt'][:,-1] == t] 1138 ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 1139 ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 1140 ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1141 ax[i,j].set_xlabel(' ') 1142 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1143 index = index + 1 1144 elif len(epochs) > 1: 1145 for j in range(2): 1146 upred_step = upred[index][test_data['xt'][:,-1] == t] 1147 ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 1148 ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 1149 ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1150 ax[j].set_xlabel(' ') 1151 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1152 index = index + 1 1153 else: 1154 upred_step = upred[index][test_data['xt'][:,-1] == t] 1155 ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 1156 ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 1157 ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1158 ax.set_xlabel(' ') 1159 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1160 index = index + 1 1161 1162 1163 #Title 1164 fig.suptitle(title + 't = ' + str(round(t,4))) 1165 fig.tight_layout() 1166 1167 #Show and save 1168 fig = plt.gcf() 1169 fig.savefig(file_name_save + '/' + str(k) + '.png') 1170 k = k + 1 1171 plt.close() 1172 bar() 1173 1174 #Create demo video 1175 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4') 1176 1177#Demo in time for 1D PINN 1178def demo_time_pinn2D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10,ffmpeg = 'ffmpeg'): 1179 """ 1180 Demo video with the time evolution of a 2D PINN 1181 ---------- 1182 Parameters 1183 ---------- 1184 test_data : dict 1185 1186 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1187 1188 file_name : str 1189 1190 Name of the files saved during training 1191 1192 epochs : list 1193 1194 Which training epochs to plot 1195 1196 file_name_save : str 1197 1198 File prefix to save the plots and video. Default 'result_pinn_time_demo' 1199 1200 title : str 1201 1202 Title for plots 1203 1204 framerate : int 1205 1206 Framerate for video. Default 10 1207 1208 ffmpeg : str 1209 1210 Path to ffmpeg 1211 1212 Returns 1213 ------- 1214 None 1215 """ 1216 #Config 1217 with open(file_name + '_config.pickle', 'rb') as file: 1218 config = pickle.load(file) 1219 train_data = config['train_data'] 1220 forward = config['forward'] 1221 1222 #Create folder to save plots 1223 os.system('mkdir ' + file_name_save) 1224 1225 #Plot parameters 1226 tdom = jnp.unique(test_data['xt'][:,-1]) 1227 ylo = jnp.min(test_data['u']) 1228 ylo = ylo - 0.1*jnp.abs(ylo) 1229 yup = jnp.max(test_data['u']) 1230 yup = yup + 0.1*jnp.abs(yup) 1231 1232 #Open PINN for each epoch 1233 results = [] 1234 upred = [] 1235 for e in epochs: 1236 tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1237 results = results + [tmp] 1238 upred = upred + [forward(test_data['xt'],tmp['params']['net'])] 1239 1240 #Create images 1241 k = 1 1242 with alive_bar(len(tdom)) as bar: 1243 for t in tdom: 1244 #Test data 1245 xt_step = test_data['xt'][test_data['xt'][:,-1] == t] 1246 ux_step = test_data['u'][test_data['xt'][:,-1] == t,0] 1247 uy_step = test_data['u'][test_data['xt'][:,-1] == t,1] 1248 #Initialize plot 1249 if len(epochs) > 1: 1250 fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2)) 1251 else: 1252 fig, ax = plt.subplots(1,1,figsize = (10,5)) 1253 #Create 1254 index = 0 1255 if int(len(epochs)/2) > 1: 1256 for i in range(int(len(epochs)/2)): 1257 for j in range(min(2,len(epochs))): 1258 upredx_step = upred[index][test_data['xt'][:,-1] == t,0] 1259 upredy_step = upred[index][test_data['xt'][:,-1] == t,1] 1260 ax[i,j].plot(ux_step,uy_step,'b-',linewidth=2,label='Exact') 1261 ax[i,j].plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction') 1262 ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1263 ax[i,j].set_xlabel(' ') 1264 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1265 index = index + 1 1266 elif len(epochs) > 1: 1267 for j in range(2): 1268 upredx_step = upred[index][test_data['xt'][:,-1] == t,0] 1269 upredy_step = upred[index][test_data['xt'][:,-1] == t,1] 1270 ax[j].plot(ux_step,uy_step,'b-',linewidth=2,label='Exact') 1271 ax[j].plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction') 1272 ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1273 ax[j].set_xlabel(' ') 1274 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1275 index = index + 1 1276 else: 1277 upredx_step = upred[index][test_data['xt'][:,-1] == t,0] 1278 upredy_step = upred[index][test_data['xt'][:,-1] == t,1] 1279 ax.plot(ux_step,uy_step,'b-',linewidth=2,label='Exact') 1280 ax.plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction') 1281 ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1282 ax.set_xlabel(' ') 1283 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1284 index = index + 1 1285 1286 1287 #Title 1288 fig.suptitle(title + 't = ' + str(round(t,4))) 1289 fig.tight_layout() 1290 1291 #Show and save 1292 fig = plt.gcf() 1293 fig.savefig(file_name_save + '/' + str(k) + '.png') 1294 k = k + 1 1295 plt.close() 1296 bar() 1297 1298 #Create demo video 1299 os.system(ffmpeg + ' -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4') 1300 1301def DN_CSF_circle(uinitial,xl,xu,tl,tu,width,radius,Ntb = 100,N0 = 100,Nc = 50,Ntc = 50,Ns = 100,Nts = 100,epochs = 100,at_each = 10,activation = 'tanh',sa = True,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,demo = True,framerate = 2,ffmpeg = 'ffmpeg',c = 1e-6): 1302 #If demo, then save 1303 if demo: 1304 save = True 1305 1306 #Define initial function and function to evaluate at boundary 1307 def uinit(x,t): 1308 u = uinitial(x,t) 1309 return jnp.append(u['u1'],u['u2']) 1310 1311 def ubound(x,t): 1312 u = uinitial(x,t) 1313 return jnp.append(u['u1'],u['u2'],1) 1314 1315 1316 #PDE operator 1317 def pde(u,x,t): 1318 #One function for each coordinate (assuming that x and t has dimension 1 x 1 and u(x,t) has dimension 1 x 2) 1319 u1 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,0][0] 1320 u2 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,1][0] 1321 #First derivatives of each coordinate 1322 ux1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t)) 1323 ux2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t)) 1324 ut1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),1)(x,t)) 1325 ut2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),1)(x,t)) 1326 #Second derivative of each coordinate 1327 ux1_tmp = lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t) 1328 ux2_tmp = lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t) 1329 uxx1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : ux1_tmp(x,t)[0],0)(x,t)) 1330 uxx2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : ux2_tmp(x,t)[0],0)(x,t)) 1331 #Return 1332 return jnp.sqrt((ut1(x,t) - uxx1(x,t)/(ux1(x,t) ** 2 + ux2(x,t) ** 2 + c)) ** 2 + (ut2(x,t) - uxx2(x,t)/(ux1(x,t) ** 2 + ux2(x,t) ** 2 + c)) ** 2) 1333 1334 #Operator to evaluate boundary conditions 1335 def oper_boundary(u,x,t,w = 1,Ntb = Ntb): 1336 #Enforce Dirichlet at the right boundary (fixed at point a, as the initial condition) 1337 res_right_dir = jnp.sum(jnp.where(x == xu,(u(x,t) - ubound(x,t)) ** 2,0),1).reshape(x.shape[0],1) 1338 #Enforce Dirichlet at the left boundary (is in the circle of radius fixed) 1339 res_left_dir = jnp.sum(jnp.where(x == xl,((jnp.sum(u(x,t) ** 2,1) - radius ** 2) ** 2).reshape(x.shape),0),1).reshape(x.shape[0],1) 1340 #One function for each coordinate (assuming that x and t has dimension 1 x 1 and u(x,t) has dimension 1 x 2) 1341 u1 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,0][0] 1342 u2 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,1][0] 1343 #Take the derivatives in x 1344 ux1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t))(x,t) 1345 ux2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t))(x,t) 1346 #Enforce Neumann at the left boundary 1347 nS = u(x,t)/jnp.sqrt(jnp.sum(u(x,t) ** 2,0)) #Assuming that u(x,y) \in S, compute the vector normal to S at u(x,t) 1348 nu = jnp.append(ux2,(-1)*ux1,1)/jnp.sqrt(ux1 ** 2 + ux2 ** 2) 1349 ip = jnp.sum(nS * nu,1).reshape(x.shape[0],1) ** 2 1350 res_left_neu = jnp.where(x == xl,ip,0) 1351 #Rearrange 1352 res = jnp.append(jnp.append(res_right_dir[:Ntb,:],res_left_dir[Ntb:2*Ntb,:],0),res_left_neu[2*Ntb:,:],0) 1353 return w*res 1354 1355 #Generate Data 1356 train_data = jd.generate_PINNdata(u = uinit,xl = xl,xu = xu,tl = tl,tu = tu,Ns = None,Nts = None,Nb = 2,Ntb = Ntb,N0 = N0,Nc = Nc,Ntc = Ntc,p = 2,poss = 'random',posts = 'random',pos0 = 'random',postb = 'random',posc = 'random',postc = 'random') 1357 1358 #Rearange boundary data 1359 train_data['boundary'] = jnp.append(jnp.append(train_data['boundary'][Ntb:,:],train_data['boundary'][:Ntb,:],0),train_data['boundary'][:Ntb,:],0) 1360 train_data['uboundary'] = jnp.append(jnp.append(train_data['uboundary'][Ntb:,:],train_data['uboundary'][:Ntb,:],0),train_data['uboundary'][:Ntb,:],0) 1361 1362 #Train PINN 1363 fit = train_PINN(train_data,width,pde,c = {'ws': 1,'wr': 1,'w0': 1,'wb': 1},test_data = None,epochs = epochs,at_each = at_each,activation = activation,neumann = True,oper_neumann = oper_boundary,sa = sa,lr = lr,b1 = b1,b2 = b2,eps = eps,eps_root = eps_root,key = key,epoch_print = epoch_print,save = save,file_name = file_name,exp_decay = exp_decay,transition_steps = transition_steps,decay_rate = decay_rate) 1364 1365 #Test data 1366 test_data = jd.generate_PINNdata(u = uinit,xl = xl,xu = xu,tl = tl,tu = tu,Ns = None,Nts = None,Nb = 2,Ntb = 2*Ntb,N0 = 2*N0,Nc = 2*Nc,Ntc = 2*Ntc,p = 2,poss = 'random',posts = 'random',pos0 = 'random',postb = 'random',posc = 'random',postc = 'random') 1367 Ntb = 2*Ntb 1368 test_data['boundary'] = jnp.append(jnp.append(test_data['boundary'][Ntb:,:],test_data['boundary'][:Ntb,:],0),test_data['boundary'][:Ntb,:],0) 1369 test_data['uboundary'] = jnp.append(jnp.append(test_data['uboundary'][Ntb:,:],test_data['uboundary'][:Ntb,:],0),test_data['uboundary'][:Ntb,:],0) 1370 1371 #Evaluate residuals 1372 def u(x,t): 1373 return fit['u'](jnp.append(x,t,1)) 1374 1375 res_pde = jnp.mean(pde(u,test_data['collocation'][:,0].reshape((test_data['collocation'].shape[0],1)),test_data['collocation'][:,1].reshape((test_data['collocation'].shape[0],1))) ** 2) 1376 res_DN = oper_boundary(u,test_data['boundary'][:,0].reshape((test_data['boundary'].shape[0],1)),test_data['boundary'][:,1].reshape((test_data['boundary'].shape[0],1)),Ntb = Ntb) 1377 res_dir_right = jnp.mean(res_DN[:Ntb,:] ** 2) 1378 res_neu = jnp.mean(res_DN[2*Ntb:,:] ** 2) 1379 res_dir_left = jnp.mean(res_DN[Ntb:2*Ntb,:] ** 2) 1380 res_initial = jnp.mean((u(test_data['initial'][:,0].reshape((test_data['initial'].shape[0],1)),test_data['initial'][:,1].reshape((test_data['initial'].shape[0],1))) - test_data['uinitial']) ** 2) 1381 1382 #Save file 1383 res_data = pd.DataFrame({'PDE': [res_pde.tolist()], 1384 'Dirichlet_Right': [res_dir_right.tolist()], 1385 'Dirichlet_Left': [res_dir_left.tolist()], 1386 'Neumann': [res_neu.tolist()], 1387 'initial': res_initial, 1388 'time': fit['time'], 1389 'epochs': epochs}) 1390 res_data.to_csv(file_name + '_residuals.csv') 1391 1392 if demo: 1393 def ucircle(x,t): 1394 y = 2*jnp.pi*(x - xl)/(xu - xl) 1395 return jnp.append(radius*jnp.sin(y),radius*jnp.cos(y),0) 1396 test_data = jd.generate_PINNdata(u = ucircle,xl = xl,xu = xu,tl = tl,tu = tu,Ns = Ns,Nts = Nts,Nb = 0,Ntb = 0,N0 = 0,Nc = 0,Ntc = 0,p = 2,train = False) 1397 demo_time_pinn2D(test_data,file_name,[epochs-1],file_name_save = file_name + '_demo',title = '',framerate = framerate,ffmpeg = ffmpeg) 1398 1399 return fit,res_data
@jax.jit
def
MSE(pred, true):
22@jax.jit 23def MSE(pred,true): 24 """ 25 Mean square error 26 ---------- 27 28 Parameters 29 ---------- 30 pred : jax.numpy.array 31 32 A JAX numpy array with the predicted values 33 34 true : jax.numpy.array 35 36 A JAX numpy array with the true values 37 38 Returns 39 ------- 40 mean square error 41 """ 42 return (true - pred) ** 2
Mean square error
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
Returns
- mean square error
@jax.jit
def
MSE_SA(pred, true, w):
45@jax.jit 46def MSE_SA(pred,true,w): 47 """ 48 Selft-adaptative mean square error 49 ---------- 50 51 Parameters 52 ---------- 53 pred : jax.numpy.array 54 55 A JAX numpy array with the predicted values 56 57 true : jax.numpy.array 58 59 A JAX numpy array with the true values 60 61 wheight : jax.numpy.array 62 63 A JAX numpy array with the weights 64 65 c : float 66 67 Hyperparameter 68 69 Returns 70 ------- 71 self-adaptative mean square error with sigmoid mask 72 """ 73 return (w * (true - pred)) ** 2
Selft-adaptative mean square error
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
- wheight (jax.numpy.array): A JAX numpy array with the weights
- c (float): Hyperparameter
Returns
- self-adaptative mean square error with sigmoid mask
@jax.jit
def
L2error(pred, true):
76@jax.jit 77def L2error(pred,true): 78 """ 79 L2-error 80 ---------- 81 82 Parameters 83 ---------- 84 pred : jax.numpy.array 85 86 A JAX numpy array with the predicted values 87 88 true : jax.numpy.array 89 90 A JAX numpy array with the true values 91 92 Returns 93 ------- 94 L2-error 95 """ 96 return jnp.sqrt(jnp.sum((true - pred)**2))/jnp.sqrt(jnp.sum(true ** 2))
L2-error
Parameters
- pred (jax.numpy.array): A JAX numpy array with the predicted values
- true (jax.numpy.array): A JAX numpy array with the true values
Returns
- L2-error
def
fconNN(width, activation=<PjitFunction of <function tanh>>, key=0):
99def fconNN(width,activation = jax.nn.tanh,key = 0): 100 """ 101 Initialize fully connected neural network 102 ---------- 103 104 Parameters 105 ---------- 106 width : list 107 108 List with the layers width 109 110 activation : jax.nn activation 111 112 The activation function. Default jax.nn.tanh 113 114 key : int 115 116 Seed for parameters initialization. Default 0 117 118 Returns 119 ------- 120 dict with initial parameters and the function for the forward pass 121 """ 122 #Initialize parameters with Glorot initialization 123 initializer = jax.nn.initializers.glorot_normal() 124 key = jax.random.split(jax.random.PRNGKey(key),len(width)-1) #Seed for initialization 125 params = list() 126 for key,lin,lout in zip(key,width[:-1],width[1:]): 127 W = initializer(key,(lin,lout),jnp.float32) 128 B = initializer(key,(1,lout),jnp.float32) 129 params.append({'W':W,'B':B}) 130 131 #Define function for forward pass 132 @jax.jit 133 def forward(x,params): 134 *hidden,output = params 135 for layer in hidden: 136 x = activation(x @ layer['W'] + layer['B']) 137 return x @ output['W'] + output['B'] 138 139 #Return initial parameters and forward function 140 return {'params': params,'forward': forward}
Initialize fully connected neural network
Parameters
- width (list): List with the layers width
- activation (jax.nn activation): The activation function. Default jax.nn.tanh
- key (int): Seed for parameters initialization. Default 0
Returns
- dict with initial parameters and the function for the forward pass
def
get_activation(act):
143def get_activation(act): 144 """ 145 Return activation function from string 146 ---------- 147 148 Parameters 149 ---------- 150 act : str 151 152 Name of the activation function. Default 'tanh' 153 154 Returns 155 ------- 156 jax.nn activation function 157 """ 158 if act == 'tanh': 159 return jax.nn.tanh 160 elif act == 'relu': 161 return jax.nn.relu 162 elif act == 'relu6': 163 return jax.nn.relu6 164 elif act == 'sigmoid': 165 return jax.nn.sigmoid 166 elif act == 'softplus': 167 return jax.nn.softplus 168 elif act == 'sparse_plus': 169 return jx.nn.sparse_plus 170 elif act == 'soft_sign': 171 return jax.nn.soft_sign 172 elif act == 'silu': 173 return jax.nn.silu 174 elif act == 'swish': 175 return jax.nn.swish 176 elif act == 'log_sigmoid': 177 return jax.nn.log_sigmoid 178 elif act == 'leaky_relu': 179 return jax.xx.leaky_relu 180 elif act == 'hard_sigmoid': 181 return jax.nn.hard_sigmoid 182 elif act == 'hard_silu': 183 return jax.nn.hard_silu 184 elif act == 'hard_swish': 185 return jax.nn.hard_swish 186 elif act == 'hard_tanh': 187 return jax.nn.hard_tanh 188 elif act == 'elu': 189 return jax.nn.elu 190 elif act == 'celu': 191 return jax.nn.celu 192 elif act == 'selu': 193 return jax.nn.selu 194 elif act == 'gelu': 195 return jax.nn.gelu 196 elif act == 'glu': 197 return jax.nn.glu 198 elif act == 'squareplus': 199 return jax.nn.squareplus 200 elif act == 'mish': 201 return jax.nn.mish
Return activation function from string
Parameters
- act (str): Name of the activation function. Default 'tanh'
Returns
- jax.nn activation function
def
train_PINN( data, width, pde, test_data=None, epochs=100, at_each=10, activation='tanh', neumann=False, oper_neumann=False, sa=False, c={'ws': 1, 'wr': 1, 'w0': 100, 'wb': 1}, inverse=False, initial_par=None, lr=0.001, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, key=0, epoch_print=100, save=False, file_name='result_pinn', exp_decay=False, transition_steps=1000, decay_rate=0.9):
204def train_PINN(data,width,pde,test_data = None,epochs = 100,at_each = 10,activation = 'tanh',neumann = False,oper_neumann = False,sa = False,c = {'ws': 1,'wr': 1,'w0': 100,'wb': 1},inverse = False,initial_par = None,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9): 205 """ 206 Train a Physics-informed Neural Network 207 ---------- 208 209 Parameters 210 ---------- 211 data : dict 212 213 Data generated by the jinnax.data.generate_PINNdata function 214 215 width : list 216 217 A list with the width of each layer 218 219 pde : function 220 221 The partial differential operator. Its arguments are u, x and t 222 223 test_data : dict, None 224 225 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error 226 227 epochs : int 228 229 Number of training epochs. Default 100 230 231 at_each : int 232 233 Save results for epochs multiple of at_each. Default 10 234 235 activation : str 236 237 The name of the activation function of the neural network. Default 'tanh' 238 239 neumann : logical 240 241 Whether to consider Neumann boundary conditions 242 243 oper_neumann : function 244 245 Penalization of Neumann boundary conditions 246 247 sa : logical 248 249 Whether to consider self-adaptative PINN 250 251 c : dict 252 253 Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1 254 255 inverse : logical 256 257 Whether to estimate parameters of the PDE 258 259 initial_par : jax.numpy.array 260 261 Initial value of the parameters of the PDE in an inverse problem 262 263 lr,b1,b2,eps,eps_root: float 264 265 Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0 266 267 key : int 268 269 Seed for parameters initialization. Default 0 270 271 epoch_print : int 272 273 Number of epochs to calculate and print test errors. Default 100 274 275 save : logical 276 277 Whether to save the current parameters. Default False 278 279 file_name : str 280 281 File prefix to save the current parameters. Default 'result_pinn' 282 283 exp_decay : logical 284 285 Whether to consider exponential decay of learning rate. Default False 286 287 transition_steps : int 288 289 Number of steps for exponential decay. Default 1000 290 291 decay_rate : float 292 293 Rate of exponential decay. Default 0.9 294 295 Returns 296 ------- 297 dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time 298 """ 299 300 #Initialize architecture 301 nnet = fconNN(width,get_activation(activation),key) 302 forward = nnet['forward'] 303 304 #Initialize self adaptative weights 305 par_sa = {} 306 if sa: 307 #Initialize wheights close to zero 308 ksa = jax.random.randint(jax.random.PRNGKey(key),(5,),1,1000000) 309 if data['sensor'] is not None: 310 par_sa.update({'ws': c['ws'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[0]),shape = (data['sensor'].shape[0],1))}) 311 if data['initial'] is not None: 312 par_sa.update({'w0': c['w0'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[1]),shape = (data['initial'].shape[0],1))}) 313 if data['collocation'] is not None: 314 par_sa.update({'wr': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[2]),shape = (data['collocation'].shape[0],1))}) 315 if data['boundary'] is not None: 316 par_sa.update({'wb': c['wr'] * jax.random.uniform(key = jax.random.PRNGKey(ksa[3]),shape = (data['boundary'].shape[0],1))}) 317 318 #Store all parameters 319 params = {'net': nnet['params'],'inverse': initial_par,'sa': par_sa} 320 321 #Save config file 322 if save: 323 pickle.dump({'train_data': data,'epochs': epochs,'activation': activation,'init_params': params,'forward': forward,'width': width,'pde': pde,'lr': lr,'b1': b1,'b2': b2,'eps': eps,'eps_root': eps_root,'key': key,'inverse': inverse,'sa': sa},open(file_name + '_config.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 324 325 #Define loss function 326 if sa: 327 #Define loss function 328 @jax.jit 329 def lf(params,x): 330 loss = 0 331 if x['sensor'] is not None: 332 #Term that refers to sensor data 333 loss = loss + jnp.mean(MSE_SA(forward(x['sensor'],params['net']),x['usensor'],params['sa']['ws'])) 334 if x['boundary'] is not None: 335 if neumann: 336 #Neumann coditions 337 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 338 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 339 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb,params['sa']['wb'])) 340 else: 341 #Term that refers to boundary data 342 loss = loss + jnp.mean(MSE_SA(forward(x['boundary'],params['net']),x['uboundary'],params['sa']['wb'])) 343 if x['initial'] is not None: 344 #Term that refers to initial data 345 loss = loss + jnp.mean(MSE_SA(forward(x['initial'],params['net']),x['uinitial'],params['sa']['w0'])) 346 if x['collocation'] is not None: 347 #Term that refers to collocation points 348 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 349 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 350 if inverse: 351 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0,params['sa']['wr'])) 352 else: 353 loss = loss + jnp.mean(MSE_SA(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0,params['sa']['wr'])) 354 return loss 355 else: 356 @jax.jit 357 def lf(params,x): 358 loss = 0 359 if x['sensor'] is not None: 360 #Term that refers to sensor data 361 loss = loss + jnp.mean(MSE(forward(x['sensor'],params['net']),x['usensor'])) 362 if x['boundary'] is not None: 363 if neumann: 364 #Neumann coditions 365 xb = x['boundary'][:,:-1].reshape((x['boundary'].shape[0],x['boundary'].shape[1] - 1)) 366 tb = x['boundary'][:,-1].reshape((x['boundary'].shape[0],1)) 367 loss = loss + jnp.mean(oper_neumann(lambda x,t: forward(jnp.append(x,t,1),params['net']),xb,tb)) 368 else: 369 #Term that refers to boundary data 370 loss = loss + jnp.mean(MSE(forward(x['boundary'],params['net']),x['uboundary'])) 371 if x['initial'] is not None: 372 #Term that refers to initial data 373 loss = loss + jnp.mean(MSE(forward(x['initial'],params['net']),x['uinitial'])) 374 if x['collocation'] is not None: 375 #Term that refers to collocation points 376 x_col = x['collocation'][:,:-1].reshape((x['collocation'].shape[0],x['collocation'].shape[1] - 1)) 377 t_col = x['collocation'][:,-1].reshape((x['collocation'].shape[0],1)) 378 if inverse: 379 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col,params['inverse']),0)) 380 else: 381 loss = loss + jnp.mean(MSE(pde(lambda x,t: forward(jnp.append(x,t,1),params['net']),x_col,t_col),0)) 382 return loss 383 384 #Initialize Adam Optmizer 385 if exp_decay: 386 lr = optax.exponential_decay(lr,transition_steps,decay_rate) 387 optimizer = optax.adam(lr,b1,b2,eps,eps_root) 388 opt_state = optimizer.init(params) 389 390 #Define the gradient function 391 grad_loss = jax.jit(jax.grad(lf,0)) 392 393 #Define update function 394 @jax.jit 395 def update(opt_state,params,x): 396 #Compute gradient 397 grads = grad_loss(params,x) 398 #Invert gradient of self-adaptative wheights 399 if sa: 400 for w in grads['sa']: 401 grads['sa'][w] = - grads['sa'][w] 402 #Calculate parameters updates 403 updates, opt_state = optimizer.update(grads, opt_state) 404 #Update parameters 405 params = optax.apply_updates(params, updates) 406 #Return state of optmizer and updated parameters 407 return opt_state,params 408 409 ###Training### 410 t0 = time.time() 411 #Initialize alive_bar for tracing in terminal 412 with alive_bar(epochs) as bar: 413 #For each epoch 414 for e in range(epochs): 415 #Update optimizer state and parameters 416 opt_state,params = update(opt_state,params,data) 417 #After epoch_print epochs 418 if e % epoch_print == 0: 419 #Compute elapsed time and current error 420 l = 'Time: ' + str(round(time.time() - t0)) + ' s Loss: ' + str(jnp.round(lf(params,data),6)) 421 #If there is test data, compute current L2 error 422 if test_data is not None: 423 #Compute L2 error 424 l2_test = L2error(forward(test_data['xt'],params['net']),test_data['u']).tolist() 425 l = l + ' L2 error: ' + str(jnp.round(l2_test,6)) 426 if inverse: 427 l = l + ' Parameter: ' + str(jnp.round(params['inverse'].tolist(),6)) 428 #Print 429 print(l) 430 if ((e % at_each == 0 and at_each != epochs) or e == epochs - 1) and save: 431 #Save current parameters 432 pickle.dump({'params': params,'width': width,'time': time.time() - t0,'loss': lf(params,data)},open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','wb'), protocol = pickle.HIGHEST_PROTOCOL) 433 #Update alive_bar 434 bar() 435 #Define estimated function 436 def u(xt): 437 return forward(xt,params['net']) 438 439 return {'u': u,'params': params,'forward': forward,'time': time.time() - t0}
Train a Physics-informed Neural Network
Parameters
- data (dict): Data generated by the jinnax.data.generate_PINNdata function
- width (list): A list with the width of each layer
- pde (function): The partial differential operator. Its arguments are u, x and t
- test_data (dict, None): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function. Default None for not calculating L2 error
- epochs (int): Number of training epochs. Default 100
- at_each (int): Save results for epochs multiple of at_each. Default 10
- activation (str): The name of the activation function of the neural network. Default 'tanh'
- neumann (logical): Whether to consider Neumann boundary conditions
- oper_neumann (function): Penalization of Neumann boundary conditions
- sa (logical): Whether to consider self-adaptative PINN
- c (dict): Dictionary with the hyperparameters of the self-adaptative sigmoid mask for the initial (w0), sensor (ws) and collocation (wr) points. The weights of the boundary points is fixed to 1
- inverse (logical): Whether to estimate parameters of the PDE
- initial_par (jax.numpy.array): Initial value of the parameters of the PDE in an inverse problem
- lr,b1,b2,eps,eps_root (float): Hyperparameters of the Adam algorithm. Default lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-08, eps_root = 0.0
- key (int): Seed for parameters initialization. Default 0
- epoch_print (int): Number of epochs to calculate and print test errors. Default 100
- save (logical): Whether to save the current parameters. Default False
- file_name (str): File prefix to save the current parameters. Default 'result_pinn'
- exp_decay (logical): Whether to consider exponential decay of learning rate. Default False
- transition_steps (int): Number of steps for exponential decay. Default 1000
- decay_rate (float): Rate of exponential decay. Default 0.9
Returns
- dict-like object with the estimated function, the estimated parameters, the neural network function for the forward pass and the training time
def
process_result( test_data, fit, train_data, plot=True, plot_test=True, times=5, d2=True, save=False, show=True, file_name='result_pinn', print_res=True, p=1):
442def process_result(test_data,fit,train_data,plot = True,plot_test = True,times = 5,d2 = True,save = False,show = True,file_name = 'result_pinn',print_res = True,p = 1): 443 """ 444 Process the results of a Physics-informed Neural Network 445 ---------- 446 447 Parameters 448 ---------- 449 test_data : dict 450 451 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 452 453 fit : function 454 455 The fitted function 456 457 train_data : dict 458 459 Training data generated by the jinnax.data.generate_PINNdata 460 461 plot : logical 462 463 Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True 464 465 plot_test : logical 466 467 Whether to plot the test data. Default True 468 469 times : int 470 471 Number of points along the time interval to plot. Default 5 472 473 d2 : logical 474 475 Whether to plot 2D plot when the spatial dimension is one. Default True 476 477 save : logical 478 479 Whether to save the plots. Default False 480 481 show : logical 482 483 Whether to show the plots. Default True 484 485 file_name : str 486 487 File prefix to save the plots. Default 'result_pinn' 488 489 print_res : logical 490 491 Whether to print the L2 error. Default True 492 493 p : int 494 495 Output dimension. Default 1 496 497 Returns 498 ------- 499 pandas data frame with L2 and MSE errors 500 """ 501 502 #Dimension 503 d = test_data['xt'].shape[1] - 1 504 505 #Number of plots multiple of 5 506 times = 5 * round(times/5.0) 507 508 #Data 509 td = get_train_data(train_data) 510 xt_train = td['x'] 511 u_train = td['y'] 512 upred_train = fit(xt_train) 513 upred_test = fit(test_data['xt']) 514 515 #Results 516 l2_error_test = L2error(upred_test,test_data['u']).tolist() 517 MSE_test = jnp.mean(MSE(upred_test,test_data['u'])).tolist() 518 l2_error_train = L2error(upred_train,u_train).tolist() 519 MSE_train = jnp.mean(MSE(upred_train,u_train)).tolist() 520 521 df = pd.DataFrame(np.array([l2_error_test,MSE_test,l2_error_train,MSE_train]).reshape((1,4)), 522 columns=['l2_error_test','MSE_test','l2_error_train','MSE_train']) 523 if print_res: 524 print('L2 error test: ' + str(jnp.round(l2_error_test,6)) + ' L2 error train: ' + str(jnp.round(l2_error_train,6)) + ' MSE error test: ' + str(jnp.round(MSE_test,6)) + ' MSE error train: ' + str(jnp.round(MSE_train,6)) ) 525 526 #Plots 527 if d == 1 and p ==1 and plot: 528 plot_pinn1D(times,test_data['xt'],test_data['u'],upred_test,d2,save,show,file_name) 529 elif p == 2 and plot: 530 plot_pinn_out2D(times,test_data['xt'],test_data['u'],upred_test,save,show,file_name,plot_test) 531 532 return df
Process the results of a Physics-informed Neural Network
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- fit (function): The fitted function
- train_data (dict): Training data generated by the jinnax.data.generate_PINNdata
- plot (logical): Whether to generate plots comparing the exact and estimated solutions when the spatial dimension is one. Default True
- plot_test (logical): Whether to plot the test data. Default True
- times (int): Number of points along the time interval to plot. Default 5
- d2 (logical): Whether to plot 2D plot when the spatial dimension is one. Default True
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- print_res (logical): Whether to print the L2 error. Default True
- p (int): Output dimension. Default 1
Returns
- pandas data frame with L2 and MSE errors
def
plot_pinn1D( times, xt, u, upred, d2=True, save=False, show=True, file_name='result_pinn', title_1d='', title_2d=''):
535def plot_pinn1D(times,xt,u,upred,d2 = True,save = False,show = True,file_name = 'result_pinn',title_1d = '',title_2d = ''): 536 """ 537 Plot the prediction of a 1D PINN 538 ---------- 539 540 Parameters 541 ---------- 542 times : int 543 544 Number of points along the time interval to plot. Default 5 545 546 xt : jax.numpy.array 547 548 Test data xt array 549 550 u : jax.numpy.array 551 552 Test data u(x,t) array 553 554 upred : jax.numpy.array 555 556 Predicted upred(x,t) array on test data 557 558 d2 : logical 559 560 Whether to plot 2D plot. Default True 561 562 save : logical 563 564 Whether to save the plots. Default False 565 566 show : logical 567 568 Whether to show the plots. Default True 569 570 file_name : str 571 572 File prefix to save the plots. Default 'result_pinn' 573 574 title_1d : str 575 576 Title of 1D plot 577 578 title_2d : str 579 580 Title of 2D plot 581 582 Returns 583 ------- 584 None 585 """ 586 #Initialize 587 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 588 tlo = jnp.min(xt[:,-1]) 589 tup = jnp.max(xt[:,-1]) 590 ylo = jnp.min(u) 591 ylo = ylo - 0.1*jnp.abs(ylo) 592 yup = jnp.max(u) 593 yup = yup + 0.1*jnp.abs(yup) 594 k = 0 595 t_values = np.linspace(tlo,tup,times) 596 597 #Create 598 for i in range(int(times/5)): 599 for j in range(5): 600 if k < len(t_values): 601 t = t_values[k] 602 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 603 x_plot = xt[xt[:,-1] == t,:-1] 604 y_plot = upred[xt[:,-1] == t,:] 605 u_plot = u[xt[:,-1] == t,:] 606 if int(times/5) > 1: 607 ax[i,j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 608 ax[i,j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 609 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 610 ax[i,j].set_xlabel(' ') 611 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 612 else: 613 ax[j].plot(x_plot[:,0],u_plot[:,0],'b-',linewidth=2,label='Exact') 614 ax[j].plot(x_plot[:,0],y_plot,'r--',linewidth=2,label='Prediction') 615 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 616 ax[j].set_xlabel(' ') 617 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 618 k = k + 1 619 620 #Title 621 fig.suptitle(title_1d) 622 fig.tight_layout() 623 624 #Show and save 625 fig = plt.gcf() 626 if show: 627 plt.show() 628 if save: 629 fig.savefig(file_name + '_slices.png') 630 plt.close() 631 632 #2d plot 633 if d2: 634 #Initialize 635 fig, ax = plt.subplots(1,2) 636 l1 = jnp.unique(xt[:,-1]).shape[0] 637 l2 = jnp.unique(xt[:,0]).shape[0] 638 639 #Create 640 ax[0].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),u[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 641 ax[0].set_title('Exact') 642 ax[1].pcolormesh(xt[:,-1].reshape((l2,l1)),xt[:,0].reshape((l2,l1)),upred[:,0].reshape((l2,l1)),cmap = 'RdBu',vmin = ylo.tolist(),vmax = yup.tolist()) 643 ax[1].set_title('Predicted') 644 645 #Title 646 fig.suptitle(title_2d) 647 fig.tight_layout() 648 649 #Show and save 650 fig = plt.gcf() 651 if show: 652 plt.show() 653 if save: 654 fig.savefig(file_name + '_2d.png') 655 plt.close()
Plot the prediction of a 1D PINN
Parameters
- times (int): Number of points along the time interval to plot. Default 5
- xt (jax.numpy.array): Test data xt array
- u (jax.numpy.array): Test data u(x,t) array
- upred (jax.numpy.array): Predicted upred(x,t) array on test data
- d2 (logical): Whether to plot 2D plot. Default True
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- title_1d (str): Title of 1D plot
- title_2d (str): Title of 2D plot
Returns
- None
def
plot_pinn_out2D( times, xt, u, upred, save=False, show=True, file_name='result_pinn', title='', plot_test=True):
658def plot_pinn_out2D(times,xt,u,upred,save = False,show = True,file_name = 'result_pinn',title = '',plot_test = True): 659 """ 660 Plot the prediction of a PINN with 2D output 661 ---------- 662 Parameters 663 ---------- 664 times : int 665 666 Number of points along the time interval to plot. Default 5 667 668 xt : jax.numpy.array 669 670 Test data xt array 671 672 u : jax.numpy.array 673 674 Test data u(x,t) array 675 676 upred : jax.numpy.array 677 678 Predicted upred(x,t) array on test data 679 680 save : logical 681 682 Whether to save the plots. Default False 683 684 show : logical 685 686 Whether to show the plots. Default True 687 688 file_name : str 689 690 File prefix to save the plots. Default 'result_pinn' 691 692 title : str 693 694 Title of plot 695 696 plot_test : logical 697 698 Whether to plot the test data. Default True 699 700 Returns 701 ------- 702 None 703 """ 704 #Initialize 705 fig, ax = plt.subplots(int(times/5),5,figsize = (10*int(times/5),3*int(times/5))) 706 tlo = jnp.min(xt[:,-1]) 707 tup = jnp.max(xt[:,-1]) 708 xlo = jnp.min(u[:,0]) 709 xlo = xlo - 0.1*jnp.abs(xlo) 710 xup = jnp.max(u[:,0]) 711 xup = xup + 0.1*jnp.abs(xup) 712 ylo = jnp.min(u[:,1]) 713 ylo = ylo - 0.1*jnp.abs(ylo) 714 yup = jnp.max(u[:,1]) 715 yup = yup + 0.1*jnp.abs(yup) 716 k = 0 717 t_values = np.linspace(tlo,tup,times) 718 719 #Create 720 for i in range(int(times/5)): 721 for j in range(5): 722 if k < len(t_values): 723 t = t_values[k] 724 t = xt[jnp.abs(xt[:,-1] - t) == jnp.min(jnp.abs(xt[:,-1] - t)),-1][0].tolist() 725 xpred_plot = upred[xt[:,-1] == t,0] 726 ypred_plot = upred[xt[:,-1] == t,1] 727 if plot_test: 728 x_plot = u[xt[:,-1] == t,0] 729 y_plot = u[xt[:,-1] == t,1] 730 if int(times/5) > 1: 731 if plot_test: 732 ax[i,j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 733 ax[i,j].plot(xpred_plot,ypred_plot,'r-',linewidth=2,label='Prediction') 734 ax[i,j].set_title('$t = %.2f$' % (t),fontsize=10) 735 ax[i,j].set_xlabel(' ') 736 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 737 else: 738 if plot_test: 739 ax[j].plot(x_plot,y_plot,'b-',linewidth=2,label='Exact') 740 ax[j].plot(xpred_plot,ypred,'r-',linewidth=2,label='Prediction') 741 ax[j].set_title('$t = %.2f$' % (t),fontsize=10) 742 ax[j].set_xlabel(' ') 743 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 744 k = k + 1 745 746 #Title 747 fig.suptitle(title) 748 fig.tight_layout() 749 750 #Show and save 751 fig = plt.gcf() 752 if show: 753 plt.show() 754 if save: 755 fig.savefig(file_name + '_slices.png') 756 plt.close()
Plot the prediction of a PINN with 2D output
Parameters
- times (int): Number of points along the time interval to plot. Default 5
- xt (jax.numpy.array): Test data xt array
- u (jax.numpy.array): Test data u(x,t) array
- upred (jax.numpy.array): Predicted upred(x,t) array on test data
- save (logical): Whether to save the plots. Default False
- show (logical): Whether to show the plots. Default True
- file_name (str): File prefix to save the plots. Default 'result_pinn'
- title (str): Title of plot
- plot_test (logical): Whether to plot the test data. Default True
Returns
- None
def
get_train_data(train_data):
759def get_train_data(train_data): 760 """ 761 Process training sample 762 ---------- 763 764 Parameters 765 ---------- 766 train_data : dict 767 768 A dictionay with train data generated by the jinnax.data.generate_PINNdata function 769 770 Returns 771 ------- 772 dict with the processed training data 773 """ 774 xdata = None 775 ydata = None 776 xydata = None 777 if train_data['sensor'] is not None: 778 sensor_sample = train_data['sensor'].shape[0] 779 xdata = train_data['sensor'] 780 ydata = train_data['usensor'] 781 xydata = jnp.column_stack((train_data['sensor'],train_data['usensor'])) 782 else: 783 sensor_sample = 0 784 if train_data['boundary'] is not None: 785 boundary_sample = train_data['boundary'].shape[0] 786 if xdata is not None: 787 xdata = jnp.vstack((xdata,train_data['boundary'])) 788 ydata = jnp.vstack((ydata,train_data['uboundary'])) 789 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['boundary'],train_data['uboundary'])))) 790 else: 791 xdata = train_data['boundary'] 792 ydata = train_data['uboundary'] 793 xydata = jnp.column_stack((train_data['boundary'],train_data['uboundary'])) 794 else: 795 boundary_sample = 0 796 if train_data['initial'] is not None: 797 initial_sample = train_data['initial'].shape[0] 798 if xdata is not None: 799 xdata = jnp.vstack((xdata,train_data['initial'])) 800 ydata = jnp.vstack((ydata,train_data['uinitial'])) 801 xydata = jnp.vstack((xydata,jnp.column_stack((train_data['initial'],train_data['uinitial'])))) 802 else: 803 xdata = train_data['initial'] 804 ydata = train_data['uinitial'] 805 xydata = jnp.column_stack((train_data['initial'],train_data['uinitial'])) 806 else: 807 initial_sample = 0 808 if train_data['collocation'] is not None: 809 collocation_sample = train_data['collocation'].shape[0] 810 else: 811 collocation_sample = 0 812 813 return {'xy': xydata,'x': xdata,'y': ydata,'sensor_sample': sensor_sample,'boundary_sample': boundary_sample,'initial_sample': initial_sample,'collocation_sample': collocation_sample}
Process training sample
Parameters
- train_data (dict): A dictionay with train data generated by the jinnax.data.generate_PINNdata function
Returns
- dict with the processed training data
def
process_training( test_data, file_name, at_each=100, bolstering=True, mc_sample=10000, save=False, file_name_save='result_pinn', key=0, ec=1e-06, lamb=1):
816def process_training(test_data,file_name,at_each = 100,bolstering = True,mc_sample = 10000,save = False,file_name_save = 'result_pinn',key = 0,ec = 1e-6,lamb = 1): 817 """ 818 Process the training of a Physics-informed Neural Network 819 ---------- 820 821 Parameters 822 ---------- 823 test_data : dict 824 825 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 826 827 file_name : str 828 829 Name of the files saved during training 830 831 at_each : int 832 833 Compute results for epochs multiple of at_each. Default 100 834 835 bolstering : logical 836 837 Whether to compute bolstering mean square error. Default True 838 839 mc_sample : int 840 841 Number of sample for Monte Carlo integration in bolstering. Default 10000 842 843 save : logical 844 845 Whether to save the training results. Default False 846 847 file_name_save : str 848 849 File prefix to save the plots and the L2 error. Default 'result_pinn' 850 851 key : int 852 853 Key for random samples in bolstering. Default 0 854 855 ec : float 856 857 Stopping criteria error for EM algorithm in bolstering. Default 1e-6 858 859 lamb : float 860 861 Hyperparameter of EM algorithm in bolstering. Default 1 862 863 Returns 864 ------- 865 pandas data frame with training results 866 """ 867 #Config 868 config = pickle.load(open(file_name + '_config.pickle', 'rb')) 869 epochs = config['epochs'] 870 train_data = config['train_data'] 871 forward = config['forward'] 872 873 #Get train data 874 td = get_train_data(train_data) 875 xydata = td['xy'] 876 xdata = td['x'] 877 ydata = td['y'] 878 sensor_sample = td['sensor_sample'] 879 boundary_sample = td['boundary_sample'] 880 initial_sample = td['initial_sample'] 881 collocation_sample = td['collocation_sample'] 882 883 #Generate keys 884 if bolstering: 885 keys = jax.random.split(jax.random.PRNGKey(key),epochs) 886 887 #Initialize loss 888 train_mse = [] 889 test_mse = [] 890 train_L2 = [] 891 test_L2 = [] 892 bolstX = [] 893 bolstXY = [] 894 loss = [] 895 time = [] 896 ep = [] 897 898 #Process training 899 with alive_bar(epochs) as bar: 900 for e in range(epochs): 901 if (e % at_each == 0 and at_each != epochs) or e == epochs - 1: 902 ep = ep + [e] 903 904 #Read parameters 905 params = pickle.load(open(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle','rb')) 906 907 #Time 908 time = time + [params['time']] 909 910 #Define learned function 911 def psi(x): 912 return forward(x,params['params']['net']) 913 914 #Train MSE and L2 915 if xdata is not None: 916 train_mse = train_mse + [jnp.mean(MSE(psi(xdata),ydata)).tolist()] 917 train_L2 = train_L2 + [L2error(psi(xdata),ydata).tolist()] 918 else: 919 train_mse = train_mse + [None] 920 train_L2 = train_L2 + [None] 921 922 #Test MSE and L2 923 test_mse = test_mse + [jnp.mean(MSE(psi(test_data['xt']),test_data['u'])).tolist()] 924 test_L2 = test_L2 + [L2error(psi(test_data['xt']),test_data['u']).tolist()] 925 926 #Bolstering 927 if bolstering: 928 bX = [] 929 bXY = [] 930 for method in ['chi','mm','mpe']: 931 kxy = gk.kernel_estimator(data = xydata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 932 kx = gk.kernel_estimator(data = xdata,key = keys[e,0],method = method,lamb = lamb,ec = ec,psi = psi) 933 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 934 bXY = bXY + [gb.bolstering(psi,xdata,ydata,kxy,key = keys[e,0],mc_sample = mc_sample).tolist()] 935 for bias in [1/jnp.sqrt(xdata.shape[0]),1/xdata.shape[0],1/(xdata.shape[0] ** 2),1/(xdata.shape[0] ** 3),1/(xdata.shape[0] ** 4)]: 936 kx = gk.kernel_estimator(data = xydata,key = keys[e,0],method = 'hessian',lamb = lamb,ec = ec,psi = psi,bias = bias) 937 bX = bX + [gb.bolstering(psi,xdata,ydata,kx,key = keys[e,0],mc_sample = mc_sample).tolist()] 938 bolstX = bolstX + [bX] 939 bolstXY = bolstXY + [bXY] 940 else: 941 bolstX = bolstX + [None] 942 bolstXY = bolstXY + [None] 943 944 #Loss 945 loss = loss + [params['loss'].tolist()] 946 947 #Delete 948 del params, psi 949 #Update alive_bar 950 bar() 951 952 #Bolstering results 953 if bolstering: 954 bolstX = jnp.array(bolstX) 955 bolstXY = jnp.array(bolstXY) 956 957 #Create data frame 958 if bolstering: 959 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 960 train_mse,test_mse,train_L2,test_L2,bolstX[:,0],bolstXY[:,0],bolstX[:,1],bolstXY[:,1],bolstX[:,2],bolstXY[:,2],bolstX[:,3],bolstX[:,4],bolstX[:,5],bolstX[:,6],bolstX[:,7]]), 961 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2','bolstX_chi','bolstXY_chi','bolstX_mm','bolstXY_mm','bolstX_mpe','bolstXY_mpe','bolstHessian_sqrtn','bolstHessian_n','bolstHessian_n2','bolstHessian_n3','bolstHessian_n4']) 962 else: 963 df = pd.DataFrame(np.column_stack([ep,time,[sensor_sample] * len(ep),[boundary_sample] * len(ep),[initial_sample] * len(ep),[collocation_sample] * len(ep),loss, 964 train_mse,test_mse,train_L2,test_L2]), 965 columns=['epoch','training_time','sensor_sample','boundary_sample','initial_sample','collocation_sample','loss','train_mse','test_mse','train_L2','test_L2']) 966 if save: 967 df.to_csv(file_name_save + '.csv',index = False) 968 969 return df
Process the training of a Physics-informed Neural Network
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- at_each (int): Compute results for epochs multiple of at_each. Default 100
- bolstering (logical): Whether to compute bolstering mean square error. Default True
- mc_sample (int): Number of sample for Monte Carlo integration in bolstering. Default 10000
- save (logical): Whether to save the training results. Default False
- file_name_save (str): File prefix to save the plots and the L2 error. Default 'result_pinn'
- key (int): Key for random samples in bolstering. Default 0
- ec (float): Stopping criteria error for EM algorithm in bolstering. Default 1e-6
- lamb (float): Hyperparameter of EM algorithm in bolstering. Default 1
Returns
- pandas data frame with training results
def
demo_train_pinn1D( test_data, file_name, at_each=100, times=5, d2=True, file_name_save='result_pinn_demo', title='', framerate=10):
972def demo_train_pinn1D(test_data,file_name,at_each = 100,times = 5,d2 = True,file_name_save = 'result_pinn_demo',title = '',framerate = 10): 973 """ 974 Demo video with the training of a 1D PINN 975 ---------- 976 977 Parameters 978 ---------- 979 test_data : dict 980 981 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 982 983 file_name : str 984 985 Name of the files saved during training 986 987 at_each : int 988 989 Compute results for epochs multiple of at_each. Default 100 990 991 times : int 992 993 Number of points along the time interval to plot. Default 5 994 995 d2 : logical 996 997 Whether to make video demo of 2D plot. Default True 998 999 file_name_save : str 1000 1001 File prefix to save the plots and videos. Default 'result_pinn_demo' 1002 1003 title : str 1004 1005 Title for plots 1006 1007 framerate : int 1008 1009 Framerate for video. Default 10 1010 1011 Returns 1012 ------- 1013 None 1014 """ 1015 #Config 1016 with open(file_name + '_config.pickle', 'rb') as file: 1017 config = pickle.load(file) 1018 epochs = config['epochs'] 1019 train_data = config['train_data'] 1020 forward = config['forward'] 1021 1022 #Get train data 1023 td = get_train_data(train_data) 1024 xt = td['x'] 1025 u = td['y'] 1026 1027 #Create folder to save plots 1028 os.system('mkdir ' + file_name_save) 1029 1030 #Create images 1031 k = 1 1032 with alive_bar(epochs) as bar: 1033 for e in range(epochs): 1034 if e % at_each == 0 or e == epochs - 1: 1035 #Read parameters 1036 params = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1037 1038 #Define learned function 1039 def psi(x): 1040 return forward(x,params['params']['net']) 1041 1042 #Compute L2 train, L2 test and loss 1043 loss = params['loss'] 1044 L2_train = L2error(psi(xt),u) 1045 L2_test = L2error(psi(test_data['xt']),test_data['u']) 1046 title_epoch = title + ' Epoch = ' + str(e) + ' L2 train = ' + str(round(L2_train,6)) + ' L2 test = ' + str(round(L2_test,6)) 1047 1048 #Save plot 1049 plot_pinn1D(times,test_data['xt'],test_data['u'],psi(test_data['xt']),d2,save = True,show = False,file_name = file_name_save + '/' + str(k),title_1d = title_epoch,title_2d = title_epoch) 1050 k = k + 1 1051 1052 #Delete 1053 del params, psi, loss, L2_train, L2_test, title_epoch 1054 #Update alive_bar 1055 bar() 1056 #Create demo video 1057 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_slices.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_slices.mp4') 1058 if d2: 1059 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d_2d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_2d.mp4')
Demo video with the training of a 1D PINN
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- at_each (int): Compute results for epochs multiple of at_each. Default 100
- times (int): Number of points along the time interval to plot. Default 5
- d2 (logical): Whether to make video demo of 2D plot. Default True
- file_name_save (str): File prefix to save the plots and videos. Default 'result_pinn_demo'
- title (str): Title for plots
- framerate (int): Framerate for video. Default 10
Returns
- None
def
demo_time_pinn1D( test_data, file_name, epochs, file_name_save='result_pinn_time_demo', title='', framerate=10):
1062def demo_time_pinn1D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10): 1063 """ 1064 Demo video with the time evolution of a 1D PINN 1065 ---------- 1066 1067 Parameters 1068 ---------- 1069 test_data : dict 1070 1071 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1072 1073 file_name : str 1074 1075 Name of the files saved during training 1076 1077 epochs : list 1078 1079 Which training epochs to plot 1080 1081 file_name_save : str 1082 1083 File prefix to save the plots and video. Default 'result_pinn_time_demo' 1084 1085 title : str 1086 1087 Title for plots 1088 1089 framerate : int 1090 1091 Framerate for video. Default 10 1092 1093 Returns 1094 ------- 1095 None 1096 """ 1097 #Config 1098 with open(file_name + '_config.pickle', 'rb') as file: 1099 config = pickle.load(file) 1100 train_data = config['train_data'] 1101 forward = config['forward'] 1102 1103 #Create folder to save plots 1104 os.system('mkdir ' + file_name_save) 1105 1106 #Plot parameters 1107 tdom = jnp.unique(test_data['xt'][:,-1]) 1108 ylo = jnp.min(test_data['u']) 1109 ylo = ylo - 0.1*jnp.abs(ylo) 1110 yup = jnp.max(test_data['u']) 1111 yup = yup + 0.1*jnp.abs(yup) 1112 1113 #Open PINN for each epoch 1114 results = [] 1115 upred = [] 1116 for e in epochs: 1117 tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1118 results = results + [tmp] 1119 upred = upred + [forward(test_data['xt'],tmp['params']['net'])] 1120 1121 #Create images 1122 k = 1 1123 with alive_bar(len(tdom)) as bar: 1124 for t in tdom: 1125 #Test data 1126 xt_step = test_data['xt'][test_data['xt'][:,-1] == t] 1127 u_step = test_data['u'][test_data['xt'][:,-1] == t] 1128 #Initialize plot 1129 if len(epochs) > 1: 1130 fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2)) 1131 else: 1132 fig, ax = plt.subplots(1,1,figsize = (10,5)) 1133 #Create 1134 index = 0 1135 if int(len(epochs)/2) > 1: 1136 for i in range(int(len(epochs)/2)): 1137 for j in range(min(2,len(epochs))): 1138 upred_step = upred[index][test_data['xt'][:,-1] == t] 1139 ax[i,j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 1140 ax[i,j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 1141 ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1142 ax[i,j].set_xlabel(' ') 1143 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1144 index = index + 1 1145 elif len(epochs) > 1: 1146 for j in range(2): 1147 upred_step = upred[index][test_data['xt'][:,-1] == t] 1148 ax[j].plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 1149 ax[j].plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 1150 ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1151 ax[j].set_xlabel(' ') 1152 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1153 index = index + 1 1154 else: 1155 upred_step = upred[index][test_data['xt'][:,-1] == t] 1156 ax.plot(xt_step[:,0],u_step[:,0],'b-',linewidth=2,label='Exact') 1157 ax.plot(xt_step[:,0],upred_step[:,0],'r--',linewidth=2,label='Prediction') 1158 ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1159 ax.set_xlabel(' ') 1160 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1161 index = index + 1 1162 1163 1164 #Title 1165 fig.suptitle(title + 't = ' + str(round(t,4))) 1166 fig.tight_layout() 1167 1168 #Show and save 1169 fig = plt.gcf() 1170 fig.savefig(file_name_save + '/' + str(k) + '.png') 1171 k = k + 1 1172 plt.close() 1173 bar() 1174 1175 #Create demo video 1176 os.system('ffmpeg -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
Demo video with the time evolution of a 1D PINN
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- epochs (list): Which training epochs to plot
- file_name_save (str): File prefix to save the plots and video. Default 'result_pinn_time_demo'
- title (str): Title for plots
- framerate (int): Framerate for video. Default 10
Returns
- None
def
demo_time_pinn2D( test_data, file_name, epochs, file_name_save='result_pinn_time_demo', title='', framerate=10, ffmpeg='ffmpeg'):
1179def demo_time_pinn2D(test_data,file_name,epochs,file_name_save = 'result_pinn_time_demo',title = '',framerate = 10,ffmpeg = 'ffmpeg'): 1180 """ 1181 Demo video with the time evolution of a 2D PINN 1182 ---------- 1183 Parameters 1184 ---------- 1185 test_data : dict 1186 1187 A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function 1188 1189 file_name : str 1190 1191 Name of the files saved during training 1192 1193 epochs : list 1194 1195 Which training epochs to plot 1196 1197 file_name_save : str 1198 1199 File prefix to save the plots and video. Default 'result_pinn_time_demo' 1200 1201 title : str 1202 1203 Title for plots 1204 1205 framerate : int 1206 1207 Framerate for video. Default 10 1208 1209 ffmpeg : str 1210 1211 Path to ffmpeg 1212 1213 Returns 1214 ------- 1215 None 1216 """ 1217 #Config 1218 with open(file_name + '_config.pickle', 'rb') as file: 1219 config = pickle.load(file) 1220 train_data = config['train_data'] 1221 forward = config['forward'] 1222 1223 #Create folder to save plots 1224 os.system('mkdir ' + file_name_save) 1225 1226 #Plot parameters 1227 tdom = jnp.unique(test_data['xt'][:,-1]) 1228 ylo = jnp.min(test_data['u']) 1229 ylo = ylo - 0.1*jnp.abs(ylo) 1230 yup = jnp.max(test_data['u']) 1231 yup = yup + 0.1*jnp.abs(yup) 1232 1233 #Open PINN for each epoch 1234 results = [] 1235 upred = [] 1236 for e in epochs: 1237 tmp = pd.read_pickle(file_name + '_epoch' + str(e).rjust(6, '0') + '.pickle') 1238 results = results + [tmp] 1239 upred = upred + [forward(test_data['xt'],tmp['params']['net'])] 1240 1241 #Create images 1242 k = 1 1243 with alive_bar(len(tdom)) as bar: 1244 for t in tdom: 1245 #Test data 1246 xt_step = test_data['xt'][test_data['xt'][:,-1] == t] 1247 ux_step = test_data['u'][test_data['xt'][:,-1] == t,0] 1248 uy_step = test_data['u'][test_data['xt'][:,-1] == t,1] 1249 #Initialize plot 1250 if len(epochs) > 1: 1251 fig, ax = plt.subplots(int(len(epochs)/2),2,figsize = (10,5*len(epochs)/2)) 1252 else: 1253 fig, ax = plt.subplots(1,1,figsize = (10,5)) 1254 #Create 1255 index = 0 1256 if int(len(epochs)/2) > 1: 1257 for i in range(int(len(epochs)/2)): 1258 for j in range(min(2,len(epochs))): 1259 upredx_step = upred[index][test_data['xt'][:,-1] == t,0] 1260 upredy_step = upred[index][test_data['xt'][:,-1] == t,1] 1261 ax[i,j].plot(ux_step,uy_step,'b-',linewidth=2,label='Exact') 1262 ax[i,j].plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction') 1263 ax[i,j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1264 ax[i,j].set_xlabel(' ') 1265 ax[i,j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1266 index = index + 1 1267 elif len(epochs) > 1: 1268 for j in range(2): 1269 upredx_step = upred[index][test_data['xt'][:,-1] == t,0] 1270 upredy_step = upred[index][test_data['xt'][:,-1] == t,1] 1271 ax[j].plot(ux_step,uy_step,'b-',linewidth=2,label='Exact') 1272 ax[j].plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction') 1273 ax[j].set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1274 ax[j].set_xlabel(' ') 1275 ax[j].set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1276 index = index + 1 1277 else: 1278 upredx_step = upred[index][test_data['xt'][:,-1] == t,0] 1279 upredy_step = upred[index][test_data['xt'][:,-1] == t,1] 1280 ax.plot(ux_step,uy_step,'b-',linewidth=2,label='Exact') 1281 ax.plot(upredx_step,upredy_step,'r-',linewidth=2,label='Prediction') 1282 ax.set_title('Epoch = ' + str(epochs[index]),fontsize=10) 1283 ax.set_xlabel(' ') 1284 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 1285 index = index + 1 1286 1287 1288 #Title 1289 fig.suptitle(title + 't = ' + str(round(t,4))) 1290 fig.tight_layout() 1291 1292 #Show and save 1293 fig = plt.gcf() 1294 fig.savefig(file_name_save + '/' + str(k) + '.png') 1295 k = k + 1 1296 plt.close() 1297 bar() 1298 1299 #Create demo video 1300 os.system(ffmpeg + ' -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '/' + file_name_save + '_time_demo.mp4')
Demo video with the time evolution of a 2D PINN
Parameters
- test_data (dict): A dictionay with test data for L2 error calculation generated by the jinnax.data.generate_PINNdata function
- file_name (str): Name of the files saved during training
- epochs (list): Which training epochs to plot
- file_name_save (str): File prefix to save the plots and video. Default 'result_pinn_time_demo'
- title (str): Title for plots
- framerate (int): Framerate for video. Default 10
- ffmpeg (str): Path to ffmpeg
Returns
- None
def
DN_CSF_circle( uinitial, xl, xu, tl, tu, width, radius, Ntb=100, N0=100, Nc=50, Ntc=50, Ns=100, Nts=100, epochs=100, at_each=10, activation='tanh', sa=True, lr=0.001, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, key=0, epoch_print=100, save=False, file_name='result_pinn', exp_decay=False, transition_steps=1000, decay_rate=0.9, demo=True, framerate=2, ffmpeg='ffmpeg', c=1e-06):
1302def DN_CSF_circle(uinitial,xl,xu,tl,tu,width,radius,Ntb = 100,N0 = 100,Nc = 50,Ntc = 50,Ns = 100,Nts = 100,epochs = 100,at_each = 10,activation = 'tanh',sa = True,lr = 0.001,b1 = 0.9,b2 = 0.999,eps = 1e-08,eps_root = 0.0,key = 0,epoch_print = 100,save = False,file_name = 'result_pinn',exp_decay = False,transition_steps = 1000,decay_rate = 0.9,demo = True,framerate = 2,ffmpeg = 'ffmpeg',c = 1e-6): 1303 #If demo, then save 1304 if demo: 1305 save = True 1306 1307 #Define initial function and function to evaluate at boundary 1308 def uinit(x,t): 1309 u = uinitial(x,t) 1310 return jnp.append(u['u1'],u['u2']) 1311 1312 def ubound(x,t): 1313 u = uinitial(x,t) 1314 return jnp.append(u['u1'],u['u2'],1) 1315 1316 1317 #PDE operator 1318 def pde(u,x,t): 1319 #One function for each coordinate (assuming that x and t has dimension 1 x 1 and u(x,t) has dimension 1 x 2) 1320 u1 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,0][0] 1321 u2 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,1][0] 1322 #First derivatives of each coordinate 1323 ux1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t)) 1324 ux2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t)) 1325 ut1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),1)(x,t)) 1326 ut2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),1)(x,t)) 1327 #Second derivative of each coordinate 1328 ux1_tmp = lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t) 1329 ux2_tmp = lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t) 1330 uxx1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : ux1_tmp(x,t)[0],0)(x,t)) 1331 uxx2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : ux2_tmp(x,t)[0],0)(x,t)) 1332 #Return 1333 return jnp.sqrt((ut1(x,t) - uxx1(x,t)/(ux1(x,t) ** 2 + ux2(x,t) ** 2 + c)) ** 2 + (ut2(x,t) - uxx2(x,t)/(ux1(x,t) ** 2 + ux2(x,t) ** 2 + c)) ** 2) 1334 1335 #Operator to evaluate boundary conditions 1336 def oper_boundary(u,x,t,w = 1,Ntb = Ntb): 1337 #Enforce Dirichlet at the right boundary (fixed at point a, as the initial condition) 1338 res_right_dir = jnp.sum(jnp.where(x == xu,(u(x,t) - ubound(x,t)) ** 2,0),1).reshape(x.shape[0],1) 1339 #Enforce Dirichlet at the left boundary (is in the circle of radius fixed) 1340 res_left_dir = jnp.sum(jnp.where(x == xl,((jnp.sum(u(x,t) ** 2,1) - radius ** 2) ** 2).reshape(x.shape),0),1).reshape(x.shape[0],1) 1341 #One function for each coordinate (assuming that x and t has dimension 1 x 1 and u(x,t) has dimension 1 x 2) 1342 u1 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,0][0] 1343 u2 = lambda x,t: u(x.reshape((x.shape[0],1)),t.reshape((t.shape[0],1)))[:,1][0] 1344 #Take the derivatives in x 1345 ux1 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u1(x,t),0)(x,t))(x,t) 1346 ux2 = jax.vmap(lambda x,t : jax.grad(lambda x,t : u2(x,t),0)(x,t))(x,t) 1347 #Enforce Neumann at the left boundary 1348 nS = u(x,t)/jnp.sqrt(jnp.sum(u(x,t) ** 2,0)) #Assuming that u(x,y) \in S, compute the vector normal to S at u(x,t) 1349 nu = jnp.append(ux2,(-1)*ux1,1)/jnp.sqrt(ux1 ** 2 + ux2 ** 2) 1350 ip = jnp.sum(nS * nu,1).reshape(x.shape[0],1) ** 2 1351 res_left_neu = jnp.where(x == xl,ip,0) 1352 #Rearrange 1353 res = jnp.append(jnp.append(res_right_dir[:Ntb,:],res_left_dir[Ntb:2*Ntb,:],0),res_left_neu[2*Ntb:,:],0) 1354 return w*res 1355 1356 #Generate Data 1357 train_data = jd.generate_PINNdata(u = uinit,xl = xl,xu = xu,tl = tl,tu = tu,Ns = None,Nts = None,Nb = 2,Ntb = Ntb,N0 = N0,Nc = Nc,Ntc = Ntc,p = 2,poss = 'random',posts = 'random',pos0 = 'random',postb = 'random',posc = 'random',postc = 'random') 1358 1359 #Rearange boundary data 1360 train_data['boundary'] = jnp.append(jnp.append(train_data['boundary'][Ntb:,:],train_data['boundary'][:Ntb,:],0),train_data['boundary'][:Ntb,:],0) 1361 train_data['uboundary'] = jnp.append(jnp.append(train_data['uboundary'][Ntb:,:],train_data['uboundary'][:Ntb,:],0),train_data['uboundary'][:Ntb,:],0) 1362 1363 #Train PINN 1364 fit = train_PINN(train_data,width,pde,c = {'ws': 1,'wr': 1,'w0': 1,'wb': 1},test_data = None,epochs = epochs,at_each = at_each,activation = activation,neumann = True,oper_neumann = oper_boundary,sa = sa,lr = lr,b1 = b1,b2 = b2,eps = eps,eps_root = eps_root,key = key,epoch_print = epoch_print,save = save,file_name = file_name,exp_decay = exp_decay,transition_steps = transition_steps,decay_rate = decay_rate) 1365 1366 #Test data 1367 test_data = jd.generate_PINNdata(u = uinit,xl = xl,xu = xu,tl = tl,tu = tu,Ns = None,Nts = None,Nb = 2,Ntb = 2*Ntb,N0 = 2*N0,Nc = 2*Nc,Ntc = 2*Ntc,p = 2,poss = 'random',posts = 'random',pos0 = 'random',postb = 'random',posc = 'random',postc = 'random') 1368 Ntb = 2*Ntb 1369 test_data['boundary'] = jnp.append(jnp.append(test_data['boundary'][Ntb:,:],test_data['boundary'][:Ntb,:],0),test_data['boundary'][:Ntb,:],0) 1370 test_data['uboundary'] = jnp.append(jnp.append(test_data['uboundary'][Ntb:,:],test_data['uboundary'][:Ntb,:],0),test_data['uboundary'][:Ntb,:],0) 1371 1372 #Evaluate residuals 1373 def u(x,t): 1374 return fit['u'](jnp.append(x,t,1)) 1375 1376 res_pde = jnp.mean(pde(u,test_data['collocation'][:,0].reshape((test_data['collocation'].shape[0],1)),test_data['collocation'][:,1].reshape((test_data['collocation'].shape[0],1))) ** 2) 1377 res_DN = oper_boundary(u,test_data['boundary'][:,0].reshape((test_data['boundary'].shape[0],1)),test_data['boundary'][:,1].reshape((test_data['boundary'].shape[0],1)),Ntb = Ntb) 1378 res_dir_right = jnp.mean(res_DN[:Ntb,:] ** 2) 1379 res_neu = jnp.mean(res_DN[2*Ntb:,:] ** 2) 1380 res_dir_left = jnp.mean(res_DN[Ntb:2*Ntb,:] ** 2) 1381 res_initial = jnp.mean((u(test_data['initial'][:,0].reshape((test_data['initial'].shape[0],1)),test_data['initial'][:,1].reshape((test_data['initial'].shape[0],1))) - test_data['uinitial']) ** 2) 1382 1383 #Save file 1384 res_data = pd.DataFrame({'PDE': [res_pde.tolist()], 1385 'Dirichlet_Right': [res_dir_right.tolist()], 1386 'Dirichlet_Left': [res_dir_left.tolist()], 1387 'Neumann': [res_neu.tolist()], 1388 'initial': res_initial, 1389 'time': fit['time'], 1390 'epochs': epochs}) 1391 res_data.to_csv(file_name + '_residuals.csv') 1392 1393 if demo: 1394 def ucircle(x,t): 1395 y = 2*jnp.pi*(x - xl)/(xu - xl) 1396 return jnp.append(radius*jnp.sin(y),radius*jnp.cos(y),0) 1397 test_data = jd.generate_PINNdata(u = ucircle,xl = xl,xu = xu,tl = tl,tu = tu,Ns = Ns,Nts = Nts,Nb = 0,Ntb = 0,N0 = 0,Nc = 0,Ntc = 0,p = 2,train = False) 1398 demo_time_pinn2D(test_data,file_name,[epochs-1],file_name_save = file_name + '_demo',title = '',framerate = framerate,ffmpeg = ffmpeg) 1399 1400 return fit,res_data