jinnax.deeponet

  1#Adapted from https://github.com/PredictiveIntelligenceLab/Physics-informed-DeepONets
  2import jax, os
  3import jax.numpy as np
  4from jax import random, grad, vmap, jit, hessian
  5from jax.example_libraries import optimizers
  6from jax.experimental.ode import odeint
  7from jax.nn import relu, elu
  8from jax.tree_util import tree_map
  9#from jax.config import config
 10#from jax.ops import index_update, index
 11#from jaxpi.utils import restore_checkpoint, save_checkpoint
 12from flax.training import checkpoints
 13from jax import lax, pmap
 14from jax.flatten_util import ravel_pytree
 15import ml_collections
 16import itertools
 17from functools import partial
 18from torch.utils import data
 19from tqdm import trange, tqdm
 20from jaxpi.samplers import BaseSampler, UniformSampler
 21from jaxpi.logging import Logger
 22import time
 23import wandb
 24
 25def save_checkpoint(state, workdir, step, keep=5, name=None):
 26    # Create the workdir if it doesn't exist.
 27    if not os.path.isdir(workdir):
 28        os.makedirs(workdir)
 29
 30    # Save the checkpoint.
 31    if jax.process_index() == 0:
 32        # Get the first replica's state and save it.
 33        state = jax.device_get(tree_map(lambda x: x[0], state))
 34        checkpoints.save_checkpoint(workdir, state, step=step, keep=keep)
 35
 36def restore_checkpoint(state, workdir, step=None):
 37    # check if passed state is in a sharded state
 38    # if so, reduce to a single device sharding
 39
 40    if isinstance(
 41        tree_map(lambda x: jnp.array(x).sharding, jax.tree.leaves(state.params))[0],
 42        jax.sharding.PmapSharding,
 43    ):
 44        state = tree_map(lambda x: x[0], state)
 45
 46    # ensuring that we're in a single device setting
 47    assert isinstance(
 48        tree_map(lambda x: jnp.array(x).sharding, jax.tree.leaves(state.params))[0],
 49        jax.sharding.SingleDeviceSharding,
 50    )
 51
 52    state = checkpoints.restore_checkpoint(workdir, state, step=step)
 53    return state
 54
 55def get_base_config():
 56    """
 57    Base config file for training PINN for CSF in jaxpi
 58
 59    Returns
 60    -------
 61    ml_collections config dictionary
 62    """
 63    #Get the default hyperparameter configuration.
 64    config = ml_collections.ConfigDict()
 65    # Weights & Biases
 66    config.wandb = wandb = ml_collections.ConfigDict()
 67    wandb.tag = None
 68    # Arch
 69    config.arch = arch = ml_collections.ConfigDict()
 70    arch.branch_layers = [1024] + 4*[256]
 71    arch.trunk_layers = [2] + 4*[256]
 72    # Optim
 73    config.optim = optim = ml_collections.ConfigDict()
 74    optim.beta1 = 0.9
 75    optim.beta2 = 0.999
 76    optim.eps = 1e-8
 77    optim.learning_rate = 1e-3
 78    optim.decay_rate = 0.9
 79    optim.decay_steps = 2000
 80    # Training
 81    config.training = training = ml_collections.ConfigDict()
 82    training.batch_size_per_device = 4096
 83    config.training.batch_size_train_data = 128
 84    # Weighting
 85    config.weights = {'b': np.array([100.0]),'res': np.array([1.0]),'data': np.array([1.0]),'ic' : np.array([100.0])}
 86    config.sa = True
 87    # Logging
 88    config.logging = logging = ml_collections.ConfigDict()
 89    logging.log_every_steps = 1000
 90    logging.log_errors = True
 91    logging.log_losses = True
 92    logging.log_weights = True
 93    logging.log_preds = False
 94    logging.log_grads = False
 95    logging.log_ntk = False
 96    # Saving
 97    config.saving = saving = ml_collections.ConfigDict()
 98    saving.save_every_steps = 1000
 99    saving.num_keep_ckpts = 1000000
100    #Seed
101    config.seed = 10
102    return config
103
104#Periodic kernel
105@jax.jit
106def kernel_periodic(x1,x2,ls = 1,p = 1):
107    return np.exp(-(np.sin(np.pi*np.abs(x1 - x2)/p) ** 2)/(2 * (ls ** 2)))
108
109#Generate initial data
110def generate_initial_data(N0,size,kernel = kernel_periodic,xl = 0,xu = 1,key = 0):
111    x = np.linspace(xl,xu,N0)
112    K = np.array([[kernel(x1,x2)] for x1 in x for x2 in x]).reshape((N0,N0))
113    u = jax.random.multivariate_normal(key = jax.random.PRNGKey(key),mean = np.zeros((K.shape[0],)),cov = K,shape = (size,),method = 'svd')
114    return u
115
116class InitialDataSampler(BaseSampler):
117    def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)):
118        super().__init__(batch_size, rng_key)
119        self.data = data
120        self.dim = data.shape[0]
121    @partial(pmap, static_broadcasted_argnums=(0,))
122    def data_generation(self, key):
123        "Generates data containing batch_size samples"
124        idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False)
125        return self.data[idx,:]
126
127class DataSampler(BaseSampler):
128    def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)):
129        super().__init__(batch_size, rng_key)
130        self.data = data
131        self.dim = data.shape[0]
132    @partial(pmap, static_broadcasted_argnums=(0,))
133    def data_generation(self, key):
134        "Generates data containing batch_size samples"
135        idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False)
136        return self.data[idx,:,:]
137
138# Define MLP
139def MLP(layers, activation=relu):
140  ''' Vanilla MLP'''
141  def init(rng_key):
142      def init_layer(key, d_in, d_out):
143          k1, k2 = random.split(key)
144          glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
145          W = glorot_stddev * random.normal(k1, (d_in, d_out))
146          b = np.zeros(d_out)
147          return W, b
148      key, *keys = random.split(rng_key, len(layers))
149      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
150      return params
151  def apply(params, inputs):
152      for W, b in params[:-1]:
153          outputs = np.dot(inputs, W) + b
154          inputs = activation(outputs)
155      W, b = params[-1]
156      outputs = np.dot(inputs, W) + b
157      return outputs
158  return init, apply
159
160# Define modified MLP
161def modified_MLP(layers, activation=relu):
162  def xavier_init(key, d_in, d_out):
163      glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
164      W = glorot_stddev * random.normal(key, (d_in, d_out))
165      b = np.zeros(d_out)
166      return W, b
167  def init(rng_key):
168      U1, b1 =  xavier_init(random.PRNGKey(12345), layers[0], layers[1])
169      U2, b2 =  xavier_init(random.PRNGKey(54321), layers[0], layers[1])
170      def init_layer(key, d_in, d_out):
171          k1, k2 = random.split(key)
172          W, b = xavier_init(k1, d_in, d_out)
173          return W, b
174      key, *keys = random.split(rng_key, len(layers))
175      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
176      return (params, U1, b1, U2, b2)
177  def apply(params, inputs):
178      params, U1, b1, U2, b2 = params
179      U = activation(np.dot(inputs, U1) + b1)
180      V = activation(np.dot(inputs, U2) + b2)
181      for W, b in params[:-1]:
182          outputs = activation(np.dot(inputs, W) + b)
183          inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V)
184      W, b = params[-1]
185      outputs = np.dot(inputs, W) + b
186      return outputs
187  return init, apply
188
189# Define the model
190class PI_DeepONet:
191    def __init__(self, config):
192        # Network initialization and evaluation functions
193        self.branch_init, self.branch_apply = modified_MLP(config.arch.branch_layers, activation = np.tanh)
194        self.trunk_init, self.trunk_apply = modified_MLP(config.arch.trunk_layers, activation = np.tanh)
195
196        # Initialize
197        self.branch_params = self.branch_init(rng_key = random.PRNGKey(config.seed))
198        self.trunk_params = self.trunk_init(rng_key = random.PRNGKey(config.seed + 1))
199        self.params = (self.branch_params, self.trunk_params,{'b': config.weights['b'],'res': config.weights['res'],'data': config.weights['data'],'ic' : config.weights['ic']})
200
201        # Use optimizers to set optimizer initialization and update functions
202        self.opt_init, \
203        self.opt_update, \
204        self.get_params = optimizers.adam(optimizers.exponential_decay(config.optim.learning_rate,
205                                                                      decay_steps = config.optim.decay_steps,
206                                                                      decay_rate = config.optim.decay_rate))
207        self.opt_state = self.opt_init(self.params)
208
209        # Used to restore the trained model parameters
210        _, self.unravel_params = ravel_pytree(self.params)
211
212        self.itercount = itertools.count()
213
214        # Residual net and boundary condition loss
215        self.loss_bc = config.loss_bc
216        self.residual_net = config.residual_net
217        self.config = config
218
219        # Limits domain
220        self.xl = config.xl
221        self.xu = config.xu
222
223        #Vmap neural net
224        self.pred_fn = vmap(self.operator_net, (None, 0, 0, 0))
225
226        #Vmap residual operator
227        if self.residual_net is not None:
228            self.r_pred_fn = vmap(self.residual_net, (None, 0, 0, 0))
229
230        #Vmap train and test data
231        self.pred_batch = vmap(
232            vmap(
233                vmap(self.operator_net, (None, None, 0, None)),(None,None,None,0)
234            ),(None,0,None,None)
235        )
236
237        self.pred_batch_xt = vmap(
238                vmap(self.operator_net, (None, 0, None, None)),(None,None,0,0))
239        if self.residual_net is not None:
240            self.r_pred_batch = vmap(
241                vmap(
242                    vmap(self.residual_net, (None, None, 0, None)),(None,None,None,0)
243                ),(None,0,None,None)
244            )
245
246        #Data
247        self.u_test = config.u_test
248        self.u_train = config.u_train
249        self.x_mesh = config.x_mesh
250        self.t_mesh = config.t_mesh
251
252        #Weights
253        self.w = config.weights
254
255    # Define DeepONet architecture
256    @partial(jit, static_argnums=(0,))
257    def operator_net(self, params, u, x, t):
258        branch_params, trunk_params,_ = params
259        y = np.stack([x,t])
260        B = self.branch_apply(branch_params, u)
261        T = self.trunk_apply(trunk_params, y)
262        outputs = np.sum(B * T)
263        return outputs
264
265    # Define residual loss
266    @partial(jit, static_argnums=(0,))
267    def loss_res(self, params, batch):
268        # Compute forward pass
269        pred = self.residual_net(self.operator_net,params,batch)
270        # Compute loss
271        loss = np.mean((pred)**2)
272        return loss
273
274    #Data loss
275    @partial(jit, static_argnums=(0,))
276    def loss_data(self,params,batch_train):
277        pred = self.pred_batch(params,batch_train['u0'],batch_train['x'],batch_train['t'])
278        return np.mean((pred - batch_train['u']) ** 2)
279
280    # Define initial condition loss
281    @partial(jit, static_argnums=(0,))
282    def loss_ic(self, params, batch):
283        # Compute forward pass
284        pred = self.pred_batch_xt(params,batch['u0'], self.x_mesh, np.zeros(self.x_mesh.shape[0]))
285        # Compute loss
286        loss = np.mean((pred - batch['u0'].transpose())**2)
287        return loss
288
289    # Define total loss
290    @partial(jit, static_argnums=(0,))
291    def loss(self, params, batch, batch_train):
292        loss_bc = 0.0
293        loss_data = 0.0
294        loss_res = 0.0
295        loss_ic = 0.0
296        if self.loss_bc is not None:
297            loss_bc = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
298        if self.residual_net is not None:
299            loss_res = self.loss_res(params, batch)
300        if batch_train is not None:
301            loss_data = self.loss_data(params,batch_train)
302        loss = (params[-1]['b'] ** 2) * loss_bc +  (params[-1]['res'] ** 2) * loss_res + (params[-1]['data'] ** 2) * loss_data + (params[-1]['ic'] ** 2) * self.loss_ic(params,batch)
303        return loss
304
305    # Define a compiled update step
306    @partial(jit, static_argnums=(0,))
307    def step(self, i, opt_state, batch, batch_train):
308        params = self.get_params(opt_state)
309        g = grad(self.loss)(params, batch, batch_train)
310        if self.config.sa:
311            g[-1]['b'] = -g[-1]['b']
312            g[-1]['res'] = -g[-1]['res']
313            g[-1]['data'] = -g[-1]['data']
314            g[-1]['ic'] = -g[-1]['ic']
315        return self.opt_update(i, g, opt_state)
316
317    def evaluator(self,batch,batch_train):
318        log_dict = {}
319        params = self.get_params(self.opt_state)
320        #Test loss
321        if self.u_test is not None:
322            pred = self.pred_batch(params, self.u_test[:,0,:], self.x_mesh, self.t_mesh)
323            log_dict['test_L2'] = np.mean(np.sqrt(np.mean((pred - self.u_test) ** 2,[1,2])/np.mean((self.u_test) ** 2,[1,2])))
324
325        #Train
326        if self.loss_bc is not None:
327            log_dict['bc_loss'] = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
328        if self.residual_net is not None:
329            log_dict['res_loss'] = self.loss_res(params,batch)
330        log_dict['ic_loss'] = self.loss_ic(params,batch)
331        if batch_train is not None:
332            log_dict['data_loss'] = self.loss_data(params,batch_train)
333        if self.config.sa:
334            log_dict['Wb'] = params[-1]['b']
335            log_dict['Wres'] = params[-1]['res']
336            log_dict['Wdata'] = params[-1]['data']
337            log_dict['Wic'] = params[-1]['ic']
338
339        return log_dict
340
341    # Optimize parameters in a loop
342    def train(self):
343        config = self.config
344        if config.save_wandb:
345            wandb_config = config.wandb
346            wandb.init(project = wandb_config.project, name = wandb_config.name)
347
348        #Initialize the initial data sampler
349        if config.initial_data is None:
350            initial_data = generate_initial_data(config.N0,int(config.size),kernel = config.kernel,xl = config.xl,xu = config.xu,key = 0)
351        else:
352            initial_data = config.initial_data
353        initial_sampler = iter(InitialDataSampler(initial_data, config.N))
354
355        # Initialize the residual sampler
356        dom = np.array([[config.xl, config.xu],[config.tl, config.tu]])
357        res_sampler = iter(UniformSampler(dom, config.Q))
358
359        # Initialize the boundary condition sampler
360        bc_sampler = iter(UniformSampler(np.array([[config.tl, config.tu]]), config.N0))
361
362        #Initialize the training data sampler
363        batch_train = None
364        if config.u_train is not None:
365            data_sampler = iter(DataSampler(config.u_train, config.training.batch_size_train_data))
366
367        # Logger
368        logger = Logger()
369        batch = {'u0': None,'x': None,'t': None,'t_bc': None}
370
371        #Train
372        w = None
373        print("Training DeepONet...")
374        start_time = time.time()
375        t0 = start_time
376        for step in range(config.training.max_steps):
377            print(step)
378            batch['u0'] = next(initial_sampler)[0,:,:]
379            res_data_tmp = next(res_sampler)[0,:,:]
380            batch['x'] = res_data_tmp[:,0]
381            batch['t'] = res_data_tmp[:,1]
382            batch['t_bc'] = next(bc_sampler)[0,:]
383            if config.u_train is not None:
384                u = next(data_sampler)
385                u = u.reshape((u.shape[1],u.shape[2],u.shape[3]))
386                batch_train = {'u0': u[:,0,:],'u': u,'t': self.t_mesh,'x': self.x_mesh}
387
388            #Initialise weights
389            if config.sa and step == 0:
390                params = self.get_params(self.opt_state)
391                if self.loss_bc is not None:
392                    lb = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
393                else:
394                    lb = np.array(0.0)
395                if self.residual_net is not None:
396                    lr = self.loss_res(params, batch)
397                else:
398                    lr = np.array(0.0)
399                if batch_train is not None:
400                    ld = self.loss_data(params,batch_train)
401                else:
402                    ld = np.array(0.0)
403                li = self.loss_ic(params,batch)
404                total = lb + lr + ld + li
405                w = {'b': np.sqrt(total/lb),'res': np.sqrt(total/lr),'data': np.sqrt(total/ld),'ic' : np.sqrt(total/li)}
406                self.params = (self.branch_params, self.trunk_params,w)
407                self.opt_state = self.opt_init(self.params)
408                # Used to restore the trained model parameters
409                _, self.unravel_params = ravel_pytree(self.params)
410
411            #Step
412            self.opt_state = self.step(next(self.itercount), self.opt_state, batch, batch_train)
413
414            # Log training metrics, only use host 0 to record results
415            if step % config.logging.log_every_steps == 0:
416                # Get the first replica of the state and batch
417                log_dict = self.evaluator(batch, batch_train)
418                if config.save_wandb:
419                    wandb.log(log_dict, step)
420                end_time = time.time()
421                logger.log_iter(step, start_time, end_time, log_dict)
422                start_time = end_time
423                print(log_dict)
424
425            # Saving
426            if config.saving.save_every_steps is not None:
427                if (step + 1) % config.saving.save_every_steps == 0 or (
428                    step + 1
429                ) == config.training.max_steps:
430                    ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name)
431                    save_checkpoint(self.opt_state, ckpt_path, step, keep=config.saving.num_keep_ckpts)
432
433
434        #Run summary
435        log_dict['total_time'] = time.time() - t0
436        log_dict['epochs'] = step + 1
437
438        return log_dict
439
440# Define PDE residual
441def bc_loss_periodic(pred_batch,params,batch,xl,xu):
442    pred_xl = pred_batch(
443        params, batch['u0'], xl + np.zeros((1,)), batch['t'].reshape((batch['t'].shape[0],))
444    )
445    pred_xu = pred_batch(
446        params, batch['u0'], xu + np.zeros((1,)), batch['t'].reshape((batch['t'].shape[0],))
447    )
448    return np.mean((pred_xl - pred_xu) ** 2)
def save_checkpoint(state, workdir, step, keep=5, name=None):
26def save_checkpoint(state, workdir, step, keep=5, name=None):
27    # Create the workdir if it doesn't exist.
28    if not os.path.isdir(workdir):
29        os.makedirs(workdir)
30
31    # Save the checkpoint.
32    if jax.process_index() == 0:
33        # Get the first replica's state and save it.
34        state = jax.device_get(tree_map(lambda x: x[0], state))
35        checkpoints.save_checkpoint(workdir, state, step=step, keep=keep)
def restore_checkpoint(state, workdir, step=None):
37def restore_checkpoint(state, workdir, step=None):
38    # check if passed state is in a sharded state
39    # if so, reduce to a single device sharding
40
41    if isinstance(
42        tree_map(lambda x: jnp.array(x).sharding, jax.tree.leaves(state.params))[0],
43        jax.sharding.PmapSharding,
44    ):
45        state = tree_map(lambda x: x[0], state)
46
47    # ensuring that we're in a single device setting
48    assert isinstance(
49        tree_map(lambda x: jnp.array(x).sharding, jax.tree.leaves(state.params))[0],
50        jax.sharding.SingleDeviceSharding,
51    )
52
53    state = checkpoints.restore_checkpoint(workdir, state, step=step)
54    return state
def get_base_config():
 56def get_base_config():
 57    """
 58    Base config file for training PINN for CSF in jaxpi
 59
 60    Returns
 61    -------
 62    ml_collections config dictionary
 63    """
 64    #Get the default hyperparameter configuration.
 65    config = ml_collections.ConfigDict()
 66    # Weights & Biases
 67    config.wandb = wandb = ml_collections.ConfigDict()
 68    wandb.tag = None
 69    # Arch
 70    config.arch = arch = ml_collections.ConfigDict()
 71    arch.branch_layers = [1024] + 4*[256]
 72    arch.trunk_layers = [2] + 4*[256]
 73    # Optim
 74    config.optim = optim = ml_collections.ConfigDict()
 75    optim.beta1 = 0.9
 76    optim.beta2 = 0.999
 77    optim.eps = 1e-8
 78    optim.learning_rate = 1e-3
 79    optim.decay_rate = 0.9
 80    optim.decay_steps = 2000
 81    # Training
 82    config.training = training = ml_collections.ConfigDict()
 83    training.batch_size_per_device = 4096
 84    config.training.batch_size_train_data = 128
 85    # Weighting
 86    config.weights = {'b': np.array([100.0]),'res': np.array([1.0]),'data': np.array([1.0]),'ic' : np.array([100.0])}
 87    config.sa = True
 88    # Logging
 89    config.logging = logging = ml_collections.ConfigDict()
 90    logging.log_every_steps = 1000
 91    logging.log_errors = True
 92    logging.log_losses = True
 93    logging.log_weights = True
 94    logging.log_preds = False
 95    logging.log_grads = False
 96    logging.log_ntk = False
 97    # Saving
 98    config.saving = saving = ml_collections.ConfigDict()
 99    saving.save_every_steps = 1000
100    saving.num_keep_ckpts = 1000000
101    #Seed
102    config.seed = 10
103    return config

Base config file for training PINN for CSF in jaxpi

Returns

ml_collections config dictionary

@jax.jit
def kernel_periodic(x1, x2, ls=1, p=1):
106@jax.jit
107def kernel_periodic(x1,x2,ls = 1,p = 1):
108    return np.exp(-(np.sin(np.pi*np.abs(x1 - x2)/p) ** 2)/(2 * (ls ** 2)))
def generate_initial_data( N0, size, kernel=<PjitFunction of <function kernel_periodic>>, xl=0, xu=1, key=0):
111def generate_initial_data(N0,size,kernel = kernel_periodic,xl = 0,xu = 1,key = 0):
112    x = np.linspace(xl,xu,N0)
113    K = np.array([[kernel(x1,x2)] for x1 in x for x2 in x]).reshape((N0,N0))
114    u = jax.random.multivariate_normal(key = jax.random.PRNGKey(key),mean = np.zeros((K.shape[0],)),cov = K,shape = (size,),method = 'svd')
115    return u
class InitialDataSampler(typing.Generic[+_T_co]):
117class InitialDataSampler(BaseSampler):
118    def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)):
119        super().__init__(batch_size, rng_key)
120        self.data = data
121        self.dim = data.shape[0]
122    @partial(pmap, static_broadcasted_argnums=(0,))
123    def data_generation(self, key):
124        "Generates data containing batch_size samples"
125        idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False)
126        return self.data[idx,:]

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

InitialDataSampler(data, batch_size, rng_key=Array([ 0, 1234], dtype=uint32))
118    def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)):
119        super().__init__(batch_size, rng_key)
120        self.data = data
121        self.dim = data.shape[0]
data
dim
@partial(pmap, static_broadcasted_argnums=(0,))
def data_generation(self, key):
122    @partial(pmap, static_broadcasted_argnums=(0,))
123    def data_generation(self, key):
124        "Generates data containing batch_size samples"
125        idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False)
126        return self.data[idx,:]

Generates data containing batch_size samples

class DataSampler(typing.Generic[+_T_co]):
128class DataSampler(BaseSampler):
129    def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)):
130        super().__init__(batch_size, rng_key)
131        self.data = data
132        self.dim = data.shape[0]
133    @partial(pmap, static_broadcasted_argnums=(0,))
134    def data_generation(self, key):
135        "Generates data containing batch_size samples"
136        idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False)
137        return self.data[idx,:,:]

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

DataSampler(data, batch_size, rng_key=Array([ 0, 1234], dtype=uint32))
129    def __init__(self, data, batch_size,rng_key = jax.random.PRNGKey(1234)):
130        super().__init__(batch_size, rng_key)
131        self.data = data
132        self.dim = data.shape[0]
data
dim
@partial(pmap, static_broadcasted_argnums=(0,))
def data_generation(self, key):
133    @partial(pmap, static_broadcasted_argnums=(0,))
134    def data_generation(self, key):
135        "Generates data containing batch_size samples"
136        idx = jax.random.choice(key, self.dim, (self.batch_size,), replace=False)
137        return self.data[idx,:,:]

Generates data containing batch_size samples

def MLP(layers, activation=<jax._src.custom_derivatives.custom_jvp object>):
140def MLP(layers, activation=relu):
141  ''' Vanilla MLP'''
142  def init(rng_key):
143      def init_layer(key, d_in, d_out):
144          k1, k2 = random.split(key)
145          glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
146          W = glorot_stddev * random.normal(k1, (d_in, d_out))
147          b = np.zeros(d_out)
148          return W, b
149      key, *keys = random.split(rng_key, len(layers))
150      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
151      return params
152  def apply(params, inputs):
153      for W, b in params[:-1]:
154          outputs = np.dot(inputs, W) + b
155          inputs = activation(outputs)
156      W, b = params[-1]
157      outputs = np.dot(inputs, W) + b
158      return outputs
159  return init, apply

Vanilla MLP

def modified_MLP(layers, activation=<jax._src.custom_derivatives.custom_jvp object>):
162def modified_MLP(layers, activation=relu):
163  def xavier_init(key, d_in, d_out):
164      glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
165      W = glorot_stddev * random.normal(key, (d_in, d_out))
166      b = np.zeros(d_out)
167      return W, b
168  def init(rng_key):
169      U1, b1 =  xavier_init(random.PRNGKey(12345), layers[0], layers[1])
170      U2, b2 =  xavier_init(random.PRNGKey(54321), layers[0], layers[1])
171      def init_layer(key, d_in, d_out):
172          k1, k2 = random.split(key)
173          W, b = xavier_init(k1, d_in, d_out)
174          return W, b
175      key, *keys = random.split(rng_key, len(layers))
176      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
177      return (params, U1, b1, U2, b2)
178  def apply(params, inputs):
179      params, U1, b1, U2, b2 = params
180      U = activation(np.dot(inputs, U1) + b1)
181      V = activation(np.dot(inputs, U2) + b2)
182      for W, b in params[:-1]:
183          outputs = activation(np.dot(inputs, W) + b)
184          inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V)
185      W, b = params[-1]
186      outputs = np.dot(inputs, W) + b
187      return outputs
188  return init, apply
class PI_DeepONet:
191class PI_DeepONet:
192    def __init__(self, config):
193        # Network initialization and evaluation functions
194        self.branch_init, self.branch_apply = modified_MLP(config.arch.branch_layers, activation = np.tanh)
195        self.trunk_init, self.trunk_apply = modified_MLP(config.arch.trunk_layers, activation = np.tanh)
196
197        # Initialize
198        self.branch_params = self.branch_init(rng_key = random.PRNGKey(config.seed))
199        self.trunk_params = self.trunk_init(rng_key = random.PRNGKey(config.seed + 1))
200        self.params = (self.branch_params, self.trunk_params,{'b': config.weights['b'],'res': config.weights['res'],'data': config.weights['data'],'ic' : config.weights['ic']})
201
202        # Use optimizers to set optimizer initialization and update functions
203        self.opt_init, \
204        self.opt_update, \
205        self.get_params = optimizers.adam(optimizers.exponential_decay(config.optim.learning_rate,
206                                                                      decay_steps = config.optim.decay_steps,
207                                                                      decay_rate = config.optim.decay_rate))
208        self.opt_state = self.opt_init(self.params)
209
210        # Used to restore the trained model parameters
211        _, self.unravel_params = ravel_pytree(self.params)
212
213        self.itercount = itertools.count()
214
215        # Residual net and boundary condition loss
216        self.loss_bc = config.loss_bc
217        self.residual_net = config.residual_net
218        self.config = config
219
220        # Limits domain
221        self.xl = config.xl
222        self.xu = config.xu
223
224        #Vmap neural net
225        self.pred_fn = vmap(self.operator_net, (None, 0, 0, 0))
226
227        #Vmap residual operator
228        if self.residual_net is not None:
229            self.r_pred_fn = vmap(self.residual_net, (None, 0, 0, 0))
230
231        #Vmap train and test data
232        self.pred_batch = vmap(
233            vmap(
234                vmap(self.operator_net, (None, None, 0, None)),(None,None,None,0)
235            ),(None,0,None,None)
236        )
237
238        self.pred_batch_xt = vmap(
239                vmap(self.operator_net, (None, 0, None, None)),(None,None,0,0))
240        if self.residual_net is not None:
241            self.r_pred_batch = vmap(
242                vmap(
243                    vmap(self.residual_net, (None, None, 0, None)),(None,None,None,0)
244                ),(None,0,None,None)
245            )
246
247        #Data
248        self.u_test = config.u_test
249        self.u_train = config.u_train
250        self.x_mesh = config.x_mesh
251        self.t_mesh = config.t_mesh
252
253        #Weights
254        self.w = config.weights
255
256    # Define DeepONet architecture
257    @partial(jit, static_argnums=(0,))
258    def operator_net(self, params, u, x, t):
259        branch_params, trunk_params,_ = params
260        y = np.stack([x,t])
261        B = self.branch_apply(branch_params, u)
262        T = self.trunk_apply(trunk_params, y)
263        outputs = np.sum(B * T)
264        return outputs
265
266    # Define residual loss
267    @partial(jit, static_argnums=(0,))
268    def loss_res(self, params, batch):
269        # Compute forward pass
270        pred = self.residual_net(self.operator_net,params,batch)
271        # Compute loss
272        loss = np.mean((pred)**2)
273        return loss
274
275    #Data loss
276    @partial(jit, static_argnums=(0,))
277    def loss_data(self,params,batch_train):
278        pred = self.pred_batch(params,batch_train['u0'],batch_train['x'],batch_train['t'])
279        return np.mean((pred - batch_train['u']) ** 2)
280
281    # Define initial condition loss
282    @partial(jit, static_argnums=(0,))
283    def loss_ic(self, params, batch):
284        # Compute forward pass
285        pred = self.pred_batch_xt(params,batch['u0'], self.x_mesh, np.zeros(self.x_mesh.shape[0]))
286        # Compute loss
287        loss = np.mean((pred - batch['u0'].transpose())**2)
288        return loss
289
290    # Define total loss
291    @partial(jit, static_argnums=(0,))
292    def loss(self, params, batch, batch_train):
293        loss_bc = 0.0
294        loss_data = 0.0
295        loss_res = 0.0
296        loss_ic = 0.0
297        if self.loss_bc is not None:
298            loss_bc = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
299        if self.residual_net is not None:
300            loss_res = self.loss_res(params, batch)
301        if batch_train is not None:
302            loss_data = self.loss_data(params,batch_train)
303        loss = (params[-1]['b'] ** 2) * loss_bc +  (params[-1]['res'] ** 2) * loss_res + (params[-1]['data'] ** 2) * loss_data + (params[-1]['ic'] ** 2) * self.loss_ic(params,batch)
304        return loss
305
306    # Define a compiled update step
307    @partial(jit, static_argnums=(0,))
308    def step(self, i, opt_state, batch, batch_train):
309        params = self.get_params(opt_state)
310        g = grad(self.loss)(params, batch, batch_train)
311        if self.config.sa:
312            g[-1]['b'] = -g[-1]['b']
313            g[-1]['res'] = -g[-1]['res']
314            g[-1]['data'] = -g[-1]['data']
315            g[-1]['ic'] = -g[-1]['ic']
316        return self.opt_update(i, g, opt_state)
317
318    def evaluator(self,batch,batch_train):
319        log_dict = {}
320        params = self.get_params(self.opt_state)
321        #Test loss
322        if self.u_test is not None:
323            pred = self.pred_batch(params, self.u_test[:,0,:], self.x_mesh, self.t_mesh)
324            log_dict['test_L2'] = np.mean(np.sqrt(np.mean((pred - self.u_test) ** 2,[1,2])/np.mean((self.u_test) ** 2,[1,2])))
325
326        #Train
327        if self.loss_bc is not None:
328            log_dict['bc_loss'] = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
329        if self.residual_net is not None:
330            log_dict['res_loss'] = self.loss_res(params,batch)
331        log_dict['ic_loss'] = self.loss_ic(params,batch)
332        if batch_train is not None:
333            log_dict['data_loss'] = self.loss_data(params,batch_train)
334        if self.config.sa:
335            log_dict['Wb'] = params[-1]['b']
336            log_dict['Wres'] = params[-1]['res']
337            log_dict['Wdata'] = params[-1]['data']
338            log_dict['Wic'] = params[-1]['ic']
339
340        return log_dict
341
342    # Optimize parameters in a loop
343    def train(self):
344        config = self.config
345        if config.save_wandb:
346            wandb_config = config.wandb
347            wandb.init(project = wandb_config.project, name = wandb_config.name)
348
349        #Initialize the initial data sampler
350        if config.initial_data is None:
351            initial_data = generate_initial_data(config.N0,int(config.size),kernel = config.kernel,xl = config.xl,xu = config.xu,key = 0)
352        else:
353            initial_data = config.initial_data
354        initial_sampler = iter(InitialDataSampler(initial_data, config.N))
355
356        # Initialize the residual sampler
357        dom = np.array([[config.xl, config.xu],[config.tl, config.tu]])
358        res_sampler = iter(UniformSampler(dom, config.Q))
359
360        # Initialize the boundary condition sampler
361        bc_sampler = iter(UniformSampler(np.array([[config.tl, config.tu]]), config.N0))
362
363        #Initialize the training data sampler
364        batch_train = None
365        if config.u_train is not None:
366            data_sampler = iter(DataSampler(config.u_train, config.training.batch_size_train_data))
367
368        # Logger
369        logger = Logger()
370        batch = {'u0': None,'x': None,'t': None,'t_bc': None}
371
372        #Train
373        w = None
374        print("Training DeepONet...")
375        start_time = time.time()
376        t0 = start_time
377        for step in range(config.training.max_steps):
378            print(step)
379            batch['u0'] = next(initial_sampler)[0,:,:]
380            res_data_tmp = next(res_sampler)[0,:,:]
381            batch['x'] = res_data_tmp[:,0]
382            batch['t'] = res_data_tmp[:,1]
383            batch['t_bc'] = next(bc_sampler)[0,:]
384            if config.u_train is not None:
385                u = next(data_sampler)
386                u = u.reshape((u.shape[1],u.shape[2],u.shape[3]))
387                batch_train = {'u0': u[:,0,:],'u': u,'t': self.t_mesh,'x': self.x_mesh}
388
389            #Initialise weights
390            if config.sa and step == 0:
391                params = self.get_params(self.opt_state)
392                if self.loss_bc is not None:
393                    lb = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
394                else:
395                    lb = np.array(0.0)
396                if self.residual_net is not None:
397                    lr = self.loss_res(params, batch)
398                else:
399                    lr = np.array(0.0)
400                if batch_train is not None:
401                    ld = self.loss_data(params,batch_train)
402                else:
403                    ld = np.array(0.0)
404                li = self.loss_ic(params,batch)
405                total = lb + lr + ld + li
406                w = {'b': np.sqrt(total/lb),'res': np.sqrt(total/lr),'data': np.sqrt(total/ld),'ic' : np.sqrt(total/li)}
407                self.params = (self.branch_params, self.trunk_params,w)
408                self.opt_state = self.opt_init(self.params)
409                # Used to restore the trained model parameters
410                _, self.unravel_params = ravel_pytree(self.params)
411
412            #Step
413            self.opt_state = self.step(next(self.itercount), self.opt_state, batch, batch_train)
414
415            # Log training metrics, only use host 0 to record results
416            if step % config.logging.log_every_steps == 0:
417                # Get the first replica of the state and batch
418                log_dict = self.evaluator(batch, batch_train)
419                if config.save_wandb:
420                    wandb.log(log_dict, step)
421                end_time = time.time()
422                logger.log_iter(step, start_time, end_time, log_dict)
423                start_time = end_time
424                print(log_dict)
425
426            # Saving
427            if config.saving.save_every_steps is not None:
428                if (step + 1) % config.saving.save_every_steps == 0 or (
429                    step + 1
430                ) == config.training.max_steps:
431                    ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name)
432                    save_checkpoint(self.opt_state, ckpt_path, step, keep=config.saving.num_keep_ckpts)
433
434
435        #Run summary
436        log_dict['total_time'] = time.time() - t0
437        log_dict['epochs'] = step + 1
438
439        return log_dict
PI_DeepONet(config)
192    def __init__(self, config):
193        # Network initialization and evaluation functions
194        self.branch_init, self.branch_apply = modified_MLP(config.arch.branch_layers, activation = np.tanh)
195        self.trunk_init, self.trunk_apply = modified_MLP(config.arch.trunk_layers, activation = np.tanh)
196
197        # Initialize
198        self.branch_params = self.branch_init(rng_key = random.PRNGKey(config.seed))
199        self.trunk_params = self.trunk_init(rng_key = random.PRNGKey(config.seed + 1))
200        self.params = (self.branch_params, self.trunk_params,{'b': config.weights['b'],'res': config.weights['res'],'data': config.weights['data'],'ic' : config.weights['ic']})
201
202        # Use optimizers to set optimizer initialization and update functions
203        self.opt_init, \
204        self.opt_update, \
205        self.get_params = optimizers.adam(optimizers.exponential_decay(config.optim.learning_rate,
206                                                                      decay_steps = config.optim.decay_steps,
207                                                                      decay_rate = config.optim.decay_rate))
208        self.opt_state = self.opt_init(self.params)
209
210        # Used to restore the trained model parameters
211        _, self.unravel_params = ravel_pytree(self.params)
212
213        self.itercount = itertools.count()
214
215        # Residual net and boundary condition loss
216        self.loss_bc = config.loss_bc
217        self.residual_net = config.residual_net
218        self.config = config
219
220        # Limits domain
221        self.xl = config.xl
222        self.xu = config.xu
223
224        #Vmap neural net
225        self.pred_fn = vmap(self.operator_net, (None, 0, 0, 0))
226
227        #Vmap residual operator
228        if self.residual_net is not None:
229            self.r_pred_fn = vmap(self.residual_net, (None, 0, 0, 0))
230
231        #Vmap train and test data
232        self.pred_batch = vmap(
233            vmap(
234                vmap(self.operator_net, (None, None, 0, None)),(None,None,None,0)
235            ),(None,0,None,None)
236        )
237
238        self.pred_batch_xt = vmap(
239                vmap(self.operator_net, (None, 0, None, None)),(None,None,0,0))
240        if self.residual_net is not None:
241            self.r_pred_batch = vmap(
242                vmap(
243                    vmap(self.residual_net, (None, None, 0, None)),(None,None,None,0)
244                ),(None,0,None,None)
245            )
246
247        #Data
248        self.u_test = config.u_test
249        self.u_train = config.u_train
250        self.x_mesh = config.x_mesh
251        self.t_mesh = config.t_mesh
252
253        #Weights
254        self.w = config.weights
branch_params
trunk_params
params
opt_state
itercount
loss_bc
residual_net
config
xl
xu
pred_fn
pred_batch
pred_batch_xt
u_test
u_train
x_mesh
t_mesh
w
@partial(jit, static_argnums=(0,))
def operator_net(self, params, u, x, t):
257    @partial(jit, static_argnums=(0,))
258    def operator_net(self, params, u, x, t):
259        branch_params, trunk_params,_ = params
260        y = np.stack([x,t])
261        B = self.branch_apply(branch_params, u)
262        T = self.trunk_apply(trunk_params, y)
263        outputs = np.sum(B * T)
264        return outputs
@partial(jit, static_argnums=(0,))
def loss_res(self, params, batch):
267    @partial(jit, static_argnums=(0,))
268    def loss_res(self, params, batch):
269        # Compute forward pass
270        pred = self.residual_net(self.operator_net,params,batch)
271        # Compute loss
272        loss = np.mean((pred)**2)
273        return loss
@partial(jit, static_argnums=(0,))
def loss_data(self, params, batch_train):
276    @partial(jit, static_argnums=(0,))
277    def loss_data(self,params,batch_train):
278        pred = self.pred_batch(params,batch_train['u0'],batch_train['x'],batch_train['t'])
279        return np.mean((pred - batch_train['u']) ** 2)
@partial(jit, static_argnums=(0,))
def loss_ic(self, params, batch):
282    @partial(jit, static_argnums=(0,))
283    def loss_ic(self, params, batch):
284        # Compute forward pass
285        pred = self.pred_batch_xt(params,batch['u0'], self.x_mesh, np.zeros(self.x_mesh.shape[0]))
286        # Compute loss
287        loss = np.mean((pred - batch['u0'].transpose())**2)
288        return loss
@partial(jit, static_argnums=(0,))
def loss(self, params, batch, batch_train):
291    @partial(jit, static_argnums=(0,))
292    def loss(self, params, batch, batch_train):
293        loss_bc = 0.0
294        loss_data = 0.0
295        loss_res = 0.0
296        loss_ic = 0.0
297        if self.loss_bc is not None:
298            loss_bc = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
299        if self.residual_net is not None:
300            loss_res = self.loss_res(params, batch)
301        if batch_train is not None:
302            loss_data = self.loss_data(params,batch_train)
303        loss = (params[-1]['b'] ** 2) * loss_bc +  (params[-1]['res'] ** 2) * loss_res + (params[-1]['data'] ** 2) * loss_data + (params[-1]['ic'] ** 2) * self.loss_ic(params,batch)
304        return loss
@partial(jit, static_argnums=(0,))
def step(self, i, opt_state, batch, batch_train):
307    @partial(jit, static_argnums=(0,))
308    def step(self, i, opt_state, batch, batch_train):
309        params = self.get_params(opt_state)
310        g = grad(self.loss)(params, batch, batch_train)
311        if self.config.sa:
312            g[-1]['b'] = -g[-1]['b']
313            g[-1]['res'] = -g[-1]['res']
314            g[-1]['data'] = -g[-1]['data']
315            g[-1]['ic'] = -g[-1]['ic']
316        return self.opt_update(i, g, opt_state)
def evaluator(self, batch, batch_train):
318    def evaluator(self,batch,batch_train):
319        log_dict = {}
320        params = self.get_params(self.opt_state)
321        #Test loss
322        if self.u_test is not None:
323            pred = self.pred_batch(params, self.u_test[:,0,:], self.x_mesh, self.t_mesh)
324            log_dict['test_L2'] = np.mean(np.sqrt(np.mean((pred - self.u_test) ** 2,[1,2])/np.mean((self.u_test) ** 2,[1,2])))
325
326        #Train
327        if self.loss_bc is not None:
328            log_dict['bc_loss'] = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
329        if self.residual_net is not None:
330            log_dict['res_loss'] = self.loss_res(params,batch)
331        log_dict['ic_loss'] = self.loss_ic(params,batch)
332        if batch_train is not None:
333            log_dict['data_loss'] = self.loss_data(params,batch_train)
334        if self.config.sa:
335            log_dict['Wb'] = params[-1]['b']
336            log_dict['Wres'] = params[-1]['res']
337            log_dict['Wdata'] = params[-1]['data']
338            log_dict['Wic'] = params[-1]['ic']
339
340        return log_dict
def train(self):
343    def train(self):
344        config = self.config
345        if config.save_wandb:
346            wandb_config = config.wandb
347            wandb.init(project = wandb_config.project, name = wandb_config.name)
348
349        #Initialize the initial data sampler
350        if config.initial_data is None:
351            initial_data = generate_initial_data(config.N0,int(config.size),kernel = config.kernel,xl = config.xl,xu = config.xu,key = 0)
352        else:
353            initial_data = config.initial_data
354        initial_sampler = iter(InitialDataSampler(initial_data, config.N))
355
356        # Initialize the residual sampler
357        dom = np.array([[config.xl, config.xu],[config.tl, config.tu]])
358        res_sampler = iter(UniformSampler(dom, config.Q))
359
360        # Initialize the boundary condition sampler
361        bc_sampler = iter(UniformSampler(np.array([[config.tl, config.tu]]), config.N0))
362
363        #Initialize the training data sampler
364        batch_train = None
365        if config.u_train is not None:
366            data_sampler = iter(DataSampler(config.u_train, config.training.batch_size_train_data))
367
368        # Logger
369        logger = Logger()
370        batch = {'u0': None,'x': None,'t': None,'t_bc': None}
371
372        #Train
373        w = None
374        print("Training DeepONet...")
375        start_time = time.time()
376        t0 = start_time
377        for step in range(config.training.max_steps):
378            print(step)
379            batch['u0'] = next(initial_sampler)[0,:,:]
380            res_data_tmp = next(res_sampler)[0,:,:]
381            batch['x'] = res_data_tmp[:,0]
382            batch['t'] = res_data_tmp[:,1]
383            batch['t_bc'] = next(bc_sampler)[0,:]
384            if config.u_train is not None:
385                u = next(data_sampler)
386                u = u.reshape((u.shape[1],u.shape[2],u.shape[3]))
387                batch_train = {'u0': u[:,0,:],'u': u,'t': self.t_mesh,'x': self.x_mesh}
388
389            #Initialise weights
390            if config.sa and step == 0:
391                params = self.get_params(self.opt_state)
392                if self.loss_bc is not None:
393                    lb = self.loss_bc(self.pred_batch,params,{'u0': batch['u0'],'t': batch['t_bc']},self.xl,self.xu)
394                else:
395                    lb = np.array(0.0)
396                if self.residual_net is not None:
397                    lr = self.loss_res(params, batch)
398                else:
399                    lr = np.array(0.0)
400                if batch_train is not None:
401                    ld = self.loss_data(params,batch_train)
402                else:
403                    ld = np.array(0.0)
404                li = self.loss_ic(params,batch)
405                total = lb + lr + ld + li
406                w = {'b': np.sqrt(total/lb),'res': np.sqrt(total/lr),'data': np.sqrt(total/ld),'ic' : np.sqrt(total/li)}
407                self.params = (self.branch_params, self.trunk_params,w)
408                self.opt_state = self.opt_init(self.params)
409                # Used to restore the trained model parameters
410                _, self.unravel_params = ravel_pytree(self.params)
411
412            #Step
413            self.opt_state = self.step(next(self.itercount), self.opt_state, batch, batch_train)
414
415            # Log training metrics, only use host 0 to record results
416            if step % config.logging.log_every_steps == 0:
417                # Get the first replica of the state and batch
418                log_dict = self.evaluator(batch, batch_train)
419                if config.save_wandb:
420                    wandb.log(log_dict, step)
421                end_time = time.time()
422                logger.log_iter(step, start_time, end_time, log_dict)
423                start_time = end_time
424                print(log_dict)
425
426            # Saving
427            if config.saving.save_every_steps is not None:
428                if (step + 1) % config.saving.save_every_steps == 0 or (
429                    step + 1
430                ) == config.training.max_steps:
431                    ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name)
432                    save_checkpoint(self.opt_state, ckpt_path, step, keep=config.saving.num_keep_ckpts)
433
434
435        #Run summary
436        log_dict['total_time'] = time.time() - t0
437        log_dict['epochs'] = step + 1
438
439        return log_dict
def bc_loss_periodic(pred_batch, params, batch, xl, xu):
442def bc_loss_periodic(pred_batch,params,batch,xl,xu):
443    pred_xl = pred_batch(
444        params, batch['u0'], xl + np.zeros((1,)), batch['t'].reshape((batch['t'].shape[0],))
445    )
446    pred_xu = pred_batch(
447        params, batch['u0'], xu + np.zeros((1,)), batch['t'].reshape((batch['t'].shape[0],))
448    )
449    return np.mean((pred_xl - pred_xu) ** 2)