jinnax.csf
1#Funtions to train PINNs for csf 2from absl import logging 3from alive_progress import alive_bar 4from functools import partial 5import jax 6from jax import grad, hessian, jit, lax, vmap, jacrev 7import jax.numpy as jnp 8from jax.tree_util import tree_map 9from jaxpi import archs 10from jaxpi.evaluator import BaseEvaluator 11from jaxpi.logging import Logger 12from jaxpi.models import ForwardIVP 13from jaxpi.samplers import UniformSampler 14from jaxpi.utils import ntk_fn, restore_checkpoint, save_checkpoint 15from jinnax import class_csf 16import matplotlib.pyplot as plt 17import ml_collections 18import numpy as np 19import optax 20import os 21import pandas as pd 22import random 23import scipy.io 24import sys 25import time 26import wandb 27 28def get_base_config(): 29 """ 30 Base config file for training PINN for CSF in jaxpi 31 32 Returns 33 ------- 34 ml_collections config dictionary 35 """ 36 #Get the default hyperparameter configuration. 37 config = ml_collections.ConfigDict() 38 39 # Weights & Biases 40 config.wandb = wandb = ml_collections.ConfigDict() 41 wandb.tag = None 42 43 # Arch 44 config.arch = arch = ml_collections.ConfigDict() 45 arch.arch_name = "ModifiedMlp" 46 arch.num_layers = 4 47 arch.hidden_dim = 256 48 arch.out_dim = 2 49 arch.activation = "tanh" 50 arch.reparam = ml_collections.ConfigDict( 51 {"type": "weight_fact", "mean": 1.0, "stddev": 0.1} 52 ) 53 54 # Optim 55 config.optim = optim = ml_collections.ConfigDict() 56 optim.optimizer = "Adam" 57 optim.beta1 = 0.9 58 optim.beta2 = 0.999 59 optim.eps = 1e-8 60 optim.learning_rate = 1e-3 61 optim.decay_rate = 0.9 62 optim.decay_steps = 2000 63 optim.grad_accum_steps = 0 64 optim.warmup_steps = 0 65 optim.grad_accum_steps = 0 66 67 # Training 68 config.training = training = ml_collections.ConfigDict() 69 training.batch_size_per_device = 4096 70 71 # Weighting 72 config.weighting = weighting = ml_collections.ConfigDict() 73 weighting.scheme = "grad_norm" 74 weighting.momentum = 0.9 75 weighting.update_every_steps = 1000 76 77 weighting.use_causal = True 78 weighting.causal_tol = 1.0 79 weighting.num_chunks = 16 80 optim.staircase = False 81 82 # Logging 83 config.logging = logging = ml_collections.ConfigDict() 84 logging.log_every_steps = 1000 85 logging.log_errors = True 86 logging.log_losses = True 87 logging.log_weights = True 88 logging.log_preds = False 89 logging.log_grads = False 90 logging.log_ntk = False 91 92 # Saving 93 config.saving = saving = ml_collections.ConfigDict() 94 saving.save_every_steps = 1000 95 saving.num_keep_ckpts = 1000000 96 97 # # Input shape for initializing Flax models 98 config.input_dim = 2 99 100 return config 101 102 103#Demo in time for 2D PINN 104def demo_time_CSF(data,type = 'DN',radius = None,file_name_save = 'result_pinn_CSF_demo',title = '',framerate = 10,ffmpeg = 'ffmpeg'): 105 """ 106 Demo video with the time evolution of a CSF in a circle 107 ---------- 108 Parameters 109 ---------- 110 data : jax.array 111 112 Data with the predicted values 113 114 type : str 115 116 Type of problem 117 118 radius : float 119 120 Radius of circle for types 'DN' and 'NN' 121 122 file_name_save : str 123 124 File prefix to save the plots and video. Default 'result_pinn_CSF_demo' 125 126 title : str 127 128 Title for plots 129 130 framerate : int 131 132 Framerate for video. Default 10 133 134 ffmpeg : str 135 136 Path to ffmpeg 137 138 Returns 139 ------- 140 None 141 """ 142 #Create folder to save plots 143 os.system('mkdir ' + file_name_save) 144 145 #Plot parameters 146 tdom = jnp.unique(data[:,0]) 147 ylo = jnp.min(data[data[:,0] == jnp.min(tdom),3]) 148 ylo = ylo - 0.1*jnp.abs(ylo) 149 yup = jnp.max(data[data[:,0] == jnp.min(tdom),3]) 150 yup = yup + 0.1*jnp.abs(yup) 151 152 #Circle data 153 if type == 'DN' or type == 'NN': 154 circle = jnp.array([[radius*jnp.sin(t),radius*jnp.cos(t)] for t in jnp.linspace(0,2*jnp.pi,1000)]) 155 156 #Create images 157 k = 1 158 with alive_bar(len(tdom)) as bar: 159 for t in tdom: 160 #Test data 161 x_step = data[data[:,0] == t,1] 162 ux_step = data[data[:,0] == t,2] 163 uy_step = data[data[:,0] == t,3] 164 #Initialize plot 165 fig, ax = plt.subplots(1,1,figsize = (10,10)) 166 #Create 167 ax.plot(ux_step,uy_step,'b-',linewidth=2) 168 if type == 'DN' or type == 'NN': 169 ax.plot(circle[:,0],circle[:,1],'r-',linewidth=2) 170 ax.set_xlabel(' ') 171 if type != 'DN' and type != 'NN': 172 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 173 #Title 174 fig.suptitle(title + 't = ' + str(round(t,4))) 175 fig.tight_layout() 176 177 #Show and save 178 fig = plt.gcf() 179 fig.savefig(file_name_save + '/' + str(k) + '.png') 180 k = k + 1 181 plt.close() 182 bar() 183 184 #Create demo video 185 os.system(ffmpeg + ' -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '_time_demo.mp4') 186 187 188def train_csf(config: ml_collections.ConfigDict): 189 """ 190 Train PINN for CSF in jaxpi 191 ---------- 192 Parameters 193 ---------- 194 config : ml_collections.ConfigDict 195 196 Dictionary for training PINN in jaxpi 197 198 Returns 199 ------- 200 model, log_dict 201 """ 202 if config.save_wandb: 203 wandb_config = config.wandb 204 wandb.init(project = wandb_config.project, name = wandb_config.name) 205 206 # Define the time and space domain 207 dom = jnp.array([[config.tl, config.tu], [config.xl, config.xu]]) 208 209 # Initialize the residual sampler 210 res_sampler = iter(UniformSampler(dom, config.training.batch_size_per_device)) 211 212 # Initialize the model 213 if config.type_csf == 'DN': 214 model = class_csf.DN_csf(config) 215 elif config.type_csf == 'NN': 216 model = class_csf.NN_csf(config) 217 elif config.type_csf == 'DD': 218 model = class_csf.DD_csf(config) 219 elif config.type_csf == 'closed': 220 model = class_csf.closed_csf(config) 221 222 # Logger 223 logger = Logger() 224 225 # Initialize evaluator 226 key = jax.random.split(jax.random.PRNGKey(config.seed),4) 227 x0_test = jax.random.uniform(key = jax.random.PRNGKey(key[0,0]),minval = config.xl,maxval = config.xu,shape = (config.N0,1)) 228 u1_0_test,u2_0_test = config.uinitial(x0_test) 229 tb_test = jax.random.uniform(key = jax.random.PRNGKey(key[1,0]),minval = config.tl,maxval = config.tu,shape = (config.Nb,1)) 230 xc_test = jax.random.uniform(key = jax.random.PRNGKey(key[2,0]),minval = config.xl,maxval = config.xu,shape = (config.Nc ** 2,1)) 231 tc_test = jax.random.uniform(key = jax.random.PRNGKey(key[3,0]),minval = config.tl,maxval = config.tu,shape = (config.Nc ** 2,1)) 232 if config.type_csf == 'DN': 233 evaluator = class_csf.DN_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test) 234 elif config.type_csf == 'NN': 235 evaluator = class_csf.NN_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test) 236 elif config.type_csf == 'DD': 237 evaluator = class_csf.DD_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test) 238 elif config.type_csf == 'closed': 239 evaluator = class_csf.closed_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test) 240 241 # jit warm up 242 print("Training CSF...") 243 start_time = time.time() 244 t0 = start_time 245 for step in range(config.training.max_steps): 246 batch = next(res_sampler) 247 model.state = model.step(model.state, batch) 248 249 # Update weights if necessary 250 if config.weighting.scheme in ["grad_norm", "ntk"]: 251 if step % config.weighting.update_every_steps == 0: 252 model.state = model.update_weights(model.state, batch) 253 254 # Log training metrics, only use host 0 to record results 255 if jax.process_index() == 0: 256 if step % config.logging.log_every_steps == 0: 257 # Get the first replica of the state and batch 258 state = jax.device_get(tree_map(lambda x: x[0], model.state)) 259 batch = jax.device_get(tree_map(lambda x: x[0], batch)) 260 log_dict = evaluator(state, batch) 261 if config.save_wandb: 262 wandb.log(log_dict, step) 263 end_time = time.time() 264 logger.log_iter(step, start_time, end_time, log_dict) 265 start_time = end_time 266 267 # Saving 268 if config.saving.save_every_steps is not None: 269 if (step + 1) % config.saving.save_every_steps == 0 or ( 270 step + 1 271 ) == config.training.max_steps: 272 ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name) 273 save_checkpoint(model.state, ckpt_path, keep=config.saving.num_keep_ckpts) 274 if config.type_csf == 'DN': 275 if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol and log_dict['ln_test'] < config.dn_tol: 276 break 277 elif config.type_csf == 'NN': 278 if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol and log_dict['ln_test'] < config.dn_tol and log_dict['rn_test'] < config.dn_tol: 279 break 280 elif config.type_csf == 'DD': 281 if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol: 282 break 283 elif config.type_csf == 'closed': 284 if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol: 285 break 286 287 #Run summary 288 log_dict['total_time'] = time.time() - t0 289 log_dict['epochs'] = step + 1 290 291 return model, log_dict 292 293def evaluate(config: ml_collections.ConfigDict): 294 """ 295 Evaluate PINN for CSF trained in jaxpi 296 ---------- 297 Parameters 298 ---------- 299 config : ml_collections.ConfigDict 300 301 Dictionary for training PINN in jaxpi 302 303 uninitial : function 304 305 Function that computes the initial condition 306 307 308 Returns 309 ------- 310 predicted values 311 """ 312 # Initialize the model 313 if config.type_csf == 'DN': 314 model = class_csf.DN_csf(config) 315 elif config.type_csf == 'NN': 316 model = class_csf.NN_csf(config) 317 elif config.type_csf == 'DD': 318 model = class_csf.DD_csf(config) 319 elif config.type_csf == 'closed': 320 model = class_csf.closed_csf(config) 321 322 # Restore the checkpoint 323 ckpt_path = os.path.join( 324 os.getcwd(), "ckpt", config.wandb.name, 325 ) 326 model.state = restore_checkpoint(model.state, ckpt_path) 327 params = model.state.params 328 329 #Collocation data 330 tx = jnp.array([[t,x] for t in jnp.linspace(config.tl, config.tu, config.Nt) for x in jnp.linspace(config.xl, config.xu, config.Nc)]) 331 332 #Predict 333 u1_pred = model.u1_pred_fn(params, tx[:,0], tx[:,1]) 334 u2_pred = model.u2_pred_fn(params, tx[:,0], tx[:,1]) 335 336 #Save 337 pred = jnp.append(tx,jnp.append(u1_pred.reshape((u1_pred.shape[0],1)),u2_pred.reshape((u2_pred.shape[0],1)),1),1) 338 jnp.save(config.wandb.project + '_' + config.wandb.name + '.npy',pred) 339 340 return pred 341 342def csf(uinitial,xl,xu,tl,tu,type = 'DN',radius = None,file_name = 'test',Nt = 400,N0 = 10000,Nb = 10000,Nc = 500,config = None,save_wandb = False,wandb_project = 'CSF_project',seed = 534,demo = True,max_epochs = 150000,res_tol = 5e-5,dn_tol = 5e-4,ic_tol = 0.01,framerate = 10,ffmpeg = 'ffmpeg'): 343 """ 344 Train PINN for CSF in jaxpi 345 ---------- 346 Parameters 347 ---------- 348 uninitial : function 349 350 Function that computes the initial condition 351 352 xl, xu, tl, tu : float 353 354 Limits of the x and t domain 355 356 type : str 357 358 Type of CSF problem to train: 'DN' or 'NN' (in a circle), 'DD' or 'closed' 359 360 radius : float 361 362 Radius of the circle for 'DN' and 'NN' 363 364 file_name : str 365 366 File name to save results 367 368 369 Nt : int 370 371 Sample size for grid in t 372 373 N0, Nb : int 374 375 Initial and boundary condition test sample size 376 377 Nc : int 378 379 Number of points in each direction in sample size for PDE residuals 380 381 382 config : ml_collections.ConfigDict 383 384 Config dictionary to train PINNs in jaxpi. If not provided, use basic configurations 385 386 save_wandb : logical 387 388 Whether to save results in wandb 389 390 wandb_project : str 391 392 Name of wandb project 393 394 seed : int 395 396 Seed for initialising neural network 397 398 demo : logical 399 400 Whether to generate video with result 401 402 max_epochs : int 403 404 Maximum number of epochs to train 405 406 res_tol, dn_tol, ic_tol : float 407 408 Tolerance on test errors for early stop 409 410 framerate : int 411 412 Framerate for video. Default 10 413 414 ffmpeg : str 415 416 Path to ffmpeg 417 418 Returns 419 ------- 420 model, log_dict 421 """ 422 #Set config file 423 if config is None: 424 config = get_base_config() 425 426 config.wandb.project = wandb_project 427 config.wandb.name = file_name 428 config.seed = seed 429 config.type_csf = type 430 config.save_wandb = save_wandb 431 config.uinitial = uinitial 432 config.xl = xl 433 config.xu = xu 434 config.tl = tl 435 config.tu = tu 436 config.radius = radius 437 config.rd = jnp.append(uinitial(xu)[0],uinitial(xu)[1]) 438 config.ld = jnp.append(uinitial(xl)[0],uinitial(xl)[1]) 439 config.Nt = Nt 440 config.N0 = N0 441 config.Nc = Nc 442 config.Nb = Nb 443 config.res_tol = res_tol 444 config.dn_tol = dn_tol 445 config.ic_tol = ic_tol 446 config.training.max_steps = max_epochs 447 if type == 'DN': 448 config.weighting.init_weights = ml_collections.ConfigDict( 449 {"ic": 1.0, 450 "res1": 1.0, 451 "res2": 1.0, 452 'rd': 1.0, 453 'ld': 1.0, 454 'ln': 1.0 455 } 456 ) 457 elif type == 'NN': 458 config.weighting.init_weights = ml_collections.ConfigDict( 459 {"ic": 1.0, 460 "res1": 1.0, 461 "res2": 1.0, 462 'rd': 1.0, 463 'ld': 1.0, 464 'ln': 1.0, 465 'rn': 1.0 466 } 467 ) 468 elif type == 'DD': 469 config.weighting.init_weights = ml_collections.ConfigDict( 470 {"ic": 1.0, 471 "res1": 1.0, 472 "res2": 1.0, 473 'rd': 1.0, 474 'ld': 1.0 475 } 476 ) 477 if type == 'closed': #Add periodic condition 478 config.weighting.init_weights = ml_collections.ConfigDict( 479 {"ic": 1.0, 480 "res1": 1.0, 481 "res2": 1.0, 482 'periodic1': 1.0, 483 'periodic2': 1.0 484 } 485 ) 486 487 #Train model 488 model, results = train_csf(config) 489 490 #Evaluate 491 pred = evaluate(config) 492 493 #Generate demo 494 if demo: 495 demo_time_CSF(pred,radius = radius,type = type,file_name_save = file_name,framerate = 10,ffmpeg = ffmpeg) 496 497 #Print results 498 pd_results = pd.DataFrame(list(results.items())) 499 print(pd_results) 500 pd_results.to_csv(file_name + '_results.csv') 501 502 return results
def
get_base_config():
29def get_base_config(): 30 """ 31 Base config file for training PINN for CSF in jaxpi 32 33 Returns 34 ------- 35 ml_collections config dictionary 36 """ 37 #Get the default hyperparameter configuration. 38 config = ml_collections.ConfigDict() 39 40 # Weights & Biases 41 config.wandb = wandb = ml_collections.ConfigDict() 42 wandb.tag = None 43 44 # Arch 45 config.arch = arch = ml_collections.ConfigDict() 46 arch.arch_name = "ModifiedMlp" 47 arch.num_layers = 4 48 arch.hidden_dim = 256 49 arch.out_dim = 2 50 arch.activation = "tanh" 51 arch.reparam = ml_collections.ConfigDict( 52 {"type": "weight_fact", "mean": 1.0, "stddev": 0.1} 53 ) 54 55 # Optim 56 config.optim = optim = ml_collections.ConfigDict() 57 optim.optimizer = "Adam" 58 optim.beta1 = 0.9 59 optim.beta2 = 0.999 60 optim.eps = 1e-8 61 optim.learning_rate = 1e-3 62 optim.decay_rate = 0.9 63 optim.decay_steps = 2000 64 optim.grad_accum_steps = 0 65 optim.warmup_steps = 0 66 optim.grad_accum_steps = 0 67 68 # Training 69 config.training = training = ml_collections.ConfigDict() 70 training.batch_size_per_device = 4096 71 72 # Weighting 73 config.weighting = weighting = ml_collections.ConfigDict() 74 weighting.scheme = "grad_norm" 75 weighting.momentum = 0.9 76 weighting.update_every_steps = 1000 77 78 weighting.use_causal = True 79 weighting.causal_tol = 1.0 80 weighting.num_chunks = 16 81 optim.staircase = False 82 83 # Logging 84 config.logging = logging = ml_collections.ConfigDict() 85 logging.log_every_steps = 1000 86 logging.log_errors = True 87 logging.log_losses = True 88 logging.log_weights = True 89 logging.log_preds = False 90 logging.log_grads = False 91 logging.log_ntk = False 92 93 # Saving 94 config.saving = saving = ml_collections.ConfigDict() 95 saving.save_every_steps = 1000 96 saving.num_keep_ckpts = 1000000 97 98 # # Input shape for initializing Flax models 99 config.input_dim = 2 100 101 return config
Base config file for training PINN for CSF in jaxpi
Returns
ml_collections config dictionary
def
demo_time_CSF( data, type='DN', radius=None, file_name_save='result_pinn_CSF_demo', title='', framerate=10, ffmpeg='ffmpeg'):
105def demo_time_CSF(data,type = 'DN',radius = None,file_name_save = 'result_pinn_CSF_demo',title = '',framerate = 10,ffmpeg = 'ffmpeg'): 106 """ 107 Demo video with the time evolution of a CSF in a circle 108 ---------- 109 Parameters 110 ---------- 111 data : jax.array 112 113 Data with the predicted values 114 115 type : str 116 117 Type of problem 118 119 radius : float 120 121 Radius of circle for types 'DN' and 'NN' 122 123 file_name_save : str 124 125 File prefix to save the plots and video. Default 'result_pinn_CSF_demo' 126 127 title : str 128 129 Title for plots 130 131 framerate : int 132 133 Framerate for video. Default 10 134 135 ffmpeg : str 136 137 Path to ffmpeg 138 139 Returns 140 ------- 141 None 142 """ 143 #Create folder to save plots 144 os.system('mkdir ' + file_name_save) 145 146 #Plot parameters 147 tdom = jnp.unique(data[:,0]) 148 ylo = jnp.min(data[data[:,0] == jnp.min(tdom),3]) 149 ylo = ylo - 0.1*jnp.abs(ylo) 150 yup = jnp.max(data[data[:,0] == jnp.min(tdom),3]) 151 yup = yup + 0.1*jnp.abs(yup) 152 153 #Circle data 154 if type == 'DN' or type == 'NN': 155 circle = jnp.array([[radius*jnp.sin(t),radius*jnp.cos(t)] for t in jnp.linspace(0,2*jnp.pi,1000)]) 156 157 #Create images 158 k = 1 159 with alive_bar(len(tdom)) as bar: 160 for t in tdom: 161 #Test data 162 x_step = data[data[:,0] == t,1] 163 ux_step = data[data[:,0] == t,2] 164 uy_step = data[data[:,0] == t,3] 165 #Initialize plot 166 fig, ax = plt.subplots(1,1,figsize = (10,10)) 167 #Create 168 ax.plot(ux_step,uy_step,'b-',linewidth=2) 169 if type == 'DN' or type == 'NN': 170 ax.plot(circle[:,0],circle[:,1],'r-',linewidth=2) 171 ax.set_xlabel(' ') 172 if type != 'DN' and type != 'NN': 173 ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()]) 174 #Title 175 fig.suptitle(title + 't = ' + str(round(t,4))) 176 fig.tight_layout() 177 178 #Show and save 179 fig = plt.gcf() 180 fig.savefig(file_name_save + '/' + str(k) + '.png') 181 k = k + 1 182 plt.close() 183 bar() 184 185 #Create demo video 186 os.system(ffmpeg + ' -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '_time_demo.mp4')
Demo video with the time evolution of a CSF in a circle
Parameters
data : jax.array
Data with the predicted values
type : str
Type of problem
radius : float
Radius of circle for types 'DN' and 'NN'
file_name_save : str
File prefix to save the plots and video. Default 'result_pinn_CSF_demo'
title : str
Title for plots
framerate : int
Framerate for video. Default 10
ffmpeg : str
Path to ffmpeg
Returns
None
def
train_csf(config: ml_collections.config_dict.config_dict.ConfigDict):
189def train_csf(config: ml_collections.ConfigDict): 190 """ 191 Train PINN for CSF in jaxpi 192 ---------- 193 Parameters 194 ---------- 195 config : ml_collections.ConfigDict 196 197 Dictionary for training PINN in jaxpi 198 199 Returns 200 ------- 201 model, log_dict 202 """ 203 if config.save_wandb: 204 wandb_config = config.wandb 205 wandb.init(project = wandb_config.project, name = wandb_config.name) 206 207 # Define the time and space domain 208 dom = jnp.array([[config.tl, config.tu], [config.xl, config.xu]]) 209 210 # Initialize the residual sampler 211 res_sampler = iter(UniformSampler(dom, config.training.batch_size_per_device)) 212 213 # Initialize the model 214 if config.type_csf == 'DN': 215 model = class_csf.DN_csf(config) 216 elif config.type_csf == 'NN': 217 model = class_csf.NN_csf(config) 218 elif config.type_csf == 'DD': 219 model = class_csf.DD_csf(config) 220 elif config.type_csf == 'closed': 221 model = class_csf.closed_csf(config) 222 223 # Logger 224 logger = Logger() 225 226 # Initialize evaluator 227 key = jax.random.split(jax.random.PRNGKey(config.seed),4) 228 x0_test = jax.random.uniform(key = jax.random.PRNGKey(key[0,0]),minval = config.xl,maxval = config.xu,shape = (config.N0,1)) 229 u1_0_test,u2_0_test = config.uinitial(x0_test) 230 tb_test = jax.random.uniform(key = jax.random.PRNGKey(key[1,0]),minval = config.tl,maxval = config.tu,shape = (config.Nb,1)) 231 xc_test = jax.random.uniform(key = jax.random.PRNGKey(key[2,0]),minval = config.xl,maxval = config.xu,shape = (config.Nc ** 2,1)) 232 tc_test = jax.random.uniform(key = jax.random.PRNGKey(key[3,0]),minval = config.tl,maxval = config.tu,shape = (config.Nc ** 2,1)) 233 if config.type_csf == 'DN': 234 evaluator = class_csf.DN_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test) 235 elif config.type_csf == 'NN': 236 evaluator = class_csf.NN_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test) 237 elif config.type_csf == 'DD': 238 evaluator = class_csf.DD_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test) 239 elif config.type_csf == 'closed': 240 evaluator = class_csf.closed_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test) 241 242 # jit warm up 243 print("Training CSF...") 244 start_time = time.time() 245 t0 = start_time 246 for step in range(config.training.max_steps): 247 batch = next(res_sampler) 248 model.state = model.step(model.state, batch) 249 250 # Update weights if necessary 251 if config.weighting.scheme in ["grad_norm", "ntk"]: 252 if step % config.weighting.update_every_steps == 0: 253 model.state = model.update_weights(model.state, batch) 254 255 # Log training metrics, only use host 0 to record results 256 if jax.process_index() == 0: 257 if step % config.logging.log_every_steps == 0: 258 # Get the first replica of the state and batch 259 state = jax.device_get(tree_map(lambda x: x[0], model.state)) 260 batch = jax.device_get(tree_map(lambda x: x[0], batch)) 261 log_dict = evaluator(state, batch) 262 if config.save_wandb: 263 wandb.log(log_dict, step) 264 end_time = time.time() 265 logger.log_iter(step, start_time, end_time, log_dict) 266 start_time = end_time 267 268 # Saving 269 if config.saving.save_every_steps is not None: 270 if (step + 1) % config.saving.save_every_steps == 0 or ( 271 step + 1 272 ) == config.training.max_steps: 273 ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name) 274 save_checkpoint(model.state, ckpt_path, keep=config.saving.num_keep_ckpts) 275 if config.type_csf == 'DN': 276 if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol and log_dict['ln_test'] < config.dn_tol: 277 break 278 elif config.type_csf == 'NN': 279 if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol and log_dict['ln_test'] < config.dn_tol and log_dict['rn_test'] < config.dn_tol: 280 break 281 elif config.type_csf == 'DD': 282 if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol: 283 break 284 elif config.type_csf == 'closed': 285 if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol: 286 break 287 288 #Run summary 289 log_dict['total_time'] = time.time() - t0 290 log_dict['epochs'] = step + 1 291 292 return model, log_dict
Train PINN for CSF in jaxpi
Parameters
config : ml_collections.ConfigDict
Dictionary for training PINN in jaxpi
Returns
model, log_dict
def
evaluate(config: ml_collections.config_dict.config_dict.ConfigDict):
294def evaluate(config: ml_collections.ConfigDict): 295 """ 296 Evaluate PINN for CSF trained in jaxpi 297 ---------- 298 Parameters 299 ---------- 300 config : ml_collections.ConfigDict 301 302 Dictionary for training PINN in jaxpi 303 304 uninitial : function 305 306 Function that computes the initial condition 307 308 309 Returns 310 ------- 311 predicted values 312 """ 313 # Initialize the model 314 if config.type_csf == 'DN': 315 model = class_csf.DN_csf(config) 316 elif config.type_csf == 'NN': 317 model = class_csf.NN_csf(config) 318 elif config.type_csf == 'DD': 319 model = class_csf.DD_csf(config) 320 elif config.type_csf == 'closed': 321 model = class_csf.closed_csf(config) 322 323 # Restore the checkpoint 324 ckpt_path = os.path.join( 325 os.getcwd(), "ckpt", config.wandb.name, 326 ) 327 model.state = restore_checkpoint(model.state, ckpt_path) 328 params = model.state.params 329 330 #Collocation data 331 tx = jnp.array([[t,x] for t in jnp.linspace(config.tl, config.tu, config.Nt) for x in jnp.linspace(config.xl, config.xu, config.Nc)]) 332 333 #Predict 334 u1_pred = model.u1_pred_fn(params, tx[:,0], tx[:,1]) 335 u2_pred = model.u2_pred_fn(params, tx[:,0], tx[:,1]) 336 337 #Save 338 pred = jnp.append(tx,jnp.append(u1_pred.reshape((u1_pred.shape[0],1)),u2_pred.reshape((u2_pred.shape[0],1)),1),1) 339 jnp.save(config.wandb.project + '_' + config.wandb.name + '.npy',pred) 340 341 return pred
Evaluate PINN for CSF trained in jaxpi
Parameters
config : ml_collections.ConfigDict
Dictionary for training PINN in jaxpi
uninitial : function
Function that computes the initial condition
Returns
predicted values
def
csf( uinitial, xl, xu, tl, tu, type='DN', radius=None, file_name='test', Nt=400, N0=10000, Nb=10000, Nc=500, config=None, save_wandb=False, wandb_project='CSF_project', seed=534, demo=True, max_epochs=150000, res_tol=5e-05, dn_tol=0.0005, ic_tol=0.01, framerate=10, ffmpeg='ffmpeg'):
343def csf(uinitial,xl,xu,tl,tu,type = 'DN',radius = None,file_name = 'test',Nt = 400,N0 = 10000,Nb = 10000,Nc = 500,config = None,save_wandb = False,wandb_project = 'CSF_project',seed = 534,demo = True,max_epochs = 150000,res_tol = 5e-5,dn_tol = 5e-4,ic_tol = 0.01,framerate = 10,ffmpeg = 'ffmpeg'): 344 """ 345 Train PINN for CSF in jaxpi 346 ---------- 347 Parameters 348 ---------- 349 uninitial : function 350 351 Function that computes the initial condition 352 353 xl, xu, tl, tu : float 354 355 Limits of the x and t domain 356 357 type : str 358 359 Type of CSF problem to train: 'DN' or 'NN' (in a circle), 'DD' or 'closed' 360 361 radius : float 362 363 Radius of the circle for 'DN' and 'NN' 364 365 file_name : str 366 367 File name to save results 368 369 370 Nt : int 371 372 Sample size for grid in t 373 374 N0, Nb : int 375 376 Initial and boundary condition test sample size 377 378 Nc : int 379 380 Number of points in each direction in sample size for PDE residuals 381 382 383 config : ml_collections.ConfigDict 384 385 Config dictionary to train PINNs in jaxpi. If not provided, use basic configurations 386 387 save_wandb : logical 388 389 Whether to save results in wandb 390 391 wandb_project : str 392 393 Name of wandb project 394 395 seed : int 396 397 Seed for initialising neural network 398 399 demo : logical 400 401 Whether to generate video with result 402 403 max_epochs : int 404 405 Maximum number of epochs to train 406 407 res_tol, dn_tol, ic_tol : float 408 409 Tolerance on test errors for early stop 410 411 framerate : int 412 413 Framerate for video. Default 10 414 415 ffmpeg : str 416 417 Path to ffmpeg 418 419 Returns 420 ------- 421 model, log_dict 422 """ 423 #Set config file 424 if config is None: 425 config = get_base_config() 426 427 config.wandb.project = wandb_project 428 config.wandb.name = file_name 429 config.seed = seed 430 config.type_csf = type 431 config.save_wandb = save_wandb 432 config.uinitial = uinitial 433 config.xl = xl 434 config.xu = xu 435 config.tl = tl 436 config.tu = tu 437 config.radius = radius 438 config.rd = jnp.append(uinitial(xu)[0],uinitial(xu)[1]) 439 config.ld = jnp.append(uinitial(xl)[0],uinitial(xl)[1]) 440 config.Nt = Nt 441 config.N0 = N0 442 config.Nc = Nc 443 config.Nb = Nb 444 config.res_tol = res_tol 445 config.dn_tol = dn_tol 446 config.ic_tol = ic_tol 447 config.training.max_steps = max_epochs 448 if type == 'DN': 449 config.weighting.init_weights = ml_collections.ConfigDict( 450 {"ic": 1.0, 451 "res1": 1.0, 452 "res2": 1.0, 453 'rd': 1.0, 454 'ld': 1.0, 455 'ln': 1.0 456 } 457 ) 458 elif type == 'NN': 459 config.weighting.init_weights = ml_collections.ConfigDict( 460 {"ic": 1.0, 461 "res1": 1.0, 462 "res2": 1.0, 463 'rd': 1.0, 464 'ld': 1.0, 465 'ln': 1.0, 466 'rn': 1.0 467 } 468 ) 469 elif type == 'DD': 470 config.weighting.init_weights = ml_collections.ConfigDict( 471 {"ic": 1.0, 472 "res1": 1.0, 473 "res2": 1.0, 474 'rd': 1.0, 475 'ld': 1.0 476 } 477 ) 478 if type == 'closed': #Add periodic condition 479 config.weighting.init_weights = ml_collections.ConfigDict( 480 {"ic": 1.0, 481 "res1": 1.0, 482 "res2": 1.0, 483 'periodic1': 1.0, 484 'periodic2': 1.0 485 } 486 ) 487 488 #Train model 489 model, results = train_csf(config) 490 491 #Evaluate 492 pred = evaluate(config) 493 494 #Generate demo 495 if demo: 496 demo_time_CSF(pred,radius = radius,type = type,file_name_save = file_name,framerate = 10,ffmpeg = ffmpeg) 497 498 #Print results 499 pd_results = pd.DataFrame(list(results.items())) 500 print(pd_results) 501 pd_results.to_csv(file_name + '_results.csv') 502 503 return results
Train PINN for CSF in jaxpi
Parameters
uninitial : function
Function that computes the initial condition
xl, xu, tl, tu : float
Limits of the x and t domain
type : str
Type of CSF problem to train: 'DN' or 'NN' (in a circle), 'DD' or 'closed'
radius : float
Radius of the circle for 'DN' and 'NN'
file_name : str
File name to save results
Nt : int
Sample size for grid in t
N0, Nb : int
Initial and boundary condition test sample size
Nc : int
Number of points in each direction in sample size for PDE residuals
config : ml_collections.ConfigDict
Config dictionary to train PINNs in jaxpi. If not provided, use basic configurations
save_wandb : logical
Whether to save results in wandb
wandb_project : str
Name of wandb project
seed : int
Seed for initialising neural network
demo : logical
Whether to generate video with result
max_epochs : int
Maximum number of epochs to train
res_tol, dn_tol, ic_tol : float
Tolerance on test errors for early stop
framerate : int
Framerate for video. Default 10
ffmpeg : str
Path to ffmpeg
Returns
model, log_dict