jinnax.deeponet
1#Adapted from https://github.com/PredictiveIntelligenceLab/Physics-informed-DeepONets 2import jax, os 3import jax.numpy as np 4from jax import random, grad, vmap, jit, hessian 5from jax.example_libraries import optimizers 6from jax.experimental.ode import odeint 7from jax.nn import relu, elu 8from jax.tree_util import tree_map 9#from jax.config import config 10#from jax.ops import index_update, index 11#from jaxpi.utils import restore_checkpoint, save_checkpoint 12from flax.training import checkpoints 13from jax import lax, pmap 14from jax.flatten_util import ravel_pytree 15import ml_collections 16import itertools 17from functools import partial 18from torch.utils import data 19from tqdm import trange, tqdm 20from jaxpi.samplers import BaseSampler, UniformSampler 21from jaxpi.logging import Logger 22import time 23import wandb 24 25def save_checkpoint(state, workdir, step, keep=5, name=None): 26 # Create the workdir if it doesn't exist. 27 if not os.path.isdir(workdir): 28 os.makedirs(workdir) 29 30 # Save the checkpoint. 31 if jax.process_index() == 0: 32 # Get the first replica's state and save it. 33 state = jax.device_get(tree_map(lambda x: x[0], state)) 34 checkpoints.save_checkpoint(workdir, state, step=step, keep=keep) 35 36def restore_checkpoint(state, workdir, step=None): 37 # check if passed state is in a sharded state 38 # if so, reduce to a single device sharding 39 40 if isinstance( 41 tree_map(lambda x: jnp.array(x).sharding, jax.tree.leaves(state.params))[0], 42 jax.sharding.PmapSharding, 43 ): 44 state = tree_map(lambda x: x[0], state) 45 46 # ensuring that we're in a single device setting 47 assert isinstance( 48 tree_map(lambda x: jnp.array(x).sharding, jax.tree.leaves(state.params))[0], 49 jax.sharding.SingleDeviceSharding, 50 ) 51 52 state = checkpoints.restore_checkpoint(workdir, state, step=step) 53 return state 54 55def get_base_config(): 56 """ 57 Base config file for training PINN for CSF in jaxpi 58 59 Returns 60 ------- 61 ml_collections config dictionary 62 """ 63 #Get the default hyperparameter configuration. 64 config = ml_collections.ConfigDict() 65 # Weights & Biases 66 config.wandb = wandb = ml_collections.ConfigDict() 67 wandb.tag = None 68 # Arch 69 config.arch = arch = ml_collections.ConfigDict() 70 arch.branch_layers = [1024] + 4*[256] 71 arch.trunk_layers = [2] + 4*[256] 72 # Optim 73 config.optim = optim = ml_collections.ConfigDict() 74 optim.beta1 = 0.9 75 optim.beta2 = 0.999 76 optim.eps = 1e-8 77 optim.learning_rate = 1e-3 78 optim.decay_rate = 0.9 79 optim.decay_steps = 2000 80 # Training 81 config.training = training = ml_collections.ConfigDict() 82 training.batch_size_per_device = 4096 83 config.training.batch_size_train_data = 128 84 # Weighting 85 config.weights = {'b': np.array([100.0]),'res': np.array([1.0]),'data': np.array([1.0]),'ic' : np.array([100.0])} 86 config.sa = True 87 # Logging 88 config.logging = logging = ml_collections.ConfigDict() 89 logging.log_every_steps = 1000 90 logging.log_errors = True 91 logging.log_losses = True 92 logging.log_weights = True 93 logging.log_preds = False 94 logging.log_grads = False 95 logging.log_ntk = False 96 # Saving 97 config.saving = saving = ml_collections.ConfigDict() 98 saving.save_every_steps = 1000 99 saving.num_keep_ckpts = 1000000 100 #Seed 101 config.seed = 10 102 return config 103 104#Periodic kernel 105@jax.jit 106def kernel_periodic(x1,x2,ls = 1,p = 1): 107 return np.exp(-(np.sin(np.pi*np.abs(x1 - x2)/p) ** 2)/(2 * (ls ** 2))) 108 109#Generate initial data 110def generate_initial_data(N0,size,kernel = kernel_periodic,xl = 0,xu = 1,key = 0): 111 x = np.linspace(xl,xu,N0) 112 K = np.array([[kernel(x1,x2)] for x1 in x for x2 in x]).reshape((N0,N0)) 113 u = jax.random.multivariate_normal(key = jax.random.PRNGKey(key),mean = np.zeros((K.shape[0],)),cov = K,shape = (size,),method = 'svd') 114 return u 115 116class InitialDataSampler(BaseSampler): 117 def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)): 118 super().__init__(batch_size, rng_key) 119 self.data = data 120 self.dim = data.shape[0] 121 @partial(pmap, static_broadcasted_argnums=(0,)) 122 def data_generation(self, key): 123 "Generates data containing batch_size samples" 124 idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False) 125 return self.data[idx,:] 126 127class DataSampler(BaseSampler): 128 def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)): 129 super().__init__(batch_size, rng_key) 130 self.data = data 131 self.dim = data.shape[0] 132 @partial(pmap, static_broadcasted_argnums=(0,)) 133 def data_generation(self, key): 134 "Generates data containing batch_size samples" 135 idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False) 136 return self.data[idx,:,:] 137 138# Define MLP 139def MLP(layers, activation=relu): 140 ''' Vanilla MLP''' 141 def init(rng_key): 142 def init_layer(key, d_in, d_out): 143 k1, k2 = random.split(key) 144 glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 145 W = glorot_stddev * random.normal(k1, (d_in, d_out)) 146 b = np.zeros(d_out) 147 return W, b 148 key, *keys = random.split(rng_key, len(layers)) 149 params = list(map(init_layer, keys, layers[:-1], layers[1:])) 150 return params 151 def apply(params, inputs): 152 for W, b in params[:-1]: 153 outputs = np.dot(inputs, W) + b 154 inputs = activation(outputs) 155 W, b = params[-1] 156 outputs = np.dot(inputs, W) + b 157 return outputs 158 return init, apply 159 160# Define modified MLP 161def modified_MLP(layers, activation=relu): 162 def xavier_init(key, d_in, d_out): 163 glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 164 W = glorot_stddev * random.normal(key, (d_in, d_out)) 165 b = np.zeros(d_out) 166 return W, b 167 def init(rng_key): 168 U1, b1 = xavier_init(random.PRNGKey(12345), layers[0], layers[1]) 169 U2, b2 = xavier_init(random.PRNGKey(54321), layers[0], layers[1]) 170 def init_layer(key, d_in, d_out): 171 k1, k2 = random.split(key) 172 W, b = xavier_init(k1, d_in, d_out) 173 return W, b 174 key, *keys = random.split(rng_key, len(layers)) 175 params = list(map(init_layer, keys, layers[:-1], layers[1:])) 176 return (params, U1, b1, U2, b2) 177 def apply(params, inputs): 178 params, U1, b1, U2, b2 = params 179 U = activation(np.dot(inputs, U1) + b1) 180 V = activation(np.dot(inputs, U2) + b2) 181 for W, b in params[:-1]: 182 outputs = activation(np.dot(inputs, W) + b) 183 inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V) 184 W, b = params[-1] 185 outputs = np.dot(inputs, W) + b 186 return outputs 187 return init, apply 188 189# Define the model 190class PI_DeepONet: 191 def __init__(self, config): 192 # Network initialization and evaluation functions 193 self.branch_init, self.branch_apply = modified_MLP(config.arch.branch_layers, activation = np.tanh) 194 self.trunk_init, self.trunk_apply = modified_MLP(config.arch.trunk_layers, activation = np.tanh) 195 196 # Initialize 197 self.branch_params = self.branch_init(rng_key = random.PRNGKey(config.seed)) 198 self.trunk_params = self.trunk_init(rng_key = random.PRNGKey(config.seed + 1)) 199 self.params = (self.branch_params, self.trunk_params,{'b': config.weights['b'],'res': config.weights['res'],'data': config.weights['data'],'ic' : config.weights['ic']}) 200 201 # Use optimizers to set optimizer initialization and update functions 202 self.opt_init, \ 203 self.opt_update, \ 204 self.get_params = optimizers.adam(optimizers.exponential_decay(config.optim.learning_rate, 205 decay_steps = config.optim.decay_steps, 206 decay_rate = config.optim.decay_rate)) 207 self.opt_state = self.opt_init(self.params) 208 209 # Used to restore the trained model parameters 210 _, self.unravel_params = ravel_pytree(self.params) 211 212 self.itercount = itertools.count() 213 214 # Residual net and boundary condition loss 215 self.loss_bc = config.loss_bc 216 self.residual_net = config.residual_net 217 self.config = config 218 219 # Limits domain 220 self.xl = config.xl 221 self.xu = config.xu 222 223 #Vmap neural net 224 self.pred_fn = vmap(self.operator_net, (None, 0, 0, 0)) 225 226 #Vmap residual operator 227 if self.residual_net is not None: 228 self.r_pred_fn = vmap(self.residual_net, (None, 0, 0, 0)) 229 230 #Vmap train and test data 231 self.pred_batch = vmap( 232 vmap( 233 vmap(self.operator_net, (None, None, 0, None)),(None,None,None,0) 234 ),(None,0,None,None) 235 ) 236 237 self.pred_batch_xt = vmap( 238 vmap(self.operator_net, (None, 0, None, None)),(None,None,0,0)) 239 if self.residual_net is not None: 240 self.r_pred_batch = vmap( 241 vmap( 242 vmap(self.residual_net, (None, None, 0, None)),(None,None,None,0) 243 ),(None,0,None,None) 244 ) 245 246 #Data 247 self.u_test = config.u_test 248 self.u_train = config.u_train 249 self.x_mesh = config.x_mesh 250 self.t_mesh = config.t_mesh 251 252 #Weights 253 self.w = config.weights 254 255 # Define DeepONet architecture 256 @partial(jit, static_argnums=(0,)) 257 def operator_net(self, params, u, x, t): 258 branch_params, trunk_params,_ = params 259 y = np.stack([x,t]) 260 B = self.branch_apply(branch_params, u) 261 T = self.trunk_apply(trunk_params, y) 262 outputs = np.sum(B * T) 263 return outputs 264 265 # Define residual loss 266 @partial(jit, static_argnums=(0,)) 267 def loss_res(self, params, batch): 268 # Compute forward pass 269 pred = self.residual_net(self.operator_net,params,batch) 270 # Compute loss 271 loss = np.mean((pred)**2) 272 return loss 273 274 #Data loss 275 @partial(jit, static_argnums=(0,)) 276 def loss_data(self,params,batch_train): 277 pred = self.pred_batch(params,batch_train['u0'],batch_train['x'],batch_train['t']) 278 return np.mean((pred - batch_train['u']) ** 2) 279 280 # Define initial condition loss 281 @partial(jit, static_argnums=(0,)) 282 def loss_ic(self, params, batch): 283 # Compute forward pass 284 pred = self.pred_batch_xt(params,batch['u0'], self.x_mesh, np.zeros(self.x_mesh.shape[0])) 285 # Compute loss 286 loss = np.mean((pred - batch['u0'].transpose())**2) 287 return loss 288 289 # Define total loss 290 @partial(jit, static_argnums=(0,)) 291 def loss(self, params, batch, batch_train): 292 loss_bc = 0.0 293 loss_data = 0.0 294 loss_res = 0.0 295 loss_ic = 0.0 296 if self.loss_bc is not None: 297 loss_bc = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 298 if self.residual_net is not None: 299 loss_res = self.loss_res(params, batch) 300 if batch_train is not None: 301 loss_data = self.loss_data(params,batch_train) 302 loss = (params[-1]['b'] ** 2) * loss_bc + (params[-1]['res'] ** 2) * loss_res + (params[-1]['data'] ** 2) * loss_data + (params[-1]['ic'] ** 2) * self.loss_ic(params,batch) 303 return loss 304 305 # Define a compiled update step 306 @partial(jit, static_argnums=(0,)) 307 def step(self, i, opt_state, batch, batch_train): 308 params = self.get_params(opt_state) 309 g = grad(self.loss)(params, batch, batch_train) 310 if self.config.sa: 311 g[-1]['b'] = -g[-1]['b'] 312 g[-1]['res'] = -g[-1]['res'] 313 g[-1]['data'] = -g[-1]['data'] 314 g[-1]['ic'] = -g[-1]['ic'] 315 return self.opt_update(i, g, opt_state) 316 317 def evaluator(self,batch,batch_train): 318 log_dict = {} 319 params = self.get_params(self.opt_state) 320 #Test loss 321 if self.u_test is not None: 322 pred = self.pred_batch(params, self.u_test[:,0,:], self.x_mesh, self.t_mesh) 323 log_dict['test_L2'] = np.mean(np.sqrt(np.mean((pred - self.u_test) ** 2,[1,2])/np.mean((self.u_test) ** 2,[1,2]))) 324 325 #Train 326 if self.loss_bc is not None: 327 log_dict['bc_loss'] = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 328 if self.residual_net is not None: 329 log_dict['res_loss'] = self.loss_res(params,batch) 330 log_dict['ic_loss'] = self.loss_ic(params,batch) 331 if batch_train is not None: 332 log_dict['data_loss'] = self.loss_data(params,batch_train) 333 if self.config.sa: 334 log_dict['Wb'] = params[-1]['b'] 335 log_dict['Wres'] = params[-1]['res'] 336 log_dict['Wdata'] = params[-1]['data'] 337 log_dict['Wic'] = params[-1]['ic'] 338 339 return log_dict 340 341 # Optimize parameters in a loop 342 def train(self): 343 config = self.config 344 if config.save_wandb: 345 wandb_config = config.wandb 346 wandb.init(project = wandb_config.project, name = wandb_config.name) 347 348 #Initialize the initial data sampler 349 if config.initial_data is None: 350 initial_data = generate_initial_data(config.N0,int(config.size),kernel = config.kernel,xl = config.xl,xu = config.xu,key = 0) 351 else: 352 initial_data = config.initial_data 353 initial_sampler = iter(InitialDataSampler(initial_data, config.N)) 354 355 # Initialize the residual sampler 356 dom = np.array([[config.xl, config.xu],[config.tl, config.tu]]) 357 res_sampler = iter(UniformSampler(dom, config.Q)) 358 359 # Initialize the boundary condition sampler 360 bc_sampler = iter(UniformSampler(np.array([[config.tl, config.tu]]), config.N0)) 361 362 #Initialize the training data sampler 363 batch_train = None 364 if config.u_train is not None: 365 data_sampler = iter(DataSampler(config.u_train, config.training.batch_size_train_data)) 366 367 # Logger 368 logger = Logger() 369 batch = {'u0': None,'x': None,'t': None,'t_bc': None} 370 371 #Train 372 w = None 373 print("Training DeepONet...") 374 start_time = time.time() 375 t0 = start_time 376 for step in range(config.training.max_steps): 377 print(step) 378 batch['u0'] = next(initial_sampler)[0,:,:] 379 res_data_tmp = next(res_sampler)[0,:,:] 380 batch['x'] = res_data_tmp[:,0] 381 batch['t'] = res_data_tmp[:,1] 382 batch['t_bc'] = next(bc_sampler)[0,:] 383 if config.u_train is not None: 384 u = next(data_sampler) 385 u = u.reshape((u.shape[1],u.shape[2],u.shape[3])) 386 batch_train = {'u0': u[:,0,:],'u': u,'t': self.t_mesh,'x': self.x_mesh} 387 388 #Initialise weights 389 if config.sa and step == 0: 390 params = self.get_params(self.opt_state) 391 if self.loss_bc is not None: 392 lb = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 393 else: 394 lb = np.array(0.0) 395 if self.residual_net is not None: 396 lr = self.loss_res(params, batch) 397 else: 398 lr = np.array(0.0) 399 if batch_train is not None: 400 ld = self.loss_data(params,batch_train) 401 else: 402 ld = np.array(0.0) 403 li = self.loss_ic(params,batch) 404 total = lb + lr + ld + li 405 w = {'b': np.sqrt(total/lb),'res': np.sqrt(total/lr),'data': np.sqrt(total/ld),'ic' : np.sqrt(total/li)} 406 self.params = (self.branch_params, self.trunk_params,w) 407 self.opt_state = self.opt_init(self.params) 408 # Used to restore the trained model parameters 409 _, self.unravel_params = ravel_pytree(self.params) 410 411 #Step 412 self.opt_state = self.step(next(self.itercount), self.opt_state, batch, batch_train) 413 414 # Log training metrics, only use host 0 to record results 415 if step % config.logging.log_every_steps == 0: 416 # Get the first replica of the state and batch 417 log_dict = self.evaluator(batch, batch_train) 418 if config.save_wandb: 419 wandb.log(log_dict, step) 420 end_time = time.time() 421 logger.log_iter(step, start_time, end_time, log_dict) 422 start_time = end_time 423 print(log_dict) 424 425 # Saving 426 if config.saving.save_every_steps is not None: 427 if (step + 1) % config.saving.save_every_steps == 0 or ( 428 step + 1 429 ) == config.training.max_steps: 430 ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name) 431 save_checkpoint(self.opt_state, ckpt_path, step, keep=config.saving.num_keep_ckpts) 432 433 434 #Run summary 435 log_dict['total_time'] = time.time() - t0 436 log_dict['epochs'] = step + 1 437 438 return log_dict 439 440# Define PDE residual 441def bc_loss_periodic(pred_batch,params,batch,xl,xu): 442 pred_xl = pred_batch( 443 params, batch['u0'], xl + np.zeros((1,)), batch['t'].reshape((batch['t'].shape[0],)) 444 ) 445 pred_xu = pred_batch( 446 params, batch['u0'], xu + np.zeros((1,)), batch['t'].reshape((batch['t'].shape[0],)) 447 ) 448 return np.mean((pred_xl - pred_xu) ** 2)
26def save_checkpoint(state, workdir, step, keep=5, name=None): 27 # Create the workdir if it doesn't exist. 28 if not os.path.isdir(workdir): 29 os.makedirs(workdir) 30 31 # Save the checkpoint. 32 if jax.process_index() == 0: 33 # Get the first replica's state and save it. 34 state = jax.device_get(tree_map(lambda x: x[0], state)) 35 checkpoints.save_checkpoint(workdir, state, step=step, keep=keep)
37def restore_checkpoint(state, workdir, step=None): 38 # check if passed state is in a sharded state 39 # if so, reduce to a single device sharding 40 41 if isinstance( 42 tree_map(lambda x: jnp.array(x).sharding, jax.tree.leaves(state.params))[0], 43 jax.sharding.PmapSharding, 44 ): 45 state = tree_map(lambda x: x[0], state) 46 47 # ensuring that we're in a single device setting 48 assert isinstance( 49 tree_map(lambda x: jnp.array(x).sharding, jax.tree.leaves(state.params))[0], 50 jax.sharding.SingleDeviceSharding, 51 ) 52 53 state = checkpoints.restore_checkpoint(workdir, state, step=step) 54 return state
56def get_base_config(): 57 """ 58 Base config file for training PINN for CSF in jaxpi 59 60 Returns 61 ------- 62 ml_collections config dictionary 63 """ 64 #Get the default hyperparameter configuration. 65 config = ml_collections.ConfigDict() 66 # Weights & Biases 67 config.wandb = wandb = ml_collections.ConfigDict() 68 wandb.tag = None 69 # Arch 70 config.arch = arch = ml_collections.ConfigDict() 71 arch.branch_layers = [1024] + 4*[256] 72 arch.trunk_layers = [2] + 4*[256] 73 # Optim 74 config.optim = optim = ml_collections.ConfigDict() 75 optim.beta1 = 0.9 76 optim.beta2 = 0.999 77 optim.eps = 1e-8 78 optim.learning_rate = 1e-3 79 optim.decay_rate = 0.9 80 optim.decay_steps = 2000 81 # Training 82 config.training = training = ml_collections.ConfigDict() 83 training.batch_size_per_device = 4096 84 config.training.batch_size_train_data = 128 85 # Weighting 86 config.weights = {'b': np.array([100.0]),'res': np.array([1.0]),'data': np.array([1.0]),'ic' : np.array([100.0])} 87 config.sa = True 88 # Logging 89 config.logging = logging = ml_collections.ConfigDict() 90 logging.log_every_steps = 1000 91 logging.log_errors = True 92 logging.log_losses = True 93 logging.log_weights = True 94 logging.log_preds = False 95 logging.log_grads = False 96 logging.log_ntk = False 97 # Saving 98 config.saving = saving = ml_collections.ConfigDict() 99 saving.save_every_steps = 1000 100 saving.num_keep_ckpts = 1000000 101 #Seed 102 config.seed = 10 103 return config
Base config file for training PINN for CSF in jaxpi
Returns
ml_collections config dictionary
111def generate_initial_data(N0,size,kernel = kernel_periodic,xl = 0,xu = 1,key = 0): 112 x = np.linspace(xl,xu,N0) 113 K = np.array([[kernel(x1,x2)] for x1 in x for x2 in x]).reshape((N0,N0)) 114 u = jax.random.multivariate_normal(key = jax.random.PRNGKey(key),mean = np.zeros((K.shape[0],)),cov = K,shape = (size,),method = 'svd') 115 return u
117class InitialDataSampler(BaseSampler): 118 def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)): 119 super().__init__(batch_size, rng_key) 120 self.data = data 121 self.dim = data.shape[0] 122 @partial(pmap, static_broadcasted_argnums=(0,)) 123 def data_generation(self, key): 124 "Generates data containing batch_size samples" 125 idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False) 126 return self.data[idx,:]
An abstract class representing a Dataset.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__(), supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__(), which is expected to return the size of the dataset by many
~torch.utils.data.Sampler implementations and the default options
of ~torch.utils.data.DataLoader. Subclasses could also
optionally implement __getitems__(), for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
122 @partial(pmap, static_broadcasted_argnums=(0,)) 123 def data_generation(self, key): 124 "Generates data containing batch_size samples" 125 idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False) 126 return self.data[idx,:]
Generates data containing batch_size samples
128class DataSampler(BaseSampler): 129 def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)): 130 super().__init__(batch_size, rng_key) 131 self.data = data 132 self.dim = data.shape[0] 133 @partial(pmap, static_broadcasted_argnums=(0,)) 134 def data_generation(self, key): 135 "Generates data containing batch_size samples" 136 idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False) 137 return self.data[idx,:,:]
An abstract class representing a Dataset.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__(), supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__(), which is expected to return the size of the dataset by many
~torch.utils.data.Sampler implementations and the default options
of ~torch.utils.data.DataLoader. Subclasses could also
optionally implement __getitems__(), for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
133 @partial(pmap, static_broadcasted_argnums=(0,)) 134 def data_generation(self, key): 135 "Generates data containing batch_size samples" 136 idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False) 137 return self.data[idx,:,:]
Generates data containing batch_size samples
140def MLP(layers, activation=relu): 141 ''' Vanilla MLP''' 142 def init(rng_key): 143 def init_layer(key, d_in, d_out): 144 k1, k2 = random.split(key) 145 glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 146 W = glorot_stddev * random.normal(k1, (d_in, d_out)) 147 b = np.zeros(d_out) 148 return W, b 149 key, *keys = random.split(rng_key, len(layers)) 150 params = list(map(init_layer, keys, layers[:-1], layers[1:])) 151 return params 152 def apply(params, inputs): 153 for W, b in params[:-1]: 154 outputs = np.dot(inputs, W) + b 155 inputs = activation(outputs) 156 W, b = params[-1] 157 outputs = np.dot(inputs, W) + b 158 return outputs 159 return init, apply
Vanilla MLP
162def modified_MLP(layers, activation=relu): 163 def xavier_init(key, d_in, d_out): 164 glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 165 W = glorot_stddev * random.normal(key, (d_in, d_out)) 166 b = np.zeros(d_out) 167 return W, b 168 def init(rng_key): 169 U1, b1 = xavier_init(random.PRNGKey(12345), layers[0], layers[1]) 170 U2, b2 = xavier_init(random.PRNGKey(54321), layers[0], layers[1]) 171 def init_layer(key, d_in, d_out): 172 k1, k2 = random.split(key) 173 W, b = xavier_init(k1, d_in, d_out) 174 return W, b 175 key, *keys = random.split(rng_key, len(layers)) 176 params = list(map(init_layer, keys, layers[:-1], layers[1:])) 177 return (params, U1, b1, U2, b2) 178 def apply(params, inputs): 179 params, U1, b1, U2, b2 = params 180 U = activation(np.dot(inputs, U1) + b1) 181 V = activation(np.dot(inputs, U2) + b2) 182 for W, b in params[:-1]: 183 outputs = activation(np.dot(inputs, W) + b) 184 inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V) 185 W, b = params[-1] 186 outputs = np.dot(inputs, W) + b 187 return outputs 188 return init, apply
191class PI_DeepONet: 192 def __init__(self, config): 193 # Network initialization and evaluation functions 194 self.branch_init, self.branch_apply = modified_MLP(config.arch.branch_layers, activation = np.tanh) 195 self.trunk_init, self.trunk_apply = modified_MLP(config.arch.trunk_layers, activation = np.tanh) 196 197 # Initialize 198 self.branch_params = self.branch_init(rng_key = random.PRNGKey(config.seed)) 199 self.trunk_params = self.trunk_init(rng_key = random.PRNGKey(config.seed + 1)) 200 self.params = (self.branch_params, self.trunk_params,{'b': config.weights['b'],'res': config.weights['res'],'data': config.weights['data'],'ic' : config.weights['ic']}) 201 202 # Use optimizers to set optimizer initialization and update functions 203 self.opt_init, \ 204 self.opt_update, \ 205 self.get_params = optimizers.adam(optimizers.exponential_decay(config.optim.learning_rate, 206 decay_steps = config.optim.decay_steps, 207 decay_rate = config.optim.decay_rate)) 208 self.opt_state = self.opt_init(self.params) 209 210 # Used to restore the trained model parameters 211 _, self.unravel_params = ravel_pytree(self.params) 212 213 self.itercount = itertools.count() 214 215 # Residual net and boundary condition loss 216 self.loss_bc = config.loss_bc 217 self.residual_net = config.residual_net 218 self.config = config 219 220 # Limits domain 221 self.xl = config.xl 222 self.xu = config.xu 223 224 #Vmap neural net 225 self.pred_fn = vmap(self.operator_net, (None, 0, 0, 0)) 226 227 #Vmap residual operator 228 if self.residual_net is not None: 229 self.r_pred_fn = vmap(self.residual_net, (None, 0, 0, 0)) 230 231 #Vmap train and test data 232 self.pred_batch = vmap( 233 vmap( 234 vmap(self.operator_net, (None, None, 0, None)),(None,None,None,0) 235 ),(None,0,None,None) 236 ) 237 238 self.pred_batch_xt = vmap( 239 vmap(self.operator_net, (None, 0, None, None)),(None,None,0,0)) 240 if self.residual_net is not None: 241 self.r_pred_batch = vmap( 242 vmap( 243 vmap(self.residual_net, (None, None, 0, None)),(None,None,None,0) 244 ),(None,0,None,None) 245 ) 246 247 #Data 248 self.u_test = config.u_test 249 self.u_train = config.u_train 250 self.x_mesh = config.x_mesh 251 self.t_mesh = config.t_mesh 252 253 #Weights 254 self.w = config.weights 255 256 # Define DeepONet architecture 257 @partial(jit, static_argnums=(0,)) 258 def operator_net(self, params, u, x, t): 259 branch_params, trunk_params,_ = params 260 y = np.stack([x,t]) 261 B = self.branch_apply(branch_params, u) 262 T = self.trunk_apply(trunk_params, y) 263 outputs = np.sum(B * T) 264 return outputs 265 266 # Define residual loss 267 @partial(jit, static_argnums=(0,)) 268 def loss_res(self, params, batch): 269 # Compute forward pass 270 pred = self.residual_net(self.operator_net,params,batch) 271 # Compute loss 272 loss = np.mean((pred)**2) 273 return loss 274 275 #Data loss 276 @partial(jit, static_argnums=(0,)) 277 def loss_data(self,params,batch_train): 278 pred = self.pred_batch(params,batch_train['u0'],batch_train['x'],batch_train['t']) 279 return np.mean((pred - batch_train['u']) ** 2) 280 281 # Define initial condition loss 282 @partial(jit, static_argnums=(0,)) 283 def loss_ic(self, params, batch): 284 # Compute forward pass 285 pred = self.pred_batch_xt(params,batch['u0'], self.x_mesh, np.zeros(self.x_mesh.shape[0])) 286 # Compute loss 287 loss = np.mean((pred - batch['u0'].transpose())**2) 288 return loss 289 290 # Define total loss 291 @partial(jit, static_argnums=(0,)) 292 def loss(self, params, batch, batch_train): 293 loss_bc = 0.0 294 loss_data = 0.0 295 loss_res = 0.0 296 loss_ic = 0.0 297 if self.loss_bc is not None: 298 loss_bc = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 299 if self.residual_net is not None: 300 loss_res = self.loss_res(params, batch) 301 if batch_train is not None: 302 loss_data = self.loss_data(params,batch_train) 303 loss = (params[-1]['b'] ** 2) * loss_bc + (params[-1]['res'] ** 2) * loss_res + (params[-1]['data'] ** 2) * loss_data + (params[-1]['ic'] ** 2) * self.loss_ic(params,batch) 304 return loss 305 306 # Define a compiled update step 307 @partial(jit, static_argnums=(0,)) 308 def step(self, i, opt_state, batch, batch_train): 309 params = self.get_params(opt_state) 310 g = grad(self.loss)(params, batch, batch_train) 311 if self.config.sa: 312 g[-1]['b'] = -g[-1]['b'] 313 g[-1]['res'] = -g[-1]['res'] 314 g[-1]['data'] = -g[-1]['data'] 315 g[-1]['ic'] = -g[-1]['ic'] 316 return self.opt_update(i, g, opt_state) 317 318 def evaluator(self,batch,batch_train): 319 log_dict = {} 320 params = self.get_params(self.opt_state) 321 #Test loss 322 if self.u_test is not None: 323 pred = self.pred_batch(params, self.u_test[:,0,:], self.x_mesh, self.t_mesh) 324 log_dict['test_L2'] = np.mean(np.sqrt(np.mean((pred - self.u_test) ** 2,[1,2])/np.mean((self.u_test) ** 2,[1,2]))) 325 326 #Train 327 if self.loss_bc is not None: 328 log_dict['bc_loss'] = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 329 if self.residual_net is not None: 330 log_dict['res_loss'] = self.loss_res(params,batch) 331 log_dict['ic_loss'] = self.loss_ic(params,batch) 332 if batch_train is not None: 333 log_dict['data_loss'] = self.loss_data(params,batch_train) 334 if self.config.sa: 335 log_dict['Wb'] = params[-1]['b'] 336 log_dict['Wres'] = params[-1]['res'] 337 log_dict['Wdata'] = params[-1]['data'] 338 log_dict['Wic'] = params[-1]['ic'] 339 340 return log_dict 341 342 # Optimize parameters in a loop 343 def train(self): 344 config = self.config 345 if config.save_wandb: 346 wandb_config = config.wandb 347 wandb.init(project = wandb_config.project, name = wandb_config.name) 348 349 #Initialize the initial data sampler 350 if config.initial_data is None: 351 initial_data = generate_initial_data(config.N0,int(config.size),kernel = config.kernel,xl = config.xl,xu = config.xu,key = 0) 352 else: 353 initial_data = config.initial_data 354 initial_sampler = iter(InitialDataSampler(initial_data, config.N)) 355 356 # Initialize the residual sampler 357 dom = np.array([[config.xl, config.xu],[config.tl, config.tu]]) 358 res_sampler = iter(UniformSampler(dom, config.Q)) 359 360 # Initialize the boundary condition sampler 361 bc_sampler = iter(UniformSampler(np.array([[config.tl, config.tu]]), config.N0)) 362 363 #Initialize the training data sampler 364 batch_train = None 365 if config.u_train is not None: 366 data_sampler = iter(DataSampler(config.u_train, config.training.batch_size_train_data)) 367 368 # Logger 369 logger = Logger() 370 batch = {'u0': None,'x': None,'t': None,'t_bc': None} 371 372 #Train 373 w = None 374 print("Training DeepONet...") 375 start_time = time.time() 376 t0 = start_time 377 for step in range(config.training.max_steps): 378 print(step) 379 batch['u0'] = next(initial_sampler)[0,:,:] 380 res_data_tmp = next(res_sampler)[0,:,:] 381 batch['x'] = res_data_tmp[:,0] 382 batch['t'] = res_data_tmp[:,1] 383 batch['t_bc'] = next(bc_sampler)[0,:] 384 if config.u_train is not None: 385 u = next(data_sampler) 386 u = u.reshape((u.shape[1],u.shape[2],u.shape[3])) 387 batch_train = {'u0': u[:,0,:],'u': u,'t': self.t_mesh,'x': self.x_mesh} 388 389 #Initialise weights 390 if config.sa and step == 0: 391 params = self.get_params(self.opt_state) 392 if self.loss_bc is not None: 393 lb = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 394 else: 395 lb = np.array(0.0) 396 if self.residual_net is not None: 397 lr = self.loss_res(params, batch) 398 else: 399 lr = np.array(0.0) 400 if batch_train is not None: 401 ld = self.loss_data(params,batch_train) 402 else: 403 ld = np.array(0.0) 404 li = self.loss_ic(params,batch) 405 total = lb + lr + ld + li 406 w = {'b': np.sqrt(total/lb),'res': np.sqrt(total/lr),'data': np.sqrt(total/ld),'ic' : np.sqrt(total/li)} 407 self.params = (self.branch_params, self.trunk_params,w) 408 self.opt_state = self.opt_init(self.params) 409 # Used to restore the trained model parameters 410 _, self.unravel_params = ravel_pytree(self.params) 411 412 #Step 413 self.opt_state = self.step(next(self.itercount), self.opt_state, batch, batch_train) 414 415 # Log training metrics, only use host 0 to record results 416 if step % config.logging.log_every_steps == 0: 417 # Get the first replica of the state and batch 418 log_dict = self.evaluator(batch, batch_train) 419 if config.save_wandb: 420 wandb.log(log_dict, step) 421 end_time = time.time() 422 logger.log_iter(step, start_time, end_time, log_dict) 423 start_time = end_time 424 print(log_dict) 425 426 # Saving 427 if config.saving.save_every_steps is not None: 428 if (step + 1) % config.saving.save_every_steps == 0 or ( 429 step + 1 430 ) == config.training.max_steps: 431 ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name) 432 save_checkpoint(self.opt_state, ckpt_path, step, keep=config.saving.num_keep_ckpts) 433 434 435 #Run summary 436 log_dict['total_time'] = time.time() - t0 437 log_dict['epochs'] = step + 1 438 439 return log_dict
192 def __init__(self, config): 193 # Network initialization and evaluation functions 194 self.branch_init, self.branch_apply = modified_MLP(config.arch.branch_layers, activation = np.tanh) 195 self.trunk_init, self.trunk_apply = modified_MLP(config.arch.trunk_layers, activation = np.tanh) 196 197 # Initialize 198 self.branch_params = self.branch_init(rng_key = random.PRNGKey(config.seed)) 199 self.trunk_params = self.trunk_init(rng_key = random.PRNGKey(config.seed + 1)) 200 self.params = (self.branch_params, self.trunk_params,{'b': config.weights['b'],'res': config.weights['res'],'data': config.weights['data'],'ic' : config.weights['ic']}) 201 202 # Use optimizers to set optimizer initialization and update functions 203 self.opt_init, \ 204 self.opt_update, \ 205 self.get_params = optimizers.adam(optimizers.exponential_decay(config.optim.learning_rate, 206 decay_steps = config.optim.decay_steps, 207 decay_rate = config.optim.decay_rate)) 208 self.opt_state = self.opt_init(self.params) 209 210 # Used to restore the trained model parameters 211 _, self.unravel_params = ravel_pytree(self.params) 212 213 self.itercount = itertools.count() 214 215 # Residual net and boundary condition loss 216 self.loss_bc = config.loss_bc 217 self.residual_net = config.residual_net 218 self.config = config 219 220 # Limits domain 221 self.xl = config.xl 222 self.xu = config.xu 223 224 #Vmap neural net 225 self.pred_fn = vmap(self.operator_net, (None, 0, 0, 0)) 226 227 #Vmap residual operator 228 if self.residual_net is not None: 229 self.r_pred_fn = vmap(self.residual_net, (None, 0, 0, 0)) 230 231 #Vmap train and test data 232 self.pred_batch = vmap( 233 vmap( 234 vmap(self.operator_net, (None, None, 0, None)),(None,None,None,0) 235 ),(None,0,None,None) 236 ) 237 238 self.pred_batch_xt = vmap( 239 vmap(self.operator_net, (None, 0, None, None)),(None,None,0,0)) 240 if self.residual_net is not None: 241 self.r_pred_batch = vmap( 242 vmap( 243 vmap(self.residual_net, (None, None, 0, None)),(None,None,None,0) 244 ),(None,0,None,None) 245 ) 246 247 #Data 248 self.u_test = config.u_test 249 self.u_train = config.u_train 250 self.x_mesh = config.x_mesh 251 self.t_mesh = config.t_mesh 252 253 #Weights 254 self.w = config.weights
291 @partial(jit, static_argnums=(0,)) 292 def loss(self, params, batch, batch_train): 293 loss_bc = 0.0 294 loss_data = 0.0 295 loss_res = 0.0 296 loss_ic = 0.0 297 if self.loss_bc is not None: 298 loss_bc = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 299 if self.residual_net is not None: 300 loss_res = self.loss_res(params, batch) 301 if batch_train is not None: 302 loss_data = self.loss_data(params,batch_train) 303 loss = (params[-1]['b'] ** 2) * loss_bc + (params[-1]['res'] ** 2) * loss_res + (params[-1]['data'] ** 2) * loss_data + (params[-1]['ic'] ** 2) * self.loss_ic(params,batch) 304 return loss
307 @partial(jit, static_argnums=(0,)) 308 def step(self, i, opt_state, batch, batch_train): 309 params = self.get_params(opt_state) 310 g = grad(self.loss)(params, batch, batch_train) 311 if self.config.sa: 312 g[-1]['b'] = -g[-1]['b'] 313 g[-1]['res'] = -g[-1]['res'] 314 g[-1]['data'] = -g[-1]['data'] 315 g[-1]['ic'] = -g[-1]['ic'] 316 return self.opt_update(i, g, opt_state)
318 def evaluator(self,batch,batch_train): 319 log_dict = {} 320 params = self.get_params(self.opt_state) 321 #Test loss 322 if self.u_test is not None: 323 pred = self.pred_batch(params, self.u_test[:,0,:], self.x_mesh, self.t_mesh) 324 log_dict['test_L2'] = np.mean(np.sqrt(np.mean((pred - self.u_test) ** 2,[1,2])/np.mean((self.u_test) ** 2,[1,2]))) 325 326 #Train 327 if self.loss_bc is not None: 328 log_dict['bc_loss'] = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 329 if self.residual_net is not None: 330 log_dict['res_loss'] = self.loss_res(params,batch) 331 log_dict['ic_loss'] = self.loss_ic(params,batch) 332 if batch_train is not None: 333 log_dict['data_loss'] = self.loss_data(params,batch_train) 334 if self.config.sa: 335 log_dict['Wb'] = params[-1]['b'] 336 log_dict['Wres'] = params[-1]['res'] 337 log_dict['Wdata'] = params[-1]['data'] 338 log_dict['Wic'] = params[-1]['ic'] 339 340 return log_dict
343 def train(self): 344 config = self.config 345 if config.save_wandb: 346 wandb_config = config.wandb 347 wandb.init(project = wandb_config.project, name = wandb_config.name) 348 349 #Initialize the initial data sampler 350 if config.initial_data is None: 351 initial_data = generate_initial_data(config.N0,int(config.size),kernel = config.kernel,xl = config.xl,xu = config.xu,key = 0) 352 else: 353 initial_data = config.initial_data 354 initial_sampler = iter(InitialDataSampler(initial_data, config.N)) 355 356 # Initialize the residual sampler 357 dom = np.array([[config.xl, config.xu],[config.tl, config.tu]]) 358 res_sampler = iter(UniformSampler(dom, config.Q)) 359 360 # Initialize the boundary condition sampler 361 bc_sampler = iter(UniformSampler(np.array([[config.tl, config.tu]]), config.N0)) 362 363 #Initialize the training data sampler 364 batch_train = None 365 if config.u_train is not None: 366 data_sampler = iter(DataSampler(config.u_train, config.training.batch_size_train_data)) 367 368 # Logger 369 logger = Logger() 370 batch = {'u0': None,'x': None,'t': None,'t_bc': None} 371 372 #Train 373 w = None 374 print("Training DeepONet...") 375 start_time = time.time() 376 t0 = start_time 377 for step in range(config.training.max_steps): 378 print(step) 379 batch['u0'] = next(initial_sampler)[0,:,:] 380 res_data_tmp = next(res_sampler)[0,:,:] 381 batch['x'] = res_data_tmp[:,0] 382 batch['t'] = res_data_tmp[:,1] 383 batch['t_bc'] = next(bc_sampler)[0,:] 384 if config.u_train is not None: 385 u = next(data_sampler) 386 u = u.reshape((u.shape[1],u.shape[2],u.shape[3])) 387 batch_train = {'u0': u[:,0,:],'u': u,'t': self.t_mesh,'x': self.x_mesh} 388 389 #Initialise weights 390 if config.sa and step == 0: 391 params = self.get_params(self.opt_state) 392 if self.loss_bc is not None: 393 lb = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu) 394 else: 395 lb = np.array(0.0) 396 if self.residual_net is not None: 397 lr = self.loss_res(params, batch) 398 else: 399 lr = np.array(0.0) 400 if batch_train is not None: 401 ld = self.loss_data(params,batch_train) 402 else: 403 ld = np.array(0.0) 404 li = self.loss_ic(params,batch) 405 total = lb + lr + ld + li 406 w = {'b': np.sqrt(total/lb),'res': np.sqrt(total/lr),'data': np.sqrt(total/ld),'ic' : np.sqrt(total/li)} 407 self.params = (self.branch_params, self.trunk_params,w) 408 self.opt_state = self.opt_init(self.params) 409 # Used to restore the trained model parameters 410 _, self.unravel_params = ravel_pytree(self.params) 411 412 #Step 413 self.opt_state = self.step(next(self.itercount), self.opt_state, batch, batch_train) 414 415 # Log training metrics, only use host 0 to record results 416 if step % config.logging.log_every_steps == 0: 417 # Get the first replica of the state and batch 418 log_dict = self.evaluator(batch, batch_train) 419 if config.save_wandb: 420 wandb.log(log_dict, step) 421 end_time = time.time() 422 logger.log_iter(step, start_time, end_time, log_dict) 423 start_time = end_time 424 print(log_dict) 425 426 # Saving 427 if config.saving.save_every_steps is not None: 428 if (step + 1) % config.saving.save_every_steps == 0 or ( 429 step + 1 430 ) == config.training.max_steps: 431 ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name) 432 save_checkpoint(self.opt_state, ckpt_path, step, keep=config.saving.num_keep_ckpts) 433 434 435 #Run summary 436 log_dict['total_time'] = time.time() - t0 437 log_dict['epochs'] = step + 1 438 439 return log_dict
442def bc_loss_periodic(pred_batch,params,batch,xl,xu): 443 pred_xl = pred_batch( 444 params, batch['u0'], xl + np.zeros((1,)), batch['t'].reshape((batch['t'].shape[0],)) 445 ) 446 pred_xu = pred_batch( 447 params, batch['u0'], xu + np.zeros((1,)), batch['t'].reshape((batch['t'].shape[0],)) 448 ) 449 return np.mean((pred_xl - pred_xu) ** 2)