jinnax.data

  1#Functions to process data for training
  2import pandas
  3import jax
  4import jax.numpy as jnp
  5from jax import random
  6import numpy as np
  7import random
  8import sys
  9from PIL import Image
 10from IPython.display import display
 11__docformat__ = "numpy"
 12
 13#Generate d-dimensional data for PINN training
 14def generate_PINNdata(u,xl,xu,tl,tu,Ns = None,Nts = None,Nb = None,Ntb = None,N0 = None,Nc = None,Ntc = None,train = True,d = 1,p = 1,poss = 'grid',posts = 'grid',pos0 = 'grid',posb = 'grid',postb = 'grid',posc = 'grid',postc = 'grid',sigmas = 0,sigmab = 0,sigma0 = 0):
 15    """
 16    Generate spatio-temporal data in a d-dimensional cube for PINN simulation
 17    ----------
 18
 19    Parameters
 20    ----------
 21    u : function
 22
 23        The function u(x,t) solution of the PDE
 24
 25    xl : float
 26
 27        Lower bound of each x coordinate
 28
 29    xu : float
 30
 31        Upper bound of each x coordinate
 32
 33    tl : float
 34
 35        Lower bound of the time interval
 36
 37    tu : float
 38
 39        Upper bound of the time interval
 40
 41    Ns : int, None
 42
 43        Number of points along each x coordinate for sensor data. None for not generating sensor data
 44
 45    Nts : int, None
 46
 47        Number of points along the time axis for sensor data. None for not generating sensor data
 48
 49    Nb : int, None
 50
 51        Number of points along each x coordinate for boundary data. None for not generating boundary data
 52
 53    Ntb : int, None
 54
 55        Number of points along the time axis for boundary data. None for not generating boundary data
 56
 57    N0 : int, None
 58
 59        Number of points along each x coordinate for initial data. None for not generating initial data
 60
 61    Nc : int, None
 62
 63        Number of points along each x coordinate for collocation points. None for not generating collocation points
 64
 65    Ntc : int, None
 66
 67        Number of points along the time axis for collocation points. None for not generating collocation points
 68
 69    train : logical
 70
 71        Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True
 72
 73    d : int
 74
 75        Domain dimension. Default 1
 76
 77    p : int
 78
 79        Output dimension. Default 1
 80
 81    poss : str
 82
 83        Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 84
 85    posts : str
 86
 87        Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 88
 89    posb : int
 90
 91        Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 92
 93    postb : int
 94
 95        Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 96
 97    pos0 : int
 98
 99        Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
100
101    posc : str
102
103        Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
104
105    postc : str
106
107        Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
108
109    sigmas : str
110
111        Standard deviation of the Gaussian noise of sensor data. Default 0
112
113    sigmab : str
114
115        Standard deviation of the Gaussian noise of boundary data. Default 0
116
117    sigma0 : str
118
119        Standard deviation of the Gaussian noise of initial data. Default 0
120
121    Returns
122    -------
123
124    dict-like object with generated data
125
126    """
127
128    #Repeat x limits
129    if isinstance(xl,int) or isinstance(xl,float):
130        xl = [xl for i in range(d)]
131    if isinstance(xu,int) or isinstance(xu,float):
132        xu = [xu for i in range(d)]
133
134    #Sensor data
135    if train:
136        if Ns is not None and Nts is not None:
137            if poss == 'grid':
138                #Create the grid for the first coordinate
139                x_sensor = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Ns + 2)[1:-1]]
140                for i in range(d-1):
141                    #Product with the grid of the i-th coordinate
142                    x_sensor =  [x1 + [x2.tolist()] for x1 in x_sensor for x2 in jnp.linspace(xl[i+1],xu[i+1],Ns + 2)[1:-1]]
143            else:
144                #Sample Ns^d points for the first coordinate
145                x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1))
146                for i in range(d-1):
147                    #Sample Ns^d points for the i-th coordinate and append collumn-wise
148                    x_sensor =  jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1)
149            x_sensor = jnp.array(x_sensor,dtype = jnp.float32)
150
151            if posts == 'grid':
152                #Create the Nt grid of (tl,tu]
153                t_sensor = jnp.linspace(tl,tu,Nts + 1)[1:]
154            else:
155                #Sample Nt points from (tl,tu)
156                t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,))
157            #Product of x and t
158            xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32)
159            #Calculate u at each point
160            u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32)
161            u_sensor = u_sensor.reshape((u_sensor.shape[0],p))
162        else:
163            #Return None if sensor data should not be generated
164            xt_sensor = None
165            u_sensor = None
166
167        #Set collocation points (always in an interior grid)
168        if Ntc is not None and Nc is not None:
169            if postc == 'grid':
170                #Create the Ntc grid of (tl,tu]
171                t_collocation = jnp.linspace(tl,tu,Ntc + 1)[1:]
172            else:
173                #Sample Ntc points from (tl,tu)
174                t_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntc,))
175
176            if posc == 'grid':
177                #Create the grid for the first coordinate
178                x_collocation = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Nc + 2)[1:-1]]
179                for i in range(d-1):
180                    #Product with the grid of the i-th coordinate
181                    x_collocation =  [x1 + [x2.tolist()] for x1 in x_collocation for x2 in jnp.linspace(xl[i+1],xu[i+1],Nc + 2)[1:-1]]
182                #Product of x and t
183                xt_collocation = jnp.array([x + [t.tolist()] for x in x_collocation for t in t_collocation],dtype = jnp.float32)
184            else:
185                #Sample Nc^d points for the first coordinate
186                x_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nc ** d,1))
187                for i in range(d-1):
188                    #Sample Nc^d points for the i-th coordinate and append collumn-wise
189                    x_collocation =  jnp.append(x_collocation,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Nc ** d,1)),1)
190                #Product of x and t
191                xt_collocation = jnp.array([x.tolist() + [t.tolist()] for x in x_collocation for t in t_collocation],dtype = jnp.float32)
192        else:
193            #Return None if collocation data should not be generated
194            xt_collocation = None
195
196        #Boundary data
197        if Ntb is not None and Nb is not None:
198            if postb == 'grid':
199                #Create the Ntb grid of (tl,tu]
200                t_boundary = jnp.linspace(tl,tu,Ntb + 1)[1:]
201            else:
202                #Sample Ntb points from (tl,tu)
203                t_boundary = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntb,))
204
205            #An array in which each line represents an edge of the n-cube
206            pre_grid = [[xl[0]],[xu[0]],[jnp.inf]]
207            for i in range(d - 1):
208                pre_grid = [x1 + [x2] for x1 in pre_grid for x2 in [xl[i + 1],xu[i + 1],jnp.inf]]
209            #Exclude last row
210            pre_grid = pre_grid[:-1]
211            #Create array with vertex (xl,...,xl)
212            x_boundary = jnp.array(pre_grid[0],dtype = jnp.float32).reshape((1,d))
213            if posb == 'grid':
214                #Create a grid over each edge of the n-cube
215                for i in range(len(pre_grid) - 1):
216                    if jnp.inf in pre_grid[i + 1]:
217                        #Create a list of the grid values along each coordinate in the edge i + 1
218                        grid_points = list()
219                        for j in range(len(pre_grid[i + 1])):
220                            #If the coordinate is free, create grid
221                            if pre_grid[i + 1][j] == jnp.inf:
222                                grid_points.append(jnp.linspace(xl[j],xu[j],Nb + 2)[1:-1].tolist())
223                            else:
224                                #If the coordinate is fixed, store its value
225                                grid_points.append([pre_grid[i + 1][j]])
226                        #Product of these values
227                        grid_values = [[x] for x in grid_points[0]]
228                        for j in range(len(grid_points) - 1):
229                            grid_values = [x1 + [x2] for x1 in grid_values for x2 in grid_points[j + 1]]
230                        #Append to data
231                        x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0)
232                    else:
233                        #If the point is a vertex, append it to data
234                        x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0)
235            else:
236                #Sample points over each edge of the n-cube
237                for i in range(len(pre_grid) - 1):
238                    if jnp.inf in pre_grid[i + 1]:
239                        #Product of the fixed and sampled values
240                        if jnp.inf == pre_grid[i + 1][0]:
241                            grid_values = [[x] for x in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nb,)).tolist()]
242                        else:
243                            grid_values = [[pre_grid[i + 1][0]]]
244                        for j in range(len(pre_grid[i + 1]) - 1):
245                            if jnp.inf == pre_grid[i + 1][j + 1]:
246                                grid_values = [x1 + [x2] for x1 in grid_values for x2 in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[j + 1],maxval = xu[j + 1],shape = (Nb,)).tolist()]
247                            else:
248                                grid_values = [x1 + [pre_grid[i + 1][j + 1]] for x1 in grid_values]
249                        #Append to data
250                        x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0)
251                    else:
252                        #If the point is a vertex, append it to data
253                        x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0)
254            #Product of x and t
255            xt_boundary = jnp.array([x.tolist() + [t.tolist()] for x in x_boundary for t in t_boundary],dtype = jnp.float32)
256            #Calculate u at each point
257            u_boundary = jnp.array([[u(x,t) + sigmab*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_boundary for t in t_boundary],dtype = jnp.float32)
258            u_boundary = u_boundary.reshape((u_boundary.shape[0],p))
259        else:
260            #Return None if boundary data should not be generated
261            xt_boundary = None
262            u_boundary = None
263
264        #Initial data
265        if N0 is not None:
266            if pos0 == 'grid':
267                #Create the grid for the first coordinate
268                x_initial = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],N0)]
269                for i in range(d-1):
270                    #Product with the grid of the i-th coordinate
271                    x_initial =  [x1 + [x2.tolist()] for x1 in x_initial for x2 in jnp.linspace(xl[i+1],xu[i+1],N0)]
272            else:
273                #Sample N0^d points for the first coordinate
274                x_initial = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (N0 ** d,1))
275                for i in range(d-1):
276                    #Sample N0^d points for the i-th coordinate and append collumn-wise
277                    x_initial =  jnp.append(x_initial,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (N0 ** d,1)),1)
278            x_initial = jnp.array(x_initial,dtype = jnp.float32)
279
280            #Product of x and t
281            xt_initial = jnp.array([x.tolist() + [t] for x in x_initial for t in [0.0]],dtype = jnp.float32)
282            #Calculate u at each point
283            u_initial = jnp.array([[u(x,t) + sigma0*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_initial for t in jnp.array([0.0])],dtype = jnp.float32)
284            u_initial = u_initial.reshape((u_initial.shape[0],p))
285        else:
286            #Return None if initial data should not be generated
287            xt_initial = None
288            u_initial = None
289    else:
290        if Ns is not None and Nts is not None:
291            if poss == 'grid':
292                #Create the grid for the first coordinate
293                x_sensor = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Ns)]
294                for i in range(d-1):
295                    #Product with the grid of the i-th coordinate
296                    x_sensor =  [x1 + [x2.tolist()] for x1 in x_sensor for x2 in jnp.linspace(xl[i+1],xu[i+1],Ns + 2)[1:-1]]
297            else:
298                #Sample Ns^d points for the first coordinate
299                x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1))
300                for i in range(d-1):
301                    #Sample Ns^d points for the i-th coordinate and append collumn-wise
302                    x_sensor =  jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1)
303            x_sensor = jnp.array(x_sensor,dtype = jnp.float32)
304
305            if posts == 'grid':
306                #Create the Nt grid of (tl,tu]
307                t_sensor = jnp.linspace(tl,tu,Nts)
308            else:
309                #Sample Nt points from (tl,tu)
310                t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,))
311            #Product of x and t
312            xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32)
313            #Calculate u at each point
314            u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32)
315            u_sensor = u_sensor.reshape((u_sensor.shape[0],p))
316        else:
317            #Return None if sensor data should not be generated
318            xt_sensor = None
319            u_sensor = None
320
321    #Create data structure
322    if train:
323        dat = {'sensor': xt_sensor,'usensor': u_sensor,'boundary': xt_boundary,'uboundary': u_boundary,'initial': xt_initial,'uinitial': u_initial,'collocation': xt_collocation}
324    else:
325        dat = {'xt': xt_sensor,'u': u_sensor}
326
327    return dat
328
329#Read and organize a data.frame
330def read_data_frame(file,sep = None,header = 'infer',sheet = 0):
331    """
332    Read a data file and convert to JAX array.
333    -------
334
335    Parameters
336    ----------
337    file : str
338
339        File name with extension .csv, .txt, .xls or .xlsx
340
341    sep : str
342
343        Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt
344
345    header : int, Sequence of int, ‘infer’ or None
346
347        See pandas.read_csv documentation. Default 'infer'
348
349    sheet : int
350
351        Sheet number for .xls and .xlsx files. Default 0
352
353    Returns
354    -------
355
356    a JAX numpy array
357
358    """
359
360    #Find out data extension
361    ext = file.split('.')[1]
362
363    #Read data frame
364    if ext == 'csv':
365        if sep is None:
366            sep = ','
367        dat = pandas.read_csv(file,sep = sep,header = header)
368    elif ext == 'txt':
369        if sep is None:
370            sep = ' '
371        dat = pandas.read_table(file,sep = sep,header = header)
372    elif ext == 'xls' or ext == 'xlsx':
373        dat = pandas.read_excel(file,header = header,sheet_name = sheet)
374
375    #Convert to JAX data structure
376    dat = jnp.array(dat,dtype = jnp.float32)
377
378    return dat
def generate_PINNdata( u, xl, xu, tl, tu, Ns=None, Nts=None, Nb=None, Ntb=None, N0=None, Nc=None, Ntc=None, train=True, d=1, p=1, poss='grid', posts='grid', pos0='grid', posb='grid', postb='grid', posc='grid', postc='grid', sigmas=0, sigmab=0, sigma0=0):
 15def generate_PINNdata(u,xl,xu,tl,tu,Ns = None,Nts = None,Nb = None,Ntb = None,N0 = None,Nc = None,Ntc = None,train = True,d = 1,p = 1,poss = 'grid',posts = 'grid',pos0 = 'grid',posb = 'grid',postb = 'grid',posc = 'grid',postc = 'grid',sigmas = 0,sigmab = 0,sigma0 = 0):
 16    """
 17    Generate spatio-temporal data in a d-dimensional cube for PINN simulation
 18    ----------
 19
 20    Parameters
 21    ----------
 22    u : function
 23
 24        The function u(x,t) solution of the PDE
 25
 26    xl : float
 27
 28        Lower bound of each x coordinate
 29
 30    xu : float
 31
 32        Upper bound of each x coordinate
 33
 34    tl : float
 35
 36        Lower bound of the time interval
 37
 38    tu : float
 39
 40        Upper bound of the time interval
 41
 42    Ns : int, None
 43
 44        Number of points along each x coordinate for sensor data. None for not generating sensor data
 45
 46    Nts : int, None
 47
 48        Number of points along the time axis for sensor data. None for not generating sensor data
 49
 50    Nb : int, None
 51
 52        Number of points along each x coordinate for boundary data. None for not generating boundary data
 53
 54    Ntb : int, None
 55
 56        Number of points along the time axis for boundary data. None for not generating boundary data
 57
 58    N0 : int, None
 59
 60        Number of points along each x coordinate for initial data. None for not generating initial data
 61
 62    Nc : int, None
 63
 64        Number of points along each x coordinate for collocation points. None for not generating collocation points
 65
 66    Ntc : int, None
 67
 68        Number of points along the time axis for collocation points. None for not generating collocation points
 69
 70    train : logical
 71
 72        Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True
 73
 74    d : int
 75
 76        Domain dimension. Default 1
 77
 78    p : int
 79
 80        Output dimension. Default 1
 81
 82    poss : str
 83
 84        Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 85
 86    posts : str
 87
 88        Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 89
 90    posb : int
 91
 92        Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 93
 94    postb : int
 95
 96        Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 97
 98    pos0 : int
 99
100        Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
101
102    posc : str
103
104        Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
105
106    postc : str
107
108        Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
109
110    sigmas : str
111
112        Standard deviation of the Gaussian noise of sensor data. Default 0
113
114    sigmab : str
115
116        Standard deviation of the Gaussian noise of boundary data. Default 0
117
118    sigma0 : str
119
120        Standard deviation of the Gaussian noise of initial data. Default 0
121
122    Returns
123    -------
124
125    dict-like object with generated data
126
127    """
128
129    #Repeat x limits
130    if isinstance(xl,int) or isinstance(xl,float):
131        xl = [xl for i in range(d)]
132    if isinstance(xu,int) or isinstance(xu,float):
133        xu = [xu for i in range(d)]
134
135    #Sensor data
136    if train:
137        if Ns is not None and Nts is not None:
138            if poss == 'grid':
139                #Create the grid for the first coordinate
140                x_sensor = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Ns + 2)[1:-1]]
141                for i in range(d-1):
142                    #Product with the grid of the i-th coordinate
143                    x_sensor =  [x1 + [x2.tolist()] for x1 in x_sensor for x2 in jnp.linspace(xl[i+1],xu[i+1],Ns + 2)[1:-1]]
144            else:
145                #Sample Ns^d points for the first coordinate
146                x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1))
147                for i in range(d-1):
148                    #Sample Ns^d points for the i-th coordinate and append collumn-wise
149                    x_sensor =  jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1)
150            x_sensor = jnp.array(x_sensor,dtype = jnp.float32)
151
152            if posts == 'grid':
153                #Create the Nt grid of (tl,tu]
154                t_sensor = jnp.linspace(tl,tu,Nts + 1)[1:]
155            else:
156                #Sample Nt points from (tl,tu)
157                t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,))
158            #Product of x and t
159            xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32)
160            #Calculate u at each point
161            u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32)
162            u_sensor = u_sensor.reshape((u_sensor.shape[0],p))
163        else:
164            #Return None if sensor data should not be generated
165            xt_sensor = None
166            u_sensor = None
167
168        #Set collocation points (always in an interior grid)
169        if Ntc is not None and Nc is not None:
170            if postc == 'grid':
171                #Create the Ntc grid of (tl,tu]
172                t_collocation = jnp.linspace(tl,tu,Ntc + 1)[1:]
173            else:
174                #Sample Ntc points from (tl,tu)
175                t_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntc,))
176
177            if posc == 'grid':
178                #Create the grid for the first coordinate
179                x_collocation = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Nc + 2)[1:-1]]
180                for i in range(d-1):
181                    #Product with the grid of the i-th coordinate
182                    x_collocation =  [x1 + [x2.tolist()] for x1 in x_collocation for x2 in jnp.linspace(xl[i+1],xu[i+1],Nc + 2)[1:-1]]
183                #Product of x and t
184                xt_collocation = jnp.array([x + [t.tolist()] for x in x_collocation for t in t_collocation],dtype = jnp.float32)
185            else:
186                #Sample Nc^d points for the first coordinate
187                x_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nc ** d,1))
188                for i in range(d-1):
189                    #Sample Nc^d points for the i-th coordinate and append collumn-wise
190                    x_collocation =  jnp.append(x_collocation,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Nc ** d,1)),1)
191                #Product of x and t
192                xt_collocation = jnp.array([x.tolist() + [t.tolist()] for x in x_collocation for t in t_collocation],dtype = jnp.float32)
193        else:
194            #Return None if collocation data should not be generated
195            xt_collocation = None
196
197        #Boundary data
198        if Ntb is not None and Nb is not None:
199            if postb == 'grid':
200                #Create the Ntb grid of (tl,tu]
201                t_boundary = jnp.linspace(tl,tu,Ntb + 1)[1:]
202            else:
203                #Sample Ntb points from (tl,tu)
204                t_boundary = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntb,))
205
206            #An array in which each line represents an edge of the n-cube
207            pre_grid = [[xl[0]],[xu[0]],[jnp.inf]]
208            for i in range(d - 1):
209                pre_grid = [x1 + [x2] for x1 in pre_grid for x2 in [xl[i + 1],xu[i + 1],jnp.inf]]
210            #Exclude last row
211            pre_grid = pre_grid[:-1]
212            #Create array with vertex (xl,...,xl)
213            x_boundary = jnp.array(pre_grid[0],dtype = jnp.float32).reshape((1,d))
214            if posb == 'grid':
215                #Create a grid over each edge of the n-cube
216                for i in range(len(pre_grid) - 1):
217                    if jnp.inf in pre_grid[i + 1]:
218                        #Create a list of the grid values along each coordinate in the edge i + 1
219                        grid_points = list()
220                        for j in range(len(pre_grid[i + 1])):
221                            #If the coordinate is free, create grid
222                            if pre_grid[i + 1][j] == jnp.inf:
223                                grid_points.append(jnp.linspace(xl[j],xu[j],Nb + 2)[1:-1].tolist())
224                            else:
225                                #If the coordinate is fixed, store its value
226                                grid_points.append([pre_grid[i + 1][j]])
227                        #Product of these values
228                        grid_values = [[x] for x in grid_points[0]]
229                        for j in range(len(grid_points) - 1):
230                            grid_values = [x1 + [x2] for x1 in grid_values for x2 in grid_points[j + 1]]
231                        #Append to data
232                        x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0)
233                    else:
234                        #If the point is a vertex, append it to data
235                        x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0)
236            else:
237                #Sample points over each edge of the n-cube
238                for i in range(len(pre_grid) - 1):
239                    if jnp.inf in pre_grid[i + 1]:
240                        #Product of the fixed and sampled values
241                        if jnp.inf == pre_grid[i + 1][0]:
242                            grid_values = [[x] for x in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nb,)).tolist()]
243                        else:
244                            grid_values = [[pre_grid[i + 1][0]]]
245                        for j in range(len(pre_grid[i + 1]) - 1):
246                            if jnp.inf == pre_grid[i + 1][j + 1]:
247                                grid_values = [x1 + [x2] for x1 in grid_values for x2 in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[j + 1],maxval = xu[j + 1],shape = (Nb,)).tolist()]
248                            else:
249                                grid_values = [x1 + [pre_grid[i + 1][j + 1]] for x1 in grid_values]
250                        #Append to data
251                        x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0)
252                    else:
253                        #If the point is a vertex, append it to data
254                        x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0)
255            #Product of x and t
256            xt_boundary = jnp.array([x.tolist() + [t.tolist()] for x in x_boundary for t in t_boundary],dtype = jnp.float32)
257            #Calculate u at each point
258            u_boundary = jnp.array([[u(x,t) + sigmab*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_boundary for t in t_boundary],dtype = jnp.float32)
259            u_boundary = u_boundary.reshape((u_boundary.shape[0],p))
260        else:
261            #Return None if boundary data should not be generated
262            xt_boundary = None
263            u_boundary = None
264
265        #Initial data
266        if N0 is not None:
267            if pos0 == 'grid':
268                #Create the grid for the first coordinate
269                x_initial = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],N0)]
270                for i in range(d-1):
271                    #Product with the grid of the i-th coordinate
272                    x_initial =  [x1 + [x2.tolist()] for x1 in x_initial for x2 in jnp.linspace(xl[i+1],xu[i+1],N0)]
273            else:
274                #Sample N0^d points for the first coordinate
275                x_initial = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (N0 ** d,1))
276                for i in range(d-1):
277                    #Sample N0^d points for the i-th coordinate and append collumn-wise
278                    x_initial =  jnp.append(x_initial,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (N0 ** d,1)),1)
279            x_initial = jnp.array(x_initial,dtype = jnp.float32)
280
281            #Product of x and t
282            xt_initial = jnp.array([x.tolist() + [t] for x in x_initial for t in [0.0]],dtype = jnp.float32)
283            #Calculate u at each point
284            u_initial = jnp.array([[u(x,t) + sigma0*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_initial for t in jnp.array([0.0])],dtype = jnp.float32)
285            u_initial = u_initial.reshape((u_initial.shape[0],p))
286        else:
287            #Return None if initial data should not be generated
288            xt_initial = None
289            u_initial = None
290    else:
291        if Ns is not None and Nts is not None:
292            if poss == 'grid':
293                #Create the grid for the first coordinate
294                x_sensor = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Ns)]
295                for i in range(d-1):
296                    #Product with the grid of the i-th coordinate
297                    x_sensor =  [x1 + [x2.tolist()] for x1 in x_sensor for x2 in jnp.linspace(xl[i+1],xu[i+1],Ns + 2)[1:-1]]
298            else:
299                #Sample Ns^d points for the first coordinate
300                x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1))
301                for i in range(d-1):
302                    #Sample Ns^d points for the i-th coordinate and append collumn-wise
303                    x_sensor =  jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1)
304            x_sensor = jnp.array(x_sensor,dtype = jnp.float32)
305
306            if posts == 'grid':
307                #Create the Nt grid of (tl,tu]
308                t_sensor = jnp.linspace(tl,tu,Nts)
309            else:
310                #Sample Nt points from (tl,tu)
311                t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,))
312            #Product of x and t
313            xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32)
314            #Calculate u at each point
315            u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32)
316            u_sensor = u_sensor.reshape((u_sensor.shape[0],p))
317        else:
318            #Return None if sensor data should not be generated
319            xt_sensor = None
320            u_sensor = None
321
322    #Create data structure
323    if train:
324        dat = {'sensor': xt_sensor,'usensor': u_sensor,'boundary': xt_boundary,'uboundary': u_boundary,'initial': xt_initial,'uinitial': u_initial,'collocation': xt_collocation}
325    else:
326        dat = {'xt': xt_sensor,'u': u_sensor}
327
328    return dat

Generate spatio-temporal data in a d-dimensional cube for PINN simulation

Parameters
  • u (function): The function u(x,t) solution of the PDE
  • xl (float): Lower bound of each x coordinate
  • xu (float): Upper bound of each x coordinate
  • tl (float): Lower bound of the time interval
  • tu (float): Upper bound of the time interval
  • Ns (int, None): Number of points along each x coordinate for sensor data. None for not generating sensor data
  • Nts (int, None): Number of points along the time axis for sensor data. None for not generating sensor data
  • Nb (int, None): Number of points along each x coordinate for boundary data. None for not generating boundary data
  • Ntb (int, None): Number of points along the time axis for boundary data. None for not generating boundary data
  • N0 (int, None): Number of points along each x coordinate for initial data. None for not generating initial data
  • Nc (int, None): Number of points along each x coordinate for collocation points. None for not generating collocation points
  • Ntc (int, None): Number of points along the time axis for collocation points. None for not generating collocation points
  • train (logical): Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True
  • d (int): Domain dimension. Default 1
  • p (int): Output dimension. Default 1
  • poss (str): Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • posts (str): Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • posb (int): Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • postb (int): Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • pos0 (int): Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • posc (str): Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • postc (str): Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • sigmas (str): Standard deviation of the Gaussian noise of sensor data. Default 0
  • sigmab (str): Standard deviation of the Gaussian noise of boundary data. Default 0
  • sigma0 (str): Standard deviation of the Gaussian noise of initial data. Default 0
Returns
  • dict-like object with generated data
def read_data_frame(file, sep=None, header='infer', sheet=0):
331def read_data_frame(file,sep = None,header = 'infer',sheet = 0):
332    """
333    Read a data file and convert to JAX array.
334    -------
335
336    Parameters
337    ----------
338    file : str
339
340        File name with extension .csv, .txt, .xls or .xlsx
341
342    sep : str
343
344        Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt
345
346    header : int, Sequence of int, ‘infer’ or None
347
348        See pandas.read_csv documentation. Default 'infer'
349
350    sheet : int
351
352        Sheet number for .xls and .xlsx files. Default 0
353
354    Returns
355    -------
356
357    a JAX numpy array
358
359    """
360
361    #Find out data extension
362    ext = file.split('.')[1]
363
364    #Read data frame
365    if ext == 'csv':
366        if sep is None:
367            sep = ','
368        dat = pandas.read_csv(file,sep = sep,header = header)
369    elif ext == 'txt':
370        if sep is None:
371            sep = ' '
372        dat = pandas.read_table(file,sep = sep,header = header)
373    elif ext == 'xls' or ext == 'xlsx':
374        dat = pandas.read_excel(file,header = header,sheet_name = sheet)
375
376    #Convert to JAX data structure
377    dat = jnp.array(dat,dtype = jnp.float32)
378
379    return dat

Read a data file and convert to JAX array.

Parameters
  • file (str): File name with extension .csv, .txt, .xls or .xlsx
  • sep (str): Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt
  • header (int, Sequence of int, ‘infer’ or None): See pandas.read_csv documentation. Default 'infer'
  • sheet (int): Sheet number for .xls and .xlsx files. Default 0
Returns
  • a JAX numpy array