jinnax.csf

  1#Funtions to train PINNs for csf
  2from absl import logging
  3from alive_progress import alive_bar
  4from functools import partial
  5import jax
  6from jax import grad, hessian, jit, lax, vmap, jacrev
  7import jax.numpy as jnp
  8from jax.tree_util import tree_map
  9from jaxpi import archs
 10from jaxpi.evaluator import BaseEvaluator
 11from jaxpi.logging import Logger
 12from jaxpi.models import ForwardIVP
 13from jaxpi.samplers import UniformSampler
 14from jaxpi.utils import ntk_fn, restore_checkpoint, save_checkpoint
 15from jinnax import class_csf
 16import matplotlib.pyplot as plt
 17import ml_collections
 18import numpy as np
 19import optax
 20import os
 21import pandas as pd
 22import random
 23import scipy.io
 24import sys
 25import time
 26import wandb
 27
 28def get_base_config():
 29    """
 30    Base config file for training PINN for CSF in jaxpi
 31
 32    Returns
 33    -------
 34    ml_collections config dictionary
 35    """
 36    #Get the default hyperparameter configuration.
 37    config = ml_collections.ConfigDict()
 38
 39    # Weights & Biases
 40    config.wandb = wandb = ml_collections.ConfigDict()
 41    wandb.tag = None
 42
 43    # Arch
 44    config.arch = arch = ml_collections.ConfigDict()
 45    arch.arch_name = "ModifiedMlp"
 46    arch.num_layers = 4
 47    arch.hidden_dim = 256
 48    arch.out_dim = 2
 49    arch.activation = "tanh"
 50    arch.reparam = ml_collections.ConfigDict(
 51        {"type": "weight_fact", "mean": 1.0, "stddev": 0.1}
 52    )
 53
 54    # Optim
 55    config.optim = optim = ml_collections.ConfigDict()
 56    optim.optimizer = "Adam"
 57    optim.beta1 = 0.9
 58    optim.beta2 = 0.999
 59    optim.eps = 1e-8
 60    optim.learning_rate = 1e-3
 61    optim.decay_rate = 0.9
 62    optim.decay_steps = 2000
 63    optim.grad_accum_steps = 0
 64    optim.warmup_steps = 0
 65    optim.grad_accum_steps = 0
 66
 67    # Training
 68    config.training = training = ml_collections.ConfigDict()
 69    training.batch_size_per_device = 4096
 70
 71    # Weighting
 72    config.weighting = weighting = ml_collections.ConfigDict()
 73    weighting.scheme = "grad_norm"
 74    weighting.momentum = 0.9
 75    weighting.update_every_steps = 1000
 76
 77    weighting.use_causal = True
 78    weighting.causal_tol = 1.0
 79    weighting.num_chunks = 16
 80    optim.staircase = False
 81
 82    # Logging
 83    config.logging = logging = ml_collections.ConfigDict()
 84    logging.log_every_steps = 1000
 85    logging.log_errors = True
 86    logging.log_losses = True
 87    logging.log_weights = True
 88    logging.log_preds = False
 89    logging.log_grads = False
 90    logging.log_ntk = False
 91
 92    # Saving
 93    config.saving = saving = ml_collections.ConfigDict()
 94    saving.save_every_steps = 1000
 95    saving.num_keep_ckpts = 1000000
 96
 97    # # Input shape for initializing Flax models
 98    config.input_dim = 2
 99
100    return config
101
102
103#Demo in time for 2D PINN
104def demo_time_CSF(data,type = 'DN',radius = None,file_name_save = 'result_pinn_CSF_demo',title = '',framerate = 10,ffmpeg = 'ffmpeg'):
105    """
106    Demo video with the time evolution of a CSF in a circle
107    ----------
108    Parameters
109    ----------
110    data : jax.array
111
112        Data with the predicted values
113
114    type : str
115
116        Type of problem
117
118    radius : float
119
120        Radius of circle for types 'DN' and 'NN'
121
122    file_name_save : str
123
124        File prefix to save the plots and video. Default 'result_pinn_CSF_demo'
125
126    title : str
127
128        Title for plots
129
130    framerate : int
131
132        Framerate for video. Default 10
133
134    ffmpeg : str
135
136        Path to ffmpeg
137
138    Returns
139    -------
140    None
141    """
142    #Create folder to save plots
143    os.system('mkdir ' + file_name_save)
144
145    #Plot parameters
146    tdom = jnp.unique(data[:,0])
147    ylo = jnp.min(data[data[:,0] == jnp.min(tdom),3])
148    ylo = ylo - 0.1*jnp.abs(ylo)
149    yup = jnp.max(data[data[:,0] == jnp.min(tdom),3])
150    yup = yup + 0.1*jnp.abs(yup)
151
152    #Circle data
153    if type == 'DN' or type == 'NN':
154        circle = jnp.array([[radius*jnp.sin(t),radius*jnp.cos(t)] for t in jnp.linspace(0,2*jnp.pi,1000)])
155
156    #Create images
157    k = 1
158    with alive_bar(len(tdom)) as bar:
159        for t in tdom:
160            #Test data
161            x_step = data[data[:,0] == t,1]
162            ux_step = data[data[:,0] == t,2]
163            uy_step = data[data[:,0] == t,3]
164            #Initialize plot
165            fig, ax = plt.subplots(1,1,figsize = (10,10))
166            #Create
167            ax.plot(ux_step,uy_step,'b-',linewidth=2)
168            if type == 'DN' or type == 'NN':
169                ax.plot(circle[:,0],circle[:,1],'r-',linewidth=2)
170            ax.set_xlabel(' ')
171            if type != 'DN' and type != 'NN':
172                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
173            #Title
174            fig.suptitle(title + 't = ' + str(round(t,4)))
175            fig.tight_layout()
176
177            #Show and save
178            fig = plt.gcf()
179            fig.savefig(file_name_save + '/' + str(k) + '.png')
180            k = k + 1
181            plt.close()
182            bar()
183
184    #Create demo video
185    os.system(ffmpeg + ' -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '_time_demo.mp4')
186
187
188def train_csf(config: ml_collections.ConfigDict):
189    """
190    Train PINN for CSF in jaxpi
191    ----------
192    Parameters
193    ----------
194    config : ml_collections.ConfigDict
195
196        Dictionary for training PINN in jaxpi
197
198    Returns
199    -------
200    model, log_dict
201    """
202    if config.save_wandb:
203        wandb_config = config.wandb
204        wandb.init(project = wandb_config.project, name = wandb_config.name)
205
206    # Define the time and space domain
207    dom = jnp.array([[config.tl, config.tu], [config.xl, config.xu]])
208
209    # Initialize the residual sampler
210    res_sampler = iter(UniformSampler(dom, config.training.batch_size_per_device))
211
212    # Initialize the model
213    if config.type_csf == 'DN':
214        model = class_csf.DN_csf(config)
215    elif config.type_csf == 'NN':
216        model = class_csf.NN_csf(config)
217    elif config.type_csf == 'DD':
218        model = class_csf.DD_csf(config)
219    elif config.type_csf == 'closed':
220        model = class_csf.closed_csf(config)
221
222    # Logger
223    logger = Logger()
224
225    # Initialize evaluator
226    key = jax.random.split(jax.random.PRNGKey(config.seed),4)
227    x0_test = jax.random.uniform(key = jax.random.PRNGKey(key[0,0]),minval = config.xl,maxval = config.xu,shape = (config.N0,1))
228    u1_0_test,u2_0_test = config.uinitial(x0_test)
229    tb_test = jax.random.uniform(key = jax.random.PRNGKey(key[1,0]),minval = config.tl,maxval = config.tu,shape = (config.Nb,1))
230    xc_test = jax.random.uniform(key = jax.random.PRNGKey(key[2,0]),minval = config.xl,maxval = config.xu,shape = (config.Nc ** 2,1))
231    tc_test = jax.random.uniform(key = jax.random.PRNGKey(key[3,0]),minval = config.tl,maxval = config.tu,shape = (config.Nc ** 2,1))
232    if config.type_csf == 'DN':
233        evaluator = class_csf.DN_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
234    elif config.type_csf == 'NN':
235        evaluator = class_csf.NN_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
236    elif config.type_csf == 'DD':
237        evaluator = class_csf.DD_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
238    elif config.type_csf == 'closed':
239        evaluator = class_csf.closed_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
240
241    # jit warm up
242    print("Training CSF...")
243    start_time = time.time()
244    t0 = start_time
245    for step in range(config.training.max_steps):
246        batch = next(res_sampler)
247        model.state = model.step(model.state, batch)
248
249        # Update weights if necessary
250        if config.weighting.scheme in ["grad_norm", "ntk"]:
251            if step % config.weighting.update_every_steps == 0:
252                model.state = model.update_weights(model.state, batch)
253
254        # Log training metrics, only use host 0 to record results
255        if jax.process_index() == 0:
256            if step % config.logging.log_every_steps == 0:
257                # Get the first replica of the state and batch
258                state = jax.device_get(tree_map(lambda x: x[0], model.state))
259                batch = jax.device_get(tree_map(lambda x: x[0], batch))
260                log_dict = evaluator(state, batch)
261                if config.save_wandb:
262                    wandb.log(log_dict, step)
263                end_time = time.time()
264                logger.log_iter(step, start_time, end_time, log_dict)
265                start_time = end_time
266
267        # Saving
268        if config.saving.save_every_steps is not None:
269            if (step + 1) % config.saving.save_every_steps == 0 or (
270                step + 1
271            ) == config.training.max_steps:
272                ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name)
273                save_checkpoint(model.state, ckpt_path, keep=config.saving.num_keep_ckpts)
274                if config.type_csf == 'DN':
275                    if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol and log_dict['ln_test'] < config.dn_tol:
276                        break
277                elif config.type_csf == 'NN':
278                    if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol and log_dict['ln_test'] < config.dn_tol and log_dict['rn_test'] < config.dn_tol:
279                        break
280                elif config.type_csf == 'DD':
281                    if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol:
282                        break
283                elif config.type_csf == 'closed':
284                    if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol:
285                        break
286
287    #Run summary
288    log_dict['total_time'] = time.time() - t0
289    log_dict['epochs'] = step + 1
290
291    return model, log_dict
292
293def evaluate(config: ml_collections.ConfigDict):
294    """
295    Evaluate PINN for CSF trained in jaxpi
296    ----------
297    Parameters
298    ----------
299    config : ml_collections.ConfigDict
300
301        Dictionary for training PINN in jaxpi
302
303    uninitial : function
304
305        Function that computes the initial condition
306
307
308    Returns
309    -------
310    predicted values
311    """
312    # Initialize the model
313    if config.type_csf == 'DN':
314        model = class_csf.DN_csf(config)
315    elif config.type_csf == 'NN':
316        model = class_csf.NN_csf(config)
317    elif config.type_csf == 'DD':
318        model = class_csf.DD_csf(config)
319    elif config.type_csf == 'closed':
320        model = class_csf.closed_csf(config)
321
322    # Restore the checkpoint
323    ckpt_path = os.path.join(
324        os.getcwd(), "ckpt", config.wandb.name,
325    )
326    model.state = restore_checkpoint(model.state, ckpt_path)
327    params = model.state.params
328
329    #Collocation data
330    tx = jnp.array([[t,x] for t in jnp.linspace(config.tl, config.tu, config.Nt) for x in jnp.linspace(config.xl, config.xu, config.Nc)])
331
332    #Predict
333    u1_pred = model.u1_pred_fn(params, tx[:,0], tx[:,1])
334    u2_pred = model.u2_pred_fn(params, tx[:,0], tx[:,1])
335
336    #Save
337    pred = jnp.append(tx,jnp.append(u1_pred.reshape((u1_pred.shape[0],1)),u2_pred.reshape((u2_pred.shape[0],1)),1),1)
338    jnp.save(config.wandb.project + '_' + config.wandb.name + '.npy',pred)
339
340    return pred
341
342def csf(uinitial,xl,xu,tl,tu,type = 'DN',radius = None,file_name = 'test',Nt = 400,N0 = 10000,Nb = 10000,Nc = 500,config = None,save_wandb = False,wandb_project = 'CSF_project',seed = 534,demo = True,max_epochs = 150000,res_tol = 5e-5,dn_tol = 5e-4,ic_tol = 0.01,framerate = 10,ffmpeg = 'ffmpeg'):
343    """
344    Train PINN for CSF in jaxpi
345    ----------
346    Parameters
347    ----------
348    uninitial : function
349
350        Function that computes the initial condition
351
352    xl, xu, tl, tu : float
353
354        Limits of the x and t domain
355
356    type : str
357
358        Type of CSF problem to train: 'DN' or 'NN' (in a circle), 'DD' or 'closed'
359
360    radius : float
361
362        Radius of the circle for 'DN' and 'NN'
363
364    file_name : str
365
366        File name to save results
367
368
369    Nt : int
370
371        Sample size for grid in t
372
373    N0, Nb : int
374
375        Initial and boundary condition test sample size
376
377    Nc : int
378
379        Number of points in each direction in sample size for PDE residuals
380
381
382    config : ml_collections.ConfigDict
383
384        Config dictionary to train PINNs in jaxpi. If not provided, use basic configurations
385
386    save_wandb : logical
387
388        Whether to save results in wandb
389
390    wandb_project : str
391
392        Name of wandb project
393
394    seed : int
395
396        Seed for initialising neural network
397
398    demo : logical
399
400        Whether to generate video with result
401
402    max_epochs : int
403
404        Maximum number of epochs to train
405
406    res_tol, dn_tol, ic_tol : float
407
408        Tolerance on test errors for early stop
409
410    framerate : int
411
412        Framerate for video. Default 10
413
414    ffmpeg : str
415
416        Path to ffmpeg
417
418    Returns
419    -------
420    model, log_dict
421    """
422    #Set config file
423    if config is None:
424        config = get_base_config()
425
426    config.wandb.project = wandb_project
427    config.wandb.name = file_name
428    config.seed = seed
429    config.type_csf = type
430    config.save_wandb = save_wandb
431    config.uinitial = uinitial
432    config.xl = xl
433    config.xu = xu
434    config.tl = tl
435    config.tu = tu
436    config.radius = radius
437    config.rd = jnp.append(uinitial(xu)[0],uinitial(xu)[1])
438    config.ld = jnp.append(uinitial(xl)[0],uinitial(xl)[1])
439    config.Nt = Nt
440    config.N0 = N0
441    config.Nc = Nc
442    config.Nb = Nb
443    config.res_tol = res_tol
444    config.dn_tol = dn_tol
445    config.ic_tol = ic_tol
446    config.training.max_steps = max_epochs
447    if type == 'DN':
448        config.weighting.init_weights = ml_collections.ConfigDict(
449            {"ic": 1.0,
450            "res1": 1.0,
451            "res2": 1.0,
452            'rd': 1.0,
453            'ld': 1.0,
454            'ln': 1.0
455            }
456        )
457    elif type == 'NN':
458        config.weighting.init_weights = ml_collections.ConfigDict(
459            {"ic": 1.0,
460            "res1": 1.0,
461            "res2": 1.0,
462            'rd': 1.0,
463            'ld': 1.0,
464            'ln': 1.0,
465            'rn': 1.0
466            }
467        )
468    elif type == 'DD':
469        config.weighting.init_weights = ml_collections.ConfigDict(
470            {"ic": 1.0,
471            "res1": 1.0,
472            "res2": 1.0,
473            'rd': 1.0,
474            'ld': 1.0
475            }
476        )
477    if type == 'closed': #Add periodic condition
478        config.weighting.init_weights = ml_collections.ConfigDict(
479            {"ic": 1.0,
480            "res1": 1.0,
481            "res2": 1.0,
482            'periodic1': 1.0,
483            'periodic2': 1.0
484            }
485        )
486
487    #Train model
488    model, results = train_csf(config)
489
490    #Evaluate
491    pred = evaluate(config)
492
493    #Generate demo
494    if demo:
495        demo_time_CSF(pred,radius = radius,type = type,file_name_save = file_name,framerate = 10,ffmpeg = ffmpeg)
496
497    #Print results
498    pd_results = pd.DataFrame(list(results.items()))
499    print(pd_results)
500    pd_results.to_csv(file_name + '_results.csv')
501
502    return results
def get_base_config():
 29def get_base_config():
 30    """
 31    Base config file for training PINN for CSF in jaxpi
 32
 33    Returns
 34    -------
 35    ml_collections config dictionary
 36    """
 37    #Get the default hyperparameter configuration.
 38    config = ml_collections.ConfigDict()
 39
 40    # Weights & Biases
 41    config.wandb = wandb = ml_collections.ConfigDict()
 42    wandb.tag = None
 43
 44    # Arch
 45    config.arch = arch = ml_collections.ConfigDict()
 46    arch.arch_name = "ModifiedMlp"
 47    arch.num_layers = 4
 48    arch.hidden_dim = 256
 49    arch.out_dim = 2
 50    arch.activation = "tanh"
 51    arch.reparam = ml_collections.ConfigDict(
 52        {"type": "weight_fact", "mean": 1.0, "stddev": 0.1}
 53    )
 54
 55    # Optim
 56    config.optim = optim = ml_collections.ConfigDict()
 57    optim.optimizer = "Adam"
 58    optim.beta1 = 0.9
 59    optim.beta2 = 0.999
 60    optim.eps = 1e-8
 61    optim.learning_rate = 1e-3
 62    optim.decay_rate = 0.9
 63    optim.decay_steps = 2000
 64    optim.grad_accum_steps = 0
 65    optim.warmup_steps = 0
 66    optim.grad_accum_steps = 0
 67
 68    # Training
 69    config.training = training = ml_collections.ConfigDict()
 70    training.batch_size_per_device = 4096
 71
 72    # Weighting
 73    config.weighting = weighting = ml_collections.ConfigDict()
 74    weighting.scheme = "grad_norm"
 75    weighting.momentum = 0.9
 76    weighting.update_every_steps = 1000
 77
 78    weighting.use_causal = True
 79    weighting.causal_tol = 1.0
 80    weighting.num_chunks = 16
 81    optim.staircase = False
 82
 83    # Logging
 84    config.logging = logging = ml_collections.ConfigDict()
 85    logging.log_every_steps = 1000
 86    logging.log_errors = True
 87    logging.log_losses = True
 88    logging.log_weights = True
 89    logging.log_preds = False
 90    logging.log_grads = False
 91    logging.log_ntk = False
 92
 93    # Saving
 94    config.saving = saving = ml_collections.ConfigDict()
 95    saving.save_every_steps = 1000
 96    saving.num_keep_ckpts = 1000000
 97
 98    # # Input shape for initializing Flax models
 99    config.input_dim = 2
100
101    return config

Base config file for training PINN for CSF in jaxpi

Returns

ml_collections config dictionary

def demo_time_CSF( data, type='DN', radius=None, file_name_save='result_pinn_CSF_demo', title='', framerate=10, ffmpeg='ffmpeg'):
105def demo_time_CSF(data,type = 'DN',radius = None,file_name_save = 'result_pinn_CSF_demo',title = '',framerate = 10,ffmpeg = 'ffmpeg'):
106    """
107    Demo video with the time evolution of a CSF in a circle
108    ----------
109    Parameters
110    ----------
111    data : jax.array
112
113        Data with the predicted values
114
115    type : str
116
117        Type of problem
118
119    radius : float
120
121        Radius of circle for types 'DN' and 'NN'
122
123    file_name_save : str
124
125        File prefix to save the plots and video. Default 'result_pinn_CSF_demo'
126
127    title : str
128
129        Title for plots
130
131    framerate : int
132
133        Framerate for video. Default 10
134
135    ffmpeg : str
136
137        Path to ffmpeg
138
139    Returns
140    -------
141    None
142    """
143    #Create folder to save plots
144    os.system('mkdir ' + file_name_save)
145
146    #Plot parameters
147    tdom = jnp.unique(data[:,0])
148    ylo = jnp.min(data[data[:,0] == jnp.min(tdom),3])
149    ylo = ylo - 0.1*jnp.abs(ylo)
150    yup = jnp.max(data[data[:,0] == jnp.min(tdom),3])
151    yup = yup + 0.1*jnp.abs(yup)
152
153    #Circle data
154    if type == 'DN' or type == 'NN':
155        circle = jnp.array([[radius*jnp.sin(t),radius*jnp.cos(t)] for t in jnp.linspace(0,2*jnp.pi,1000)])
156
157    #Create images
158    k = 1
159    with alive_bar(len(tdom)) as bar:
160        for t in tdom:
161            #Test data
162            x_step = data[data[:,0] == t,1]
163            ux_step = data[data[:,0] == t,2]
164            uy_step = data[data[:,0] == t,3]
165            #Initialize plot
166            fig, ax = plt.subplots(1,1,figsize = (10,10))
167            #Create
168            ax.plot(ux_step,uy_step,'b-',linewidth=2)
169            if type == 'DN' or type == 'NN':
170                ax.plot(circle[:,0],circle[:,1],'r-',linewidth=2)
171            ax.set_xlabel(' ')
172            if type != 'DN' and type != 'NN':
173                ax.set_ylim([1.3 * ylo.tolist(),1.3 * yup.tolist()])
174            #Title
175            fig.suptitle(title + 't = ' + str(round(t,4)))
176            fig.tight_layout()
177
178            #Show and save
179            fig = plt.gcf()
180            fig.savefig(file_name_save + '/' + str(k) + '.png')
181            k = k + 1
182            plt.close()
183            bar()
184
185    #Create demo video
186    os.system(ffmpeg + ' -framerate ' + str(framerate) + ' -i ' + file_name_save + '/' + '%00d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p ' + file_name_save + '_time_demo.mp4')

Demo video with the time evolution of a CSF in a circle

Parameters

data : jax.array

Data with the predicted values

type : str

Type of problem

radius : float

Radius of circle for types 'DN' and 'NN'

file_name_save : str

File prefix to save the plots and video. Default 'result_pinn_CSF_demo'

title : str

Title for plots

framerate : int

Framerate for video. Default 10

ffmpeg : str

Path to ffmpeg

Returns

None

def train_csf(config: ml_collections.config_dict.config_dict.ConfigDict):
189def train_csf(config: ml_collections.ConfigDict):
190    """
191    Train PINN for CSF in jaxpi
192    ----------
193    Parameters
194    ----------
195    config : ml_collections.ConfigDict
196
197        Dictionary for training PINN in jaxpi
198
199    Returns
200    -------
201    model, log_dict
202    """
203    if config.save_wandb:
204        wandb_config = config.wandb
205        wandb.init(project = wandb_config.project, name = wandb_config.name)
206
207    # Define the time and space domain
208    dom = jnp.array([[config.tl, config.tu], [config.xl, config.xu]])
209
210    # Initialize the residual sampler
211    res_sampler = iter(UniformSampler(dom, config.training.batch_size_per_device))
212
213    # Initialize the model
214    if config.type_csf == 'DN':
215        model = class_csf.DN_csf(config)
216    elif config.type_csf == 'NN':
217        model = class_csf.NN_csf(config)
218    elif config.type_csf == 'DD':
219        model = class_csf.DD_csf(config)
220    elif config.type_csf == 'closed':
221        model = class_csf.closed_csf(config)
222
223    # Logger
224    logger = Logger()
225
226    # Initialize evaluator
227    key = jax.random.split(jax.random.PRNGKey(config.seed),4)
228    x0_test = jax.random.uniform(key = jax.random.PRNGKey(key[0,0]),minval = config.xl,maxval = config.xu,shape = (config.N0,1))
229    u1_0_test,u2_0_test = config.uinitial(x0_test)
230    tb_test = jax.random.uniform(key = jax.random.PRNGKey(key[1,0]),minval = config.tl,maxval = config.tu,shape = (config.Nb,1))
231    xc_test = jax.random.uniform(key = jax.random.PRNGKey(key[2,0]),minval = config.xl,maxval = config.xu,shape = (config.Nc ** 2,1))
232    tc_test = jax.random.uniform(key = jax.random.PRNGKey(key[3,0]),minval = config.tl,maxval = config.tu,shape = (config.Nc ** 2,1))
233    if config.type_csf == 'DN':
234        evaluator = class_csf.DN_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
235    elif config.type_csf == 'NN':
236        evaluator = class_csf.NN_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
237    elif config.type_csf == 'DD':
238        evaluator = class_csf.DD_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
239    elif config.type_csf == 'closed':
240        evaluator = class_csf.closed_csf_Evaluator(config, model, x0_test, tb_test, xc_test, tc_test, u1_0_test, u2_0_test)
241
242    # jit warm up
243    print("Training CSF...")
244    start_time = time.time()
245    t0 = start_time
246    for step in range(config.training.max_steps):
247        batch = next(res_sampler)
248        model.state = model.step(model.state, batch)
249
250        # Update weights if necessary
251        if config.weighting.scheme in ["grad_norm", "ntk"]:
252            if step % config.weighting.update_every_steps == 0:
253                model.state = model.update_weights(model.state, batch)
254
255        # Log training metrics, only use host 0 to record results
256        if jax.process_index() == 0:
257            if step % config.logging.log_every_steps == 0:
258                # Get the first replica of the state and batch
259                state = jax.device_get(tree_map(lambda x: x[0], model.state))
260                batch = jax.device_get(tree_map(lambda x: x[0], batch))
261                log_dict = evaluator(state, batch)
262                if config.save_wandb:
263                    wandb.log(log_dict, step)
264                end_time = time.time()
265                logger.log_iter(step, start_time, end_time, log_dict)
266                start_time = end_time
267
268        # Saving
269        if config.saving.save_every_steps is not None:
270            if (step + 1) % config.saving.save_every_steps == 0 or (
271                step + 1
272            ) == config.training.max_steps:
273                ckpt_path = os.path.join(os.getcwd(),"ckpt",config.wandb.name)
274                save_checkpoint(model.state, ckpt_path, keep=config.saving.num_keep_ckpts)
275                if config.type_csf == 'DN':
276                    if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol and log_dict['ln_test'] < config.dn_tol:
277                        break
278                elif config.type_csf == 'NN':
279                    if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol and log_dict['ln_test'] < config.dn_tol and log_dict['rn_test'] < config.dn_tol:
280                        break
281                elif config.type_csf == 'DD':
282                    if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol and log_dict['ld_test'] < config.dn_tol and log_dict['rd_test'] < config.dn_tol:
283                        break
284                elif config.type_csf == 'closed':
285                    if log_dict['res1_test'] < config.res_tol and log_dict['res2_test'] < config.res_tol and log_dict['ic_rel_test'] < config.ic_tol:
286                        break
287
288    #Run summary
289    log_dict['total_time'] = time.time() - t0
290    log_dict['epochs'] = step + 1
291
292    return model, log_dict

Train PINN for CSF in jaxpi

Parameters

config : ml_collections.ConfigDict

Dictionary for training PINN in jaxpi

Returns

model, log_dict

def evaluate(config: ml_collections.config_dict.config_dict.ConfigDict):
294def evaluate(config: ml_collections.ConfigDict):
295    """
296    Evaluate PINN for CSF trained in jaxpi
297    ----------
298    Parameters
299    ----------
300    config : ml_collections.ConfigDict
301
302        Dictionary for training PINN in jaxpi
303
304    uninitial : function
305
306        Function that computes the initial condition
307
308
309    Returns
310    -------
311    predicted values
312    """
313    # Initialize the model
314    if config.type_csf == 'DN':
315        model = class_csf.DN_csf(config)
316    elif config.type_csf == 'NN':
317        model = class_csf.NN_csf(config)
318    elif config.type_csf == 'DD':
319        model = class_csf.DD_csf(config)
320    elif config.type_csf == 'closed':
321        model = class_csf.closed_csf(config)
322
323    # Restore the checkpoint
324    ckpt_path = os.path.join(
325        os.getcwd(), "ckpt", config.wandb.name,
326    )
327    model.state = restore_checkpoint(model.state, ckpt_path)
328    params = model.state.params
329
330    #Collocation data
331    tx = jnp.array([[t,x] for t in jnp.linspace(config.tl, config.tu, config.Nt) for x in jnp.linspace(config.xl, config.xu, config.Nc)])
332
333    #Predict
334    u1_pred = model.u1_pred_fn(params, tx[:,0], tx[:,1])
335    u2_pred = model.u2_pred_fn(params, tx[:,0], tx[:,1])
336
337    #Save
338    pred = jnp.append(tx,jnp.append(u1_pred.reshape((u1_pred.shape[0],1)),u2_pred.reshape((u2_pred.shape[0],1)),1),1)
339    jnp.save(config.wandb.project + '_' + config.wandb.name + '.npy',pred)
340
341    return pred

Evaluate PINN for CSF trained in jaxpi

Parameters

config : ml_collections.ConfigDict

Dictionary for training PINN in jaxpi

uninitial : function

Function that computes the initial condition

Returns

predicted values

def csf( uinitial, xl, xu, tl, tu, type='DN', radius=None, file_name='test', Nt=400, N0=10000, Nb=10000, Nc=500, config=None, save_wandb=False, wandb_project='CSF_project', seed=534, demo=True, max_epochs=150000, res_tol=5e-05, dn_tol=0.0005, ic_tol=0.01, framerate=10, ffmpeg='ffmpeg'):
343def csf(uinitial,xl,xu,tl,tu,type = 'DN',radius = None,file_name = 'test',Nt = 400,N0 = 10000,Nb = 10000,Nc = 500,config = None,save_wandb = False,wandb_project = 'CSF_project',seed = 534,demo = True,max_epochs = 150000,res_tol = 5e-5,dn_tol = 5e-4,ic_tol = 0.01,framerate = 10,ffmpeg = 'ffmpeg'):
344    """
345    Train PINN for CSF in jaxpi
346    ----------
347    Parameters
348    ----------
349    uninitial : function
350
351        Function that computes the initial condition
352
353    xl, xu, tl, tu : float
354
355        Limits of the x and t domain
356
357    type : str
358
359        Type of CSF problem to train: 'DN' or 'NN' (in a circle), 'DD' or 'closed'
360
361    radius : float
362
363        Radius of the circle for 'DN' and 'NN'
364
365    file_name : str
366
367        File name to save results
368
369
370    Nt : int
371
372        Sample size for grid in t
373
374    N0, Nb : int
375
376        Initial and boundary condition test sample size
377
378    Nc : int
379
380        Number of points in each direction in sample size for PDE residuals
381
382
383    config : ml_collections.ConfigDict
384
385        Config dictionary to train PINNs in jaxpi. If not provided, use basic configurations
386
387    save_wandb : logical
388
389        Whether to save results in wandb
390
391    wandb_project : str
392
393        Name of wandb project
394
395    seed : int
396
397        Seed for initialising neural network
398
399    demo : logical
400
401        Whether to generate video with result
402
403    max_epochs : int
404
405        Maximum number of epochs to train
406
407    res_tol, dn_tol, ic_tol : float
408
409        Tolerance on test errors for early stop
410
411    framerate : int
412
413        Framerate for video. Default 10
414
415    ffmpeg : str
416
417        Path to ffmpeg
418
419    Returns
420    -------
421    model, log_dict
422    """
423    #Set config file
424    if config is None:
425        config = get_base_config()
426
427    config.wandb.project = wandb_project
428    config.wandb.name = file_name
429    config.seed = seed
430    config.type_csf = type
431    config.save_wandb = save_wandb
432    config.uinitial = uinitial
433    config.xl = xl
434    config.xu = xu
435    config.tl = tl
436    config.tu = tu
437    config.radius = radius
438    config.rd = jnp.append(uinitial(xu)[0],uinitial(xu)[1])
439    config.ld = jnp.append(uinitial(xl)[0],uinitial(xl)[1])
440    config.Nt = Nt
441    config.N0 = N0
442    config.Nc = Nc
443    config.Nb = Nb
444    config.res_tol = res_tol
445    config.dn_tol = dn_tol
446    config.ic_tol = ic_tol
447    config.training.max_steps = max_epochs
448    if type == 'DN':
449        config.weighting.init_weights = ml_collections.ConfigDict(
450            {"ic": 1.0,
451            "res1": 1.0,
452            "res2": 1.0,
453            'rd': 1.0,
454            'ld': 1.0,
455            'ln': 1.0
456            }
457        )
458    elif type == 'NN':
459        config.weighting.init_weights = ml_collections.ConfigDict(
460            {"ic": 1.0,
461            "res1": 1.0,
462            "res2": 1.0,
463            'rd': 1.0,
464            'ld': 1.0,
465            'ln': 1.0,
466            'rn': 1.0
467            }
468        )
469    elif type == 'DD':
470        config.weighting.init_weights = ml_collections.ConfigDict(
471            {"ic": 1.0,
472            "res1": 1.0,
473            "res2": 1.0,
474            'rd': 1.0,
475            'ld': 1.0
476            }
477        )
478    if type == 'closed': #Add periodic condition
479        config.weighting.init_weights = ml_collections.ConfigDict(
480            {"ic": 1.0,
481            "res1": 1.0,
482            "res2": 1.0,
483            'periodic1': 1.0,
484            'periodic2': 1.0
485            }
486        )
487
488    #Train model
489    model, results = train_csf(config)
490
491    #Evaluate
492    pred = evaluate(config)
493
494    #Generate demo
495    if demo:
496        demo_time_CSF(pred,radius = radius,type = type,file_name_save = file_name,framerate = 10,ffmpeg = ffmpeg)
497
498    #Print results
499    pd_results = pd.DataFrame(list(results.items()))
500    print(pd_results)
501    pd_results.to_csv(file_name + '_results.csv')
502
503    return results

Train PINN for CSF in jaxpi

Parameters

uninitial : function

Function that computes the initial condition

xl, xu, tl, tu : float

Limits of the x and t domain

type : str

Type of CSF problem to train: 'DN' or 'NN' (in a circle), 'DD' or 'closed'

radius : float

Radius of the circle for 'DN' and 'NN'

file_name : str

File name to save results

Nt : int

Sample size for grid in t

N0, Nb : int

Initial and boundary condition test sample size

Nc : int

Number of points in each direction in sample size for PDE residuals

config : ml_collections.ConfigDict

Config dictionary to train PINNs in jaxpi. If not provided, use basic configurations

save_wandb : logical

Whether to save results in wandb

wandb_project : str

Name of wandb project

seed : int

Seed for initialising neural network

demo : logical

Whether to generate video with result

max_epochs : int

Maximum number of epochs to train

res_tol, dn_tol, ic_tol : float

Tolerance on test errors for early stop

framerate : int

Framerate for video. Default 10

ffmpeg : str

Path to ffmpeg

Returns

model, log_dict